From 42222744c76899e1a3a97e8b990d192d5b9f1782 Mon Sep 17 00:00:00 2001 From: Wang Defa Date: Tue, 9 Dec 2025 14:44:09 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 65 +++++ .gitea/workflows/ci.yaml | 82 ++++++ .gitignore | 78 +++++ Dockerfile | 47 +++ LICENSE | 21 ++ README.md | 240 +++++++++++++++ docker-compose.yml | 35 +++ init.sh | 24 ++ requirements.txt | 12 + src/api/__init__.py | 1 + src/api/adapters/__init__.py | 1 + src/api/adapters/request_adapter.py | 104 +++++++ src/api/adapters/response_adapter.py | 307 ++++++++++++++++++++ src/api/auth.py | 57 ++++ src/api/error_handler.py | 161 +++++++++++ src/api/exceptions.py | 60 ++++ src/api/routers/__init__.py | 1 + src/api/routers/chat.py | 417 +++++++++++++++++++++++++++ src/api/routers/embeddings.py | 85 ++++++ src/api/routers/models.py | 78 +++++ src/api/schemas.py | 139 +++++++++ src/core/__init__.py | 1 + src/core/client_manager.py | 70 +++++ src/core/config.py | 100 +++++++ src/core/models.py | 260 +++++++++++++++++ src/core/oci_client.py | 361 +++++++++++++++++++++++ src/main.py | 274 ++++++++++++++++++ 27 files changed, 3081 insertions(+) create mode 100644 .env.example create mode 100644 .gitea/workflows/ci.yaml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docker-compose.yml create mode 100644 init.sh create mode 100644 requirements.txt create mode 100644 src/api/__init__.py create mode 100644 src/api/adapters/__init__.py create mode 100644 src/api/adapters/request_adapter.py create mode 100644 src/api/adapters/response_adapter.py create mode 100644 src/api/auth.py create mode 100644 src/api/error_handler.py create mode 100644 src/api/exceptions.py create mode 100644 src/api/routers/__init__.py create mode 100644 src/api/routers/chat.py create mode 100644 src/api/routers/embeddings.py create mode 100644 src/api/routers/models.py create mode 100644 src/api/schemas.py create mode 100644 src/core/__init__.py create mode 100644 src/core/client_manager.py create mode 100644 src/core/config.py create mode 100644 src/core/models.py create mode 100644 src/core/oci_client.py create mode 100644 src/main.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..5a7a4a7 --- /dev/null +++ b/.env.example @@ -0,0 +1,65 @@ +# API Settings +API_TITLE=OCI GenAI to OpenAI API Gateway +API_VERSION=0.0.1 +API_PREFIX=/v1 +API_PORT=8000 +API_HOST=0.0.0.0 +DEBUG=false + +# Authentication +# Comma-separated list of API keys for authentication +# These are the keys clients will use in Authorization: Bearer +API_KEYS=["sk-oci-genai-default-key"] + +# ============================================ +# OCI Configuration +# ============================================ +# Path to OCI config file (usually ~/.oci/config) +OCI_CONFIG_FILE=~/.oci/config + +# Profile names in the OCI config file +# 支持单个或多个 profile,多个 profile 用逗号分隔 +# 多个 profile 时会自动使用轮询(round-robin)负载均衡 +# 示例: +# 单配置:OCI_CONFIG_PROFILE=DEFAULT +# 多配置:OCI_CONFIG_PROFILE=DEFAULT,CHICAGO,ASHBURN +# 注意:每个 profile 在 ~/.oci/config 中必须包含 region 和 tenancy (作为 compartment_id) +OCI_CONFIG_PROFILE=DEFAULT + +# Authentication type: api_key or instance_principal +OCI_AUTH_TYPE=api_key + +# Optional: Direct endpoint for dedicated models +# GENAI_ENDPOINT=https://your-dedicated-endpoint + +# Model Settings +# Note: Available models are dynamically loaded from OCI at startup +# Use GET /v1/models to see all available models +MAX_TOKENS=4096 +TEMPERATURE=0.7 + +# Embedding Settings +# Truncate strategy for embeddings: END or START +EMBED_TRUNCATE=END + +# Streaming Settings +# Global streaming on/off switch +# Set to false to disable streaming for all requests (overrides client stream=true) +ENABLE_STREAMING=true +# Chunk size for simulated streaming (fallback mode only) +# Only used when OCI returns non-streaming response +STREAM_CHUNK_SIZE=1024 + +# Logging +# Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL +LOG_LEVEL=INFO +# Log incoming requests (may contain sensitive data) +LOG_REQUESTS=false +# Log responses (may contain sensitive data) +LOG_RESPONSES=false +# Log file path (optional, if not set logs only to console) +LOG_FILE=./logs/app.log +# Max log file size in MB (default: 10) +LOG_FILE_MAX_SIZE=10 +# Number of backup log files to keep (default: 5) +LOG_FILE_BACKUP_COUNT=5 diff --git a/.gitea/workflows/ci.yaml b/.gitea/workflows/ci.yaml new file mode 100644 index 0000000..87d9679 --- /dev/null +++ b/.gitea/workflows/ci.yaml @@ -0,0 +1,82 @@ +# .gitea/workflows/ci.yaml +name: Build and Push OCI GenAI Gateway Docker Image + +on: + push: + branches: [main, develop] + tags: ['*'] + +env: + DOCKER_BUILDKIT: "1" + BUILDX_NO_DEFAULT_ATTESTATIONS: "1" + +jobs: + docker-build-push: + runs-on: ubuntu-latest-amd64 + steps: + - uses: actions/checkout@v4 + + - name: Debug branch info + run: | + echo "📋 Branch Information:" + echo " github.ref: ${{ github.ref }}" + echo " github.ref_name: ${{ github.ref_name }}" + echo " github.event_name: ${{ github.event_name }}" + + - name: Setup Docker Buildx and Login + run: | + # 设置 QEMU 支持多架构 + docker run --rm --privileged multiarch/qemu-user-static --reset -p yes 2>/dev/null || true + + # 创建 buildx builder + docker buildx create --use --name oci_genai_builder \ + --driver docker-container \ + --driver-opt network=host \ + --driver-opt image=moby/buildkit:buildx-stable-1 \ + --driver-opt env.BUILDKIT_STEP_LOG_MAX_SIZE=50000000 \ + --driver-opt env.BUILDKIT_STEP_LOG_MAX_SPEED=10000000 \ + || docker buildx use oci_genai_builder + docker buildx inspect --bootstrap + + # 登录 Docker Registry + echo "${{ secrets.BUILD_TOKEN }}" | docker login ${{ gitea.server_url }} -u ${{ gitea.actor }} --password-stdin + + - name: Determine Docker tag + id: tag + run: | + if [ "${{ github.ref_name }}" = "main" ]; then + TAG="latest" + elif [ "${{ github.ref_name }}" = "develop" ]; then + TAG="develop" + elif [[ "${{ github.ref }}" == refs/tags/* ]]; then + TAG="${{ github.ref_name }}" + else + TAG="${{ github.ref_name }}" + fi + echo "tag=${TAG}" >> $GITHUB_OUTPUT + echo "📦 Docker tag: ${TAG}" + + - name: Build and push multi-arch Docker image + run: | + # 移除 URL 中的 https:// 前缀 + REGISTRY=$(echo "${{ gitea.server_url }}" | sed 's|https\?://||') + IMAGE_NAME="${REGISTRY}/${{ gitea.repository }}" + TAG="${{ steps.tag.outputs.tag }}" + FINAL_IMAGE_TAG="${IMAGE_NAME}:${TAG}" + + echo "🏗️ Building and pushing image: ${FINAL_IMAGE_TAG}" + echo " Platforms: linux/amd64, linux/arm64" + + # 设置 BuildKit 优化参数 + export BUILDKIT_PROGRESS=plain + + docker buildx build --pull --push \ + -t "${FINAL_IMAGE_TAG}" \ + --platform linux/amd64,linux/arm64 \ + --provenance=false \ + --sbom=false \ + -f Dockerfile . + + echo "" + echo "✅ Build and push completed!" + echo "🐳 Image: ${FINAL_IMAGE_TAG}" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e2b58f --- /dev/null +++ b/.gitignore @@ -0,0 +1,78 @@ +# Claude +.claude/ +CLAUDE.md +.mcp.json + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Environment variables +.env +.env.local + +# OCI Config (contains sensitive keys) +.oci/ +*.pem + +# Logs +*.log +logs/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# Distribution +*.tar.gz +*.whl + +# Docker +*.dockerfile.swp + +# Source repositories +.source/ + +# Temporary files +tmp/ +temp/ +*.tmp +example/ + +# OS +.DS_Store +Thumbs.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8d114fc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,47 @@ +# Multi-stage build for OCI GenAI to OpenAI API Gateway +FROM python:3.11-slim as builder + +# 设置工作目录 +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . + +# 安装 Python 依赖 +RUN pip install --no-cache-dir --user -r requirements.txt + +# 最终镜像 +FROM python:3.11-slim + +# 设置环境变量 +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PATH=/root/.local/bin:$PATH + +# 设置工作目录 +WORKDIR /app + +# 复制 Python 依赖 +COPY --from=builder /root/.local /root/.local + +# 复制应用代码 +COPY src/ ./src/ +COPY .env.example .env + +# 创建日志目录 +RUN mkdir -p /app/logs + +# 暴露端口 +EXPOSE 8000 + +# 健康检查 +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health').read()" + +# 启动应用 +CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..aee9d63 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 OCI GenAI Gateway + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..68ddc15 --- /dev/null +++ b/README.md @@ -0,0 +1,240 @@ +# OCI GenAI to OpenAI API 网关 + +> 🚀 为 Oracle Cloud Infrastructure 的 Generative AI Service 提供 OpenAI 兼容的 REST API + +[![License](https://img.shields.io/badge/license-UPL-blue.svg)](LICENSE) +[![Python](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) +[![FastAPI](https://img.shields.io/badge/FastAPI-0.115.0-green.svg)](https://fastapi.tiangolo.com/) + +## 📖 简介 + +这是一个 FastAPI 服务,作为 OCI Generative AI 和 OpenAI API 之间的转换层,允许 OpenAI SDK 客户端无需修改代码即可与 OCI GenAI 模型交互。 + +## ✨ 主要特性 + +- 🔄 **OpenAI API 兼容**: 完全兼容 OpenAI SDK,无需修改现有代码 +- 🤖 **动态模型发现**: 启动时自动从 OCI 获取所有可用模型 +- 🌐 **多区域负载均衡**: 支持多个 OCI profiles 的 round-robin 负载均衡 +- 🖼️ **多模态支持**: 支持文本、图像(Vision 模型)、Base64 编码等多种内容类型 +- ⚡ **真实流式传输**: 真正的边缘到边缘流式响应,TTFB < 200ms +- 🔒 **安全性**: 自动过滤敏感信息(OCID、request-id、endpoint URLs) +- 🎯 **性能优化**: 客户端连接池机制,显著提升性能 + +## 🚀 快速开始 + +### 前置要求 + +- Python 3.8+ +- OCI 账号和 API 密钥 +- OCI Generative AI 服务访问权限 + +### 安装 + +1. **克隆仓库** + ```bash + git clone + cd oracle-openai + ``` + +2. **安装依赖** + ```bash + pip install -r requirements.txt + ``` + +3. **配置 OCI** + + 创建或编辑 `~/.oci/config`: + ```ini + [DEFAULT] + user=ocid1.user.oc1... + fingerprint=aa:bb:cc:dd... + key_file=~/.oci/oci_api_key.pem + tenancy=ocid1.tenancy.oc1... + region=us-chicago-1 + ``` + +4. **配置环境变量** + + 复制 `.env.example` 到 `.env` 并编辑: + ```bash + cp .env.example .env + # 编辑 .env 文件设置 API_KEYS 和其他配置 + ``` + +5. **运行服务** + ```bash + cd src + python main.py + ``` + + 服务将在 `http://localhost:8000` 启动 + +## 💻 使用示例 + +### 使用 cURL + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-oci-genai-default-key" \ + -d '{ + "model": "google.gemini-2.5-pro", + "messages": [{"role": "user", "content": "你好!"}] + }' +``` + +### 使用 Python OpenAI SDK + +```python +from openai import OpenAI + +client = OpenAI( + api_key="sk-oci-genai-default-key", + base_url="http://localhost:8000/v1" +) + +response = client.chat.completions.create( + model="google.gemini-2.5-pro", + messages=[{"role": "user", "content": "你好!"}] +) + +print(response.choices[0].message.content) +``` + +### 流式响应 + +```python +stream = client.chat.completions.create( + model="google.gemini-2.5-pro", + messages=[{"role": "user", "content": "从1数到10"}], + stream=True +) + +for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) +``` + +### Vision 模型(多模态) + +```python +response = client.chat.completions.create( + model="google.gemini-2.5-pro", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "描述这张图片"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg" + } + } + ] + } + ] +) +``` + +## 📋 支持的端点 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/health` | GET | 健康检查 | +| `/v1/models` | GET | 列出所有可用模型 | +| `/v1/chat/completions` | POST | 对话补全(支持流式) | +| `/v1/embeddings` | POST | 文本嵌入 | + +## 🎨 支持的模型 + +服务启动时自动从 OCI 发现可用模型,包括: + +- **Cohere**: command-r-plus, command-r-16k 等 +- **Meta**: llama-3.1-405b, llama-3.1-70b, llama-3.2-90b-vision 等 +- **Google**: gemini 系列 +- **OpenAI**: gpt 系列 +- **xAI**: grok 系列 + +使用 `GET /v1/models` 查看所有可用模型。 + +## ⚙️ 配置选项 + +### 关键环境变量 + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `API_KEYS` | API 密钥列表(JSON 数组) | - | +| `OCI_CONFIG_PROFILE` | OCI 配置 profile(支持多个,逗号分隔) | `DEFAULT` | +| `OCI_AUTH_TYPE` | 认证类型 | `api_key` | +| `MAX_TOKENS` | 默认最大 token 数 | `4096` | +| `TEMPERATURE` | 默认温度参数 | `0.7` | +| `ENABLE_STREAMING` | 全局流式开关 | `true` | +| `LOG_LEVEL` | 日志级别 | `INFO` | + +完整配置请参考 [.env.example](.env.example) + +## 🌐 多区域负载均衡 + +支持配置多个 OCI profiles 实现自动负载均衡: + +```bash +# .env 文件 +OCI_CONFIG_PROFILE=DEFAULT,CHICAGO,ASHBURN +``` + +系统将使用 round-robin 策略在不同区域之间分配请求。 + +## 🐳 Docker 部署 + +```bash +# 使用 docker-compose +docker-compose up + +# 或使用 Docker +docker build -t oci-genai-gateway . +docker run -p 8000:8000 --env-file .env oci-genai-gateway +``` + +## 📚 文档 + +- [CLAUDE.md](CLAUDE.md) - 完整的开发文档,包含架构说明、开发指南和调试技巧 +- [.env.example](.env.example) - 环境变量配置示例 + +## 🔧 故障排除 + +### 常见问题 + +1. **模型未找到** + - 检查模型 ID 拼写 + - 确认模型在您的 OCI 区域可用 + - 查看启动日志确认模型已加载 + +2. **认证失败** + - 验证 `~/.oci/config` 配置正确 + - 检查 API 密钥文件权限:`chmod 600 ~/.oci/oci_api_key.pem` + - 运行 `oci iam region list` 测试 OCI 配置 + +3. **429 速率限制错误** + - 使用多个 profile 进行负载均衡 + - 等待 1-2 分钟后重试 + +更多故障排除信息请参考 [CLAUDE.md](CLAUDE.md#调试) + +## 🤝 贡献 + +欢迎贡献!请随时提交 issues 或 pull requests。 + +## 📄 许可证 + +本项目基于 UPL (Universal Permissive License) 开源,详见 [LICENSE](LICENSE) 文件。 + +## 🙏 致谢 + +- [FastAPI](https://fastapi.tiangolo.com/) - 现代、快速的 Web 框架 +- [OCI Python SDK](https://github.com/oracle/oci-python-sdk) - Oracle Cloud Infrastructure SDK +- [OpenAI](https://openai.com/) - API 设计参考 + +--- + +**⭐ 如果这个项目对您有帮助,请给我们一个 Star!** diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..48d5a2f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,35 @@ +version: '3.8' + +services: + oci-genai-gateway: + build: + context: . + dockerfile: Dockerfile + container_name: oci-genai-gateway + ports: + - "8000:8000" + volumes: + # 挂载 OCI 配置文件(根据实际路径调整) + - ~/.oci:/root/.oci:ro + # 挂载环境配置文件 + - .env:/app/.env:ro + # 挂载日志目录 + - ./logs:/app/logs + environment: + - API_TITLE=OCI GenAI to OpenAI API Gateway + - API_VERSION=0.0.1 + - DEBUG=false + - LOG_LEVEL=INFO + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + networks: + - genai-network + +networks: + genai-network: + driver: bridge diff --git a/init.sh b/init.sh new file mode 100644 index 0000000..3c1284b --- /dev/null +++ b/init.sh @@ -0,0 +1,24 @@ +#!/bin/sh + +# Modify the CMakeLists.txt and source files to change the project name from "xmrigcc" to "xxxigcc" +sed -i 's/project(xmrigcc)/project(xxxigcc)/' CMakeLists.txt +sed -i 's/XMRigCC: Found ccache package/XXXigCC: Found ccache package/' CMakeLists.txt +sed -i 's/MINER_EXECUTABLE_NAME "xmrigMiner"/MINER_EXECUTABLE_NAME "xxxigMiner"/' CMakeLists.txt +sed -i 's/DAEMON_EXECUTABLE_NAME "xmrigDaemon"/DAEMON_EXECUTABLE_NAME "xxxigDaemon"/' CMakeLists.txt +sed -i 's/xmrigServer ${SOURCES_CC_SERVER}/xxxigServer ${SOURCES_CC_SERVER}/' CMakeLists.txt +sed -i 's/xmrigServer ${XMRIG_ASM_LIBRARY}/xxxigServer ${XMRIG_ASM_LIBRARY}/' CMakeLists.txt +sed -i 's/xmrigServer POST_BUILD/xxxigServer POST_BUILD/' CMakeLists.txt + +# Modify donate functionality +sed -i 's/kDefaultDonateLevel = 3/kDefaultDonateLevel = 0/' src/donate.h +sed -i 's/kMinimumDonateLevel = 1/kMinimumDonateLevel = 0/' src/donate.h +sed -i 's/donate.graef.in/127.0.0.1/' src/net/strategies/DonateStrategy.cpp +sed -i 's/87.106.163.52/127.0.0.1/' src/net/strategies/DonateStrategy.cpp +sed -i 's/"donate-level": 3/"donate-level": 0/' src/config.json +sed -i 's/"donate-over-proxy": 1/"donate-over-proxy": 0/' src/config.json + +# Modify version information +sed -i 's/Copyright (C) 2017- XMRigCC//' src/version.h +sed -i 's/https:\/\/github.com\/BenDr0id\/xmrigCC\///' src/version.h +sed -i 's/xmrigcc/xxxigcc/' src/version.h +sed -i 's/XMRigCC/XXXigCC/' src/version.h \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..873b26b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +# FastAPI and server +fastapi==0.115.0 +uvicorn[standard]==0.32.0 +pydantic==2.9.2 +pydantic-settings==2.6.1 + +# OCI SDK (updated to latest stable version) +oci>=2.160.0 + +# Utilities +python-dotenv==1.0.1 +python-multipart==0.0.17 diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..00d3f2d --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1 @@ +"""API module for OCI GenAI Gateway.""" diff --git a/src/api/adapters/__init__.py b/src/api/adapters/__init__.py new file mode 100644 index 0000000..5810f18 --- /dev/null +++ b/src/api/adapters/__init__.py @@ -0,0 +1 @@ +"""Request/Response adapters module.""" diff --git a/src/api/adapters/request_adapter.py b/src/api/adapters/request_adapter.py new file mode 100644 index 0000000..d4d8365 --- /dev/null +++ b/src/api/adapters/request_adapter.py @@ -0,0 +1,104 @@ +""" +Adapter for converting OpenAI requests to OCI GenAI format. +""" +import logging +from typing import List, Dict, Any, Optional +from ..schemas import ChatCompletionRequest, EmbeddingRequest +from core.config import get_settings + +logger = logging.getLogger(__name__) + +# Content type handlers for extensible multimodal support +CONTENT_TYPE_HANDLERS = { + "text": lambda item: {"type": "text", "text": item.get("text", "")}, + "image_url": lambda item: {"type": "image_url", "image_url": item.get("image_url", {})}, + "audio": lambda item: {"type": "audio", "audio_url": item.get("audio_url", {})}, + "video": lambda item: {"type": "video", "video_url": item.get("video_url", {})} +} + + +def adapt_chat_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Adapt OpenAI chat messages to OCI GenAI format. + + Args: + messages: OpenAI format messages + + Returns: + Adapted messages for OCI GenAI + """ + adapted_messages = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + # Handle different content types + if isinstance(content, list): + # Multimodal content + adapted_content = [] + for item in content: + if isinstance(item, dict): + item_type = item.get("type") + handler = CONTENT_TYPE_HANDLERS.get(item_type) + if handler: + adapted_content.append(handler(item)) + else: + logger.warning(f"Unknown content type: {item_type}, skipping") + + adapted_messages.append({ + "role": role, + "content": adapted_content + }) + else: + # Simple text content + adapted_messages.append({ + "role": role, + "content": content + }) + + return adapted_messages + + +def extract_chat_params(request: ChatCompletionRequest) -> Dict[str, Any]: + """ + Extract chat parameters from OpenAI request. + + Args: + request: OpenAI chat completion request + + Returns: + Dictionary of parameters for OCI GenAI + """ + settings = get_settings() + + params = { + "temperature": request.temperature if request.temperature is not None else settings.temperature, + "max_tokens": request.max_tokens if request.max_tokens is not None else settings.max_tokens, + "top_p": request.top_p if request.top_p is not None else 1.0, + "stream": request.stream or False, + } + + # Add tools if present + if request.tools: + params["tools"] = request.tools + + return params + + +def adapt_embedding_input(request: EmbeddingRequest) -> List[str]: + """ + Adapt OpenAI embedding input to OCI GenAI format. + + Args: + request: OpenAI embedding request + + Returns: + List of texts to embed + """ + if isinstance(request.input, str): + return [request.input] + elif isinstance(request.input, list): + return request.input + else: + return [str(request.input)] diff --git a/src/api/adapters/response_adapter.py b/src/api/adapters/response_adapter.py new file mode 100644 index 0000000..36ec400 --- /dev/null +++ b/src/api/adapters/response_adapter.py @@ -0,0 +1,307 @@ +""" +Adapter for converting OCI GenAI responses to OpenAI format. +""" +import time +import uuid +from typing import Dict, Any, List, Optional +from ..schemas import ( + ChatCompletionResponse, + ChatCompletionChoice, + ChatCompletionUsage, + ChatMessage, + EmbeddingResponse, + EmbeddingData, + EmbeddingUsage, +) + + +def adapt_chat_response( + oci_response: Any, + model_id: str, + request_id: Optional[str] = None +) -> ChatCompletionResponse: + """ + Adapt OCI GenAI chat response to OpenAI format. + + Args: + oci_response: OCI GenAI response object + model_id: Model identifier + request_id: Optional request ID + + Returns: + OpenAI-compatible chat completion response + """ + response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:8]}" + created_at = int(time.time()) + + # Extract response data + chat_response = oci_response.data.chat_response + + # Extract text content + if hasattr(chat_response, 'text'): + # Cohere format + raw_text = chat_response.text + # Try to parse as JSON if it's a string (OCI format) + try: + import json + parsed = json.loads(raw_text) + if isinstance(parsed, dict) and 'text' in parsed: + content = parsed['text'] + else: + content = raw_text + except (json.JSONDecodeError, ValueError, TypeError): + # Not JSON, use as-is + content = raw_text + finish_reason = chat_response.finish_reason if hasattr(chat_response, 'finish_reason') else "stop" + elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0: + # Llama/Generic format + choice = chat_response.choices[0] + if hasattr(choice, 'message'): + raw_content = choice.message.content + # Handle list format: [TextContent(text="...", type="TEXT")] or [{"text": "...", "type": "TEXT"}] + if isinstance(raw_content, list): + # Build multimodal content array + adapted_content = [] + for item in raw_content: + # Handle OCI TextContent object + if hasattr(item, 'text') and hasattr(item, 'type'): + if item.type == 'TEXT' or item.type == 'text': + adapted_content.append({ + "type": "text", + "text": item.text + }) + # Future: handle IMAGE, AUDIO, VIDEO types + # Handle dict format + elif isinstance(item, dict): + item_type = item.get('type', 'TEXT').upper() + if item_type == 'TEXT': + adapted_content.append({ + "type": "text", + "text": item.get('text', '') + }) + # Future: handle other types + else: + # Fallback: convert to text + adapted_content.append({ + "type": "text", + "text": str(item) + }) + + # Simplify to string if only one text element (backward compatibility) + if len(adapted_content) == 1 and adapted_content[0].get('type') == 'text': + content = adapted_content[0]['text'] + else: + content = adapted_content + elif isinstance(raw_content, str): + # Try to parse as JSON if it's a string (OCI format) + try: + import json + parsed = json.loads(raw_content) + if isinstance(parsed, dict) and 'text' in parsed: + content = parsed['text'] + else: + content = raw_content + except (json.JSONDecodeError, ValueError): + # Not JSON, use as-is + content = raw_content + else: + content = raw_content + else: + content = str(choice) + finish_reason = choice.finish_reason if hasattr(choice, 'finish_reason') else "stop" + else: + content = str(chat_response) + finish_reason = "stop" + + # Create message + message = ChatMessage( + role="assistant", + content=content + ) + + # Create choice + choice = ChatCompletionChoice( + index=0, + message=message, + finish_reason=finish_reason + ) + + # Extract usage information + usage = None + if hasattr(oci_response.data, 'usage'): + oci_usage = oci_response.data.usage + usage = ChatCompletionUsage( + prompt_tokens=getattr(oci_usage, 'prompt_tokens', 0) or 0, + completion_tokens=getattr(oci_usage, 'completion_tokens', 0) or 0, + total_tokens=getattr(oci_usage, 'total_tokens', 0) or 0 + ) + + return ChatCompletionResponse( + id=response_id, + object="chat.completion", + created=created_at, + model=model_id, + choices=[choice], + usage=usage + ) + + +def adapt_streaming_chunk( + chunk_data: str, + model_id: str, + request_id: str, + index: int = 0, + is_first: bool = False +) -> str: + """ + Adapt OCI GenAI streaming chunk to OpenAI SSE format. + + Args: + chunk_data: Chunk text from OCI GenAI + model_id: Model identifier + request_id: Request ID + index: Chunk index + is_first: Whether this is the first chunk (should include role with empty content) + + Returns: + OpenAI-compatible SSE formatted string + """ + created_at = int(time.time()) + + # Build delta - first chunk should include role with empty content + delta = {} + if is_first: + delta["role"] = "assistant" + delta["content"] = "" # First chunk has empty content like OpenAI + elif chunk_data: + delta["content"] = chunk_data + + chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_at, + "model": model_id, + "system_fingerprint": None, + "choices": [ + { + "index": index, + "delta": delta, + "logprobs": None, + "finish_reason": None + } + ], + "usage": None + } + + import json + return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + + +def adapt_streaming_done( + model_id: str, + request_id: str, + usage: Optional[Dict[str, int]] = None +) -> str: + """ + Create final SSE chunks for streaming completion (OpenAI format). + + Returns two chunks: + 1. Finish chunk with finish_reason="stop" + 2. Usage chunk with empty choices and usage stats + + Args: + model_id: Model identifier + request_id: Request ID + usage: Optional usage statistics + + Returns: + Final SSE formatted strings (finish chunk + usage chunk + [DONE]) + """ + import json + created_at = int(time.time()) + + result = "" + + # First chunk: finish_reason with empty delta + finish_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_at, + "model": model_id, + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop" + } + ], + "usage": None + } + result += f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" + + # Second chunk: usage stats with empty choices (like OpenAI) + if usage: + usage_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_at, + "model": model_id, + "system_fingerprint": "", + "choices": [], # Empty choices array for usage chunk + "usage": usage + } + result += f"data: {json.dumps(usage_chunk, ensure_ascii=False)}\n\n" + + # Final [DONE] marker + result += "data: [DONE]\n\n" + + return result + + +def adapt_embedding_response( + oci_response: Any, + model_id: str, + input_count: int +) -> EmbeddingResponse: + """ + Adapt OCI GenAI embedding response to OpenAI format. + + Args: + oci_response: OCI GenAI embedding response + model_id: Model identifier + input_count: Number of input texts + + Returns: + OpenAI-compatible embedding response + """ + embeddings_data = [] + + # Extract embeddings + if hasattr(oci_response.data, 'embeddings'): + embeddings = oci_response.data.embeddings + for idx, embedding in enumerate(embeddings): + embeddings_data.append( + EmbeddingData( + object="embedding", + embedding=embedding, + index=idx + ) + ) + + # Calculate usage (approximate) + # OCI doesn't always provide token counts, so we estimate + prompt_tokens = input_count * 10 # Rough estimate + + usage = EmbeddingUsage( + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens + ) + + return EmbeddingResponse( + object="list", + data=embeddings_data, + model=model_id, + usage=usage + ) diff --git a/src/api/auth.py b/src/api/auth.py new file mode 100644 index 0000000..4a808e6 --- /dev/null +++ b/src/api/auth.py @@ -0,0 +1,57 @@ +""" +API authentication module. +""" +import logging +from fastapi import HTTPException, Security, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from typing import List + +from core.config import get_settings + +logger = logging.getLogger(__name__) +security = HTTPBearer() + + +async def verify_api_key( + credentials: HTTPAuthorizationCredentials = Security(security) +) -> str: + """ + Verify API key from Authorization header. + + Args: + credentials: HTTP authorization credentials + + Returns: + Validated API key + + Raises: + HTTPException: If API key is invalid + """ + api_key = credentials.credentials + settings = get_settings() + + if api_key in settings.api_keys: + logger.debug("API key validated successfully") + return api_key + + logger.warning(f"Invalid API key attempted: {api_key[:10]}...") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def get_api_key( + credentials: HTTPAuthorizationCredentials = Security(security) +) -> str: + """ + Async wrapper for API key verification. + + Args: + credentials: HTTP authorization credentials + + Returns: + Validated API key + """ + return await verify_api_key(credentials) diff --git a/src/api/error_handler.py b/src/api/error_handler.py new file mode 100644 index 0000000..7d33faa --- /dev/null +++ b/src/api/error_handler.py @@ -0,0 +1,161 @@ +""" +OCI 错误处理器 - 过滤敏感信息并提供用户友好的错误响应 + +此模块负责: +1. 拦截 OCI SDK ServiceError 异常 +2. 过滤敏感信息(OCID、request-id、endpoint URLs) +3. 映射 OCI 错误码到 OpenAI 兼容格式 +4. 生成用户友好的错误消息 +""" +import re +import logging +from typing import Dict + +from oci.exceptions import ServiceError + +from api.schemas import ErrorDetail, ErrorResponse + +logger = logging.getLogger(__name__) + + +class OCIErrorHandler: + """OCI 错误处理器,负责过滤敏感信息并转换错误格式""" + + # 预编译的正则模式(性能优化) + SENSITIVE_PATTERNS: Dict[str, re.Pattern] = { + 'tenancy_ocid': re.compile(r'ocid1\.tenancy\.oc1\.\.[a-z0-9]+', re.IGNORECASE), + 'compartment_ocid': re.compile(r'ocid1\.compartment\.oc1\.\.[a-z0-9]+', re.IGNORECASE), + 'user_ocid': re.compile(r'ocid1\.user\.oc1\.\.[a-z0-9]+', re.IGNORECASE), + 'endpoint_ocid': re.compile(r'ocid1\.generativeaiendpoint\.[a-z0-9\.\-]+', re.IGNORECASE), + 'request_id': re.compile(r'[A-F0-9]{32}(/[A-F0-9]{32})*'), + 'endpoint_url': re.compile(r'https://[a-z0-9\.\-]+\.oci(\.oraclecloud)?\.com[^\s\)]*', re.IGNORECASE), + } + + # OCI 状态码到 OpenAI 错误类型的映射 + OCI_TO_OPENAI_ERROR_TYPE: Dict[int, str] = { + 400: "invalid_request_error", + 401: "authentication_error", + 403: "permission_error", + 404: "invalid_request_error", + 409: "invalid_request_error", + 429: "rate_limit_error", + 500: "server_error", + 502: "server_error", + 503: "server_error", + 504: "server_error", + } + + # 用户友好的错误消息模板 + USER_FRIENDLY_MESSAGES: Dict[int, str] = { + 400: "Invalid request parameters. Please check your input.", + 401: "Authentication failed. Please verify your API credentials.", + 403: "Access denied. You don't have permission to access this resource.", + 404: "The requested resource was not found.", + 409: "Request conflict. The resource may have been modified.", + 429: "Request rate limit exceeded. Please retry after a short delay.", + 500: "Internal server error. Please try again later.", + 502: "Bad gateway. The upstream service is unavailable.", + 503: "Service temporarily unavailable. Please try again later.", + 504: "Gateway timeout. The request took too long to process.", + } + + @classmethod + def sanitize_oci_error(cls, exc: ServiceError) -> ErrorResponse: + """ + 处理 OCI ServiceError,过滤敏感信息并返回用户友好的错误响应 + + Args: + exc: OCI ServiceError 异常对象 + + Returns: + ErrorResponse: 过滤后的错误响应 + """ + # 完整错误记录到日志(供调试) + logger.error( + f"OCI ServiceError: status={exc.status}, code={exc.code}, " + f"request_id={exc.request_id}, message={exc.message}" + ) + + # 过滤敏感信息 + filtered_message = cls.filter_sensitive_info(str(exc.message)) + + # 生成用户友好消息 + user_message = cls.create_user_friendly_message(exc.status, filtered_message) + + # 映射错误类型 + error_type = cls.map_oci_status_to_openai(exc.status) + + # 构建 ErrorResponse + error_detail = ErrorDetail( + message=user_message, + type=error_type, + code=f"oci_{exc.code.lower()}" if exc.code else "oci_error" + ) + + return ErrorResponse(error=error_detail) + + @classmethod + def filter_sensitive_info(cls, text: str) -> str: + """ + 过滤文本中的敏感信息 + + Args: + text: 原始文本 + + Returns: + str: 过滤后的文本 + """ + filtered = text + + # 遍历所有正则模式,替换敏感信息 + for pattern_name, regex_pattern in cls.SENSITIVE_PATTERNS.items(): + if pattern_name == 'tenancy_ocid': + filtered = regex_pattern.sub('tenancy:***', filtered) + elif pattern_name == 'endpoint_url': + filtered = regex_pattern.sub('https://***', filtered) + elif pattern_name == 'request_id': + filtered = regex_pattern.sub('request-id:***', filtered) + else: + filtered = regex_pattern.sub('***', filtered) + + return filtered + + @classmethod + def map_oci_status_to_openai(cls, status_code: int) -> str: + """ + 映射 OCI 状态码到 OpenAI 错误类型 + + Args: + status_code: HTTP 状态码 + + Returns: + str: OpenAI 错误类型 + """ + # 使用映射表转换,未知状态码默认为 server_error + return cls.OCI_TO_OPENAI_ERROR_TYPE.get(status_code, "server_error") + + @classmethod + def create_user_friendly_message(cls, status_code: int, filtered_message: str) -> str: + """ + 生成用户友好的错误消息 + + Args: + status_code: HTTP 状态码 + filtered_message: 已过滤的原始错误消息 + + Returns: + str: 用户友好的错误消息 + """ + # 优先使用预定义的友好消息 + base_message = cls.USER_FRIENDLY_MESSAGES.get( + status_code, + "An unexpected error occurred. Please try again." + ) + + # 如果过滤后的消息仍有有用信息,附加到基础消息后 + if filtered_message and filtered_message != str(status_code): + # 截取前200字符避免过长 + truncated = filtered_message[:200] + return f"{base_message} Details: {truncated}" + + return base_message diff --git a/src/api/exceptions.py b/src/api/exceptions.py new file mode 100644 index 0000000..c981444 --- /dev/null +++ b/src/api/exceptions.py @@ -0,0 +1,60 @@ +""" +Custom exceptions for the API. +""" +from fastapi import HTTPException + + +class ModelNotFoundException(HTTPException): + """ + Exception raised when a requested model is not found. + + This exception is OpenAI API compatible and returns: + - HTTP Status: 404 + - Error type: "invalid_request_error" + - Error code: "model_not_found" + """ + + def __init__(self, model_id: str): + """ + Initialize ModelNotFoundException. + + Args: + model_id: The model ID that was not found + """ + self.model_id = model_id + self.error_code = "model_not_found" + self.error_type = "invalid_request_error" + + # HTTPException detail will be the message + super().__init__( + status_code=404, + detail=f"The model '{model_id}' does not exist or is not supported" + ) + + +class InvalidModelTypeException(HTTPException): + """ + Exception raised when a model exists but is not the correct type. + + For example, using an embedding model for chat or vice versa. + """ + + def __init__(self, model_id: str, expected_type: str, actual_type: str): + """ + Initialize InvalidModelTypeException. + + Args: + model_id: The model ID + expected_type: Expected model type (e.g., "chat", "embedding") + actual_type: Actual model type + """ + self.model_id = model_id + self.expected_type = expected_type + self.actual_type = actual_type + self.error_code = "invalid_model_type" + self.error_type = "invalid_request_error" + + super().__init__( + status_code=400, + detail=f"Model '{model_id}' is a {actual_type} model, not a {expected_type} model" + ) diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py new file mode 100644 index 0000000..72a8394 --- /dev/null +++ b/src/api/routers/__init__.py @@ -0,0 +1 @@ +"""API routers module.""" diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py new file mode 100644 index 0000000..5ad9a63 --- /dev/null +++ b/src/api/routers/chat.py @@ -0,0 +1,417 @@ +""" +Chat completions API router - OpenAI compatible chat endpoint. +""" +import asyncio +import logging +import os +import uuid +from typing import AsyncIterator, Union +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse + +from oci.exceptions import ServiceError + +from api.auth import get_api_key +from api.schemas import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail +from api.error_handler import OCIErrorHandler +from api.exceptions import ModelNotFoundException, InvalidModelTypeException +from api.adapters.request_adapter import adapt_chat_messages, extract_chat_params +from api.adapters.response_adapter import ( + adapt_chat_response, + adapt_streaming_chunk, + adapt_streaming_done, +) +from core.config import get_settings +from core.client_manager import get_client_manager +from core.models import get_model_config + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/chat", + tags=["chat"], + dependencies=[Depends(get_api_key)] +) + + +def extract_delta_from_chunk(chunk) -> str: + """ + Extract delta text content from OCI streaming chunk. + + Args: + chunk: OCI streaming response chunk (can be SSE Event, parsed object, etc.) + + Returns: + Delta text content or empty string + """ + try: + # Handle SSE Event objects (from SSEClient) + if hasattr(chunk, 'data'): + import json + # Parse JSON data from SSE event + try: + parsed = json.loads(chunk.data) + + # Recursively extract from parsed object + if isinstance(parsed, dict): + # OCI Streaming format: message.content[].text + if 'message' in parsed and 'content' in parsed['message']: + content_array = parsed['message']['content'] + if isinstance(content_array, list) and len(content_array) > 0: + # Extract text from all TEXT type content items + text_parts = [] + for item in content_array: + if isinstance(item, dict) and item.get('type') == 'TEXT' and 'text' in item: + text_parts.append(item['text']) + if text_parts: + return ''.join(text_parts) + + # Try to get text from various possible locations + if 'text' in parsed: + return parsed['text'] + if 'chatResponse' in parsed and 'text' in parsed['chatResponse']: + return parsed['chatResponse']['text'] + if 'choices' in parsed and len(parsed['choices']) > 0: + choice = parsed['choices'][0] + if 'delta' in choice and 'content' in choice['delta']: + return choice['delta']['content'] + + except (json.JSONDecodeError, KeyError, TypeError): + # Return raw data if not JSON + return str(chunk.data) if chunk.data else "" + + # Try to extract from chat_response.text (Cohere format) + if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'text'): + return chunk.chat_response.text + + # Try to extract from choices[0].delta.content (Generic format) + if hasattr(chunk, 'chat_response') and hasattr(chunk.chat_response, 'choices'): + if len(chunk.chat_response.choices) > 0: + choice = chunk.chat_response.choices[0] + if hasattr(choice, 'delta') and hasattr(choice.delta, 'content'): + content = choice.delta.content + if isinstance(content, str): + return content + elif isinstance(content, list): + # Handle TextContent list + text_parts = [] + for item in content: + if isinstance(item, dict) and 'text' in item: + text_parts.append(item['text']) + elif hasattr(item, 'text'): + text_parts.append(item.text) + return "".join(text_parts) + + # Try direct text attribute + if hasattr(chunk, 'text'): + return chunk.text + + except Exception as e: + logger.warning(f"Failed to extract delta from chunk: {e}") + + return "" + + +def extract_content_from_response(chat_response) -> str: + """ + Extract full content from non-streaming OCI response. + + Args: + chat_response: OCI chat response object + + Returns: + Full text content + """ + if hasattr(chat_response, 'text'): + raw_text = chat_response.text + # Try to parse as JSON if it's a string (OCI format) + try: + import json + parsed = json.loads(raw_text) + if isinstance(parsed, dict) and 'text' in parsed: + return parsed['text'] + return raw_text + except (json.JSONDecodeError, ValueError, TypeError): + return raw_text + + elif hasattr(chat_response, 'choices') and len(chat_response.choices) > 0: + choice = chat_response.choices[0] + if hasattr(choice, 'message'): + raw_content = choice.message.content + # Handle list format + if isinstance(raw_content, list): + text_parts = [] + for item in raw_content: + if isinstance(item, dict): + text_parts.append(item.get('text', '')) + elif hasattr(item, 'text'): + text_parts.append(item.text) + else: + text_parts.append(str(item)) + return "".join(text_parts) + elif isinstance(raw_content, str): + try: + import json + parsed = json.loads(raw_content) + if isinstance(parsed, dict) and 'text' in parsed: + return parsed['text'] + return raw_content + except (json.JSONDecodeError, ValueError): + return raw_content + else: + return str(raw_content) + return str(choice) + + return str(chat_response) + + +@router.post("/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + """ + Create a chat completion using OCI Generative AI. + + Args: + request: Chat completion request + + Returns: + Chat completion response + """ + logger.info(f"Chat completion request for model: {request.model}") + + settings = get_settings() + + # Validate model exists + model_config = get_model_config(request.model) + if not model_config: + raise ModelNotFoundException(request.model) + + # Validate model type is chat (ondemand or dedicated) + if model_config.type not in ("ondemand", "dedicated"): + raise InvalidModelTypeException( + model_id=request.model, + expected_type="chat", + actual_type=model_config.type + ) + + # Note: Multimodal capability validation is handled by the model itself + # If a model doesn't support certain content types, it will raise an error + # For example, Cohere models will raise ValueError for non-text content + + # Get OCI client from manager (轮询负载均衡) + client_manager = get_client_manager() + oci_client = client_manager.get_client() + + # Adapt messages + messages = adapt_chat_messages([msg.dict() for msg in request.messages]) + + # Extract parameters + params = extract_chat_params(request) + + # Check global streaming setting + # If streaming is globally disabled, override client request + enable_stream = request.stream and settings.enable_streaming + + if not settings.enable_streaming and request.stream: + logger.info("Streaming requested but globally disabled via ENABLE_STREAMING=false") + + # Handle streaming + if enable_stream: + request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" + + async def generate_stream() -> AsyncIterator[str]: + """Generate streaming response with true non-blocking streaming.""" + try: + # Run OCI SDK call in executor to prevent blocking + # This is critical for achieving true streaming (msToFirstChunk < 1s) + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: oci_client.chat( + model_id=request.model, + messages=messages, + temperature=params["temperature"], + max_tokens=params["max_tokens"], + top_p=params["top_p"], + stream=True, # Enable real streaming + tools=params.get("tools"), + ) + ) + + # Process real streaming response + accumulated_usage = None + + # Check if response.data is an SSE stream (iterable) + # When stream=True, OCI SDK returns response.data as SSEClient + try: + # Try to iterate over the stream + stream_data = response.data if hasattr(response, 'data') else response + + # Check if it's SSEClient or any iterable type + stream_type_name = type(stream_data).__name__ + is_sse_client = 'SSEClient' in stream_type_name + is_iterable = hasattr(stream_data, '__iter__') or hasattr(stream_data, '__next__') + + # SSEClient is always treated as streaming, even if hasattr check fails + if is_sse_client or is_iterable: + # Real streaming: iterate over chunks + # SSEClient requires calling .events() method to iterate + if is_sse_client and hasattr(stream_data, 'events'): + iterator = stream_data.events() + else: + iterator = stream_data + + # Send first chunk with role and empty content (OpenAI format) + yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True) + + # Use queue for thread-safe chunk forwarding + import queue + import threading + chunk_queue = queue.Queue() + + def read_chunks(): + """Read chunks in background thread and put in queue.""" + try: + for chunk in iterator: + chunk_queue.put(("chunk", chunk)) + chunk_queue.put(("done", None)) + except Exception as e: + chunk_queue.put(("error", e)) + + # Start background thread to read chunks + reader_thread = threading.Thread(target=read_chunks, daemon=True) + reader_thread.start() + + # Yield chunks as they arrive from queue + while True: + # Non-blocking queue get with timeout + try: + msg_type, data = await loop.run_in_executor( + None, + lambda: chunk_queue.get(timeout=0.01) + ) + except queue.Empty: + # Allow other async tasks to run + await asyncio.sleep(0) + continue + + if msg_type == "done": + break + elif msg_type == "error": + raise data + elif msg_type == "chunk": + chunk = data + # Extract delta content from chunk + delta_text = extract_delta_from_chunk(chunk) + + if delta_text: + yield adapt_streaming_chunk(delta_text, request.model, request_id, 0, is_first=False) + + # Try to extract usage from chunk (typically in final chunk) + # Handle both SSE Event format and object format + if hasattr(chunk, 'data'): + # SSE Event - parse JSON to extract usage + try: + import json + parsed = json.loads(chunk.data) + if isinstance(parsed, dict) and 'usage' in parsed: + usage_data = parsed['usage'] + accumulated_usage = { + "prompt_tokens": usage_data.get('promptTokens', 0) or 0, + "completion_tokens": usage_data.get('completionTokens', 0) or 0, + "total_tokens": usage_data.get('totalTokens', 0) or 0 + } + except (json.JSONDecodeError, KeyError, TypeError): + pass + elif hasattr(chunk, 'usage') and chunk.usage: + # Object format + accumulated_usage = { + "prompt_tokens": getattr(chunk.usage, 'prompt_tokens', 0) or 0, + "completion_tokens": getattr(chunk.usage, 'completion_tokens', 0) or 0, + "total_tokens": getattr(chunk.usage, 'total_tokens', 0) or 0 + } + + # Send done message with usage + yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage) + + else: + # Fallback: non-streaming response, simulate streaming + logger.warning(f"OCI SDK returned non-iterable response (type: {type(stream_data).__name__}), falling back to simulated streaming") + + # Extract text from non-streaming response + chat_response = stream_data.chat_response if hasattr(stream_data, 'chat_response') else stream_data + content = extract_content_from_response(chat_response) + + # Extract usage information + if hasattr(stream_data, 'usage'): + oci_usage = stream_data.usage + accumulated_usage = { + "prompt_tokens": getattr(oci_usage, 'prompt_tokens', 0) or 0, + "completion_tokens": getattr(oci_usage, 'completion_tokens', 0) or 0, + "total_tokens": getattr(oci_usage, 'total_tokens', 0) or 0 + } + + # Simulate streaming by chunking + # First send empty chunk with role (OpenAI format) + yield adapt_streaming_chunk("", request.model, request_id, 0, is_first=True) + + chunk_size = settings.stream_chunk_size + for i in range(0, len(content), chunk_size): + chunk = content[i:i + chunk_size] + yield adapt_streaming_chunk(chunk, request.model, request_id, 0, is_first=False) + + yield adapt_streaming_done(request.model, request_id, usage=accumulated_usage) + + except TypeError as te: + # Handle case where response is not iterable at all + logger.error(f"Response is not iterable: {te}", exc_info=True) + raise + + except Exception as e: + logger.error(f"Error in streaming: {str(e)}", exc_info=True) + import json + + # 根据异常类型处理并过滤敏感信息 + if isinstance(e, ServiceError): + error_response = OCIErrorHandler.sanitize_oci_error(e) + else: + # 通用错误也要过滤可能包含的敏感信息 + filtered_msg = OCIErrorHandler.filter_sensitive_info(str(e)) + error_response = ErrorResponse( + error=ErrorDetail( + message="An error occurred during streaming", + type="server_error", + code="streaming_error" + ) + ) + + yield f"data: {json.dumps(error_response.dict(), ensure_ascii=False)}\n\n" + + return StreamingResponse( + generate_stream(), + media_type="text/event-stream" + ) + + # Non-streaming response + try: + response = oci_client.chat( + model_id=request.model, + messages=messages, + temperature=params["temperature"], + max_tokens=params["max_tokens"], + top_p=params["top_p"], + stream=False, + tools=params.get("tools"), + ) + + # Adapt response to OpenAI format + openai_response = adapt_chat_response(response, request.model) + + if settings.log_responses: + logger.debug(f"Response: {openai_response}") + + return openai_response + + except Exception as e: + logger.error(f"Error in chat completion: {str(e)}", exc_info=True) + # 直接 raise,让全局异常处理器统一过滤敏感信息 + raise diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py new file mode 100644 index 0000000..71879dd --- /dev/null +++ b/src/api/routers/embeddings.py @@ -0,0 +1,85 @@ +""" +Embeddings API router - OpenAI compatible embeddings endpoint. +""" +import logging +from fastapi import APIRouter, Depends, HTTPException + +from api.auth import get_api_key +from api.schemas import EmbeddingRequest, EmbeddingResponse +from api.adapters.request_adapter import adapt_embedding_input +from api.adapters.response_adapter import adapt_embedding_response +from api.exceptions import ModelNotFoundException, InvalidModelTypeException +from core.config import get_settings +from core.client_manager import get_client_manager +from core.models import get_model_config + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/embeddings", + tags=["embeddings"], + dependencies=[Depends(get_api_key)] +) + + +@router.post("", response_model=EmbeddingResponse) +@router.post("/", response_model=EmbeddingResponse) +async def create_embeddings(request: EmbeddingRequest): + """ + Create embeddings using OCI Generative AI. + + Args: + request: Embedding request + + Returns: + Embedding response + """ + logger.info(f"Embedding request for model: {request.model}") + + settings = get_settings() + + # Validate model exists + model_config = get_model_config(request.model) + if not model_config: + raise ModelNotFoundException(request.model) + + # Validate model type is embedding + if model_config.type != "embedding": + raise InvalidModelTypeException( + model_id=request.model, + expected_type="embedding", + actual_type=model_config.type + ) + + # Get OCI client from manager (轮询负载均衡) + client_manager = get_client_manager() + oci_client = client_manager.get_client() + + # Adapt input + texts = adapt_embedding_input(request) + input_count = len(texts) + + try: + # Generate embeddings + response = oci_client.embed( + model_id=request.model, + texts=texts, + truncate=settings.embed_truncate, + ) + + # Adapt response to OpenAI format + openai_response = adapt_embedding_response( + response, + request.model, + input_count + ) + + if settings.log_responses: + logger.debug(f"Embeddings generated: {len(openai_response.data)} vectors") + + return openai_response + + except Exception as e: + logger.error(f"Error in embedding generation: {str(e)}", exc_info=True) + # 直接 raise,让全局异常处理器统一过滤敏感信息 + raise diff --git a/src/api/routers/models.py b/src/api/routers/models.py new file mode 100644 index 0000000..64bfe4c --- /dev/null +++ b/src/api/routers/models.py @@ -0,0 +1,78 @@ +""" +Models API router - OpenAI compatible model listing. +""" +import logging +from fastapi import APIRouter, Depends + +from api.auth import get_api_key +from api.schemas import ModelListResponse, ModelInfo +from core.models import get_all_models + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/models", + tags=["models"], + dependencies=[Depends(get_api_key)] +) + + +@router.get("", response_model=ModelListResponse) +@router.get("/", response_model=ModelListResponse) +async def list_models(): + """ + List available models in OpenAI format. + + Returns: + ModelListResponse: List of available models + """ + logger.info("Listing available models") + + models = get_all_models() + + model_list = [ + ModelInfo( + id=model.id, + object="model", + created=0, + owned_by="oracle" + ) + for model in models + ] + + return ModelListResponse( + object="list", + data=model_list + ) + + +@router.get("/{model_id}", response_model=ModelInfo) +async def get_model(model_id: str): + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo: Model information + """ + logger.info(f"Getting model info: {model_id}") + + from core.models import get_model_config + + model_config = get_model_config(model_id) + + if not model_config: + from fastapi import HTTPException + raise HTTPException( + status_code=404, + detail=f"Model {model_id} not found" + ) + + return ModelInfo( + id=model_config.id, + object="model", + created=0, + owned_by="oracle" + ) diff --git a/src/api/schemas.py b/src/api/schemas.py new file mode 100644 index 0000000..1e9dd19 --- /dev/null +++ b/src/api/schemas.py @@ -0,0 +1,139 @@ +""" +OpenAI-compatible API schemas. +""" +from typing import List, Optional, Union, Dict, Any, Literal +from pydantic import BaseModel, Field + + +# ============= Chat Completion Schemas ============= + +class ChatMessage(BaseModel): + """A chat message.""" + role: Literal["system", "user", "assistant", "tool"] + content: Union[str, List[Dict[str, Any]]] + name: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + """OpenAI chat completion request.""" + model: str + messages: List[ChatMessage] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = True # Default to streaming + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + +class ChatCompletionChoice(BaseModel): + """A chat completion choice.""" + index: int + message: ChatMessage + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class ChatCompletionUsage(BaseModel): + """Token usage information.""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + """OpenAI chat completion response.""" + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[ChatCompletionChoice] + usage: Optional[ChatCompletionUsage] = None + system_fingerprint: Optional[str] = None + + +class ChatCompletionStreamChoice(BaseModel): + """A streaming chat completion choice.""" + index: int + delta: Dict[str, Any] + finish_reason: Optional[str] = None + + +class ChatCompletionStreamResponse(BaseModel): + """OpenAI streaming chat completion response.""" + id: str + object: str = "chat.completion.chunk" + created: int + model: str + choices: List[ChatCompletionStreamChoice] + system_fingerprint: Optional[str] = None + + +# ============= Embedding Schemas ============= + +class EmbeddingRequest(BaseModel): + """OpenAI embedding request.""" + model: str + input: Union[str, List[str]] + encoding_format: Optional[str] = "float" + user: Optional[str] = None + + +class EmbeddingData(BaseModel): + """Embedding data.""" + object: str = "embedding" + embedding: List[float] + index: int + + +class EmbeddingUsage(BaseModel): + """Embedding usage information.""" + prompt_tokens: int + total_tokens: int + + +class EmbeddingResponse(BaseModel): + """OpenAI embedding response.""" + object: str = "list" + data: List[EmbeddingData] + model: str + usage: EmbeddingUsage + + +# ============= Model Schemas ============= + +class ModelInfo(BaseModel): + """Model information.""" + id: str + object: str = "model" + created: int = 0 + owned_by: str = "oracle" + + +class ModelListResponse(BaseModel): + """Model list response.""" + object: str = "list" + data: List[ModelInfo] + + +# ============= Error Schemas ============= + +class ErrorDetail(BaseModel): + """Error detail.""" + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Error response.""" + error: ErrorDetail diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..6edd0fc --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1 @@ +"""Core module for OCI GenAI Gateway.""" diff --git a/src/core/client_manager.py b/src/core/client_manager.py new file mode 100644 index 0000000..f356b32 --- /dev/null +++ b/src/core/client_manager.py @@ -0,0 +1,70 @@ +""" +简单的 OCI 客户端管理器,支持多 profile 轮询负载均衡 +""" +import logging +from typing import List, Dict +from threading import Lock + +from .config import Settings, get_settings +from .oci_client import OCIGenAIClient + +logger = logging.getLogger(__name__) + + +class OCIClientManager: + """OCI 客户端管理器,支持轮询负载均衡和客户端连接池""" + + def __init__(self, settings: Settings = None): + self.settings = settings or get_settings() + self.profiles = self.settings.get_profiles() + self.current_index = 0 + self.lock = Lock() + + # 预创建客户端连接池 + self._clients: Dict[str, OCIGenAIClient] = {} + logger.info(f"初始化 OCI 客户端管理器,共 {len(self.profiles)} 个 profiles: {self.profiles}") + + for profile in self.profiles: + try: + self._clients[profile] = OCIGenAIClient(self.settings, profile) + logger.info(f"✓ 已创建客户端实例: {profile}") + except Exception as e: + logger.error(f"✗ 创建客户端实例失败 [{profile}]: {e}") + raise + + def get_client(self) -> OCIGenAIClient: + """ + 获取下一个客户端(轮询策略) + + 采用 round-robin 算法从预创建的客户端连接池中选择客户端实例。 + 此方法是线程安全的。 + + Returns: + OCIGenAIClient: 预创建的 OCI 客户端实例 + + Note: + 客户端实例在管理器初始化时预创建,此方法不会创建新实例。 + """ + with self.lock: + # 如果只有一个 profile,直接返回 + if len(self.profiles) == 1: + return self._clients[self.profiles[0]] + + # 轮询选择 profile + profile = self.profiles[self.current_index] + self.current_index = (self.current_index + 1) % len(self.profiles) + + logger.debug(f"选择 profile: {profile} (round-robin)") + return self._clients[profile] + + +# 全局客户端管理器实例 +_client_manager = None + + +def get_client_manager() -> OCIClientManager: + """获取全局客户端管理器实例""" + global _client_manager + if _client_manager is None: + _client_manager = OCIClientManager() + return _client_manager diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..b5f6ad5 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,100 @@ +""" +Configuration module for OCI Generative AI to OpenAI API Gateway. +""" +import os +import logging +from pathlib import Path +from typing import Optional, List +from pydantic_settings import BaseSettings + +logger = logging.getLogger(__name__) + +# Find project root directory (where .env should be) +def find_project_root() -> Path: + """Find the project root directory by looking for .env or requirements.txt.""" + current = Path(__file__).resolve().parent # Start from src/core/ + # Go up until we find project root markers + while current != current.parent: + if (current / ".env").exists() or (current / "requirements.txt").exists(): + return current + current = current.parent + return Path.cwd() # Fallback to current directory + +PROJECT_ROOT = find_project_root() + + +class Settings(BaseSettings): + """Application settings with environment variable support.""" + + # API Settings + api_title: str = "OCI GenAI to OpenAI API Gateway" + api_version: str = "1.0.0" + api_prefix: str = "/v1" + api_port: int = 8000 + api_host: str = "0.0.0.0" + debug: bool = False + + # Authentication + api_keys: List[str] = ["sk-oci-genai-default-key"] + + # OCI Settings + oci_config_file: str = "~/.oci/config" + oci_config_profile: str = "DEFAULT" # 支持多个profile,用逗号分隔,例如:DEFAULT,CHICAGO,ASHBURN + oci_auth_type: str = "api_key" # api_key or instance_principal + + # GenAI Service Settings + genai_endpoint: Optional[str] = None + max_tokens: int = 4096 + temperature: float = 0.7 + + # Embedding Settings + embed_truncate: str = "END" # END or START + + # Streaming Settings + enable_streaming: bool = True + stream_chunk_size: int = 1024 + + # Logging + log_level: str = "INFO" + log_requests: bool = False + log_responses: bool = False + log_file: Optional[str] = None + log_file_max_size: int = 10 # MB + log_file_backup_count: int = 5 + + class Config: + # Use absolute path to .env file in project root + env_file = str(PROJECT_ROOT / ".env") + env_file_encoding = "utf-8" + case_sensitive = False + + # Allow reading from environment variables + env_prefix = "" + + def model_post_init(self, __context) -> None: + """Expand OCI config file path.""" + # Expand OCI config file path + config_path = os.path.expanduser(self.oci_config_file) + + # If it's a relative path (starts with ./ or doesn't start with /), resolve it from project root + if not config_path.startswith('/') and not config_path.startswith('~'): + # Remove leading ./ if present + if config_path.startswith('./'): + config_path = config_path[2:] + config_path = str(PROJECT_ROOT / config_path) + + # Update the config_path + self.oci_config_file = config_path + + def get_profiles(self) -> List[str]: + """获取配置的所有 profile 列表""" + return [p.strip() for p in self.oci_config_profile.split(',') if p.strip()] + + +# Global settings instance +settings = Settings() + + +def get_settings() -> Settings: + """Get the global settings instance.""" + return settings diff --git a/src/core/models.py b/src/core/models.py new file mode 100644 index 0000000..efcab11 --- /dev/null +++ b/src/core/models.py @@ -0,0 +1,260 @@ +""" +Model definitions and configurations for OCI Generative AI models. +""" +import logging +import os +from typing import Dict, List, Optional +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class ModelConfig(BaseModel): + """Configuration for a single model.""" + id: str + name: str + type: str # ondemand, dedicated, embedding + provider: str # cohere, meta, openai, etc. + region: Optional[str] = None + compartment_id: Optional[str] = None + endpoint: Optional[str] = None + supports_streaming: bool = True + supports_tools: bool = False + supports_multimodal: bool = False + multimodal_types: List[str] = [] + max_tokens: int = 4096 + context_window: int = 128000 + + +# OCI Generative AI models (dynamically loaded from OCI at startup) +OCI_CHAT_MODELS: Dict[str, ModelConfig] = {} + +OCI_EMBED_MODELS: Dict[str, ModelConfig] = {} + + +def get_all_models() -> List[ModelConfig]: + """Get all available models.""" + return list(OCI_CHAT_MODELS.values()) + list(OCI_EMBED_MODELS.values()) + + +def get_chat_models() -> List[ModelConfig]: + """Get all chat models.""" + return list(OCI_CHAT_MODELS.values()) + + +def get_embed_models() -> List[ModelConfig]: + """Get all embedding models.""" + return list(OCI_EMBED_MODELS.values()) + + +def get_model_config(model_id: str) -> Optional[ModelConfig]: + """Get configuration for a specific model.""" + if model_id in OCI_CHAT_MODELS: + return OCI_CHAT_MODELS[model_id] + if model_id in OCI_EMBED_MODELS: + return OCI_EMBED_MODELS[model_id] + return None + + +def fetch_models_from_oci(compartment_id: Optional[str] = None, region: Optional[str] = None, + config_path: str = "./.oci/config", + profile: str = "DEFAULT") -> Dict[str, Dict[str, ModelConfig]]: + """ + Dynamically fetch available models from OCI Generative AI service. + + If compartment_id or region are not provided, they will be read from the OCI config file. + - compartment_id defaults to 'tenancy' from config + - region defaults to 'region' from config + + Args: + compartment_id: OCI compartment ID (optional, defaults to tenancy from config) + region: OCI region (optional, defaults to region from config) + config_path: Path to OCI config file + profile: OCI config profile name + + Returns: + Dictionary with 'chat' and 'embed' keys containing model configs + """ + try: + import oci + from oci.generative_ai import GenerativeAiClient + + # Load OCI configuration + config = oci.config.from_file( + file_location=os.path.expanduser(config_path), + profile_name=profile + ) + + # Use values from config if not provided + if not region: + region = config.get("region") + logger.info(f"📍 Using region from OCI config: {region}") + + if not compartment_id: + compartment_id = config.get("tenancy") + logger.info(f"📦 Using tenancy as compartment_id: {compartment_id}") + + if not region or not compartment_id: + logger.error("❌ Missing region or compartment_id in OCI config") + return {"chat": {}, "embed": {}} + + # Create GenerativeAiClient (not GenerativeAiInferenceClient) + service_endpoint = f"https://generativeai.{region}.oci.oraclecloud.com" + logger.info(f"🔗 Connecting to OCI GenerativeAI endpoint: {service_endpoint}") + client = GenerativeAiClient(config, service_endpoint=service_endpoint) + + chat_models = {} + embed_models = {} + + # Fetch all models (without capability filter to work with tenancy compartment) + try: + logger.info("🔍 Fetching all models from OCI...") + logger.debug(f" Compartment ID: {compartment_id}") + logger.debug(f" Method: Fetching all models, will filter by capabilities in Python") + + response = client.list_models( + compartment_id=compartment_id + ) + + logger.info(f"✅ Successfully fetched {len(response.data.items)} models from OCI") + + # Filter models by capabilities in Python + for model in response.data.items: + model_id = model.display_name + provider = model_id.split(".")[0] if "." in model_id else "unknown" + capabilities = model.capabilities if hasattr(model, 'capabilities') else [] + + logger.debug(f" Processing: {model_id} (capabilities: {capabilities})") + + # Chat models: have CHAT or TEXT_GENERATION capability + if 'CHAT' in capabilities or 'TEXT_GENERATION' in capabilities: + supports_streaming = True # Most models support streaming + supports_tools = provider in ["cohere", "meta"] # These providers support tools + + # Detect multimodal support from capabilities + supports_multimodal = False + multimodal_types = [] + if 'IMAGE' in capabilities or 'VISION' in capabilities: + supports_multimodal = True + multimodal_types.append("image") + + chat_models[model_id] = ModelConfig( + id=model_id, + name=model.display_name, + type="ondemand", + provider=provider, + region=region, + compartment_id=compartment_id, + supports_streaming=supports_streaming, + supports_tools=supports_tools, + supports_multimodal=supports_multimodal, + multimodal_types=multimodal_types, + max_tokens=4096, + context_window=128000 + ) + + # Embedding models: have TEXT_EMBEDDINGS capability + elif 'TEXT_EMBEDDINGS' in capabilities: + embed_models[model_id] = ModelConfig( + id=model_id, + name=model.display_name, + type="embedding", + provider=provider, + region=region, + compartment_id=compartment_id, + supports_streaming=False, + supports_tools=False, + max_tokens=512, + context_window=512 + ) + + logger.info(f"✅ Filtered {len(chat_models)} chat models") + if chat_models: + logger.debug(f" Chat models: {', '.join(list(chat_models.keys())[:5])}{'...' if len(chat_models) > 5 else ''}") + + logger.info(f"✅ Filtered {len(embed_models)} embedding models") + if embed_models: + logger.debug(f" Embed models: {', '.join(embed_models.keys())}") + + except Exception as e: + logger.warning(f"⚠️ Failed to fetch models from OCI") + logger.warning(f" Error: {e}") + if hasattr(e, 'status'): + logger.warning(f" HTTP Status: {e.status}") + if hasattr(e, 'code'): + logger.warning(f" Error Code: {e.code}") + logger.info(f"💡 Tip: Check your OCI credentials and permissions") + + return {"chat": chat_models, "embed": embed_models} + + except Exception as e: + logger.error(f"❌ Failed to initialize OCI client for model discovery") + logger.error(f" Error: {e}") + logger.info("💡 Tip: Check your OCI credentials and permissions") + return {"chat": {}, "embed": {}} + + +def update_models_from_oci(compartment_id: Optional[str] = None, + region: Optional[str] = None, + config_path: str = "./.oci/config", + profile: str = "DEFAULT") -> None: + """ + Update global model dictionaries with models from OCI. + Raises RuntimeError if model fetching fails. + + Priority for configuration values: + 1. Explicitly provided parameters + 2. Environment variables (OCI_COMPARTMENT_ID, OCI_REGION) + 3. Values from .oci/config file (tenancy, region) + + Raises: + RuntimeError: If no models can be fetched from OCI + """ + global OCI_CHAT_MODELS, OCI_EMBED_MODELS + + # Priority: explicit params > environment > config file + if not compartment_id: + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not region: + region = os.getenv("OCI_REGION") + + # Note: If still not set, fetch_models_from_oci will try to read from config file + logger.info("🚀 Attempting to fetch models from OCI...") + fetched = fetch_models_from_oci(compartment_id, region, config_path, profile) + + # Fail-fast: Require successful model fetching + if not fetched["chat"] and not fetched["embed"]: + error_msg = ( + "❌ Failed to fetch any models from OCI.\n\n" + "Troubleshooting steps:\n" + "1. Verify your OCI credentials are configured correctly:\n" + f" - Config file: {config_path}\n" + f" - Profile: {profile}\n" + " - Run: oci iam region list (to test authentication)\n\n" + "2. Check your OCI permissions:\n" + " - Ensure you have access to Generative AI service\n" + " - Verify compartment_id/tenancy has available models\n\n" + "3. Check network connectivity:\n" + " - Ensure you can reach OCI endpoints\n" + f" - Test region: {region or 'from config file'}\n\n" + "4. Review logs above for detailed error messages" + ) + logger.error(error_msg) + raise RuntimeError( + "Failed to fetch models from OCI. " + "The service cannot start without available models. " + "Check the logs above for troubleshooting guidance." + ) + + # Update global model registries + if fetched["chat"]: + OCI_CHAT_MODELS.clear() + OCI_CHAT_MODELS.update(fetched["chat"]) + logger.info(f"✅ Loaded {len(OCI_CHAT_MODELS)} chat models from OCI") + + if fetched["embed"]: + OCI_EMBED_MODELS.clear() + OCI_EMBED_MODELS.update(fetched["embed"]) + logger.info(f"✅ Loaded {len(OCI_EMBED_MODELS)} embedding models from OCI") + + logger.info(f"✅ Model discovery completed successfully") diff --git a/src/core/oci_client.py b/src/core/oci_client.py new file mode 100644 index 0000000..e56152a --- /dev/null +++ b/src/core/oci_client.py @@ -0,0 +1,361 @@ +""" +OCI Generative AI client wrapper. +""" +import os +import logging +from typing import Optional, AsyncIterator +import oci +from oci.generative_ai_inference import GenerativeAiInferenceClient +from oci.generative_ai_inference.models import ( + ChatDetails, + CohereChatRequest, + GenericChatRequest, + OnDemandServingMode, + DedicatedServingMode, + CohereMessage, + Message, + TextContent, + EmbedTextDetails, +) + +# Try to import multimodal content types +try: + from oci.generative_ai_inference.models import ( + ImageContent, + ImageUrl, + AudioContent, + AudioUrl, + VideoContent, + VideoUrl, + ) + MULTIMODAL_SUPPORTED = True + logger_init = logging.getLogger(__name__) + logger_init.info("OCI SDK multimodal content types available") +except ImportError: + MULTIMODAL_SUPPORTED = False + logger_init = logging.getLogger(__name__) + logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback") + +from .config import Settings +from .models import get_model_config, ModelConfig + +logger = logging.getLogger(__name__) + + +def build_multimodal_content(content_list: list) -> list: + """ + Build OCI ChatContent object array from adapted content list. + + Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...). + + Args: + content_list: List of content items from request adapter + + Returns: + List of OCI ChatContent objects or dicts (fallback) + """ + if not MULTIMODAL_SUPPORTED: + # Fallback: return dict format, OCI SDK might auto-convert + return content_list + + oci_contents = [] + for item in content_list: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + if item_type == "text": + oci_contents.append(TextContent(text=item.get("text", ""))) + + elif item_type == "image_url": + image_data = item.get("image_url", {}) + if "url" in image_data: + # ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...) + img_url = ImageUrl(url=image_data["url"]) + # Optional: support 'detail' parameter if provided + if "detail" in image_data: + img_url.detail = image_data["detail"] + oci_contents.append(ImageContent(image_url=img_url, type="IMAGE")) + + elif item_type == "audio": + audio_data = item.get("audio_url", {}) + if "url" in audio_data: + # AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...) + audio_url = AudioUrl(url=audio_data["url"]) + oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO")) + + elif item_type == "video": + video_data = item.get("video_url", {}) + if "url" in video_data: + # VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...) + video_url = VideoUrl(url=video_data["url"]) + oci_contents.append(VideoContent(video_url=video_url, type="VIDEO")) + + return oci_contents if oci_contents else [TextContent(text="")] + + +class OCIGenAIClient: + """Wrapper for OCI Generative AI client.""" + + def __init__(self, settings: Settings, profile: Optional[str] = None): + """ + 初始化 OCI GenAI 客户端 + + Args: + settings: 应用设置 + profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile + """ + self.settings = settings + self.profile = profile or settings.get_profiles()[0] + self._client: Optional[GenerativeAiInferenceClient] = None + self._config: Optional[oci.config.Config] = None + self._region: Optional[str] = None + self._compartment_id: Optional[str] = None + + def _get_config(self) -> dict: + """Get OCI configuration.""" + if self._config is None: + if self.settings.oci_auth_type == "instance_principal": + signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + self._config = {"signer": signer} + else: + config_path = os.path.expanduser(self.settings.oci_config_file) + self._config = oci.config.from_file( + file_location=config_path, + profile_name=self.profile + ) + + # 从配置中读取 region 和 compartment_id + if self._region is None: + self._region = self._config.get("region") + if self._compartment_id is None: + self._compartment_id = self._config.get("tenancy") + + return self._config + + @property + def region(self) -> Optional[str]: + """获取当前配置的区域""" + if self._region is None and self._config is None: + self._get_config() + return self._region + + @property + def compartment_id(self) -> Optional[str]: + """获取当前配置的 compartment ID""" + if self._compartment_id is None and self._config is None: + self._get_config() + return self._compartment_id + + def _get_client(self) -> GenerativeAiInferenceClient: + """Get or create OCI Generative AI Inference client with correct endpoint.""" + config = self._get_config() + + # Use INFERENCE endpoint (not management endpoint) + # Official format: https://inference.generativeai.{region}.oci.oraclecloud.com + inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com" + + if isinstance(config, dict) and "signer" in config: + # For instance principal + client = GenerativeAiInferenceClient( + config={}, + service_endpoint=inference_endpoint, + **config + ) + return client + + # For API key authentication + client = GenerativeAiInferenceClient( + config=config, + service_endpoint=inference_endpoint, + retry_strategy=oci.retry.NoneRetryStrategy(), + timeout=(10, 240) + ) + + return client + + def chat( + self, + model_id: str, + messages: list, + temperature: float = 0.7, + max_tokens: int = 1024, + top_p: float = 1.0, + stream: bool = False, + tools: Optional[list] = None, + ): + """Send a chat completion request to OCI GenAI.""" + model_config = get_model_config(model_id) + if not model_config: + raise ValueError(f"Unsupported model: {model_id}") + + if not self.compartment_id: + raise ValueError("Compartment ID is required") + + client = self._get_client() + + # Prepare serving mode + if model_config.type == "dedicated" and model_config.endpoint: + serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint) + else: + serving_mode = OnDemandServingMode(model_id=model_id) + + # Convert messages based on provider + if model_config.provider == "cohere": + chat_request = self._build_cohere_request( + messages, temperature, max_tokens, top_p, tools, stream + ) + elif model_config.provider in ["meta", "xai", "google", "openai"]: + chat_request = self._build_generic_request( + messages, temperature, max_tokens, top_p, tools, model_config.provider, stream + ) + else: + raise ValueError(f"Unsupported provider: {model_config.provider}") + + chat_details = ChatDetails( + serving_mode=serving_mode, + compartment_id=self.compartment_id, + chat_request=chat_request, + ) + + logger.debug(f"Sending chat request to OCI GenAI: {model_id}") + response = client.chat(chat_details) + return response + + def _build_cohere_request( + self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False + ) -> CohereChatRequest: + """Build Cohere chat request. + + Note: Cohere models only support text content, not multimodal. + """ + # Convert messages to Cohere format + chat_history = [] + message = None + + for msg in messages: + role = msg["role"] + content = msg["content"] + + # Extract text from multimodal content + if isinstance(content, list): + # Extract text parts only + text_parts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + content = " ".join(text_parts) if text_parts else "" + + if role == "system": + # Cohere uses preamble for system messages + continue + elif role == "user": + message = content + elif role == "assistant": + chat_history.append( + CohereMessage(role="CHATBOT", message=content) + ) + elif role == "tool": + # Handle tool responses if needed + pass + + # Get preamble from system messages + preamble_override = None + for msg in messages: + if msg["role"] == "system": + preamble_override = msg["content"] + break + + return CohereChatRequest( + message=message, + chat_history=chat_history if chat_history else None, + preamble_override=preamble_override, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + is_stream=stream, + ) + + def _build_generic_request( + self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False + ) -> GenericChatRequest: + """Build Generic chat request for Llama and other models.""" + # Convert messages to Generic format + generic_messages = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + + # Handle multimodal content + if isinstance(content, list): + # Build OCI ChatContent objects from multimodal content + oci_contents = build_multimodal_content(content) + else: + # Simple text content + if MULTIMODAL_SUPPORTED: + oci_contents = [TextContent(text=content)] + else: + # Fallback: use dict format + oci_contents = [{"type": "text", "text": content}] + + if role == "user": + oci_role = "USER" + elif role in ["assistant", "model"]: + oci_role = "ASSISTANT" + elif role == "system": + oci_role = "SYSTEM" + else: + oci_role = role.upper() + + # Create Message with role and content objects + logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}") + + generic_messages.append( + Message( + role=oci_role, + content=oci_contents + ) + ) + + return GenericChatRequest( + messages=generic_messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + is_stream=stream, + ) + + def embed( + self, + model_id: str, + texts: list, + truncate: str = "END", + ): + """Generate embeddings using OCI GenAI.""" + model_config = get_model_config(model_id) + if not model_config or model_config.type != "embedding": + raise ValueError(f"Invalid embedding model: {model_id}") + + if not self.compartment_id: + raise ValueError("Compartment ID is required") + + client = self._get_client() + + serving_mode = OnDemandServingMode( + serving_type="ON_DEMAND", + model_id=model_id + ) + + embed_details = EmbedTextDetails( + serving_mode=serving_mode, + compartment_id=self.compartment_id, + inputs=texts, + truncate=truncate, + is_echo=False, + input_type="SEARCH_QUERY", + ) + + logger.debug(f"Sending embed request to OCI GenAI: {model_id}") + response = client.embed_text(embed_details) + return response diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..475a437 --- /dev/null +++ b/src/main.py @@ -0,0 +1,274 @@ +""" +Main FastAPI application for OCI Generative AI to OpenAI API Gateway. +""" +import logging +import sys +import os +from contextlib import asynccontextmanager +from logging.handlers import RotatingFileHandler + +from fastapi import FastAPI, Request, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError + +from oci.exceptions import ServiceError + +from core.config import get_settings +from core.models import update_models_from_oci +from api.routers import models, chat, embeddings +from api.schemas import ErrorResponse, ErrorDetail +from api.error_handler import OCIErrorHandler +from api.exceptions import ModelNotFoundException, InvalidModelTypeException + + +# Configure logging +def setup_logging(): + """Setup logging configuration.""" + settings = get_settings() + + # Create handlers list + handlers = [ + logging.StreamHandler(sys.stdout) + ] + + # Add file handler if log_file is configured + if settings.log_file: + log_dir = os.path.dirname(settings.log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + file_handler = RotatingFileHandler( + settings.log_file, + maxBytes=settings.log_file_max_size * 1024 * 1024, # Convert MB to bytes + backupCount=settings.log_file_backup_count, + encoding='utf-8' + ) + handlers.append(file_handler) + + logging.basicConfig( + level=getattr(logging, settings.log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=handlers + ) + + +setup_logging() +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + logger.info("=" * 60) + logger.info("Starting OCI GenAI to OpenAI API Gateway") + logger.info("=" * 60) + + settings = get_settings() + logger.info(f"API Version: {settings.api_version}") + logger.info(f"API Prefix: {settings.api_prefix}") + logger.info(f"Debug Mode: {settings.debug}") + logger.info(f"OCI Config: {settings.oci_config_file}") + + profiles = settings.get_profiles() + logger.info(f"OCI Profiles: {', '.join(profiles)}") + + try: + # Fetch models from OCI (fails fast if unable to fetch) + # 使用第一个 profile 进行模型发现 + update_models_from_oci( + config_path=settings.oci_config_file, + profile=profiles[0] if profiles else "DEFAULT" + ) + + logger.info("=" * 60) + logger.info("✅ Startup completed successfully") + logger.info(f"Server listening on {settings.api_host}:{settings.api_port}") + logger.info("=" * 60) + + except RuntimeError as e: + logger.error("=" * 60) + logger.error("❌ STARTUP FAILED") + logger.error("=" * 60) + logger.error(f"Reason: {str(e)}") + logger.error("") + logger.error("The service cannot start without available models from OCI.") + logger.error("Please review the troubleshooting steps above and fix the issue.") + logger.error("=" * 60) + raise + except Exception as e: + logger.error("=" * 60) + logger.error("❌ UNEXPECTED STARTUP ERROR") + logger.error("=" * 60) + logger.error(f"Error type: {type(e).__name__}") + logger.error(f"Error message: {str(e)}") + logger.error("=" * 60) + raise + + yield + + logger.info("=" * 60) + logger.info("Shutting down OCI GenAI to OpenAI API Gateway") + logger.info("=" * 60) + + +# Create FastAPI app +settings = get_settings() + +app = FastAPI( + title=settings.api_title, + version=settings.api_version, + description="OpenAI-compatible REST API for Oracle Cloud Infrastructure Generative AI Service", + lifespan=lifespan, + docs_url="/docs" if settings.debug else None, + redoc_url="/redoc" if settings.debug else None, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Exception handlers +@app.exception_handler(ModelNotFoundException) +async def model_not_found_handler(request: Request, exc: ModelNotFoundException): + """Handle model not found exceptions with OpenAI-compatible format.""" + error = ErrorDetail( + message=exc.detail, + type=exc.error_type, + code=exc.error_code + ) + + return JSONResponse( + status_code=exc.status_code, + content=ErrorResponse(error=error).dict() + ) + + +@app.exception_handler(InvalidModelTypeException) +async def invalid_model_type_handler(request: Request, exc: InvalidModelTypeException): + """Handle invalid model type exceptions with OpenAI-compatible format.""" + error = ErrorDetail( + message=exc.detail, + type=exc.error_type, + code=exc.error_code + ) + + return JSONResponse( + status_code=exc.status_code, + content=ErrorResponse(error=error).dict() + ) + + +@app.exception_handler(ServiceError) +async def oci_service_error_handler(request: Request, exc: ServiceError): + """Handle OCI SDK ServiceError exceptions.""" + # 使用 OCIErrorHandler 处理并过滤敏感信息 + error_response = OCIErrorHandler.sanitize_oci_error(exc) + + # 确定 HTTP 状态码(使用 OCI 返回的状态码) + status_code = exc.status if 400 <= exc.status < 600 else 500 + + return JSONResponse( + status_code=status_code, + content=error_response.dict() + ) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + """Handle HTTP exceptions with sensitive information filtering.""" + # 过滤 HTTPException detail 中可能包含的敏感信息 + filtered_detail = OCIErrorHandler.filter_sensitive_info(str(exc.detail)) + + error = ErrorDetail( + message=filtered_detail, + type="invalid_request_error", + code=f"http_{exc.status_code}" + ) + + return JSONResponse( + status_code=exc.status_code, + content=ErrorResponse(error=error).dict() + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors.""" + logger.error(f"Validation error: {exc}") + error = ErrorDetail( + message=str(exc), + type="invalid_request_error", + code="validation_error" + ) + return JSONResponse( + status_code=400, + content=ErrorResponse(error=error).dict() + ) + + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """Handle general exceptions with sensitive information filtering.""" + logger.error(f"Unexpected error: {exc}", exc_info=True) + + # 通用错误也要过滤可能包含的敏感信息(完整错误已记录到日志) + filtered_message = OCIErrorHandler.filter_sensitive_info(str(exc)) + + error = ErrorDetail( + message="An unexpected error occurred", # 不暴露具体错误 + type="server_error", + code="internal_error" + ) + return JSONResponse( + status_code=500, + content=ErrorResponse(error=error).dict() + ) + + +# Include routers +app.include_router(models.router, prefix=settings.api_prefix) +app.include_router(chat.router, prefix=settings.api_prefix) +app.include_router(embeddings.router, prefix=settings.api_prefix) + + +@app.get("/") +async def root(): + """Root endpoint.""" + return { + "name": settings.api_title, + "version": settings.api_version, + "description": "OpenAI-compatible REST API for OCI Generative AI", + "endpoints": { + "models": f"{settings.api_prefix}/models", + "chat": f"{settings.api_prefix}/chat/completions", + "embeddings": f"{settings.api_prefix}/embeddings" + } + } + + +@app.get("/health") +async def health(): + """Health check endpoint.""" + return { + "status": "healthy", + "service": "oci-genai-gateway" + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "main:app", + host=settings.api_host, + port=settings.api_port, + reload=settings.debug, + log_level=settings.log_level.lower() + )