第一次提交
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
All checks were successful
Build and Push OCI GenAI Gateway Docker Image / docker-build-push (push) Successful in 32m3s
This commit is contained in:
361
src/core/oci_client.py
Normal file
361
src/core/oci_client.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
OCI Generative AI client wrapper.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, AsyncIterator
|
||||
import oci
|
||||
from oci.generative_ai_inference import GenerativeAiInferenceClient
|
||||
from oci.generative_ai_inference.models import (
|
||||
ChatDetails,
|
||||
CohereChatRequest,
|
||||
GenericChatRequest,
|
||||
OnDemandServingMode,
|
||||
DedicatedServingMode,
|
||||
CohereMessage,
|
||||
Message,
|
||||
TextContent,
|
||||
EmbedTextDetails,
|
||||
)
|
||||
|
||||
# Try to import multimodal content types
|
||||
try:
|
||||
from oci.generative_ai_inference.models import (
|
||||
ImageContent,
|
||||
ImageUrl,
|
||||
AudioContent,
|
||||
AudioUrl,
|
||||
VideoContent,
|
||||
VideoUrl,
|
||||
)
|
||||
MULTIMODAL_SUPPORTED = True
|
||||
logger_init = logging.getLogger(__name__)
|
||||
logger_init.info("OCI SDK multimodal content types available")
|
||||
except ImportError:
|
||||
MULTIMODAL_SUPPORTED = False
|
||||
logger_init = logging.getLogger(__name__)
|
||||
logger_init.warning("OCI SDK does not support multimodal content types, using dict format as fallback")
|
||||
|
||||
from .config import Settings
|
||||
from .models import get_model_config, ModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_multimodal_content(content_list: list) -> list:
|
||||
"""
|
||||
Build OCI ChatContent object array from adapted content list.
|
||||
|
||||
Supports both HTTP URLs and Base64 data URIs (data:image/jpeg;base64,...).
|
||||
|
||||
Args:
|
||||
content_list: List of content items from request adapter
|
||||
|
||||
Returns:
|
||||
List of OCI ChatContent objects or dicts (fallback)
|
||||
"""
|
||||
if not MULTIMODAL_SUPPORTED:
|
||||
# Fallback: return dict format, OCI SDK might auto-convert
|
||||
return content_list
|
||||
|
||||
oci_contents = []
|
||||
for item in content_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
oci_contents.append(TextContent(text=item.get("text", "")))
|
||||
|
||||
elif item_type == "image_url":
|
||||
image_data = item.get("image_url", {})
|
||||
if "url" in image_data:
|
||||
# ImageUrl accepts both HTTP URLs and data URIs (data:image/jpeg;base64,...)
|
||||
img_url = ImageUrl(url=image_data["url"])
|
||||
# Optional: support 'detail' parameter if provided
|
||||
if "detail" in image_data:
|
||||
img_url.detail = image_data["detail"]
|
||||
oci_contents.append(ImageContent(image_url=img_url, type="IMAGE"))
|
||||
|
||||
elif item_type == "audio":
|
||||
audio_data = item.get("audio_url", {})
|
||||
if "url" in audio_data:
|
||||
# AudioUrl accepts both HTTP URLs and data URIs (data:audio/wav;base64,...)
|
||||
audio_url = AudioUrl(url=audio_data["url"])
|
||||
oci_contents.append(AudioContent(audio_url=audio_url, type="AUDIO"))
|
||||
|
||||
elif item_type == "video":
|
||||
video_data = item.get("video_url", {})
|
||||
if "url" in video_data:
|
||||
# VideoUrl accepts both HTTP URLs and data URIs (data:video/mp4;base64,...)
|
||||
video_url = VideoUrl(url=video_data["url"])
|
||||
oci_contents.append(VideoContent(video_url=video_url, type="VIDEO"))
|
||||
|
||||
return oci_contents if oci_contents else [TextContent(text="")]
|
||||
|
||||
|
||||
class OCIGenAIClient:
|
||||
"""Wrapper for OCI Generative AI client."""
|
||||
|
||||
def __init__(self, settings: Settings, profile: Optional[str] = None):
|
||||
"""
|
||||
初始化 OCI GenAI 客户端
|
||||
|
||||
Args:
|
||||
settings: 应用设置
|
||||
profile: 可选的 OCI 配置 profile 名称。如果未提供,使用 settings 中的第一个 profile
|
||||
"""
|
||||
self.settings = settings
|
||||
self.profile = profile or settings.get_profiles()[0]
|
||||
self._client: Optional[GenerativeAiInferenceClient] = None
|
||||
self._config: Optional[oci.config.Config] = None
|
||||
self._region: Optional[str] = None
|
||||
self._compartment_id: Optional[str] = None
|
||||
|
||||
def _get_config(self) -> dict:
|
||||
"""Get OCI configuration."""
|
||||
if self._config is None:
|
||||
if self.settings.oci_auth_type == "instance_principal":
|
||||
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
self._config = {"signer": signer}
|
||||
else:
|
||||
config_path = os.path.expanduser(self.settings.oci_config_file)
|
||||
self._config = oci.config.from_file(
|
||||
file_location=config_path,
|
||||
profile_name=self.profile
|
||||
)
|
||||
|
||||
# 从配置中读取 region 和 compartment_id
|
||||
if self._region is None:
|
||||
self._region = self._config.get("region")
|
||||
if self._compartment_id is None:
|
||||
self._compartment_id = self._config.get("tenancy")
|
||||
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def region(self) -> Optional[str]:
|
||||
"""获取当前配置的区域"""
|
||||
if self._region is None and self._config is None:
|
||||
self._get_config()
|
||||
return self._region
|
||||
|
||||
@property
|
||||
def compartment_id(self) -> Optional[str]:
|
||||
"""获取当前配置的 compartment ID"""
|
||||
if self._compartment_id is None and self._config is None:
|
||||
self._get_config()
|
||||
return self._compartment_id
|
||||
|
||||
def _get_client(self) -> GenerativeAiInferenceClient:
|
||||
"""Get or create OCI Generative AI Inference client with correct endpoint."""
|
||||
config = self._get_config()
|
||||
|
||||
# Use INFERENCE endpoint (not management endpoint)
|
||||
# Official format: https://inference.generativeai.{region}.oci.oraclecloud.com
|
||||
inference_endpoint = f"https://inference.generativeai.{self.region}.oci.oraclecloud.com"
|
||||
|
||||
if isinstance(config, dict) and "signer" in config:
|
||||
# For instance principal
|
||||
client = GenerativeAiInferenceClient(
|
||||
config={},
|
||||
service_endpoint=inference_endpoint,
|
||||
**config
|
||||
)
|
||||
return client
|
||||
|
||||
# For API key authentication
|
||||
client = GenerativeAiInferenceClient(
|
||||
config=config,
|
||||
service_endpoint=inference_endpoint,
|
||||
retry_strategy=oci.retry.NoneRetryStrategy(),
|
||||
timeout=(10, 240)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 1024,
|
||||
top_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
"""Send a chat completion request to OCI GenAI."""
|
||||
model_config = get_model_config(model_id)
|
||||
if not model_config:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
|
||||
if not self.compartment_id:
|
||||
raise ValueError("Compartment ID is required")
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# Prepare serving mode
|
||||
if model_config.type == "dedicated" and model_config.endpoint:
|
||||
serving_mode = DedicatedServingMode(endpoint_id=model_config.endpoint)
|
||||
else:
|
||||
serving_mode = OnDemandServingMode(model_id=model_id)
|
||||
|
||||
# Convert messages based on provider
|
||||
if model_config.provider == "cohere":
|
||||
chat_request = self._build_cohere_request(
|
||||
messages, temperature, max_tokens, top_p, tools, stream
|
||||
)
|
||||
elif model_config.provider in ["meta", "xai", "google", "openai"]:
|
||||
chat_request = self._build_generic_request(
|
||||
messages, temperature, max_tokens, top_p, tools, model_config.provider, stream
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {model_config.provider}")
|
||||
|
||||
chat_details = ChatDetails(
|
||||
serving_mode=serving_mode,
|
||||
compartment_id=self.compartment_id,
|
||||
chat_request=chat_request,
|
||||
)
|
||||
|
||||
logger.debug(f"Sending chat request to OCI GenAI: {model_id}")
|
||||
response = client.chat(chat_details)
|
||||
return response
|
||||
|
||||
def _build_cohere_request(
|
||||
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], stream: bool = False
|
||||
) -> CohereChatRequest:
|
||||
"""Build Cohere chat request.
|
||||
|
||||
Note: Cohere models only support text content, not multimodal.
|
||||
"""
|
||||
# Convert messages to Cohere format
|
||||
chat_history = []
|
||||
message = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Extract text from multimodal content
|
||||
if isinstance(content, list):
|
||||
# Extract text parts only
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
content = " ".join(text_parts) if text_parts else ""
|
||||
|
||||
if role == "system":
|
||||
# Cohere uses preamble for system messages
|
||||
continue
|
||||
elif role == "user":
|
||||
message = content
|
||||
elif role == "assistant":
|
||||
chat_history.append(
|
||||
CohereMessage(role="CHATBOT", message=content)
|
||||
)
|
||||
elif role == "tool":
|
||||
# Handle tool responses if needed
|
||||
pass
|
||||
|
||||
# Get preamble from system messages
|
||||
preamble_override = None
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
preamble_override = msg["content"]
|
||||
break
|
||||
|
||||
return CohereChatRequest(
|
||||
message=message,
|
||||
chat_history=chat_history if chat_history else None,
|
||||
preamble_override=preamble_override,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
is_stream=stream,
|
||||
)
|
||||
|
||||
def _build_generic_request(
|
||||
self, messages: list, temperature: float, max_tokens: int, top_p: float, tools: Optional[list], provider: str, stream: bool = False
|
||||
) -> GenericChatRequest:
|
||||
"""Build Generic chat request for Llama and other models."""
|
||||
# Convert messages to Generic format
|
||||
generic_messages = []
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
# Handle multimodal content
|
||||
if isinstance(content, list):
|
||||
# Build OCI ChatContent objects from multimodal content
|
||||
oci_contents = build_multimodal_content(content)
|
||||
else:
|
||||
# Simple text content
|
||||
if MULTIMODAL_SUPPORTED:
|
||||
oci_contents = [TextContent(text=content)]
|
||||
else:
|
||||
# Fallback: use dict format
|
||||
oci_contents = [{"type": "text", "text": content}]
|
||||
|
||||
if role == "user":
|
||||
oci_role = "USER"
|
||||
elif role in ["assistant", "model"]:
|
||||
oci_role = "ASSISTANT"
|
||||
elif role == "system":
|
||||
oci_role = "SYSTEM"
|
||||
else:
|
||||
oci_role = role.upper()
|
||||
|
||||
# Create Message with role and content objects
|
||||
logger.debug(f"Creating message with role: {oci_role}, provider: {provider}, original role: {role}")
|
||||
|
||||
generic_messages.append(
|
||||
Message(
|
||||
role=oci_role,
|
||||
content=oci_contents
|
||||
)
|
||||
)
|
||||
|
||||
return GenericChatRequest(
|
||||
messages=generic_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
is_stream=stream,
|
||||
)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_id: str,
|
||||
texts: list,
|
||||
truncate: str = "END",
|
||||
):
|
||||
"""Generate embeddings using OCI GenAI."""
|
||||
model_config = get_model_config(model_id)
|
||||
if not model_config or model_config.type != "embedding":
|
||||
raise ValueError(f"Invalid embedding model: {model_id}")
|
||||
|
||||
if not self.compartment_id:
|
||||
raise ValueError("Compartment ID is required")
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
serving_mode = OnDemandServingMode(
|
||||
serving_type="ON_DEMAND",
|
||||
model_id=model_id
|
||||
)
|
||||
|
||||
embed_details = EmbedTextDetails(
|
||||
serving_mode=serving_mode,
|
||||
compartment_id=self.compartment_id,
|
||||
inputs=texts,
|
||||
truncate=truncate,
|
||||
is_echo=False,
|
||||
input_type="SEARCH_QUERY",
|
||||
)
|
||||
|
||||
logger.debug(f"Sending embed request to OCI GenAI: {model_id}")
|
||||
response = client.embed_text(embed_details)
|
||||
return response
|
||||
Reference in New Issue
Block a user