第一次提交
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s

This commit is contained in:
2025-12-09 14:44:09 +08:00
commit 42222744c7
27 changed files with 3081 additions and 0 deletions

1
src/core/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Core module for OCI GenAI Gateway."""

View File

@@ -0,0 +1,70 @@
"""
简单的 OCI 客户端管理器,支持多 profile 轮询负载均衡
"""
import logging
from typing import List, Dict
from threading import Lock
from .config import Settings, get_settings
from .oci_client import OCIGenAIClient
logger = logging.getLogger(__name__)
class OCIClientManager:
"""OCI 客户端管理器,支持轮询负载均衡和客户端连接池"""
def __init__(self, settings: Settings = None):
self.settings = settings or get_settings()
self.profiles = self.settings.get_profiles()
self.current_index = 0
self.lock = Lock()
# 预创建客户端连接池
self._clients: Dict[str, OCIGenAIClient] = {}
logger.info(f"初始化 OCI 客户端管理器,共 {len(self.profiles)} 个 profiles: {self.profiles}")
for profile in self.profiles:
try:
self._clients[profile] = OCIGenAIClient(self.settings, profile)
logger.info(f"✓ 已创建客户端实例: {profile}")
except Exception as e:
logger.error(f"✗ 创建客户端实例失败 [{profile}]: {e}")
raise
def get_client(self) -> OCIGenAIClient:
"""
获取下一个客户端(轮询策略)
采用 round-robin 算法从预创建的客户端连接池中选择客户端实例。
此方法是线程安全的。
Returns:
OCIGenAIClient: 预创建的 OCI 客户端实例
Note:
客户端实例在管理器初始化时预创建,此方法不会创建新实例。
"""
with self.lock:
# 如果只有一个 profile直接返回
if len(self.profiles) == 1:
return self._clients[self.profiles[0]]
# 轮询选择 profile
profile = self.profiles[self.current_index]
self.current_index = (self.current_index + 1) % len(self.profiles)
logger.debug(f"选择 profile: {profile} (round-robin)")
return self._clients[profile]
# 全局客户端管理器实例
_client_manager = None
def get_client_manager() -> OCIClientManager:
"""获取全局客户端管理器实例"""
global _client_manager
if _client_manager is None:
_client_manager = OCIClientManager()
return _client_manager

100
src/core/config.py Normal file
View File

@@ -0,0 +1,100 @@
"""
Configuration module for OCI Generative AI to OpenAI API Gateway.
"""
import os
import logging
from pathlib import Path
from typing import Optional, List
from pydantic_settings import BaseSettings
logger = logging.getLogger(__name__)
# Find project root directory (where .env should be)
def find_project_root() -> Path:
"""Find the project root directory by looking for .env or requirements.txt."""
current = Path(__file__).resolve().parent # Start from src/core/
# Go up until we find project root markers
while current != current.parent:
if (current / ".env").exists() or (current / "requirements.txt").exists():
return current
current = current.parent
return Path.cwd() # Fallback to current directory
PROJECT_ROOT = find_project_root()
class Settings(BaseSettings):
"""Application settings with environment variable support."""
# API Settings
api_title: str = "OCI GenAI to OpenAI API Gateway"
api_version: str = "1.0.0"
api_prefix: str = "/v1"
api_port: int = 8000
api_host: str = "0.0.0.0"
debug: bool = False
# Authentication
api_keys: List[str] = ["sk-oci-genai-default-key"]
# OCI Settings
oci_config_file: str = "~/.oci/config"
oci_config_profile: str = "DEFAULT" # 支持多个profile用逗号分隔例如DEFAULT,CHICAGO,ASHBURN
oci_auth_type: str = "api_key" # api_key or instance_principal
# GenAI Service Settings
genai_endpoint: Optional[str] = None
max_tokens: int = 4096
temperature: float = 0.7
# Embedding Settings
embed_truncate: str = "END" # END or START
# Streaming Settings
enable_streaming: bool = True
stream_chunk_size: int = 1024
# Logging
log_level: str = "INFO"
log_requests: bool = False
log_responses: bool = False
log_file: Optional[str] = None
log_file_max_size: int = 10 # MB
log_file_backup_count: int = 5
class Config:
# Use absolute path to .env file in project root
env_file = str(PROJECT_ROOT / ".env")
env_file_encoding = "utf-8"
case_sensitive = False
# Allow reading from environment variables
env_prefix = ""
def model_post_init(self, __context) -> None:
"""Expand OCI config file path."""
# Expand OCI config file path
config_path = os.path.expanduser(self.oci_config_file)
# If it's a relative path (starts with ./ or doesn't start with /), resolve it from project root
if not config_path.startswith('/') and not config_path.startswith('~'):
# Remove leading ./ if present
if config_path.startswith('./'):
config_path = config_path[2:]
config_path = str(PROJECT_ROOT / config_path)
# Update the config_path
self.oci_config_file = config_path
def get_profiles(self) -> List[str]:
"""获取配置的所有 profile 列表"""
return [p.strip() for p in self.oci_config_profile.split(',') if p.strip()]
# Global settings instance
settings = Settings()
def get_settings() -> Settings:
"""Get the global settings instance."""
return settings

260
src/core/models.py Normal file
View File

@@ -0,0 +1,260 @@
"""
Model definitions and configurations for OCI Generative AI models.
"""
import logging
import os
from typing import Dict, List, Optional
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class ModelConfig(BaseModel):
"""Configuration for a single model."""
id: str
name: str
type: str # ondemand, dedicated, embedding
provider: str # cohere, meta, openai, etc.
region: Optional[str] = None
compartment_id: Optional[str] = None
endpoint: Optional[str] = None
supports_streaming: bool = True
supports_tools: bool = False
supports_multimodal: bool = False
multimodal_types: List[str] = []
max_tokens: int = 4096
context_window: int = 128000
# OCI Generative AI models (dynamically loaded from OCI at startup)
OCI_CHAT_MODELS: Dict[str, ModelConfig] = {}
OCI_EMBED_MODELS: Dict[str, ModelConfig] = {}
def get_all_models() -> List[ModelConfig]:
"""Get all available models."""
return list(OCI_CHAT_MODELS.values()) + list(OCI_EMBED_MODELS.values())
def get_chat_models() -> List[ModelConfig]:
"""Get all chat models."""
return list(OCI_CHAT_MODELS.values())
def get_embed_models() -> List[ModelConfig]:
"""Get all embedding models."""
return list(OCI_EMBED_MODELS.values())
def get_model_config(model_id: str) -> Optional[ModelConfig]:
"""Get configuration for a specific model."""
if model_id in OCI_CHAT_MODELS:
return OCI_CHAT_MODELS[model_id]
if model_id in OCI_EMBED_MODELS:
return OCI_EMBED_MODELS[model_id]
return None
def fetch_models_from_oci(compartment_id: Optional[str] = None, region: Optional[str] = None,
config_path: str = "./.oci/config",
profile: str = "DEFAULT") -> Dict[str, Dict[str, ModelConfig]]:
"""
Dynamically fetch available models from OCI Generative AI service.
If compartment_id or region are not provided, they will be read from the OCI config file.
- compartment_id defaults to 'tenancy' from config
- region defaults to 'region' from config
Args:
compartment_id: OCI compartment ID (optional, defaults to tenancy from config)
region: OCI region (optional, defaults to region from config)
config_path: Path to OCI config file
profile: OCI config profile name
Returns:
Dictionary with 'chat' and 'embed' keys containing model configs
"""
try:
import oci
from oci.generative_ai import GenerativeAiClient
# Load OCI configuration
config = oci.config.from_file(
file_location=os.path.expanduser(config_path),
profile_name=profile
)
# Use values from config if not provided
if not region:
region = config.get("region")
logger.info(f"📍 Using region from OCI config: {region}")
if not compartment_id:
compartment_id = config.get("tenancy")
logger.info(f"📦 Using tenancy as compartment_id: {compartment_id}")
if not region or not compartment_id:
logger.error("❌ Missing region or compartment_id in OCI config")
return {"chat": {}, "embed": {}}
# Create GenerativeAiClient (not GenerativeAiInferenceClient)
service_endpoint = f"https://generativeai.{region}.oci.oraclecloud.com"
logger.info(f"🔗 Connecting to OCI GenerativeAI endpoint: {service_endpoint}")
client = GenerativeAiClient(config, service_endpoint=service_endpoint)
chat_models = {}
embed_models = {}
# Fetch all models (without capability filter to work with tenancy compartment)
try:
logger.info("🔍 Fetching all models from OCI...")
logger.debug(f" Compartment ID: {compartment_id}")
logger.debug(f" Method: Fetching all models, will filter by capabilities in Python")
response = client.list_models(
compartment_id=compartment_id
)
logger.info(f"✅ Successfully fetched {len(response.data.items)} models from OCI")
# Filter models by capabilities in Python
for model in response.data.items:
model_id = model.display_name
provider = model_id.split(".")[0] if "." in model_id else "unknown"
capabilities = model.capabilities if hasattr(model, 'capabilities') else []
logger.debug(f" Processing: {model_id} (capabilities: {capabilities})")
# Chat models: have CHAT or TEXT_GENERATION capability
if 'CHAT' in capabilities or 'TEXT_GENERATION' in capabilities:
supports_streaming = True # Most models support streaming
supports_tools = provider in ["cohere", "meta"] # These providers support tools
# Detect multimodal support from capabilities
supports_multimodal = False
multimodal_types = []
if 'IMAGE' in capabilities or 'VISION' in capabilities:
supports_multimodal = True
multimodal_types.append("image")
chat_models[model_id] = ModelConfig(
id=model_id,
name=model.display_name,
type="ondemand",
provider=provider,
region=region,
compartment_id=compartment_id,
supports_streaming=supports_streaming,
supports_tools=supports_tools,
supports_multimodal=supports_multimodal,
multimodal_types=multimodal_types,
max_tokens=4096,
context_window=128000
)
# Embedding models: have TEXT_EMBEDDINGS capability
elif 'TEXT_EMBEDDINGS' in capabilities:
embed_models[model_id] = ModelConfig(
id=model_id,
name=model.display_name,
type="embedding",
provider=provider,
region=region,
compartment_id=compartment_id,
supports_streaming=False,
supports_tools=False,
max_tokens=512,
context_window=512
)
logger.info(f"✅ Filtered {len(chat_models)} chat models")
if chat_models:
logger.debug(f" Chat models: {', '.join(list(chat_models.keys())[:5])}{'...' if len(chat_models) > 5 else ''}")
logger.info(f"✅ Filtered {len(embed_models)} embedding models")
if embed_models:
logger.debug(f" Embed models: {', '.join(embed_models.keys())}")
except Exception as e:
logger.warning(f"⚠️ Failed to fetch models from OCI")
logger.warning(f" Error: {e}")
if hasattr(e, 'status'):
logger.warning(f" HTTP Status: {e.status}")
if hasattr(e, 'code'):
logger.warning(f" Error Code: {e.code}")
logger.info(f"💡 Tip: Check your OCI credentials and permissions")
return {"chat": chat_models, "embed": embed_models}
except Exception as e:
logger.error(f"❌ Failed to initialize OCI client for model discovery")
logger.error(f" Error: {e}")
logger.info("💡 Tip: Check your OCI credentials and permissions")
return {"chat": {}, "embed": {}}
def update_models_from_oci(compartment_id: Optional[str] = None,
region: Optional[str] = None,
config_path: str = "./.oci/config",
profile: str = "DEFAULT") -> None:
"""
Update global model dictionaries with models from OCI.
Raises RuntimeError if model fetching fails.
Priority for configuration values:
1. Explicitly provided parameters
2. Environment variables (OCI_COMPARTMENT_ID, OCI_REGION)
3. Values from .oci/config file (tenancy, region)
Raises:
RuntimeError: If no models can be fetched from OCI
"""
global OCI_CHAT_MODELS, OCI_EMBED_MODELS
# Priority: explicit params > environment > config file
if not compartment_id:
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
if not region:
region = os.getenv("OCI_REGION")
# Note: If still not set, fetch_models_from_oci will try to read from config file
logger.info("🚀 Attempting to fetch models from OCI...")
fetched = fetch_models_from_oci(compartment_id, region, config_path, profile)
# Fail-fast: Require successful model fetching
if not fetched["chat"] and not fetched["embed"]:
error_msg = (
"❌ Failed to fetch any models from OCI.\n\n"
"Troubleshooting steps:\n"
"1. Verify your OCI credentials are configured correctly:\n"
f" - Config file: {config_path}\n"
f" - Profile: {profile}\n"
" - Run: oci iam region list (to test authentication)\n\n"
"2. Check your OCI permissions:\n"
" - Ensure you have access to Generative AI service\n"
" - Verify compartment_id/tenancy has available models\n\n"
"3. Check network connectivity:\n"
" - Ensure you can reach OCI endpoints\n"
f" - Test region: {region or 'from config file'}\n\n"
"4. Review logs above for detailed error messages"
)
logger.error(error_msg)
raise RuntimeError(
"Failed to fetch models from OCI. "
"The service cannot start without available models. "
"Check the logs above for troubleshooting guidance."
)
# Update global model registries
if fetched["chat"]:
OCI_CHAT_MODELS.clear()
OCI_CHAT_MODELS.update(fetched["chat"])
logger.info(f"✅ Loaded {len(OCI_CHAT_MODELS)} chat models from OCI")
if fetched["embed"]:
OCI_EMBED_MODELS.clear()
OCI_EMBED_MODELS.update(fetched["embed"])
logger.info(f"✅ Loaded {len(OCI_EMBED_MODELS)} embedding models from OCI")
logger.info(f"✅ Model discovery completed successfully")

361
src/core/oci_client.py Normal file
View File

@@ -0,0 +1,361 @@
"""
OCI Generative AI client wrapper.
"""
import os
import logging
from typing import Optional, AsyncIterator
import oci
from oci.generative_ai_inference import GenerativeAiInferenceClient
from oci.generative_ai_inference.models import (
ChatDetails,
CohereChatRequest,
GenericChatRequest,
OnDemandServingMode,
DedicatedServingMode,
CohereMessage,
Message,
TextContent,
EmbedTextDetails,
)
# Try to import multimodal content types
try:
from oci.generative_ai_inference.models import (
ImageContent,
ImageUrl,
AudioContent,
AudioUrl,
VideoContent,
VideoUrl,
)
MULTIMODAL_SUPPORTED = True
logger_init = logging.getLogger(__name__)
logger_init.info("OCI SDK multimodal content types available")
except ImportError:
MULTIMODAL_SUPPORTED = False
logger_init = logging.getLogger(__name__)
logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback")
from .config import Settings
from .models import get_model_config, ModelConfig
logger = logging.getLogger(__name__)
def build_multimodal_content(content_list: list) -> list:
"""
Build OCI ChatContent object array from adapted content list.
Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...).
Args:
content_list: List of content items from request adapter
Returns:
List of OCI ChatContent objects or dicts (fallback)
"""
if not MULTIMODAL_SUPPORTED:
# Fallback: return dict format, OCI SDK might auto-convert
return content_list
oci_contents = []
for item in content_list:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
oci_contents.append(TextContent(text=item.get("text", "")))
elif item_type == "image_url":
image_data = item.get("image_url", {})
if "url" in image_data:
# ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...)
img_url = ImageUrl(url=image_data["url"])
# Optional: support 'detail' parameter if provided
if "detail" in image_data:
img_url.detail = image_data["detail"]
oci_contents.append(ImageContent(image_url=img_url, type="IMAGE"))
elif item_type == "audio":
audio_data = item.get("audio_url", {})
if "url" in audio_data:
# AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...)
audio_url = AudioUrl(url=audio_data["url"])
oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO"))
elif item_type == "video":
video_data = item.get("video_url", {})
if "url" in video_data:
# VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...)
video_url = VideoUrl(url=video_data["url"])
oci_contents.append(VideoContent(video_url=video_url, type="VIDEO"))
return oci_contents if oci_contents else [TextContent(text="")]
class OCIGenAIClient:
"""Wrapper for OCI Generative AI client."""
def __init__(self, settings: Settings, profile: Optional[str] = None):
"""
初始化 OCI GenAI 客户端
Args:
settings: 应用设置
profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile
"""
self.settings = settings
self.profile = profile or settings.get_profiles()[0]
self._client: Optional[GenerativeAiInferenceClient] = None
self._config: Optional[oci.config.Config] = None
self._region: Optional[str] = None
self._compartment_id: Optional[str] = None
def _get_config(self) -> dict:
"""Get OCI configuration."""
if self._config is None:
if self.settings.oci_auth_type == "instance_principal":
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
self._config = {"signer": signer}
else:
config_path = os.path.expanduser(self.settings.oci_config_file)
self._config = oci.config.from_file(
file_location=config_path,
profile_name=self.profile
)
# 从配置中读取 region 和 compartment_id
if self._region is None:
self._region = self._config.get("region")
if self._compartment_id is None:
self._compartment_id = self._config.get("tenancy")
return self._config
@property
def region(self) -> Optional[str]:
"""获取当前配置的区域"""
if self._region is None and self._config is None:
self._get_config()
return self._region
@property
def compartment_id(self) -> Optional[str]:
"""获取当前配置的 compartment ID"""
if self._compartment_id is None and self._config is None:
self._get_config()
return self._compartment_id
def _get_client(self) -> GenerativeAiInferenceClient:
"""Get or create OCI Generative AI Inference client with correct endpoint."""
config = self._get_config()
# Use INFERENCE endpoint (not management endpoint)
# Official format: https://inference.generativeai.{region}.oci.oraclecloud.com
inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com"
if isinstance(config, dict) and "signer" in config:
# For instance principal
client = GenerativeAiInferenceClient(
config={},
service_endpoint=inference_endpoint,
**config
)
return client
# For API key authentication
client = GenerativeAiInferenceClient(
config=config,
service_endpoint=inference_endpoint,
retry_strategy=oci.retry.NoneRetryStrategy(),
timeout=(10, 240)
)
return client
def chat(
self,
model_id: str,
messages: list,
temperature: float = 0.7,
max_tokens: int = 1024,
top_p: float = 1.0,
stream: bool = False,
tools: Optional[list] = None,
):
"""Send a chat completion request to OCI GenAI."""
model_config = get_model_config(model_id)
if not model_config:
raise ValueError(f"Unsupported model: {model_id}")
if not self.compartment_id:
raise ValueError("Compartment ID is required")
client = self._get_client()
# Prepare serving mode
if model_config.type == "dedicated" and model_config.endpoint:
serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint)
else:
serving_mode = OnDemandServingMode(model_id=model_id)
# Convert messages based on provider
if model_config.provider == "cohere":
chat_request = self._build_cohere_request(
messages, temperature, max_tokens, top_p, tools, stream
)
elif model_config.provider in ["meta", "xai", "google", "openai"]:
chat_request = self._build_generic_request(
messages, temperature, max_tokens, top_p, tools, model_config.provider, stream
)
else:
raise ValueError(f"Unsupported provider: {model_config.provider}")
chat_details = ChatDetails(
serving_mode=serving_mode,
compartment_id=self.compartment_id,
chat_request=chat_request,
)
logger.debug(f"Sending chat request to OCI GenAI: {model_id}")
response = client.chat(chat_details)
return response
def _build_cohere_request(
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False
) -> CohereChatRequest:
"""Build Cohere chat request.
Note: Cohere models only support text content, not multimodal.
"""
# Convert messages to Cohere format
chat_history = []
message = None
for msg in messages:
role = msg["role"]
content = msg["content"]
# Extract text from multimodal content
if isinstance(content, list):
# Extract text parts only
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
content = " ".join(text_parts) if text_parts else ""
if role == "system":
# Cohere uses preamble for system messages
continue
elif role == "user":
message = content
elif role == "assistant":
chat_history.append(
CohereMessage(role="CHATBOT", message=content)
)
elif role == "tool":
# Handle tool responses if needed
pass
# Get preamble from system messages
preamble_override = None
for msg in messages:
if msg["role"] == "system":
preamble_override = msg["content"]
break
return CohereChatRequest(
message=message,
chat_history=chat_history if chat_history else None,
preamble_override=preamble_override,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
is_stream=stream,
)
def _build_generic_request(
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False
) -> GenericChatRequest:
"""Build Generic chat request for Llama and other models."""
# Convert messages to Generic format
generic_messages = []
for msg in messages:
role = msg["role"]
content = msg["content"]
# Handle multimodal content
if isinstance(content, list):
# Build OCI ChatContent objects from multimodal content
oci_contents = build_multimodal_content(content)
else:
# Simple text content
if MULTIMODAL_SUPPORTED:
oci_contents = [TextContent(text=content)]
else:
# Fallback: use dict format
oci_contents = [{"type": "text", "text": content}]
if role == "user":
oci_role = "USER"
elif role in ["assistant", "model"]:
oci_role = "ASSISTANT"
elif role == "system":
oci_role = "SYSTEM"
else:
oci_role = role.upper()
# Create Message with role and content objects
logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}")
generic_messages.append(
Message(
role=oci_role,
content=oci_contents
)
)
return GenericChatRequest(
messages=generic_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
is_stream=stream,
)
def embed(
self,
model_id: str,
texts: list,
truncate: str = "END",
):
"""Generate embeddings using OCI GenAI."""
model_config = get_model_config(model_id)
if not model_config or model_config.type != "embedding":
raise ValueError(f"Invalid embedding model: {model_id}")
if not self.compartment_id:
raise ValueError("Compartment ID is required")
client = self._get_client()
serving_mode = OnDemandServingMode(
serving_type="ON_DEMAND",
model_id=model_id
)
embed_details = EmbedTextDetails(
serving_mode=serving_mode,
compartment_id=self.compartment_id,
inputs=texts,
truncate=truncate,
is_echo=False,
input_type="SEARCH_QUERY",
)
logger.debug(f"Sending embed request to OCI GenAI: {model_id}")
response = client.embed_text(embed_details)
return response