""" OCI Generative AI client wrapper. """ import os import logging from typing import Optional, AsyncIterator import oci from oci.generative_ai_inference import GenerativeAiInferenceClient from oci.generative_ai_inference.models import ( ChatDetails, CohereChatRequest, GenericChatRequest, OnDemandServingMode, DedicatedServingMode, CohereMessage, Message, TextContent, EmbedTextDetails, ) # Try to import multimodal content types try: from oci.generative_ai_inference.models import ( ImageContent, ImageUrl, AudioContent, AudioUrl, VideoContent, VideoUrl, ) MULTIMODAL_SUPPORTED = True logger_init = logging.getLogger(__name__) logger_init.info("OCI SDK multimodal content types available") except ImportError: MULTIMODAL_SUPPORTED = False logger_init = logging.getLogger(__name__) logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback") from .config import Settings from .models import get_model_config, ModelConfig logger = logging.getLogger(__name__) def build_multimodal_content(content_list: list) -> list: """ Build OCI ChatContent object array from adapted content list. Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...). Args: content_list: List of content items from request adapter Returns: List of OCI ChatContent objects or dicts (fallback) """ if not MULTIMODAL_SUPPORTED: # Fallback: return dict format, OCI SDK might auto-convert return content_list oci_contents = [] for item in content_list: if not isinstance(item, dict): continue item_type = item.get("type") if item_type == "text": oci_contents.append(TextContent(text=item.get("text", ""))) elif item_type == "image_url": image_data = item.get("image_url", {}) if "url" in image_data: # ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...) img_url = ImageUrl(url=image_data["url"]) # Optional: support 'detail' parameter if provided if "detail" in image_data: img_url.detail = image_data["detail"] oci_contents.append(ImageContent(image_url=img_url, type="IMAGE")) elif item_type == "audio": audio_data = item.get("audio_url", {}) if "url" in audio_data: # AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...) audio_url = AudioUrl(url=audio_data["url"]) oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO")) elif item_type == "video": video_data = item.get("video_url", {}) if "url" in video_data: # VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...) video_url = VideoUrl(url=video_data["url"]) oci_contents.append(VideoContent(video_url=video_url, type="VIDEO")) return oci_contents if oci_contents else [TextContent(text="")] class OCIGenAIClient: """Wrapper for OCI Generative AI client.""" def __init__(self, settings: Settings, profile: Optional[str] = None): """ 初始化 OCI GenAI 客户端 Args: settings: 应用设置 profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile """ self.settings = settings self.profile = profile or settings.get_profiles()[0] self._client: Optional[GenerativeAiInferenceClient] = None self._config: Optional[oci.config.Config] = None self._region: Optional[str] = None self._compartment_id: Optional[str] = None def _get_config(self) -> dict: """Get OCI configuration.""" if self._config is None: if self.settings.oci_auth_type == "instance_principal": signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() self._config = {"signer": signer} else: config_path = os.path.expanduser(self.settings.oci_config_file) self._config = oci.config.from_file( file_location=config_path, profile_name=self.profile ) # 从配置中读取 region 和 compartment_id if self._region is None: self._region = self._config.get("region") if self._compartment_id is None: self._compartment_id = self._config.get("tenancy") return self._config @property def region(self) -> Optional[str]: """获取当前配置的区域""" if self._region is None and self._config is None: self._get_config() return self._region @property def compartment_id(self) -> Optional[str]: """获取当前配置的 compartment ID""" if self._compartment_id is None and self._config is None: self._get_config() return self._compartment_id def _get_client(self) -> GenerativeAiInferenceClient: """Get or create OCI Generative AI Inference client with correct endpoint.""" config = self._get_config() # Use INFERENCE endpoint (not management endpoint) # Official format: https://inference.generativeai.{region}.oci.oraclecloud.com inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com" if isinstance(config, dict) and "signer" in config: # For instance principal client = GenerativeAiInferenceClient( config={}, service_endpoint=inference_endpoint, **config ) return client # For API key authentication client = GenerativeAiInferenceClient( config=config, service_endpoint=inference_endpoint, retry_strategy=oci.retry.NoneRetryStrategy(), timeout=(10, 240) ) return client def chat( self, model_id: str, messages: list, temperature: float = 0.7, max_tokens: int = 1024, top_p: float = 1.0, stream: bool = False, tools: Optional[list] = None, ): """Send a chat completion request to OCI GenAI.""" model_config = get_model_config(model_id) if not model_config: raise ValueError(f"Unsupported model: {model_id}") if not self.compartment_id: raise ValueError("Compartment ID is required") client = self._get_client() # Prepare serving mode if model_config.type == "dedicated" and model_config.endpoint: serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint) else: serving_mode = OnDemandServingMode(model_id=model_id) # Convert messages based on provider if model_config.provider == "cohere": chat_request = self._build_cohere_request( messages, temperature, max_tokens, top_p, tools, stream ) elif model_config.provider in ["meta", "xai", "google", "openai"]: chat_request = self._build_generic_request( messages, temperature, max_tokens, top_p, tools, model_config.provider, stream ) else: raise ValueError(f"Unsupported provider: {model_config.provider}") chat_details = ChatDetails( serving_mode=serving_mode, compartment_id=self.compartment_id, chat_request=chat_request, ) logger.debug(f"Sending chat request to OCI GenAI: {model_id}") response = client.chat(chat_details) return response def _build_cohere_request( self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False ) -> CohereChatRequest: """Build Cohere chat request. Note: Cohere models only support text content, not multimodal. """ # Convert messages to Cohere format chat_history = [] message = None for msg in messages: role = msg["role"] content = msg["content"] # Extract text from multimodal content if isinstance(content, list): # Extract text parts only text_parts = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": text_parts.append(item.get("text", "")) content = " ".join(text_parts) if text_parts else "" if role == "system": # Cohere uses preamble for system messages continue elif role == "user": message = content elif role == "assistant": chat_history.append( CohereMessage(role="CHATBOT", message=content) ) elif role == "tool": # Handle tool responses if needed pass # Get preamble from system messages preamble_override = None for msg in messages: if msg["role"] == "system": preamble_override = msg["content"] break return CohereChatRequest( message=message, chat_history=chat_history if chat_history else None, preamble_override=preamble_override, temperature=temperature, max_tokens=max_tokens, top_p=top_p, is_stream=stream, ) def _build_generic_request( self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False ) -> GenericChatRequest: """Build Generic chat request for Llama and other models.""" # Convert messages to Generic format generic_messages = [] for msg in messages: role = msg["role"] content = msg["content"] # Handle multimodal content if isinstance(content, list): # Build OCI ChatContent objects from multimodal content oci_contents = build_multimodal_content(content) else: # Simple text content if MULTIMODAL_SUPPORTED: oci_contents = [TextContent(text=content)] else: # Fallback: use dict format oci_contents = [{"type": "text", "text": content}] if role == "user": oci_role = "USER" elif role in ["assistant", "model"]: oci_role = "ASSISTANT" elif role == "system": oci_role = "SYSTEM" else: oci_role = role.upper() # Create Message with role and content objects logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}") generic_messages.append( Message( role=oci_role, content=oci_contents ) ) return GenericChatRequest( messages=generic_messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, is_stream=stream, ) def embed( self, model_id: str, texts: list, truncate: str = "END", ): """Generate embeddings using OCI GenAI.""" model_config = get_model_config(model_id) if not model_config or model_config.type != "embedding": raise ValueError(f"Invalid embedding model: {model_id}") if not self.compartment_id: raise ValueError("Compartment ID is required") client = self._get_client() serving_mode = OnDemandServingMode( serving_type="ON_DEMAND", model_id=model_id ) embed_details = EmbedTextDetails( serving_mode=serving_mode, compartment_id=self.compartment_id, inputs=texts, truncate=truncate, is_echo=False, input_type="SEARCH_QUERY", ) logger.debug(f"Sending embed request to OCI GenAI: {model_id}") response = client.embed_text(embed_details) return response