Files
oracle-openai/src/api/routers/chat.py
Wang Defa 42222744c7
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
第一次提交
2025-12-09 14:44:09 +08:00

418 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Chat completions API router - OpenAI compatible chat endpoint.
"""
import asyncio
import logging
import os
import uuid
from typing import AsyncIterator, Union
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from oci.exceptions import ServiceError
from api.auth import get_api_key
from api.schemas import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail
from api.error_handler import OCIErrorHandler
from api.exceptions import ModelNotFoundException, InvalidModelTypeException
from api.adapters.request_adapter import adapt_chat_messages, extract_chat_params
from api.adapters.response_adapter import (
adapt_chat_response,
adapt_streaming_chunk,
adapt_streaming_done,
)
from core.config import get_settings
from core.client_manager import get_client_manager
from core.models import get_model_config
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/chat",
tags=["chat"],
dependencies=[Depends(get_api_key)]
)
def extract_delta_from_chunk(chunk) -> str:
"""
Extract delta text content from OCI streaming chunk.
Args:
chunk: OCI streaming response chunk (can be SSE Event, parsed object, etc.)
Returns:
Delta text content or empty string
"""
try:
# Handle SSE Event objects (from SSEClient)
if hasattr(chunk, 'data'):
import json
# Parse JSON data from SSE event
try:
parsed = json.loads(chunk.data)
# Recursively extract from parsed object
if isinstance(parsed, dict):
# OCI Streaming format: message.content[].text
if 'message' in parsed and 'content' in parsed['message']:
content_array = parsed['message']['content']
if isinstance(content_array, list) and len(content_array) > 0:
# Extract text from all TEXT type content items
text_parts = []
for item in content_array:
if isinstance(item, dict) and item.get('type') == 'TEXT' and 'text' in item:
text_parts.append(item['text'])
if text_parts:
return ''.join(text_parts)
# Try to get text from various possible locations
if 'text' in parsed:
return parsed['text']
if 'chatResponse' in parsed and 'text' in parsed['chatResponse']:
return parsed['chatResponse']['text']
if 'choices' in parsed and len(parsed['choices']) > 0:
choice = parsed['choices'][0]
if 'delta' in choice and 'content' in choice['delta']:
return choice['delta']['content']
except (json.JSONDecodeError, KeyError, TypeError):
# Return raw data if not JSON
return str(chunk.data) if chunk.data else ""
# Try to extract from chat_response.text (Cohere format)
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'text'):
return chunk.chat_response.text
# Try to extract from choices[0].delta.content (Generic format)
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'choices'):
if len(chunk.chat_response.choices) > 0:
choice = chunk.chat_response.choices[0]
if hasattr(choice, 'delta') and hasattr(choice.delta, 'content'):
content = choice.delta.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle TextContent list
text_parts = []
for item in content:
if isinstance(item, dict) and 'text' in item:
text_parts.append(item['text'])
elif hasattr(item, 'text'):
text_parts.append(item.text)
return "".join(text_parts)
# Try direct text attribute
if hasattr(chunk, 'text'):
return chunk.text
except Exception as e:
logger.warning(f"Failed to extract delta from chunk: {e}")
return ""
def extract_content_from_response(chat_response) -> str:
"""
Extract full content from non-streaming OCI response.
Args:
chat_response: OCI chat response object
Returns:
Full text content
"""
if hasattr(chat_response, 'text'):
raw_text = chat_response.text
# Try to parse as JSON if it's a string (OCI format)
try:
import json
parsed = json.loads(raw_text)
if isinstance(parsed, dict) and 'text' in parsed:
return parsed['text']
return raw_text
except (json.JSONDecodeError, ValueError, TypeError):
return raw_text
elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0:
choice = chat_response.choices[0]
if hasattr(choice, 'message'):
raw_content = choice.message.content
# Handle list format
if isinstance(raw_content, list):
text_parts = []
for item in raw_content:
if isinstance(item, dict):
text_parts.append(item.get('text', ''))
elif hasattr(item, 'text'):
text_parts.append(item.text)
else:
text_parts.append(str(item))
return "".join(text_parts)
elif isinstance(raw_content, str):
try:
import json
parsed = json.loads(raw_content)
if isinstance(parsed, dict) and 'text' in parsed:
return parsed['text']
return raw_content
except (json.JSONDecodeError, ValueError):
return raw_content
else:
return str(raw_content)
return str(choice)
return str(chat_response)
@router.post("/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
"""
Create a chat completion using OCI Generative AI.
Args:
request: Chat completion request
Returns:
Chat completion response
"""
logger.info(f"Chat completion request for model: {request.model}")
settings = get_settings()
# Validate model exists
model_config = get_model_config(request.model)
if not model_config:
raise ModelNotFoundException(request.model)
# Validate model type is chat (ondemand or dedicated)
if model_config.type not in ("ondemand", "dedicated"):
raise InvalidModelTypeException(
model_id=request.model,
expected_type="chat",
actual_type=model_config.type
)
# Note: Multimodal capability validation is handled by the model itself
# If a model doesn't support certain content types, it will raise an error
# For example, Cohere models will raise ValueError for non-text content
# Get OCI client from manager (轮询负载均衡)
client_manager = get_client_manager()
oci_client = client_manager.get_client()
# Adapt messages
messages = adapt_chat_messages([msg.dict() for msg in request.messages])
# Extract parameters
params = extract_chat_params(request)
# Check global streaming setting
# If streaming is globally disabled, override client request
enable_stream = request.stream and settings.enable_streaming
if not settings.enable_streaming and request.stream:
logger.info("Streaming requested but globally disabled via ENABLE_STREAMING=false")
# Handle streaming
if enable_stream:
request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
async def generate_stream() -> AsyncIterator[str]:
"""Generate streaming response with true non-blocking streaming."""
try:
# Run OCI SDK call in executor to prevent blocking
# This is critical for achieving true streaming (msToFirstChunk < 1s)
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: oci_client.chat(
model_id=request.model,
messages=messages,
temperature=params["temperature"],
max_tokens=params["max_tokens"],
top_p=params["top_p"],
stream=True, # Enable real streaming
tools=params.get("tools"),
)
)
# Process real streaming response
accumulated_usage = None
# Check if response.data is an SSE stream (iterable)
# When stream=True, OCI SDK returns response.data as SSEClient
try:
# Try to iterate over the stream
stream_data = response.data if hasattr(response, 'data') else response
# Check if it's SSEClient or any iterable type
stream_type_name = type(stream_data).__name__
is_sse_client = 'SSEClient' in stream_type_name
is_iterable = hasattr(stream_data, '__iter__') or hasattr(stream_data, '__next__')
# SSEClient is always treated as streaming, even if hasattr check fails
if is_sse_client or is_iterable:
# Real streaming: iterate over chunks
# SSEClient requires calling .events() method to iterate
if is_sse_client and hasattr(stream_data, 'events'):
iterator = stream_data.events()
else:
iterator = stream_data
# Send first chunk with role and empty content (OpenAI format)
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
# Use queue for thread-safe chunk forwarding
import queue
import threading
chunk_queue = queue.Queue()
def read_chunks():
"""Read chunks in background thread and put in queue."""
try:
for chunk in iterator:
chunk_queue.put(("chunk", chunk))
chunk_queue.put(("done", None))
except Exception as e:
chunk_queue.put(("error", e))
# Start background thread to read chunks
reader_thread = threading.Thread(target=read_chunks, daemon=True)
reader_thread.start()
# Yield chunks as they arrive from queue
while True:
# Non-blocking queue get with timeout
try:
msg_type, data = await loop.run_in_executor(
None,
lambda: chunk_queue.get(timeout=0.01)
)
except queue.Empty:
# Allow other async tasks to run
await asyncio.sleep(0)
continue
if msg_type == "done":
break
elif msg_type == "error":
raise data
elif msg_type == "chunk":
chunk = data
# Extract delta content from chunk
delta_text = extract_delta_from_chunk(chunk)
if delta_text:
yield adapt_streaming_chunk(delta_text, request.model, request_id, 0, is_first=False)
# Try to extract usage from chunk (typically in final chunk)
# Handle both SSE Event format and object format
if hasattr(chunk, 'data'):
# SSE Event - parse JSON to extract usage
try:
import json
parsed = json.loads(chunk.data)
if isinstance(parsed, dict) and 'usage' in parsed:
usage_data = parsed['usage']
accumulated_usage = {
"prompt_tokens": usage_data.get('promptTokens', 0) or 0,
"completion_tokens": usage_data.get('completionTokens', 0) or 0,
"total_tokens": usage_data.get('totalTokens', 0) or 0
}
except (json.JSONDecodeError, KeyError, TypeError):
pass
elif hasattr(chunk, 'usage') and chunk.usage:
# Object format
accumulated_usage = {
"prompt_tokens": getattr(chunk.usage, 'prompt_tokens', 0) or 0,
"completion_tokens": getattr(chunk.usage, 'completion_tokens', 0) or 0,
"total_tokens": getattr(chunk.usage, 'total_tokens', 0) or 0
}
# Send done message with usage
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
else:
# Fallback: non-streaming response, simulate streaming
logger.warning(f"OCI SDK returned non-iterable response (type: {type(stream_data).__name__}), falling back to simulated streaming")
# Extract text from non-streaming response
chat_response = stream_data.chat_response if hasattr(stream_data, 'chat_response') else stream_data
content = extract_content_from_response(chat_response)
# Extract usage information
if hasattr(stream_data, 'usage'):
oci_usage = stream_data.usage
accumulated_usage = {
"prompt_tokens": getattr(oci_usage, 'prompt_tokens', 0) or 0,
"completion_tokens": getattr(oci_usage, 'completion_tokens', 0) or 0,
"total_tokens": getattr(oci_usage, 'total_tokens', 0) or 0
}
# Simulate streaming by chunking
# First send empty chunk with role (OpenAI format)
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
chunk_size = settings.stream_chunk_size
for i in range(0, len(content), chunk_size):
chunk = content[i:i + chunk_size]
yield adapt_streaming_chunk(chunk, request.model, request_id, 0, is_first=False)
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
except TypeError as te:
# Handle case where response is not iterable at all
logger.error(f"Response is not iterable: {te}", exc_info=True)
raise
except Exception as e:
logger.error(f"Error in streaming: {str(e)}", exc_info=True)
import json
# 根据异常类型处理并过滤敏感信息
if isinstance(e, ServiceError):
error_response = OCIErrorHandler.sanitize_oci_error(e)
else:
# 通用错误也要过滤可能包含的敏感信息
filtered_msg = OCIErrorHandler.filter_sensitive_info(str(e))
error_response = ErrorResponse(
error=ErrorDetail(
message="An error occurred during streaming",
type="server_error",
code="streaming_error"
)
)
yield f"data: {json.dumps(error_response.dict(), ensure_ascii=False)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
# Non-streaming response
try:
response = oci_client.chat(
model_id=request.model,
messages=messages,
temperature=params["temperature"],
max_tokens=params["max_tokens"],
top_p=params["top_p"],
stream=False,
tools=params.get("tools"),
)
# Adapt response to OpenAI format
openai_response = adapt_chat_response(response, request.model)
if settings.log_responses:
logger.debug(f"Response: {openai_response}")
return openai_response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
# 直接 raise让全局异常处理器统一过滤敏感信息
raise