第一次提交
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
This commit is contained in:
417
src/api/routers/chat.py
Normal file
417
src/api/routers/chat.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Chat completions API router - OpenAI compatible chat endpoint.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncIterator, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from oci.exceptions import ServiceError
|
||||
|
||||
from api.auth import get_api_key
|
||||
from api.schemas import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail
|
||||
from api.error_handler import OCIErrorHandler
|
||||
from api.exceptions import ModelNotFoundException, InvalidModelTypeException
|
||||
from api.adapters.request_adapter import adapt_chat_messages, extract_chat_params
|
||||
from api.adapters.response_adapter import (
|
||||
adapt_chat_response,
|
||||
adapt_streaming_chunk,
|
||||
adapt_streaming_done,
|
||||
)
|
||||
from core.config import get_settings
|
||||
from core.client_manager import get_client_manager
|
||||
from core.models import get_model_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/chat",
|
||||
tags=["chat"],
|
||||
dependencies=[Depends(get_api_key)]
|
||||
)
|
||||
|
||||
|
||||
def extract_delta_from_chunk(chunk) -> str:
|
||||
"""
|
||||
Extract delta text content from OCI streaming chunk.
|
||||
|
||||
Args:
|
||||
chunk: OCI streaming response chunk (can be SSE Event, parsed object, etc.)
|
||||
|
||||
Returns:
|
||||
Delta text content or empty string
|
||||
"""
|
||||
try:
|
||||
# Handle SSE Event objects (from SSEClient)
|
||||
if hasattr(chunk, 'data'):
|
||||
import json
|
||||
# Parse JSON data from SSE event
|
||||
try:
|
||||
parsed = json.loads(chunk.data)
|
||||
|
||||
# Recursively extract from parsed object
|
||||
if isinstance(parsed, dict):
|
||||
# OCI Streaming format: message.content[].text
|
||||
if 'message' in parsed and 'content' in parsed['message']:
|
||||
content_array = parsed['message']['content']
|
||||
if isinstance(content_array, list) and len(content_array) > 0:
|
||||
# Extract text from all TEXT type content items
|
||||
text_parts = []
|
||||
for item in content_array:
|
||||
if isinstance(item, dict) and item.get('type') == 'TEXT' and 'text' in item:
|
||||
text_parts.append(item['text'])
|
||||
if text_parts:
|
||||
return ''.join(text_parts)
|
||||
|
||||
# Try to get text from various possible locations
|
||||
if 'text' in parsed:
|
||||
return parsed['text']
|
||||
if 'chatResponse' in parsed and 'text' in parsed['chatResponse']:
|
||||
return parsed['chatResponse']['text']
|
||||
if 'choices' in parsed and len(parsed['choices']) > 0:
|
||||
choice = parsed['choices'][0]
|
||||
if 'delta' in choice and 'content' in choice['delta']:
|
||||
return choice['delta']['content']
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
# Return raw data if not JSON
|
||||
return str(chunk.data) if chunk.data else ""
|
||||
|
||||
# Try to extract from chat_response.text (Cohere format)
|
||||
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'text'):
|
||||
return chunk.chat_response.text
|
||||
|
||||
# Try to extract from choices[0].delta.content (Generic format)
|
||||
if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'choices'):
|
||||
if len(chunk.chat_response.choices) > 0:
|
||||
choice = chunk.chat_response.choices[0]
|
||||
if hasattr(choice, 'delta') and hasattr(choice.delta, 'content'):
|
||||
content = choice.delta.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Handle TextContent list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
text_parts.append(item['text'])
|
||||
elif hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
return "".join(text_parts)
|
||||
|
||||
# Try direct text attribute
|
||||
if hasattr(chunk, 'text'):
|
||||
return chunk.text
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract delta from chunk: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def extract_content_from_response(chat_response) -> str:
|
||||
"""
|
||||
Extract full content from non-streaming OCI response.
|
||||
|
||||
Args:
|
||||
chat_response: OCI chat response object
|
||||
|
||||
Returns:
|
||||
Full text content
|
||||
"""
|
||||
if hasattr(chat_response, 'text'):
|
||||
raw_text = chat_response.text
|
||||
# Try to parse as JSON if it's a string (OCI format)
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(raw_text)
|
||||
if isinstance(parsed, dict) and 'text' in parsed:
|
||||
return parsed['text']
|
||||
return raw_text
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
return raw_text
|
||||
|
||||
elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0:
|
||||
choice = chat_response.choices[0]
|
||||
if hasattr(choice, 'message'):
|
||||
raw_content = choice.message.content
|
||||
# Handle list format
|
||||
if isinstance(raw_content, list):
|
||||
text_parts = []
|
||||
for item in raw_content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
return "".join(text_parts)
|
||||
elif isinstance(raw_content, str):
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(raw_content)
|
||||
if isinstance(parsed, dict) and 'text' in parsed:
|
||||
return parsed['text']
|
||||
return raw_content
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return raw_content
|
||||
else:
|
||||
return str(raw_content)
|
||||
return str(choice)
|
||||
|
||||
return str(chat_response)
|
||||
|
||||
|
||||
@router.post("/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
"""
|
||||
Create a chat completion using OCI Generative AI.
|
||||
|
||||
Args:
|
||||
request: Chat completion request
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
logger.info(f"Chat completion request for model: {request.model}")
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Validate model exists
|
||||
model_config = get_model_config(request.model)
|
||||
if not model_config:
|
||||
raise ModelNotFoundException(request.model)
|
||||
|
||||
# Validate model type is chat (ondemand or dedicated)
|
||||
if model_config.type not in ("ondemand", "dedicated"):
|
||||
raise InvalidModelTypeException(
|
||||
model_id=request.model,
|
||||
expected_type="chat",
|
||||
actual_type=model_config.type
|
||||
)
|
||||
|
||||
# Note: Multimodal capability validation is handled by the model itself
|
||||
# If a model doesn't support certain content types, it will raise an error
|
||||
# For example, Cohere models will raise ValueError for non-text content
|
||||
|
||||
# Get OCI client from manager (轮询负载均衡)
|
||||
client_manager = get_client_manager()
|
||||
oci_client = client_manager.get_client()
|
||||
|
||||
# Adapt messages
|
||||
messages = adapt_chat_messages([msg.dict() for msg in request.messages])
|
||||
|
||||
# Extract parameters
|
||||
params = extract_chat_params(request)
|
||||
|
||||
# Check global streaming setting
|
||||
# If streaming is globally disabled, override client request
|
||||
enable_stream = request.stream and settings.enable_streaming
|
||||
|
||||
if not settings.enable_streaming and request.stream:
|
||||
logger.info("Streaming requested but globally disabled via ENABLE_STREAMING=false")
|
||||
|
||||
# Handle streaming
|
||||
if enable_stream:
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def generate_stream() -> AsyncIterator[str]:
|
||||
"""Generate streaming response with true non-blocking streaming."""
|
||||
try:
|
||||
# Run OCI SDK call in executor to prevent blocking
|
||||
# This is critical for achieving true streaming (msToFirstChunk < 1s)
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: oci_client.chat(
|
||||
model_id=request.model,
|
||||
messages=messages,
|
||||
temperature=params["temperature"],
|
||||
max_tokens=params["max_tokens"],
|
||||
top_p=params["top_p"],
|
||||
stream=True, # Enable real streaming
|
||||
tools=params.get("tools"),
|
||||
)
|
||||
)
|
||||
|
||||
# Process real streaming response
|
||||
accumulated_usage = None
|
||||
|
||||
# Check if response.data is an SSE stream (iterable)
|
||||
# When stream=True, OCI SDK returns response.data as SSEClient
|
||||
try:
|
||||
# Try to iterate over the stream
|
||||
stream_data = response.data if hasattr(response, 'data') else response
|
||||
|
||||
# Check if it's SSEClient or any iterable type
|
||||
stream_type_name = type(stream_data).__name__
|
||||
is_sse_client = 'SSEClient' in stream_type_name
|
||||
is_iterable = hasattr(stream_data, '__iter__') or hasattr(stream_data, '__next__')
|
||||
|
||||
# SSEClient is always treated as streaming, even if hasattr check fails
|
||||
if is_sse_client or is_iterable:
|
||||
# Real streaming: iterate over chunks
|
||||
# SSEClient requires calling .events() method to iterate
|
||||
if is_sse_client and hasattr(stream_data, 'events'):
|
||||
iterator = stream_data.events()
|
||||
else:
|
||||
iterator = stream_data
|
||||
|
||||
# Send first chunk with role and empty content (OpenAI format)
|
||||
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
|
||||
|
||||
# Use queue for thread-safe chunk forwarding
|
||||
import queue
|
||||
import threading
|
||||
chunk_queue = queue.Queue()
|
||||
|
||||
def read_chunks():
|
||||
"""Read chunks in background thread and put in queue."""
|
||||
try:
|
||||
for chunk in iterator:
|
||||
chunk_queue.put(("chunk", chunk))
|
||||
chunk_queue.put(("done", None))
|
||||
except Exception as e:
|
||||
chunk_queue.put(("error", e))
|
||||
|
||||
# Start background thread to read chunks
|
||||
reader_thread = threading.Thread(target=read_chunks, daemon=True)
|
||||
reader_thread.start()
|
||||
|
||||
# Yield chunks as they arrive from queue
|
||||
while True:
|
||||
# Non-blocking queue get with timeout
|
||||
try:
|
||||
msg_type, data = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: chunk_queue.get(timeout=0.01)
|
||||
)
|
||||
except queue.Empty:
|
||||
# Allow other async tasks to run
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
|
||||
if msg_type == "done":
|
||||
break
|
||||
elif msg_type == "error":
|
||||
raise data
|
||||
elif msg_type == "chunk":
|
||||
chunk = data
|
||||
# Extract delta content from chunk
|
||||
delta_text = extract_delta_from_chunk(chunk)
|
||||
|
||||
if delta_text:
|
||||
yield adapt_streaming_chunk(delta_text, request.model, request_id, 0, is_first=False)
|
||||
|
||||
# Try to extract usage from chunk (typically in final chunk)
|
||||
# Handle both SSE Event format and object format
|
||||
if hasattr(chunk, 'data'):
|
||||
# SSE Event - parse JSON to extract usage
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(chunk.data)
|
||||
if isinstance(parsed, dict) and 'usage' in parsed:
|
||||
usage_data = parsed['usage']
|
||||
accumulated_usage = {
|
||||
"prompt_tokens": usage_data.get('promptTokens', 0) or 0,
|
||||
"completion_tokens": usage_data.get('completionTokens', 0) or 0,
|
||||
"total_tokens": usage_data.get('totalTokens', 0) or 0
|
||||
}
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
pass
|
||||
elif hasattr(chunk, 'usage') and chunk.usage:
|
||||
# Object format
|
||||
accumulated_usage = {
|
||||
"prompt_tokens": getattr(chunk.usage, 'prompt_tokens', 0) or 0,
|
||||
"completion_tokens": getattr(chunk.usage, 'completion_tokens', 0) or 0,
|
||||
"total_tokens": getattr(chunk.usage, 'total_tokens', 0) or 0
|
||||
}
|
||||
|
||||
# Send done message with usage
|
||||
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
|
||||
|
||||
else:
|
||||
# Fallback: non-streaming response, simulate streaming
|
||||
logger.warning(f"OCI SDK returned non-iterable response (type: {type(stream_data).__name__}), falling back to simulated streaming")
|
||||
|
||||
# Extract text from non-streaming response
|
||||
chat_response = stream_data.chat_response if hasattr(stream_data, 'chat_response') else stream_data
|
||||
content = extract_content_from_response(chat_response)
|
||||
|
||||
# Extract usage information
|
||||
if hasattr(stream_data, 'usage'):
|
||||
oci_usage = stream_data.usage
|
||||
accumulated_usage = {
|
||||
"prompt_tokens": getattr(oci_usage, 'prompt_tokens', 0) or 0,
|
||||
"completion_tokens": getattr(oci_usage, 'completion_tokens', 0) or 0,
|
||||
"total_tokens": getattr(oci_usage, 'total_tokens', 0) or 0
|
||||
}
|
||||
|
||||
# Simulate streaming by chunking
|
||||
# First send empty chunk with role (OpenAI format)
|
||||
yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True)
|
||||
|
||||
chunk_size = settings.stream_chunk_size
|
||||
for i in range(0, len(content), chunk_size):
|
||||
chunk = content[i:i + chunk_size]
|
||||
yield adapt_streaming_chunk(chunk, request.model, request_id, 0, is_first=False)
|
||||
|
||||
yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage)
|
||||
|
||||
except TypeError as te:
|
||||
# Handle case where response is not iterable at all
|
||||
logger.error(f"Response is not iterable: {te}", exc_info=True)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming: {str(e)}", exc_info=True)
|
||||
import json
|
||||
|
||||
# 根据异常类型处理并过滤敏感信息
|
||||
if isinstance(e, ServiceError):
|
||||
error_response = OCIErrorHandler.sanitize_oci_error(e)
|
||||
else:
|
||||
# 通用错误也要过滤可能包含的敏感信息
|
||||
filtered_msg = OCIErrorHandler.filter_sensitive_info(str(e))
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorDetail(
|
||||
message="An error occurred during streaming",
|
||||
type="server_error",
|
||||
code="streaming_error"
|
||||
)
|
||||
)
|
||||
|
||||
yield f"data: {json.dumps(error_response.dict(), ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
try:
|
||||
response = oci_client.chat(
|
||||
model_id=request.model,
|
||||
messages=messages,
|
||||
temperature=params["temperature"],
|
||||
max_tokens=params["max_tokens"],
|
||||
top_p=params["top_p"],
|
||||
stream=False,
|
||||
tools=params.get("tools"),
|
||||
)
|
||||
|
||||
# Adapt response to OpenAI format
|
||||
openai_response = adapt_chat_response(response, request.model)
|
||||
|
||||
if settings.log_responses:
|
||||
logger.debug(f"Response: {openai_response}")
|
||||
|
||||
return openai_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
||||
# 直接 raise,让全局异常处理器统一过滤敏感信息
|
||||
raise
|
||||
Reference in New Issue
Block a user