第一次提交
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
This commit is contained in:
1
src/core/__init__.py
Normal file
1
src/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core module for OCI GenAI Gateway."""
|
||||
70
src/core/client_manager.py
Normal file
70
src/core/client_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
简单的 OCI 客户端管理器,支持多 profile 轮询负载均衡
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from threading import Lock
|
||||
|
||||
from .config import Settings, get_settings
|
||||
from .oci_client import OCIGenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCIClientManager:
|
||||
"""OCI 客户端管理器,支持轮询负载均衡和客户端连接池"""
|
||||
|
||||
def __init__(self, settings: Settings = None):
|
||||
self.settings = settings or get_settings()
|
||||
self.profiles = self.settings.get_profiles()
|
||||
self.current_index = 0
|
||||
self.lock = Lock()
|
||||
|
||||
# 预创建客户端连接池
|
||||
self._clients: Dict[str, OCIGenAIClient] = {}
|
||||
logger.info(f"初始化 OCI 客户端管理器,共 {len(self.profiles)} 个 profiles: {self.profiles}")
|
||||
|
||||
for profile in self.profiles:
|
||||
try:
|
||||
self._clients[profile] = OCIGenAIClient(self.settings, profile)
|
||||
logger.info(f"✓ 已创建客户端实例: {profile}")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ 创建客户端实例失败 [{profile}]: {e}")
|
||||
raise
|
||||
|
||||
def get_client(self) -> OCIGenAIClient:
|
||||
"""
|
||||
获取下一个客户端(轮询策略)
|
||||
|
||||
采用 round-robin 算法从预创建的客户端连接池中选择客户端实例。
|
||||
此方法是线程安全的。
|
||||
|
||||
Returns:
|
||||
OCIGenAIClient: 预创建的 OCI 客户端实例
|
||||
|
||||
Note:
|
||||
客户端实例在管理器初始化时预创建,此方法不会创建新实例。
|
||||
"""
|
||||
with self.lock:
|
||||
# 如果只有一个 profile,直接返回
|
||||
if len(self.profiles) == 1:
|
||||
return self._clients[self.profiles[0]]
|
||||
|
||||
# 轮询选择 profile
|
||||
profile = self.profiles[self.current_index]
|
||||
self.current_index = (self.current_index + 1) % len(self.profiles)
|
||||
|
||||
logger.debug(f"选择 profile: {profile} (round-robin)")
|
||||
return self._clients[profile]
|
||||
|
||||
|
||||
# 全局客户端管理器实例
|
||||
_client_manager = None
|
||||
|
||||
|
||||
def get_client_manager() -> OCIClientManager:
|
||||
"""获取全局客户端管理器实例"""
|
||||
global _client_manager
|
||||
if _client_manager is None:
|
||||
_client_manager = OCIClientManager()
|
||||
return _client_manager
|
||||
100
src/core/config.py
Normal file
100
src/core/config.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Configuration module for OCI Generative AI to OpenAI API Gateway.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Find project root directory (where .env should be)
|
||||
def find_project_root() -> Path:
|
||||
"""Find the project root directory by looking for .env or requirements.txt."""
|
||||
current = Path(__file__).resolve().parent # Start from src/core/
|
||||
# Go up until we find project root markers
|
||||
while current != current.parent:
|
||||
if (current / ".env").exists() or (current / "requirements.txt").exists():
|
||||
return current
|
||||
current = current.parent
|
||||
return Path.cwd() # Fallback to current directory
|
||||
|
||||
PROJECT_ROOT = find_project_root()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable support."""
|
||||
|
||||
# API Settings
|
||||
api_title: str = "OCI GenAI to OpenAI API Gateway"
|
||||
api_version: str = "1.0.0"
|
||||
api_prefix: str = "/v1"
|
||||
api_port: int = 8000
|
||||
api_host: str = "0.0.0.0"
|
||||
debug: bool = False
|
||||
|
||||
# Authentication
|
||||
api_keys: List[str] = ["sk-oci-genai-default-key"]
|
||||
|
||||
# OCI Settings
|
||||
oci_config_file: str = "~/.oci/config"
|
||||
oci_config_profile: str = "DEFAULT" # 支持多个profile,用逗号分隔,例如:DEFAULT,CHICAGO,ASHBURN
|
||||
oci_auth_type: str = "api_key" # api_key or instance_principal
|
||||
|
||||
# GenAI Service Settings
|
||||
genai_endpoint: Optional[str] = None
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
|
||||
# Embedding Settings
|
||||
embed_truncate: str = "END" # END or START
|
||||
|
||||
# Streaming Settings
|
||||
enable_streaming: bool = True
|
||||
stream_chunk_size: int = 1024
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_requests: bool = False
|
||||
log_responses: bool = False
|
||||
log_file: Optional[str] = None
|
||||
log_file_max_size: int = 10 # MB
|
||||
log_file_backup_count: int = 5
|
||||
|
||||
class Config:
|
||||
# Use absolute path to .env file in project root
|
||||
env_file = str(PROJECT_ROOT / ".env")
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
# Allow reading from environment variables
|
||||
env_prefix = ""
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Expand OCI config file path."""
|
||||
# Expand OCI config file path
|
||||
config_path = os.path.expanduser(self.oci_config_file)
|
||||
|
||||
# If it's a relative path (starts with ./ or doesn't start with /), resolve it from project root
|
||||
if not config_path.startswith('/') and not config_path.startswith('~'):
|
||||
# Remove leading ./ if present
|
||||
if config_path.startswith('./'):
|
||||
config_path = config_path[2:]
|
||||
config_path = str(PROJECT_ROOT / config_path)
|
||||
|
||||
# Update the config_path
|
||||
self.oci_config_file = config_path
|
||||
|
||||
def get_profiles(self) -> List[str]:
|
||||
"""获取配置的所有 profile 列表"""
|
||||
return [p.strip() for p in self.oci_config_profile.split(',') if p.strip()]
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get the global settings instance."""
|
||||
return settings
|
||||
260
src/core/models.py
Normal file
260
src/core/models.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Model definitions and configurations for OCI Generative AI models.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Configuration for a single model."""
|
||||
id: str
|
||||
name: str
|
||||
type: str # ondemand, dedicated, embedding
|
||||
provider: str # cohere, meta, openai, etc.
|
||||
region: Optional[str] = None
|
||||
compartment_id: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
supports_streaming: bool = True
|
||||
supports_tools: bool = False
|
||||
supports_multimodal: bool = False
|
||||
multimodal_types: List[str] = []
|
||||
max_tokens: int = 4096
|
||||
context_window: int = 128000
|
||||
|
||||
|
||||
# OCI Generative AI models (dynamically loaded from OCI at startup)
|
||||
OCI_CHAT_MODELS: Dict[str, ModelConfig] = {}
|
||||
|
||||
OCI_EMBED_MODELS: Dict[str, ModelConfig] = {}
|
||||
|
||||
|
||||
def get_all_models() -> List[ModelConfig]:
|
||||
"""Get all available models."""
|
||||
return list(OCI_CHAT_MODELS.values()) + list(OCI_EMBED_MODELS.values())
|
||||
|
||||
|
||||
def get_chat_models() -> List[ModelConfig]:
|
||||
"""Get all chat models."""
|
||||
return list(OCI_CHAT_MODELS.values())
|
||||
|
||||
|
||||
def get_embed_models() -> List[ModelConfig]:
|
||||
"""Get all embedding models."""
|
||||
return list(OCI_EMBED_MODELS.values())
|
||||
|
||||
|
||||
def get_model_config(model_id: str) -> Optional[ModelConfig]:
|
||||
"""Get configuration for a specific model."""
|
||||
if model_id in OCI_CHAT_MODELS:
|
||||
return OCI_CHAT_MODELS[model_id]
|
||||
if model_id in OCI_EMBED_MODELS:
|
||||
return OCI_EMBED_MODELS[model_id]
|
||||
return None
|
||||
|
||||
|
||||
def fetch_models_from_oci(compartment_id: Optional[str] = None, region: Optional[str] = None,
|
||||
config_path: str = "./.oci/config",
|
||||
profile: str = "DEFAULT") -> Dict[str, Dict[str, ModelConfig]]:
|
||||
"""
|
||||
Dynamically fetch available models from OCI Generative AI service.
|
||||
|
||||
If compartment_id or region are not provided, they will be read from the OCI config file.
|
||||
- compartment_id defaults to 'tenancy' from config
|
||||
- region defaults to 'region' from config
|
||||
|
||||
Args:
|
||||
compartment_id: OCI compartment ID (optional, defaults to tenancy from config)
|
||||
region: OCI region (optional, defaults to region from config)
|
||||
config_path: Path to OCI config file
|
||||
profile: OCI config profile name
|
||||
|
||||
Returns:
|
||||
Dictionary with 'chat' and 'embed' keys containing model configs
|
||||
"""
|
||||
try:
|
||||
import oci
|
||||
from oci.generative_ai import GenerativeAiClient
|
||||
|
||||
# Load OCI configuration
|
||||
config = oci.config.from_file(
|
||||
file_location=os.path.expanduser(config_path),
|
||||
profile_name=profile
|
||||
)
|
||||
|
||||
# Use values from config if not provided
|
||||
if not region:
|
||||
region = config.get("region")
|
||||
logger.info(f"📍 Using region from OCI config: {region}")
|
||||
|
||||
if not compartment_id:
|
||||
compartment_id = config.get("tenancy")
|
||||
logger.info(f"📦 Using tenancy as compartment_id: {compartment_id}")
|
||||
|
||||
if not region or not compartment_id:
|
||||
logger.error("❌ Missing region or compartment_id in OCI config")
|
||||
return {"chat": {}, "embed": {}}
|
||||
|
||||
# Create GenerativeAiClient (not GenerativeAiInferenceClient)
|
||||
service_endpoint = f"https://generativeai.{region}.oci.oraclecloud.com"
|
||||
logger.info(f"🔗 Connecting to OCI GenerativeAI endpoint: {service_endpoint}")
|
||||
client = GenerativeAiClient(config, service_endpoint=service_endpoint)
|
||||
|
||||
chat_models = {}
|
||||
embed_models = {}
|
||||
|
||||
# Fetch all models (without capability filter to work with tenancy compartment)
|
||||
try:
|
||||
logger.info("🔍 Fetching all models from OCI...")
|
||||
logger.debug(f" Compartment ID: {compartment_id}")
|
||||
logger.debug(f" Method: Fetching all models, will filter by capabilities in Python")
|
||||
|
||||
response = client.list_models(
|
||||
compartment_id=compartment_id
|
||||
)
|
||||
|
||||
logger.info(f"✅ Successfully fetched {len(response.data.items)} models from OCI")
|
||||
|
||||
# Filter models by capabilities in Python
|
||||
for model in response.data.items:
|
||||
model_id = model.display_name
|
||||
provider = model_id.split(".")[0] if "." in model_id else "unknown"
|
||||
capabilities = model.capabilities if hasattr(model, 'capabilities') else []
|
||||
|
||||
logger.debug(f" Processing: {model_id} (capabilities: {capabilities})")
|
||||
|
||||
# Chat models: have CHAT or TEXT_GENERATION capability
|
||||
if 'CHAT' in capabilities or 'TEXT_GENERATION' in capabilities:
|
||||
supports_streaming = True # Most models support streaming
|
||||
supports_tools = provider in ["cohere", "meta"] # These providers support tools
|
||||
|
||||
# Detect multimodal support from capabilities
|
||||
supports_multimodal = False
|
||||
multimodal_types = []
|
||||
if 'IMAGE' in capabilities or 'VISION' in capabilities:
|
||||
supports_multimodal = True
|
||||
multimodal_types.append("image")
|
||||
|
||||
chat_models[model_id] = ModelConfig(
|
||||
id=model_id,
|
||||
name=model.display_name,
|
||||
type="ondemand",
|
||||
provider=provider,
|
||||
region=region,
|
||||
compartment_id=compartment_id,
|
||||
supports_streaming=supports_streaming,
|
||||
supports_tools=supports_tools,
|
||||
supports_multimodal=supports_multimodal,
|
||||
multimodal_types=multimodal_types,
|
||||
max_tokens=4096,
|
||||
context_window=128000
|
||||
)
|
||||
|
||||
# Embedding models: have TEXT_EMBEDDINGS capability
|
||||
elif 'TEXT_EMBEDDINGS' in capabilities:
|
||||
embed_models[model_id] = ModelConfig(
|
||||
id=model_id,
|
||||
name=model.display_name,
|
||||
type="embedding",
|
||||
provider=provider,
|
||||
region=region,
|
||||
compartment_id=compartment_id,
|
||||
supports_streaming=False,
|
||||
supports_tools=False,
|
||||
max_tokens=512,
|
||||
context_window=512
|
||||
)
|
||||
|
||||
logger.info(f"✅ Filtered {len(chat_models)} chat models")
|
||||
if chat_models:
|
||||
logger.debug(f" Chat models: {', '.join(list(chat_models.keys())[:5])}{'...' if len(chat_models) > 5 else ''}")
|
||||
|
||||
logger.info(f"✅ Filtered {len(embed_models)} embedding models")
|
||||
if embed_models:
|
||||
logger.debug(f" Embed models: {', '.join(embed_models.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Failed to fetch models from OCI")
|
||||
logger.warning(f" Error: {e}")
|
||||
if hasattr(e, 'status'):
|
||||
logger.warning(f" HTTP Status: {e.status}")
|
||||
if hasattr(e, 'code'):
|
||||
logger.warning(f" Error Code: {e.code}")
|
||||
logger.info(f"💡 Tip: Check your OCI credentials and permissions")
|
||||
|
||||
return {"chat": chat_models, "embed": embed_models}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize OCI client for model discovery")
|
||||
logger.error(f" Error: {e}")
|
||||
logger.info("💡 Tip: Check your OCI credentials and permissions")
|
||||
return {"chat": {}, "embed": {}}
|
||||
|
||||
|
||||
def update_models_from_oci(compartment_id: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
config_path: str = "./.oci/config",
|
||||
profile: str = "DEFAULT") -> None:
|
||||
"""
|
||||
Update global model dictionaries with models from OCI.
|
||||
Raises RuntimeError if model fetching fails.
|
||||
|
||||
Priority for configuration values:
|
||||
1. Explicitly provided parameters
|
||||
2. Environment variables (OCI_COMPARTMENT_ID, OCI_REGION)
|
||||
3. Values from .oci/config file (tenancy, region)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no models can be fetched from OCI
|
||||
"""
|
||||
global OCI_CHAT_MODELS, OCI_EMBED_MODELS
|
||||
|
||||
# Priority: explicit params > environment > config file
|
||||
if not compartment_id:
|
||||
compartment_id = os.getenv("OCI_COMPARTMENT_ID")
|
||||
if not region:
|
||||
region = os.getenv("OCI_REGION")
|
||||
|
||||
# Note: If still not set, fetch_models_from_oci will try to read from config file
|
||||
logger.info("🚀 Attempting to fetch models from OCI...")
|
||||
fetched = fetch_models_from_oci(compartment_id, region, config_path, profile)
|
||||
|
||||
# Fail-fast: Require successful model fetching
|
||||
if not fetched["chat"] and not fetched["embed"]:
|
||||
error_msg = (
|
||||
"❌ Failed to fetch any models from OCI.\n\n"
|
||||
"Troubleshooting steps:\n"
|
||||
"1. Verify your OCI credentials are configured correctly:\n"
|
||||
f" - Config file: {config_path}\n"
|
||||
f" - Profile: {profile}\n"
|
||||
" - Run: oci iam region list (to test authentication)\n\n"
|
||||
"2. Check your OCI permissions:\n"
|
||||
" - Ensure you have access to Generative AI service\n"
|
||||
" - Verify compartment_id/tenancy has available models\n\n"
|
||||
"3. Check network connectivity:\n"
|
||||
" - Ensure you can reach OCI endpoints\n"
|
||||
f" - Test region: {region or 'from config file'}\n\n"
|
||||
"4. Review logs above for detailed error messages"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(
|
||||
"Failed to fetch models from OCI. "
|
||||
"The service cannot start without available models. "
|
||||
"Check the logs above for troubleshooting guidance."
|
||||
)
|
||||
|
||||
# Update global model registries
|
||||
if fetched["chat"]:
|
||||
OCI_CHAT_MODELS.clear()
|
||||
OCI_CHAT_MODELS.update(fetched["chat"])
|
||||
logger.info(f"✅ Loaded {len(OCI_CHAT_MODELS)} chat models from OCI")
|
||||
|
||||
if fetched["embed"]:
|
||||
OCI_EMBED_MODELS.clear()
|
||||
OCI_EMBED_MODELS.update(fetched["embed"])
|
||||
logger.info(f"✅ Loaded {len(OCI_EMBED_MODELS)} embedding models from OCI")
|
||||
|
||||
logger.info(f"✅ Model discovery completed successfully")
|
||||
361
src/core/oci_client.py
Normal file
361
src/core/oci_client.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
OCI Generative AI client wrapper.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, AsyncIterator
|
||||
import oci
|
||||
from oci.generative_ai_inference import GenerativeAiInferenceClient
|
||||
from oci.generative_ai_inference.models import (
|
||||
ChatDetails,
|
||||
CohereChatRequest,
|
||||
GenericChatRequest,
|
||||
OnDemandServingMode,
|
||||
DedicatedServingMode,
|
||||
CohereMessage,
|
||||
Message,
|
||||
TextContent,
|
||||
EmbedTextDetails,
|
||||
)
|
||||
|
||||
# Try to import multimodal content types
|
||||
try:
|
||||
from oci.generative_ai_inference.models import (
|
||||
ImageContent,
|
||||
ImageUrl,
|
||||
AudioContent,
|
||||
AudioUrl,
|
||||
VideoContent,
|
||||
VideoUrl,
|
||||
)
|
||||
MULTIMODAL_SUPPORTED = True
|
||||
logger_init = logging.getLogger(__name__)
|
||||
logger_init.info("OCI SDK multimodal content types available")
|
||||
except ImportError:
|
||||
MULTIMODAL_SUPPORTED = False
|
||||
logger_init = logging.getLogger(__name__)
|
||||
logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback")
|
||||
|
||||
from .config import Settings
|
||||
from .models import get_model_config, ModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_multimodal_content(content_list: list) -> list:
|
||||
"""
|
||||
Build OCI ChatContent object array from adapted content list.
|
||||
|
||||
Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...).
|
||||
|
||||
Args:
|
||||
content_list: List of content items from request adapter
|
||||
|
||||
Returns:
|
||||
List of OCI ChatContent objects or dicts (fallback)
|
||||
"""
|
||||
if not MULTIMODAL_SUPPORTED:
|
||||
# Fallback: return dict format, OCI SDK might auto-convert
|
||||
return content_list
|
||||
|
||||
oci_contents = []
|
||||
for item in content_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
oci_contents.append(TextContent(text=item.get("text", "")))
|
||||
|
||||
elif item_type == "image_url":
|
||||
image_data = item.get("image_url", {})
|
||||
if "url" in image_data:
|
||||
# ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...)
|
||||
img_url = ImageUrl(url=image_data["url"])
|
||||
# Optional: support 'detail' parameter if provided
|
||||
if "detail" in image_data:
|
||||
img_url.detail = image_data["detail"]
|
||||
oci_contents.append(ImageContent(image_url=img_url, type="IMAGE"))
|
||||
|
||||
elif item_type == "audio":
|
||||
audio_data = item.get("audio_url", {})
|
||||
if "url" in audio_data:
|
||||
# AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...)
|
||||
audio_url = AudioUrl(url=audio_data["url"])
|
||||
oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO"))
|
||||
|
||||
elif item_type == "video":
|
||||
video_data = item.get("video_url", {})
|
||||
if "url" in video_data:
|
||||
# VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...)
|
||||
video_url = VideoUrl(url=video_data["url"])
|
||||
oci_contents.append(VideoContent(video_url=video_url, type="VIDEO"))
|
||||
|
||||
return oci_contents if oci_contents else [TextContent(text="")]
|
||||
|
||||
|
||||
class OCIGenAIClient:
|
||||
"""Wrapper for OCI Generative AI client."""
|
||||
|
||||
def __init__(self, settings: Settings, profile: Optional[str] = None):
|
||||
"""
|
||||
初始化 OCI GenAI 客户端
|
||||
|
||||
Args:
|
||||
settings: 应用设置
|
||||
profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile
|
||||
"""
|
||||
self.settings = settings
|
||||
self.profile = profile or settings.get_profiles()[0]
|
||||
self._client: Optional[GenerativeAiInferenceClient] = None
|
||||
self._config: Optional[oci.config.Config] = None
|
||||
self._region: Optional[str] = None
|
||||
self._compartment_id: Optional[str] = None
|
||||
|
||||
def _get_config(self) -> dict:
|
||||
"""Get OCI configuration."""
|
||||
if self._config is None:
|
||||
if self.settings.oci_auth_type == "instance_principal":
|
||||
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
self._config = {"signer": signer}
|
||||
else:
|
||||
config_path = os.path.expanduser(self.settings.oci_config_file)
|
||||
self._config = oci.config.from_file(
|
||||
file_location=config_path,
|
||||
profile_name=self.profile
|
||||
)
|
||||
|
||||
# 从配置中读取 region 和 compartment_id
|
||||
if self._region is None:
|
||||
self._region = self._config.get("region")
|
||||
if self._compartment_id is None:
|
||||
self._compartment_id = self._config.get("tenancy")
|
||||
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def region(self) -> Optional[str]:
|
||||
"""获取当前配置的区域"""
|
||||
if self._region is None and self._config is None:
|
||||
self._get_config()
|
||||
return self._region
|
||||
|
||||
@property
|
||||
def compartment_id(self) -> Optional[str]:
|
||||
"""获取当前配置的 compartment ID"""
|
||||
if self._compartment_id is None and self._config is None:
|
||||
self._get_config()
|
||||
return self._compartment_id
|
||||
|
||||
def _get_client(self) -> GenerativeAiInferenceClient:
|
||||
"""Get or create OCI Generative AI Inference client with correct endpoint."""
|
||||
config = self._get_config()
|
||||
|
||||
# Use INFERENCE endpoint (not management endpoint)
|
||||
# Official format: https://inference.generativeai.{region}.oci.oraclecloud.com
|
||||
inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com"
|
||||
|
||||
if isinstance(config, dict) and "signer" in config:
|
||||
# For instance principal
|
||||
client = GenerativeAiInferenceClient(
|
||||
config={},
|
||||
service_endpoint=inference_endpoint,
|
||||
**config
|
||||
)
|
||||
return client
|
||||
|
||||
# For API key authentication
|
||||
client = GenerativeAiInferenceClient(
|
||||
config=config,
|
||||
service_endpoint=inference_endpoint,
|
||||
retry_strategy=oci.retry.NoneRetryStrategy(),
|
||||
timeout=(10, 240)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 1024,
|
||||
top_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
"""Send a chat completion request to OCI GenAI."""
|
||||
model_config = get_model_config(model_id)
|
||||
if not model_config:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
|
||||
if not self.compartment_id:
|
||||
raise ValueError("Compartment ID is required")
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Prepare serving mode
|
||||
if model_config.type == "dedicated" and model_config.endpoint:
|
||||
serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint)
|
||||
else:
|
||||
serving_mode = OnDemandServingMode(model_id=model_id)
|
||||
|
||||
# Convert messages based on provider
|
||||
if model_config.provider == "cohere":
|
||||
chat_request = self._build_cohere_request(
|
||||
messages, temperature, max_tokens, top_p, tools, stream
|
||||
)
|
||||
elif model_config.provider in ["meta", "xai", "google", "openai"]:
|
||||
chat_request = self._build_generic_request(
|
||||
messages, temperature, max_tokens, top_p, tools, model_config.provider, stream
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {model_config.provider}")
|
||||
|
||||
chat_details = ChatDetails(
|
||||
serving_mode=serving_mode,
|
||||
compartment_id=self.compartment_id,
|
||||
chat_request=chat_request,
|
||||
)
|
||||
|
||||
logger.debug(f"Sending chat request to OCI GenAI: {model_id}")
|
||||
response = client.chat(chat_details)
|
||||
return response
|
||||
|
||||
def _build_cohere_request(
|
||||
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False
|
||||
) -> CohereChatRequest:
|
||||
"""Build Cohere chat request.
|
||||
|
||||
Note: Cohere models only support text content, not multimodal.
|
||||
"""
|
||||
# Convert messages to Cohere format
|
||||
chat_history = []
|
||||
message = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Extract text from multimodal content
|
||||
if isinstance(content, list):
|
||||
# Extract text parts only
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
content = " ".join(text_parts) if text_parts else ""
|
||||
|
||||
if role == "system":
|
||||
# Cohere uses preamble for system messages
|
||||
continue
|
||||
elif role == "user":
|
||||
message = content
|
||||
elif role == "assistant":
|
||||
chat_history.append(
|
||||
CohereMessage(role="CHATBOT", message=content)
|
||||
)
|
||||
elif role == "tool":
|
||||
# Handle tool responses if needed
|
||||
pass
|
||||
|
||||
# Get preamble from system messages
|
||||
preamble_override = None
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
preamble_override = msg["content"]
|
||||
break
|
||||
|
||||
return CohereChatRequest(
|
||||
message=message,
|
||||
chat_history=chat_history if chat_history else None,
|
||||
preamble_override=preamble_override,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
is_stream=stream,
|
||||
)
|
||||
|
||||
def _build_generic_request(
|
||||
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False
|
||||
) -> GenericChatRequest:
|
||||
"""Build Generic chat request for Llama and other models."""
|
||||
# Convert messages to Generic format
|
||||
generic_messages = []
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Handle multimodal content
|
||||
if isinstance(content, list):
|
||||
# Build OCI ChatContent objects from multimodal content
|
||||
oci_contents = build_multimodal_content(content)
|
||||
else:
|
||||
# Simple text content
|
||||
if MULTIMODAL_SUPPORTED:
|
||||
oci_contents = [TextContent(text=content)]
|
||||
else:
|
||||
# Fallback: use dict format
|
||||
oci_contents = [{"type": "text", "text": content}]
|
||||
|
||||
if role == "user":
|
||||
oci_role = "USER"
|
||||
elif role in ["assistant", "model"]:
|
||||
oci_role = "ASSISTANT"
|
||||
elif role == "system":
|
||||
oci_role = "SYSTEM"
|
||||
else:
|
||||
oci_role = role.upper()
|
||||
|
||||
# Create Message with role and content objects
|
||||
logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}")
|
||||
|
||||
generic_messages.append(
|
||||
Message(
|
||||
role=oci_role,
|
||||
content=oci_contents
|
||||
)
|
||||
)
|
||||
|
||||
return GenericChatRequest(
|
||||
messages=generic_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
is_stream=stream,
|
||||
)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_id: str,
|
||||
texts: list,
|
||||
truncate: str = "END",
|
||||
):
|
||||
"""Generate embeddings using OCI GenAI."""
|
||||
model_config = get_model_config(model_id)
|
||||
if not model_config or model_config.type != "embedding":
|
||||
raise ValueError(f"Invalid embedding model: {model_id}")
|
||||
|
||||
if not self.compartment_id:
|
||||
raise ValueError("Compartment ID is required")
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
serving_mode = OnDemandServingMode(
|
||||
serving_type="ON_DEMAND",
|
||||
model_id=model_id
|
||||
)
|
||||
|
||||
embed_details = EmbedTextDetails(
|
||||
serving_mode=serving_mode,
|
||||
compartment_id=self.compartment_id,
|
||||
inputs=texts,
|
||||
truncate=truncate,
|
||||
is_echo=False,
|
||||
input_type="SEARCH_QUERY",
|
||||
)
|
||||
|
||||
logger.debug(f"Sending embed request to OCI GenAI: {model_id}")
|
||||
response = client.embed_text(embed_details)
|
||||
return response
|
||||
Reference in New Issue
Block a user