All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
362 lines
12 KiB
Python
362 lines
12 KiB
Python
"""
|
|
OCI Generative AI client wrapper.
|
|
"""
|
|
import os
|
|
import logging
|
|
from typing import Optional, AsyncIterator
|
|
import oci
|
|
from oci.generative_ai_inference import GenerativeAiInferenceClient
|
|
from oci.generative_ai_inference.models import (
|
|
ChatDetails,
|
|
CohereChatRequest,
|
|
GenericChatRequest,
|
|
OnDemandServingMode,
|
|
DedicatedServingMode,
|
|
CohereMessage,
|
|
Message,
|
|
TextContent,
|
|
EmbedTextDetails,
|
|
)
|
|
|
|
# Try to import multimodal content types
|
|
try:
|
|
from oci.generative_ai_inference.models import (
|
|
ImageContent,
|
|
ImageUrl,
|
|
AudioContent,
|
|
AudioUrl,
|
|
VideoContent,
|
|
VideoUrl,
|
|
)
|
|
MULTIMODAL_SUPPORTED = True
|
|
logger_init = logging.getLogger(__name__)
|
|
logger_init.info("OCI SDK multimodal content types available")
|
|
except ImportError:
|
|
MULTIMODAL_SUPPORTED = False
|
|
logger_init = logging.getLogger(__name__)
|
|
logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback")
|
|
|
|
from .config import Settings
|
|
from .models import get_model_config, ModelConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def build_multimodal_content(content_list: list) -> list:
|
|
"""
|
|
Build OCI ChatContent object array from adapted content list.
|
|
|
|
Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...).
|
|
|
|
Args:
|
|
content_list: List of content items from request adapter
|
|
|
|
Returns:
|
|
List of OCI ChatContent objects or dicts (fallback)
|
|
"""
|
|
if not MULTIMODAL_SUPPORTED:
|
|
# Fallback: return dict format, OCI SDK might auto-convert
|
|
return content_list
|
|
|
|
oci_contents = []
|
|
for item in content_list:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
|
|
item_type = item.get("type")
|
|
|
|
if item_type == "text":
|
|
oci_contents.append(TextContent(text=item.get("text", "")))
|
|
|
|
elif item_type == "image_url":
|
|
image_data = item.get("image_url", {})
|
|
if "url" in image_data:
|
|
# ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...)
|
|
img_url = ImageUrl(url=image_data["url"])
|
|
# Optional: support 'detail' parameter if provided
|
|
if "detail" in image_data:
|
|
img_url.detail = image_data["detail"]
|
|
oci_contents.append(ImageContent(image_url=img_url, type="IMAGE"))
|
|
|
|
elif item_type == "audio":
|
|
audio_data = item.get("audio_url", {})
|
|
if "url" in audio_data:
|
|
# AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...)
|
|
audio_url = AudioUrl(url=audio_data["url"])
|
|
oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO"))
|
|
|
|
elif item_type == "video":
|
|
video_data = item.get("video_url", {})
|
|
if "url" in video_data:
|
|
# VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...)
|
|
video_url = VideoUrl(url=video_data["url"])
|
|
oci_contents.append(VideoContent(video_url=video_url, type="VIDEO"))
|
|
|
|
return oci_contents if oci_contents else [TextContent(text="")]
|
|
|
|
|
|
class OCIGenAIClient:
|
|
"""Wrapper for OCI Generative AI client."""
|
|
|
|
def __init__(self, settings: Settings, profile: Optional[str] = None):
|
|
"""
|
|
初始化 OCI GenAI 客户端
|
|
|
|
Args:
|
|
settings: 应用设置
|
|
profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile
|
|
"""
|
|
self.settings = settings
|
|
self.profile = profile or settings.get_profiles()[0]
|
|
self._client: Optional[GenerativeAiInferenceClient] = None
|
|
self._config: Optional[oci.config.Config] = None
|
|
self._region: Optional[str] = None
|
|
self._compartment_id: Optional[str] = None
|
|
|
|
def _get_config(self) -> dict:
|
|
"""Get OCI configuration."""
|
|
if self._config is None:
|
|
if self.settings.oci_auth_type == "instance_principal":
|
|
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
|
self._config = {"signer": signer}
|
|
else:
|
|
config_path = os.path.expanduser(self.settings.oci_config_file)
|
|
self._config = oci.config.from_file(
|
|
file_location=config_path,
|
|
profile_name=self.profile
|
|
)
|
|
|
|
# 从配置中读取 region 和 compartment_id
|
|
if self._region is None:
|
|
self._region = self._config.get("region")
|
|
if self._compartment_id is None:
|
|
self._compartment_id = self._config.get("tenancy")
|
|
|
|
return self._config
|
|
|
|
@property
|
|
def region(self) -> Optional[str]:
|
|
"""获取当前配置的区域"""
|
|
if self._region is None and self._config is None:
|
|
self._get_config()
|
|
return self._region
|
|
|
|
@property
|
|
def compartment_id(self) -> Optional[str]:
|
|
"""获取当前配置的 compartment ID"""
|
|
if self._compartment_id is None and self._config is None:
|
|
self._get_config()
|
|
return self._compartment_id
|
|
|
|
def _get_client(self) -> GenerativeAiInferenceClient:
|
|
"""Get or create OCI Generative AI Inference client with correct endpoint."""
|
|
config = self._get_config()
|
|
|
|
# Use INFERENCE endpoint (not management endpoint)
|
|
# Official format: https://inference.generativeai.{region}.oci.oraclecloud.com
|
|
inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com"
|
|
|
|
if isinstance(config, dict) and "signer" in config:
|
|
# For instance principal
|
|
client = GenerativeAiInferenceClient(
|
|
config={},
|
|
service_endpoint=inference_endpoint,
|
|
**config
|
|
)
|
|
return client
|
|
|
|
# For API key authentication
|
|
client = GenerativeAiInferenceClient(
|
|
config=config,
|
|
service_endpoint=inference_endpoint,
|
|
retry_strategy=oci.retry.NoneRetryStrategy(),
|
|
timeout=(10, 240)
|
|
)
|
|
|
|
return client
|
|
|
|
def chat(
|
|
self,
|
|
model_id: str,
|
|
messages: list,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 1024,
|
|
top_p: float = 1.0,
|
|
stream: bool = False,
|
|
tools: Optional[list] = None,
|
|
):
|
|
"""Send a chat completion request to OCI GenAI."""
|
|
model_config = get_model_config(model_id)
|
|
if not model_config:
|
|
raise ValueError(f"Unsupported model: {model_id}")
|
|
|
|
if not self.compartment_id:
|
|
raise ValueError("Compartment ID is required")
|
|
|
|
client = self._get_client()
|
|
|
|
# Prepare serving mode
|
|
if model_config.type == "dedicated" and model_config.endpoint:
|
|
serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint)
|
|
else:
|
|
serving_mode = OnDemandServingMode(model_id=model_id)
|
|
|
|
# Convert messages based on provider
|
|
if model_config.provider == "cohere":
|
|
chat_request = self._build_cohere_request(
|
|
messages, temperature, max_tokens, top_p, tools, stream
|
|
)
|
|
elif model_config.provider in ["meta", "xai", "google", "openai"]:
|
|
chat_request = self._build_generic_request(
|
|
messages, temperature, max_tokens, top_p, tools, model_config.provider, stream
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {model_config.provider}")
|
|
|
|
chat_details = ChatDetails(
|
|
serving_mode=serving_mode,
|
|
compartment_id=self.compartment_id,
|
|
chat_request=chat_request,
|
|
)
|
|
|
|
logger.debug(f"Sending chat request to OCI GenAI: {model_id}")
|
|
response = client.chat(chat_details)
|
|
return response
|
|
|
|
def _build_cohere_request(
|
|
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False
|
|
) -> CohereChatRequest:
|
|
"""Build Cohere chat request.
|
|
|
|
Note: Cohere models only support text content, not multimodal.
|
|
"""
|
|
# Convert messages to Cohere format
|
|
chat_history = []
|
|
message = None
|
|
|
|
for msg in messages:
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
|
|
# Extract text from multimodal content
|
|
if isinstance(content, list):
|
|
# Extract text parts only
|
|
text_parts = []
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|
text_parts.append(item.get("text", ""))
|
|
content = " ".join(text_parts) if text_parts else ""
|
|
|
|
if role == "system":
|
|
# Cohere uses preamble for system messages
|
|
continue
|
|
elif role == "user":
|
|
message = content
|
|
elif role == "assistant":
|
|
chat_history.append(
|
|
CohereMessage(role="CHATBOT", message=content)
|
|
)
|
|
elif role == "tool":
|
|
# Handle tool responses if needed
|
|
pass
|
|
|
|
# Get preamble from system messages
|
|
preamble_override = None
|
|
for msg in messages:
|
|
if msg["role"] == "system":
|
|
preamble_override = msg["content"]
|
|
break
|
|
|
|
return CohereChatRequest(
|
|
message=message,
|
|
chat_history=chat_history if chat_history else None,
|
|
preamble_override=preamble_override,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
is_stream=stream,
|
|
)
|
|
|
|
def _build_generic_request(
|
|
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False
|
|
) -> GenericChatRequest:
|
|
"""Build Generic chat request for Llama and other models."""
|
|
# Convert messages to Generic format
|
|
generic_messages = []
|
|
for msg in messages:
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
|
|
# Handle multimodal content
|
|
if isinstance(content, list):
|
|
# Build OCI ChatContent objects from multimodal content
|
|
oci_contents = build_multimodal_content(content)
|
|
else:
|
|
# Simple text content
|
|
if MULTIMODAL_SUPPORTED:
|
|
oci_contents = [TextContent(text=content)]
|
|
else:
|
|
# Fallback: use dict format
|
|
oci_contents = [{"type": "text", "text": content}]
|
|
|
|
if role == "user":
|
|
oci_role = "USER"
|
|
elif role in ["assistant", "model"]:
|
|
oci_role = "ASSISTANT"
|
|
elif role == "system":
|
|
oci_role = "SYSTEM"
|
|
else:
|
|
oci_role = role.upper()
|
|
|
|
# Create Message with role and content objects
|
|
logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}")
|
|
|
|
generic_messages.append(
|
|
Message(
|
|
role=oci_role,
|
|
content=oci_contents
|
|
)
|
|
)
|
|
|
|
return GenericChatRequest(
|
|
messages=generic_messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
is_stream=stream,
|
|
)
|
|
|
|
def embed(
|
|
self,
|
|
model_id: str,
|
|
texts: list,
|
|
truncate: str = "END",
|
|
):
|
|
"""Generate embeddings using OCI GenAI."""
|
|
model_config = get_model_config(model_id)
|
|
if not model_config or model_config.type != "embedding":
|
|
raise ValueError(f"Invalid embedding model: {model_id}")
|
|
|
|
if not self.compartment_id:
|
|
raise ValueError("Compartment ID is required")
|
|
|
|
client = self._get_client()
|
|
|
|
serving_mode = OnDemandServingMode(
|
|
serving_type="ON_DEMAND",
|
|
model_id=model_id
|
|
)
|
|
|
|
embed_details = EmbedTextDetails(
|
|
serving_mode=serving_mode,
|
|
compartment_id=self.compartment_id,
|
|
inputs=texts,
|
|
truncate=truncate,
|
|
is_echo=False,
|
|
input_type="SEARCH_QUERY",
|
|
)
|
|
|
|
logger.debug(f"Sending embed request to OCI GenAI: {model_id}")
|
|
response = client.embed_text(embed_details)
|
|
return response
|