All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
418 lines
18 KiB
Python
418 lines
18 KiB
Python
"""
|
||
Chat completions API router - OpenAI compatible chat endpoint.
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import uuid
|
||
from typing import AsyncIterator, Union
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from oci.exceptions import ServiceError
|
||
|
||
from api.auth import get_api_key
|
||
from api.schemas import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail
|
||
from api.error_handler import OCIErrorHandler
|
||
from api.exceptions import ModelNotFoundException, InvalidModelTypeException
|
||
from api.adapters.request_adapter import adapt_chat_messages, extract_chat_params
|
||
from api.adapters.response_adapter import (
|
||
adapt_chat_response,
|
||
adapt_streaming_chunk,
|
||
adapt_streaming_done,
|
||
)
|
||
from core.config import get_settings
|
||
from core.client_manager import get_client_manager
|
||
from core.models import get_model_config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(
|
||
prefix="/chat",
|
||
tags=["chat"],
|
||
dependencies=[Depends(get_api_key)]
|
||
)
|
||
|
||
|
||
def extract_delta_from_chunk(chunk) -> str:
|
||
"""
|
||
Extract delta text content from OCI streaming chunk.
|
||
|
||
Args:
|
||
chunk: OCI streaming response chunk (can be SSE Event, parsed object, etc.)
|
||
|
||
Returns:
|
||
Delta text content or empty string
|
||
"""
|
||
try:
|
||
# Handle SSE Event objects (from SSEClient)
|
||
if hasattr(chunk, 'data'):
|
||
import json
|
||
# Parse JSON data from SSE event
|
||
try:
|
||
parsed = json.loads(chunk.data)
|
||
|
||
# Recursively extract from parsed object
|
||
if isinstance(parsed, dict):
|
||
# OCI Streaming format: message.content[].text
|
||
if 'message' in parsed and 'content' in parsed['message']:
|
||
content_array = parsed['message']['content']
|
||
if isinstance(content_array, list) and len(content_array) > 0:
|
||
# Extract text from all TEXT type content items
|
||
text_parts = []
|
||
for item in content_array:
|
||
if isinstance(item, dict) and item.get('type') == 'TEXT' and 'text' in item:
|
||
text_parts.append(item['text'])
|
||
if text_parts:
|
||
return ''.join(text_parts)
|
||
|
||
# Try to get text from various possible locations
|
||
if 'text' in parsed:
|
||
return parsed['text']
|
||
if 'chatResponse' in parsed and 'text' in parsed['chatResponse']:
|
||
return parsed['chatResponse']['text']
|
||
if 'choices' in parsed and len(parsed['choices']) > 0:
|
||
choice = parsed['choices'][0]
|
||
if 'delta' in choice and 'content' in choice['delta']:
|
||
return choice['delta']['content']
|
||
|
||
except (json.JSONDecodeError, KeyError, TypeError):
|
||
# Return raw data if not JSON
|
||
return str(chunk.data) if chunk.data else ""
|
||
|
||
# Try to extract from chat_response.text (Cohere format)
|
||
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'text'):
|
||
return chunk.chat_response.text
|
||
|
||
# Try to extract from choices[0].delta.content (Generic format)
|
||
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'choices'):
|
||
if len(chunk.chat_response.choices) > 0:
|
||
choice = chunk.chat_response.choices[0]
|
||
if hasattr(choice, 'delta') and hasattr(choice.delta, 'content'):
|
||
content = choice.delta.content
|
||
if isinstance(content, str):
|
||
return content
|
||
elif isinstance(content, list):
|
||
# Handle TextContent list
|
||
text_parts = []
|
||
for item in content:
|
||
if isinstance(item, dict) and 'text' in item:
|
||
text_parts.append(item['text'])
|
||
elif hasattr(item, 'text'):
|
||
text_parts.append(item.text)
|
||
return "".join(text_parts)
|
||
|
||
# Try direct text attribute
|
||
if hasattr(chunk, 'text'):
|
||
return chunk.text
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to extract delta from chunk: {e}")
|
||
|
||
return ""
|
||
|
||
|
||
def extract_content_from_response(chat_response) -> str:
|
||
"""
|
||
Extract full content from non-streaming OCI response.
|
||
|
||
Args:
|
||
chat_response: OCI chat response object
|
||
|
||
Returns:
|
||
Full text content
|
||
"""
|
||
if hasattr(chat_response, 'text'):
|
||
raw_text = chat_response.text
|
||
# Try to parse as JSON if it's a string (OCI format)
|
||
try:
|
||
import json
|
||
parsed = json.loads(raw_text)
|
||
if isinstance(parsed, dict) and 'text' in parsed:
|
||
return parsed['text']
|
||
return raw_text
|
||
except (json.JSONDecodeError, ValueError, TypeError):
|
||
return raw_text
|
||
|
||
elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0:
|
||
choice = chat_response.choices[0]
|
||
if hasattr(choice, 'message'):
|
||
raw_content = choice.message.content
|
||
# Handle list format
|
||
if isinstance(raw_content, list):
|
||
text_parts = []
|
||
for item in raw_content:
|
||
if isinstance(item, dict):
|
||
text_parts.append(item.get('text', ''))
|
||
elif hasattr(item, 'text'):
|
||
text_parts.append(item.text)
|
||
else:
|
||
text_parts.append(str(item))
|
||
return "".join(text_parts)
|
||
elif isinstance(raw_content, str):
|
||
try:
|
||
import json
|
||
parsed = json.loads(raw_content)
|
||
if isinstance(parsed, dict) and 'text' in parsed:
|
||
return parsed['text']
|
||
return raw_content
|
||
except (json.JSONDecodeError, ValueError):
|
||
return raw_content
|
||
else:
|
||
return str(raw_content)
|
||
return str(choice)
|
||
|
||
return str(chat_response)
|
||
|
||
|
||
@router.post("/completions", response_model=ChatCompletionResponse)
|
||
async def create_chat_completion(request: ChatCompletionRequest):
|
||
"""
|
||
Create a chat completion using OCI Generative AI.
|
||
|
||
Args:
|
||
request: Chat completion request
|
||
|
||
Returns:
|
||
Chat completion response
|
||
"""
|
||
logger.info(f"Chat completion request for model: {request.model}")
|
||
|
||
settings = get_settings()
|
||
|
||
# Validate model exists
|
||
model_config = get_model_config(request.model)
|
||
if not model_config:
|
||
raise ModelNotFoundException(request.model)
|
||
|
||
# Validate model type is chat (ondemand or dedicated)
|
||
if model_config.type not in ("ondemand", "dedicated"):
|
||
raise InvalidModelTypeException(
|
||
model_id=request.model,
|
||
expected_type="chat",
|
||
actual_type=model_config.type
|
||
)
|
||
|
||
# Note: Multimodal capability validation is handled by the model itself
|
||
# If a model doesn't support certain content types, it will raise an error
|
||
# For example, Cohere models will raise ValueError for non-text content
|
||
|
||
# Get OCI client from manager (轮询负载均衡)
|
||
client_manager = get_client_manager()
|
||
oci_client = client_manager.get_client()
|
||
|
||
# Adapt messages
|
||
messages = adapt_chat_messages([msg.dict() for msg in request.messages])
|
||
|
||
# Extract parameters
|
||
params = extract_chat_params(request)
|
||
|
||
# Check global streaming setting
|
||
# If streaming is globally disabled, override client request
|
||
enable_stream = request.stream and settings.enable_streaming
|
||
|
||
if not settings.enable_streaming and request.stream:
|
||
logger.info("Streaming requested but globally disabled via ENABLE_STREAMING=false")
|
||
|
||
# Handle streaming
|
||
if enable_stream:
|
||
request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||
|
||
async def generate_stream() -> AsyncIterator[str]:
|
||
"""Generate streaming response with true non-blocking streaming."""
|
||
try:
|
||
# Run OCI SDK call in executor to prevent blocking
|
||
# This is critical for achieving true streaming (msToFirstChunk < 1s)
|
||
loop = asyncio.get_event_loop()
|
||
response = await loop.run_in_executor(
|
||
None,
|
||
lambda: oci_client.chat(
|
||
model_id=request.model,
|
||
messages=messages,
|
||
temperature=params["temperature"],
|
||
max_tokens=params["max_tokens"],
|
||
top_p=params["top_p"],
|
||
stream=True, # Enable real streaming
|
||
tools=params.get("tools"),
|
||
)
|
||
)
|
||
|
||
# Process real streaming response
|
||
accumulated_usage = None
|
||
|
||
# Check if response.data is an SSE stream (iterable)
|
||
# When stream=True, OCI SDK returns response.data as SSEClient
|
||
try:
|
||
# Try to iterate over the stream
|
||
stream_data = response.data if hasattr(response, 'data') else response
|
||
|
||
# Check if it's SSEClient or any iterable type
|
||
stream_type_name = type(stream_data).__name__
|
||
is_sse_client = 'SSEClient' in stream_type_name
|
||
is_iterable = hasattr(stream_data, '__iter__') or hasattr(stream_data, '__next__')
|
||
|
||
# SSEClient is always treated as streaming, even if hasattr check fails
|
||
if is_sse_client or is_iterable:
|
||
# Real streaming: iterate over chunks
|
||
# SSEClient requires calling .events() method to iterate
|
||
if is_sse_client and hasattr(stream_data, 'events'):
|
||
iterator = stream_data.events()
|
||
else:
|
||
iterator = stream_data
|
||
|
||
# Send first chunk with role and empty content (OpenAI format)
|
||
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
|
||
|
||
# Use queue for thread-safe chunk forwarding
|
||
import queue
|
||
import threading
|
||
chunk_queue = queue.Queue()
|
||
|
||
def read_chunks():
|
||
"""Read chunks in background thread and put in queue."""
|
||
try:
|
||
for chunk in iterator:
|
||
chunk_queue.put(("chunk", chunk))
|
||
chunk_queue.put(("done", None))
|
||
except Exception as e:
|
||
chunk_queue.put(("error", e))
|
||
|
||
# Start background thread to read chunks
|
||
reader_thread = threading.Thread(target=read_chunks, daemon=True)
|
||
reader_thread.start()
|
||
|
||
# Yield chunks as they arrive from queue
|
||
while True:
|
||
# Non-blocking queue get with timeout
|
||
try:
|
||
msg_type, data = await loop.run_in_executor(
|
||
None,
|
||
lambda: chunk_queue.get(timeout=0.01)
|
||
)
|
||
except queue.Empty:
|
||
# Allow other async tasks to run
|
||
await asyncio.sleep(0)
|
||
continue
|
||
|
||
if msg_type == "done":
|
||
break
|
||
elif msg_type == "error":
|
||
raise data
|
||
elif msg_type == "chunk":
|
||
chunk = data
|
||
# Extract delta content from chunk
|
||
delta_text = extract_delta_from_chunk(chunk)
|
||
|
||
if delta_text:
|
||
yield adapt_streaming_chunk(delta_text, request.model, request_id, 0, is_first=False)
|
||
|
||
# Try to extract usage from chunk (typically in final chunk)
|
||
# Handle both SSE Event format and object format
|
||
if hasattr(chunk, 'data'):
|
||
# SSE Event - parse JSON to extract usage
|
||
try:
|
||
import json
|
||
parsed = json.loads(chunk.data)
|
||
if isinstance(parsed, dict) and 'usage' in parsed:
|
||
usage_data = parsed['usage']
|
||
accumulated_usage = {
|
||
"prompt_tokens": usage_data.get('promptTokens', 0) or 0,
|
||
"completion_tokens": usage_data.get('completionTokens', 0) or 0,
|
||
"total_tokens": usage_data.get('totalTokens', 0) or 0
|
||
}
|
||
except (json.JSONDecodeError, KeyError, TypeError):
|
||
pass
|
||
elif hasattr(chunk, 'usage') and chunk.usage:
|
||
# Object format
|
||
accumulated_usage = {
|
||
"prompt_tokens": getattr(chunk.usage, 'prompt_tokens', 0) or 0,
|
||
"completion_tokens": getattr(chunk.usage, 'completion_tokens', 0) or 0,
|
||
"total_tokens": getattr(chunk.usage, 'total_tokens', 0) or 0
|
||
}
|
||
|
||
# Send done message with usage
|
||
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
|
||
|
||
else:
|
||
# Fallback: non-streaming response, simulate streaming
|
||
logger.warning(f"OCI SDK returned non-iterable response (type: {type(stream_data).__name__}), falling back to simulated streaming")
|
||
|
||
# Extract text from non-streaming response
|
||
chat_response = stream_data.chat_response if hasattr(stream_data, 'chat_response') else stream_data
|
||
content = extract_content_from_response(chat_response)
|
||
|
||
# Extract usage information
|
||
if hasattr(stream_data, 'usage'):
|
||
oci_usage = stream_data.usage
|
||
accumulated_usage = {
|
||
"prompt_tokens": getattr(oci_usage, 'prompt_tokens', 0) or 0,
|
||
"completion_tokens": getattr(oci_usage, 'completion_tokens', 0) or 0,
|
||
"total_tokens": getattr(oci_usage, 'total_tokens', 0) or 0
|
||
}
|
||
|
||
# Simulate streaming by chunking
|
||
# First send empty chunk with role (OpenAI format)
|
||
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
|
||
|
||
chunk_size = settings.stream_chunk_size
|
||
for i in range(0, len(content), chunk_size):
|
||
chunk = content[i:i + chunk_size]
|
||
yield adapt_streaming_chunk(chunk, request.model, request_id, 0, is_first=False)
|
||
|
||
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
|
||
|
||
except TypeError as te:
|
||
# Handle case where response is not iterable at all
|
||
logger.error(f"Response is not iterable: {te}", exc_info=True)
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in streaming: {str(e)}", exc_info=True)
|
||
import json
|
||
|
||
# 根据异常类型处理并过滤敏感信息
|
||
if isinstance(e, ServiceError):
|
||
error_response = OCIErrorHandler.sanitize_oci_error(e)
|
||
else:
|
||
# 通用错误也要过滤可能包含的敏感信息
|
||
filtered_msg = OCIErrorHandler.filter_sensitive_info(str(e))
|
||
error_response = ErrorResponse(
|
||
error=ErrorDetail(
|
||
message="An error occurred during streaming",
|
||
type="server_error",
|
||
code="streaming_error"
|
||
)
|
||
)
|
||
|
||
yield f"data: {json.dumps(error_response.dict(), ensure_ascii=False)}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_stream(),
|
||
media_type="text/event-stream"
|
||
)
|
||
|
||
# Non-streaming response
|
||
try:
|
||
response = oci_client.chat(
|
||
model_id=request.model,
|
||
messages=messages,
|
||
temperature=params["temperature"],
|
||
max_tokens=params["max_tokens"],
|
||
top_p=params["top_p"],
|
||
stream=False,
|
||
tools=params.get("tools"),
|
||
)
|
||
|
||
# Adapt response to OpenAI format
|
||
openai_response = adapt_chat_response(response, request.model)
|
||
|
||
if settings.log_responses:
|
||
logger.debug(f"Response: {openai_response}")
|
||
|
||
return openai_response
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
||
# 直接 raise,让全局异常处理器统一过滤敏感信息
|
||
raise
|