Files
oracle-openai/src/main.py
Wang Defa ba7ec48c4f
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 34s
新增请求/响应日志中间件,支持详细的请求和响应信息记录
2025-12-09 17:46:07 +08:00

279 lines
8.3 KiB
Python

"""
Main FastAPI application for OCI Generative AI to OpenAI API Gateway.
"""
import logging
import sys
import os
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from oci.exceptions import ServiceError
from core.config import get_settings
from core.models import update_models_from_oci
from api.routers import models, chat, embeddings
from api.schemas import ErrorResponse, ErrorDetail
from api.error_handler import OCIErrorHandler
from api.exceptions import ModelNotFoundException, InvalidModelTypeException
from api.middleware import setup_logging_middleware
# Configure logging
def setup_logging():
"""Setup logging configuration."""
settings = get_settings()
# Create handlers list
handlers = [
logging.StreamHandler(sys.stdout)
]
# Add file handler if log_file is configured
if settings.log_file:
log_dir = os.path.dirname(settings.log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
file_handler = RotatingFileHandler(
settings.log_file,
maxBytes=settings.log_file_max_size * 1024 * 1024, # Convert MB to bytes
backupCount=settings.log_file_backup_count,
encoding='utf-8'
)
handlers.append(file_handler)
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=handlers
)
setup_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler."""
logger.info("=" * 60)
logger.info("Starting OCI GenAI to OpenAI API Gateway")
logger.info("=" * 60)
settings = get_settings()
logger.info(f"API Version: {settings.api_version}")
logger.info(f"API Prefix: {settings.api_prefix}")
logger.info(f"Debug Mode: {settings.debug}")
logger.info(f"OCI Config: {settings.oci_config_file}")
profiles = settings.get_profiles()
logger.info(f"OCI Profiles: {', '.join(profiles)}")
try:
# Fetch models from OCI (fails fast if unable to fetch)
# 使用第一个 profile 进行模型发现
update_models_from_oci(
config_path=settings.oci_config_file,
profile=profiles[0] if profiles else "DEFAULT"
)
logger.info("=" * 60)
logger.info("✅ Startup completed successfully")
logger.info(f"Server listening on {settings.api_host}:{settings.api_port}")
logger.info("=" * 60)
except RuntimeError as e:
logger.error("=" * 60)
logger.error("❌ STARTUP FAILED")
logger.error("=" * 60)
logger.error(f"Reason: {str(e)}")
logger.error("")
logger.error("The service cannot start without available models from OCI.")
logger.error("Please review the troubleshooting steps above and fix the issue.")
logger.error("=" * 60)
raise
except Exception as e:
logger.error("=" * 60)
logger.error("❌ UNEXPECTED STARTUP ERROR")
logger.error("=" * 60)
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
logger.error("=" * 60)
raise
yield
logger.info("=" * 60)
logger.info("Shutting down OCI GenAI to OpenAI API Gateway")
logger.info("=" * 60)
# Create FastAPI app
settings = get_settings()
app = FastAPI(
title=settings.api_title,
version=settings.api_version,
description="OpenAI-compatible REST API for Oracle Cloud Infrastructure Generative AI Service",
lifespan=lifespan,
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add logging middleware (for request/response debugging)
setup_logging_middleware(app)
# Exception handlers
@app.exception_handler(ModelNotFoundException)
async def model_not_found_handler(request: Request, exc: ModelNotFoundException):
"""Handle model not found exceptions with OpenAI-compatible format."""
error = ErrorDetail(
message=exc.detail,
type=exc.error_type,
code=exc.error_code
)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(error=error).dict()
)
@app.exception_handler(InvalidModelTypeException)
async def invalid_model_type_handler(request: Request, exc: InvalidModelTypeException):
"""Handle invalid model type exceptions with OpenAI-compatible format."""
error = ErrorDetail(
message=exc.detail,
type=exc.error_type,
code=exc.error_code
)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(error=error).dict()
)
@app.exception_handler(ServiceError)
async def oci_service_error_handler(request: Request, exc: ServiceError):
"""Handle OCI SDK ServiceError exceptions."""
# 使用 OCIErrorHandler 处理并过滤敏感信息
error_response = OCIErrorHandler.sanitize_oci_error(exc)
# 确定 HTTP 状态码(使用 OCI 返回的状态码)
status_code = exc.status if 400 <= exc.status < 600 else 500
return JSONResponse(
status_code=status_code,
content=error_response.dict()
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions with sensitive information filtering."""
# 过滤 HTTPException detail 中可能包含的敏感信息
filtered_detail = OCIErrorHandler.filter_sensitive_info(str(exc.detail))
error = ErrorDetail(
message=filtered_detail,
type="invalid_request_error",
code=f"http_{exc.status_code}"
)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(error=error).dict()
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
logger.error(f"Validation error: {exc}")
error = ErrorDetail(
message=str(exc),
type="invalid_request_error",
code="validation_error"
)
return JSONResponse(
status_code=400,
content=ErrorResponse(error=error).dict()
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions with sensitive information filtering."""
logger.error(f"Unexpected error: {exc}", exc_info=True)
# 通用错误也要过滤可能包含的敏感信息(完整错误已记录到日志)
filtered_message = OCIErrorHandler.filter_sensitive_info(str(exc))
error = ErrorDetail(
message="An unexpected error occurred", # 不暴露具体错误
type="server_error",
code="internal_error"
)
return JSONResponse(
status_code=500,
content=ErrorResponse(error=error).dict()
)
# Include routers
app.include_router(models.router, prefix=settings.api_prefix)
app.include_router(chat.router, prefix=settings.api_prefix)
app.include_router(embeddings.router, prefix=settings.api_prefix)
@app.get("/")
async def root():
"""Root endpoint."""
return {
"name": settings.api_title,
"version": settings.api_version,
"description": "OpenAI-compatible REST API for OCI Generative AI",
"endpoints": {
"models": f"{settings.api_prefix}/models",
"chat": f"{settings.api_prefix}/chat/completions",
"embeddings": f"{settings.api_prefix}/embeddings"
}
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "healthy",
"service": "oci-genai-gateway"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.api_host,
port=settings.api_port,
reload=settings.debug,
log_level=settings.log_level.lower()
)