""" Chat completions API router - OpenAI compatible chat endpoint. """ import asyncio import logging import os import uuid from typing import AsyncIterator, Union from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from oci.exceptions import ServiceError from api.auth import get_api_key from api.schemas import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail from api.error_handler import OCIErrorHandler from api.exceptions import ModelNotFoundException, InvalidModelTypeException from api.adapters.request_adapter import adapt_chat_messages, extract_chat_params from api.adapters.response_adapter import ( adapt_chat_response, adapt_streaming_chunk, adapt_streaming_done, ) from core.config import get_settings from core.client_manager import get_client_manager from core.models import get_model_config logger = logging.getLogger(__name__) router = APIRouter( prefix="/chat", tags=["chat"], dependencies=[Depends(get_api_key)] ) def extract_delta_from_chunk(chunk) -> str: """ Extract delta text content from OCI streaming chunk. Args: chunk: OCI streaming response chunk (can be SSE Event, parsed object, etc.) Returns: Delta text content or empty string """ try: # Handle SSE Event objects (from SSEClient) if hasattr(chunk, 'data'): import json # Parse JSON data from SSE event try: parsed = json.loads(chunk.data) # Recursively extract from parsed object if isinstance(parsed, dict): # OCI Streaming format: message.content[].text if 'message' in parsed and 'content' in parsed['message']: content_array = parsed['message']['content'] if isinstance(content_array, list) and len(content_array) > 0: # Extract text from all TEXT type content items text_parts = [] for item in content_array: if isinstance(item, dict) and item.get('type') == 'TEXT' and 'text' in item: text_parts.append(item['text']) if text_parts: return ''.join(text_parts) # Try to get text from various possible locations if 'text' in parsed: return parsed['text'] if 'chatResponse' in parsed and 'text' in parsed['chatResponse']: return parsed['chatResponse']['text'] if 'choices' in parsed and len(parsed['choices']) > 0: choice = parsed['choices'][0] if 'delta' in choice and 'content' in choice['delta']: return choice['delta']['content'] except (json.JSONDecodeError, KeyError, TypeError): # Return raw data if not JSON return str(chunk.data) if chunk.data else "" # Try to extract from chat_response.text (Cohere format) if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'text'): return chunk.chat_response.text # Try to extract from choices[0].delta.content (Generic format) if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'choices'): if len(chunk.chat_response.choices) > 0: choice = chunk.chat_response.choices[0] if hasattr(choice, 'delta') and hasattr(choice.delta, 'content'): content = choice.delta.content if isinstance(content, str): return content elif isinstance(content, list): # Handle TextContent list text_parts = [] for item in content: if isinstance(item, dict) and 'text' in item: text_parts.append(item['text']) elif hasattr(item, 'text'): text_parts.append(item.text) return "".join(text_parts) # Try direct text attribute if hasattr(chunk, 'text'): return chunk.text except Exception as e: logger.warning(f"Failed to extract delta from chunk: {e}") return "" def extract_content_from_response(chat_response) -> str: """ Extract full content from non-streaming OCI response. Args: chat_response: OCI chat response object Returns: Full text content """ if hasattr(chat_response, 'text'): raw_text = chat_response.text # Try to parse as JSON if it's a string (OCI format) try: import json parsed = json.loads(raw_text) if isinstance(parsed, dict) and 'text' in parsed: return parsed['text'] return raw_text except (json.JSONDecodeError, ValueError, TypeError): return raw_text elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0: choice = chat_response.choices[0] if hasattr(choice, 'message'): raw_content = choice.message.content # Handle list format if isinstance(raw_content, list): text_parts = [] for item in raw_content: if isinstance(item, dict): text_parts.append(item.get('text', '')) elif hasattr(item, 'text'): text_parts.append(item.text) else: text_parts.append(str(item)) return "".join(text_parts) elif isinstance(raw_content, str): try: import json parsed = json.loads(raw_content) if isinstance(parsed, dict) and 'text' in parsed: return parsed['text'] return raw_content except (json.JSONDecodeError, ValueError): return raw_content else: return str(raw_content) return str(choice) return str(chat_response) @router.post("/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): """ Create a chat completion using OCI Generative AI. Args: request: Chat completion request Returns: Chat completion response """ logger.info(f"Chat completion request for model: {request.model}") settings = get_settings() # Validate model exists model_config = get_model_config(request.model) if not model_config: raise ModelNotFoundException(request.model) # Validate model type is chat (ondemand or dedicated) if model_config.type not in ("ondemand", "dedicated"): raise InvalidModelTypeException( model_id=request.model, expected_type="chat", actual_type=model_config.type ) # Note: Multimodal capability validation is handled by the model itself # If a model doesn't support certain content types, it will raise an error # For example, Cohere models will raise ValueError for non-text content # Get OCI client from manager (轮询负载均衡) client_manager = get_client_manager() oci_client = client_manager.get_client() # Adapt messages messages = adapt_chat_messages([msg.dict() for msg in request.messages]) # Extract parameters params = extract_chat_params(request) # Check global streaming setting # If streaming is globally disabled, override client request enable_stream = request.stream and settings.enable_streaming if not settings.enable_streaming and request.stream: logger.info("Streaming requested but globally disabled via ENABLE_STREAMING=false") # Handle streaming if enable_stream: request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" async def generate_stream() -> AsyncIterator[str]: """Generate streaming response with true non-blocking streaming.""" try: # Run OCI SDK call in executor to prevent blocking # This is critical for achieving true streaming (msToFirstChunk < 1s) loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: oci_client.chat( model_id=request.model, messages=messages, temperature=params["temperature"], max_tokens=params["max_tokens"], top_p=params["top_p"], stream=True, # Enable real streaming tools=params.get("tools"), ) ) # Process real streaming response accumulated_usage = None # Check if response.data is an SSE stream (iterable) # When stream=True, OCI SDK returns response.data as SSEClient try: # Try to iterate over the stream stream_data = response.data if hasattr(response, 'data') else response # Check if it's SSEClient or any iterable type stream_type_name = type(stream_data).__name__ is_sse_client = 'SSEClient' in stream_type_name is_iterable = hasattr(stream_data, '__iter__') or hasattr(stream_data, '__next__') # SSEClient is always treated as streaming, even if hasattr check fails if is_sse_client or is_iterable: # Real streaming: iterate over chunks # SSEClient requires calling .events() method to iterate if is_sse_client and hasattr(stream_data, 'events'): iterator = stream_data.events() else: iterator = stream_data # Send first chunk with role and empty content (OpenAI format) yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True) # Use queue for thread-safe chunk forwarding import queue import threading chunk_queue = queue.Queue() def read_chunks(): """Read chunks in background thread and put in queue.""" try: for chunk in iterator: chunk_queue.put(("chunk", chunk)) chunk_queue.put(("done", None)) except Exception as e: chunk_queue.put(("error", e)) # Start background thread to read chunks reader_thread = threading.Thread(target=read_chunks, daemon=True) reader_thread.start() # Yield chunks as they arrive from queue while True: # Non-blocking queue get with timeout try: msg_type, data = await loop.run_in_executor( None, lambda: chunk_queue.get(timeout=0.01) ) except queue.Empty: # Allow other async tasks to run await asyncio.sleep(0) continue if msg_type == "done": break elif msg_type == "error": raise data elif msg_type == "chunk": chunk = data # Extract delta content from chunk delta_text = extract_delta_from_chunk(chunk) if delta_text: yield adapt_streaming_chunk(delta_text, request.model, request_id, 0, is_first=False) # Try to extract usage from chunk (typically in final chunk) # Handle both SSE Event format and object format if hasattr(chunk, 'data'): # SSE Event - parse JSON to extract usage try: import json parsed = json.loads(chunk.data) if isinstance(parsed, dict) and 'usage' in parsed: usage_data = parsed['usage'] accumulated_usage = { "prompt_tokens": usage_data.get('promptTokens', 0) or 0, "completion_tokens": usage_data.get('completionTokens', 0) or 0, "total_tokens": usage_data.get('totalTokens', 0) or 0 } except (json.JSONDecodeError, KeyError, TypeError): pass elif hasattr(chunk, 'usage') and chunk.usage: # Object format accumulated_usage = { "prompt_tokens": getattr(chunk.usage, 'prompt_tokens', 0) or 0, "completion_tokens": getattr(chunk.usage, 'completion_tokens', 0) or 0, "total_tokens": getattr(chunk.usage, 'total_tokens', 0) or 0 } # Send done message with usage yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage) else: # Fallback: non-streaming response, simulate streaming logger.warning(f"OCI SDK returned non-iterable response (type: {type(stream_data).__name__}), falling back to simulated streaming") # Extract text from non-streaming response chat_response = stream_data.chat_response if hasattr(stream_data, 'chat_response') else stream_data content = extract_content_from_response(chat_response) # Extract usage information if hasattr(stream_data, 'usage'): oci_usage = stream_data.usage accumulated_usage = { "prompt_tokens": getattr(oci_usage, 'prompt_tokens', 0) or 0, "completion_tokens": getattr(oci_usage, 'completion_tokens', 0) or 0, "total_tokens": getattr(oci_usage, 'total_tokens', 0) or 0 } # Simulate streaming by chunking # First send empty chunk with role (OpenAI format) yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True) chunk_size = settings.stream_chunk_size for i in range(0, len(content), chunk_size): chunk = content[i:i + chunk_size] yield adapt_streaming_chunk(chunk, request.model, request_id, 0, is_first=False) yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage) except TypeError as te: # Handle case where response is not iterable at all logger.error(f"Response is not iterable: {te}", exc_info=True) raise except Exception as e: logger.error(f"Error in streaming: {str(e)}", exc_info=True) import json # 根据异常类型处理并过滤敏感信息 if isinstance(e, ServiceError): error_response = OCIErrorHandler.sanitize_oci_error(e) else: # 通用错误也要过滤可能包含的敏感信息 filtered_msg = OCIErrorHandler.filter_sensitive_info(str(e)) error_response = ErrorResponse( error=ErrorDetail( message="An error occurred during streaming", type="server_error", code="streaming_error" ) ) yield f"data: {json.dumps(error_response.dict(), ensure_ascii=False)}\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) # Non-streaming response try: response = oci_client.chat( model_id=request.model, messages=messages, temperature=params["temperature"], max_tokens=params["max_tokens"], top_p=params["top_p"], stream=False, tools=params.get("tools"), ) # Adapt response to OpenAI format openai_response = adapt_chat_response(response, request.model) if settings.log_responses: logger.debug(f"Response: {openai_response}") return openai_response except Exception as e: logger.error(f"Error in chat completion: {str(e)}", exc_info=True) # 直接 raise,让全局异常处理器统一过滤敏感信息 raise