572 lines
21 KiB
Python
572 lines
21 KiB
Python
"""AI Model Registry for managing multiple AI providers."""
|
|
|
|
import logging
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Any, Type
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta
|
|
import asyncio
|
|
|
|
from ..services.ai_service import AIService, SummaryRequest, SummaryResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelProvider(Enum):
|
|
"""Supported AI model providers."""
|
|
OPENAI = "openai"
|
|
ANTHROPIC = "anthropic"
|
|
DEEPSEEK = "deepseek"
|
|
GOOGLE = "google"
|
|
|
|
|
|
class ModelCapability(Enum):
|
|
"""Model capabilities for matching."""
|
|
SHORT_FORM = "short_form" # < 5 min videos
|
|
MEDIUM_FORM = "medium_form" # 5-30 min videos
|
|
LONG_FORM = "long_form" # 30+ min videos
|
|
TECHNICAL = "technical" # Code, tutorials
|
|
EDUCATIONAL = "educational" # Lectures, courses
|
|
CONVERSATIONAL = "conversational" # Interviews, podcasts
|
|
NEWS = "news" # News, current events
|
|
CREATIVE = "creative" # Music, art, entertainment
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""Configuration for an AI model."""
|
|
provider: ModelProvider
|
|
model_name: str
|
|
display_name: str
|
|
max_tokens: int
|
|
context_window: int
|
|
|
|
# Cost per 1K tokens in USD
|
|
input_cost_per_1k: float
|
|
output_cost_per_1k: float
|
|
|
|
# Performance characteristics
|
|
average_latency_ms: float = 1000.0
|
|
reliability_score: float = 0.95 # 0-1 scale
|
|
quality_score: float = 0.90 # 0-1 scale
|
|
|
|
# Capabilities
|
|
capabilities: List[ModelCapability] = field(default_factory=list)
|
|
supported_languages: List[str] = field(default_factory=lambda: ["en"])
|
|
|
|
# Rate limits
|
|
requests_per_minute: int = 60
|
|
tokens_per_minute: int = 90000
|
|
|
|
# Status
|
|
is_available: bool = True
|
|
last_error: Optional[str] = None
|
|
last_error_time: Optional[datetime] = None
|
|
|
|
def get_total_cost(self, input_tokens: int, output_tokens: int) -> float:
|
|
"""Calculate total cost for token usage."""
|
|
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
|
|
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
|
|
return input_cost + output_cost
|
|
|
|
|
|
@dataclass
|
|
class ModelMetrics:
|
|
"""Performance metrics for a model."""
|
|
total_requests: int = 0
|
|
successful_requests: int = 0
|
|
failed_requests: int = 0
|
|
total_input_tokens: int = 0
|
|
total_output_tokens: int = 0
|
|
total_cost: float = 0.0
|
|
total_latency_ms: float = 0.0
|
|
last_used: Optional[datetime] = None
|
|
|
|
@property
|
|
def success_rate(self) -> float:
|
|
"""Calculate success rate."""
|
|
if self.total_requests == 0:
|
|
return 1.0
|
|
return self.successful_requests / self.total_requests
|
|
|
|
@property
|
|
def average_latency(self) -> float:
|
|
"""Calculate average latency."""
|
|
if self.successful_requests == 0:
|
|
return 0.0
|
|
return self.total_latency_ms / self.successful_requests
|
|
|
|
|
|
class ModelSelectionStrategy(Enum):
|
|
"""Strategy for selecting models."""
|
|
COST_OPTIMIZED = "cost_optimized" # Minimize cost
|
|
QUALITY_OPTIMIZED = "quality_optimized" # Maximize quality
|
|
SPEED_OPTIMIZED = "speed_optimized" # Minimize latency
|
|
BALANCED = "balanced" # Balance all factors
|
|
|
|
|
|
@dataclass
|
|
class ModelSelectionContext:
|
|
"""Context for model selection."""
|
|
content_length: int # Characters
|
|
content_type: Optional[ModelCapability] = None
|
|
language: str = "en"
|
|
strategy: ModelSelectionStrategy = ModelSelectionStrategy.BALANCED
|
|
max_cost: Optional[float] = None # Maximum cost in USD
|
|
max_latency_ms: Optional[float] = None
|
|
required_quality: float = 0.8 # Minimum quality score
|
|
user_preference: Optional[ModelProvider] = None
|
|
|
|
|
|
class AIModelRegistry:
|
|
"""Registry for managing multiple AI model providers."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the model registry."""
|
|
self.models: Dict[ModelProvider, ModelConfig] = {}
|
|
self.services: Dict[ModelProvider, AIService] = {}
|
|
self.metrics: Dict[ModelProvider, ModelMetrics] = {}
|
|
self.fallback_chain: List[ModelProvider] = []
|
|
|
|
# Initialize default model configurations
|
|
self._initialize_default_models()
|
|
|
|
def _initialize_default_models(self):
|
|
"""Initialize default model configurations."""
|
|
# OpenAI GPT-4o-mini
|
|
self.register_model(ModelConfig(
|
|
provider=ModelProvider.OPENAI,
|
|
model_name="gpt-4o-mini",
|
|
display_name="GPT-4 Omni Mini",
|
|
max_tokens=16384,
|
|
context_window=128000,
|
|
input_cost_per_1k=0.00015,
|
|
output_cost_per_1k=0.0006,
|
|
average_latency_ms=800,
|
|
reliability_score=0.95,
|
|
quality_score=0.88,
|
|
capabilities=[
|
|
ModelCapability.SHORT_FORM,
|
|
ModelCapability.MEDIUM_FORM,
|
|
ModelCapability.TECHNICAL,
|
|
ModelCapability.EDUCATIONAL,
|
|
ModelCapability.CONVERSATIONAL
|
|
],
|
|
supported_languages=["en", "es", "fr", "de", "zh", "ja", "ko"],
|
|
requests_per_minute=500,
|
|
tokens_per_minute=200000
|
|
))
|
|
|
|
# Anthropic Claude 3.5 Haiku
|
|
self.register_model(ModelConfig(
|
|
provider=ModelProvider.ANTHROPIC,
|
|
model_name="claude-3-5-haiku-20241022",
|
|
display_name="Claude 3.5 Haiku",
|
|
max_tokens=8192,
|
|
context_window=200000,
|
|
input_cost_per_1k=0.00025,
|
|
output_cost_per_1k=0.00125,
|
|
average_latency_ms=500,
|
|
reliability_score=0.98,
|
|
quality_score=0.92,
|
|
capabilities=[
|
|
ModelCapability.SHORT_FORM,
|
|
ModelCapability.MEDIUM_FORM,
|
|
ModelCapability.LONG_FORM,
|
|
ModelCapability.TECHNICAL,
|
|
ModelCapability.EDUCATIONAL,
|
|
ModelCapability.CREATIVE
|
|
],
|
|
supported_languages=["en", "es", "fr", "de", "pt", "it", "nl"],
|
|
requests_per_minute=100,
|
|
tokens_per_minute=100000
|
|
))
|
|
|
|
# DeepSeek V2
|
|
self.register_model(ModelConfig(
|
|
provider=ModelProvider.DEEPSEEK,
|
|
model_name="deepseek-chat",
|
|
display_name="DeepSeek V2",
|
|
max_tokens=4096,
|
|
context_window=32000,
|
|
input_cost_per_1k=0.00014,
|
|
output_cost_per_1k=0.00028,
|
|
average_latency_ms=1200,
|
|
reliability_score=0.90,
|
|
quality_score=0.85,
|
|
capabilities=[
|
|
ModelCapability.SHORT_FORM,
|
|
ModelCapability.MEDIUM_FORM,
|
|
ModelCapability.TECHNICAL,
|
|
ModelCapability.EDUCATIONAL
|
|
],
|
|
supported_languages=["en", "zh"],
|
|
requests_per_minute=60,
|
|
tokens_per_minute=90000
|
|
))
|
|
|
|
# Google Gemini 1.5 Pro - MASSIVE CONTEXT WINDOW (2M tokens!)
|
|
self.register_model(ModelConfig(
|
|
provider=ModelProvider.GOOGLE,
|
|
model_name="gemini-1.5-pro",
|
|
display_name="Gemini 1.5 Pro (2M Context)",
|
|
max_tokens=8192,
|
|
context_window=2000000, # 2 MILLION token context!
|
|
input_cost_per_1k=0.007, # $7 per 1M tokens - competitive for massive context
|
|
output_cost_per_1k=0.021, # $21 per 1M tokens
|
|
average_latency_ms=2000, # Slightly higher due to large context processing
|
|
reliability_score=0.96,
|
|
quality_score=0.94, # Excellent quality with full context
|
|
capabilities=[
|
|
ModelCapability.SHORT_FORM,
|
|
ModelCapability.MEDIUM_FORM,
|
|
ModelCapability.LONG_FORM, # EXCELS at long-form content
|
|
ModelCapability.TECHNICAL,
|
|
ModelCapability.EDUCATIONAL,
|
|
ModelCapability.CONVERSATIONAL,
|
|
ModelCapability.NEWS,
|
|
ModelCapability.CREATIVE
|
|
],
|
|
supported_languages=["en", "es", "fr", "de", "pt", "it", "nl", "ja", "ko", "zh", "hi"],
|
|
requests_per_minute=60,
|
|
tokens_per_minute=32000 # Large context means fewer but higher-quality requests
|
|
))
|
|
|
|
# Set default fallback chain - Gemini FIRST for long content due to massive context
|
|
self.fallback_chain = [
|
|
ModelProvider.GOOGLE, # Best for long-form content
|
|
ModelProvider.ANTHROPIC, # Great quality fallback
|
|
ModelProvider.OPENAI, # Reliable alternative
|
|
ModelProvider.DEEPSEEK # Cost-effective option
|
|
]
|
|
|
|
def register_model(self, config: ModelConfig):
|
|
"""Register a model configuration."""
|
|
self.models[config.provider] = config
|
|
self.metrics[config.provider] = ModelMetrics()
|
|
logger.info(f"Registered model: {config.display_name} ({config.provider.value})")
|
|
|
|
def register_service(self, provider: ModelProvider, service: AIService):
|
|
"""Register an AI service implementation."""
|
|
if provider not in self.models:
|
|
raise ValueError(f"Model {provider} not registered")
|
|
self.services[provider] = service
|
|
logger.info(f"Registered service for {provider.value}")
|
|
|
|
def get_model_config(self, provider: ModelProvider) -> Optional[ModelConfig]:
|
|
"""Get model configuration."""
|
|
return self.models.get(provider)
|
|
|
|
def get_service(self, provider: ModelProvider) -> Optional[AIService]:
|
|
"""Get AI service for a provider."""
|
|
return self.services.get(provider)
|
|
|
|
def select_model(self, context: ModelSelectionContext) -> Optional[ModelProvider]:
|
|
"""Select the best model based on context.
|
|
|
|
Args:
|
|
context: Selection context with requirements
|
|
|
|
Returns:
|
|
Selected model provider or None if no suitable model
|
|
"""
|
|
available_models = self._get_available_models(context)
|
|
|
|
if not available_models:
|
|
logger.warning("No available models match the context requirements")
|
|
return None
|
|
|
|
# Apply user preference if specified
|
|
if context.user_preference and context.user_preference in available_models:
|
|
return context.user_preference
|
|
|
|
# Score and rank models
|
|
scored_models = []
|
|
for provider in available_models:
|
|
score = self._score_model(provider, context)
|
|
scored_models.append((provider, score))
|
|
|
|
# Sort by score (higher is better)
|
|
scored_models.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
if scored_models:
|
|
selected = scored_models[0][0]
|
|
logger.info(f"Selected model: {selected.value} (score: {scored_models[0][1]:.2f})")
|
|
return selected
|
|
|
|
return None
|
|
|
|
def _get_available_models(self, context: ModelSelectionContext) -> List[ModelProvider]:
|
|
"""Get available models that meet context requirements."""
|
|
available = []
|
|
|
|
for provider, config in self.models.items():
|
|
# Check availability
|
|
if not config.is_available:
|
|
continue
|
|
|
|
# Check language support
|
|
if context.language not in config.supported_languages:
|
|
continue
|
|
|
|
# Check quality requirement
|
|
if config.quality_score < context.required_quality:
|
|
continue
|
|
|
|
# Check cost constraint
|
|
if context.max_cost:
|
|
estimated_tokens = context.content_length / 4 # Rough estimate
|
|
estimated_cost = config.get_total_cost(estimated_tokens, estimated_tokens / 2)
|
|
if estimated_cost > context.max_cost:
|
|
continue
|
|
|
|
# Check latency constraint
|
|
if context.max_latency_ms and config.average_latency_ms > context.max_latency_ms:
|
|
continue
|
|
|
|
# Check capabilities match
|
|
if context.content_type and context.content_type not in config.capabilities:
|
|
continue
|
|
|
|
available.append(provider)
|
|
|
|
return available
|
|
|
|
def _score_model(self, provider: ModelProvider, context: ModelSelectionContext) -> float:
|
|
"""Score a model based on selection strategy.
|
|
|
|
Args:
|
|
provider: Model provider to score
|
|
context: Selection context
|
|
|
|
Returns:
|
|
Score from 0-100
|
|
"""
|
|
config = self.models[provider]
|
|
metrics = self.metrics[provider]
|
|
|
|
# Base scores (0-1 scale)
|
|
cost_score = 1.0 - (config.input_cost_per_1k / 0.001) # Normalize to $0.001 baseline
|
|
quality_score = config.quality_score
|
|
speed_score = 1.0 - (config.average_latency_ms / 5000) # Normalize to 5s baseline
|
|
reliability_score = config.reliability_score * metrics.success_rate
|
|
|
|
# Apply strategy weights
|
|
if context.strategy == ModelSelectionStrategy.COST_OPTIMIZED:
|
|
weights = {"cost": 0.6, "quality": 0.2, "speed": 0.1, "reliability": 0.1}
|
|
elif context.strategy == ModelSelectionStrategy.QUALITY_OPTIMIZED:
|
|
weights = {"cost": 0.1, "quality": 0.6, "speed": 0.1, "reliability": 0.2}
|
|
elif context.strategy == ModelSelectionStrategy.SPEED_OPTIMIZED:
|
|
weights = {"cost": 0.1, "quality": 0.2, "speed": 0.5, "reliability": 0.2}
|
|
else: # BALANCED
|
|
weights = {"cost": 0.25, "quality": 0.35, "speed": 0.2, "reliability": 0.2}
|
|
|
|
# Calculate weighted score
|
|
score = (
|
|
cost_score * weights["cost"] +
|
|
quality_score * weights["quality"] +
|
|
speed_score * weights["speed"] +
|
|
reliability_score * weights["reliability"]
|
|
) * 100
|
|
|
|
# Boost score if model has specific capability
|
|
if context.content_type in config.capabilities:
|
|
score += 10
|
|
|
|
return min(score, 100) # Cap at 100
|
|
|
|
async def execute_with_fallback(
|
|
self,
|
|
request: SummaryRequest,
|
|
context: Optional[ModelSelectionContext] = None,
|
|
max_retries: int = 3
|
|
) -> tuple[SummaryResult, ModelProvider]:
|
|
"""Execute request with automatic fallback.
|
|
|
|
Args:
|
|
request: Summary request
|
|
context: Selection context
|
|
max_retries: Maximum retry attempts
|
|
|
|
Returns:
|
|
Tuple of (result, provider used)
|
|
"""
|
|
if not context:
|
|
# Create default context from request
|
|
context = ModelSelectionContext(
|
|
content_length=len(request.transcript),
|
|
strategy=ModelSelectionStrategy.BALANCED
|
|
)
|
|
|
|
# Get fallback chain
|
|
primary = self.select_model(context)
|
|
if not primary:
|
|
raise ValueError("No suitable model available")
|
|
|
|
# Build fallback list
|
|
fallback_providers = [primary]
|
|
for provider in self.fallback_chain:
|
|
if provider != primary and provider in self.services:
|
|
fallback_providers.append(provider)
|
|
|
|
last_error = None
|
|
for provider in fallback_providers:
|
|
service = self.services.get(provider)
|
|
if not service:
|
|
continue
|
|
|
|
config = self.models[provider]
|
|
if not config.is_available:
|
|
continue
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Execute request
|
|
start_time = datetime.utcnow()
|
|
result = await service.generate_summary(request)
|
|
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
|
|
|
|
# Update metrics
|
|
await self._update_metrics(
|
|
provider,
|
|
success=True,
|
|
latency_ms=latency_ms,
|
|
input_tokens=result.usage.input_tokens,
|
|
output_tokens=result.usage.output_tokens
|
|
)
|
|
|
|
# Mark as available
|
|
config.is_available = True
|
|
|
|
logger.info(f"Successfully used {provider.value} (attempt {attempt + 1})")
|
|
return result, provider
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.warning(f"Failed with {provider.value} (attempt {attempt + 1}): {e}")
|
|
|
|
# Update failure metrics
|
|
await self._update_metrics(provider, success=False)
|
|
|
|
# Mark as unavailable after multiple failures
|
|
if attempt == max_retries - 1:
|
|
config.is_available = False
|
|
config.last_error = str(e)
|
|
config.last_error_time = datetime.utcnow()
|
|
|
|
# Wait before retry
|
|
if attempt < max_retries - 1:
|
|
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
|
|
|
# All attempts failed
|
|
raise Exception(f"All models failed. Last error: {last_error}")
|
|
|
|
async def _update_metrics(
|
|
self,
|
|
provider: ModelProvider,
|
|
success: bool,
|
|
latency_ms: float = 0,
|
|
input_tokens: int = 0,
|
|
output_tokens: int = 0
|
|
):
|
|
"""Update model metrics.
|
|
|
|
Args:
|
|
provider: Model provider
|
|
success: Whether request succeeded
|
|
latency_ms: Request latency
|
|
input_tokens: Input token count
|
|
output_tokens: Output token count
|
|
"""
|
|
metrics = self.metrics[provider]
|
|
config = self.models[provider]
|
|
|
|
metrics.total_requests += 1
|
|
if success:
|
|
metrics.successful_requests += 1
|
|
metrics.total_latency_ms += latency_ms
|
|
metrics.total_input_tokens += input_tokens
|
|
metrics.total_output_tokens += output_tokens
|
|
|
|
# Calculate cost
|
|
cost = config.get_total_cost(input_tokens, output_tokens)
|
|
metrics.total_cost += cost
|
|
else:
|
|
metrics.failed_requests += 1
|
|
|
|
metrics.last_used = datetime.utcnow()
|
|
|
|
def get_metrics(self, provider: Optional[ModelProvider] = None) -> Dict[str, Any]:
|
|
"""Get metrics for models.
|
|
|
|
Args:
|
|
provider: Specific provider or None for all
|
|
|
|
Returns:
|
|
Metrics dictionary
|
|
"""
|
|
if provider:
|
|
metrics = self.metrics.get(provider)
|
|
if not metrics:
|
|
return {}
|
|
|
|
return {
|
|
"provider": provider.value,
|
|
"total_requests": metrics.total_requests,
|
|
"success_rate": metrics.success_rate,
|
|
"average_latency_ms": metrics.average_latency,
|
|
"total_cost_usd": metrics.total_cost,
|
|
"total_tokens": metrics.total_input_tokens + metrics.total_output_tokens
|
|
}
|
|
|
|
# Return metrics for all providers
|
|
all_metrics = {}
|
|
for prov, metrics in self.metrics.items():
|
|
all_metrics[prov.value] = {
|
|
"total_requests": metrics.total_requests,
|
|
"success_rate": metrics.success_rate,
|
|
"average_latency_ms": metrics.average_latency,
|
|
"total_cost_usd": metrics.total_cost,
|
|
"total_tokens": metrics.total_input_tokens + metrics.total_output_tokens
|
|
}
|
|
|
|
return all_metrics
|
|
|
|
def get_cost_comparison(self, token_count: int) -> Dict[str, float]:
|
|
"""Get cost comparison across models.
|
|
|
|
Args:
|
|
token_count: Estimated token count
|
|
|
|
Returns:
|
|
Cost comparison dictionary
|
|
"""
|
|
comparison = {}
|
|
for provider, config in self.models.items():
|
|
# Estimate 1:1 input/output ratio
|
|
cost = config.get_total_cost(token_count, token_count)
|
|
comparison[provider.value] = {
|
|
"cost_usd": cost,
|
|
"model": config.display_name,
|
|
"quality_score": config.quality_score,
|
|
"latency_ms": config.average_latency_ms
|
|
}
|
|
|
|
return comparison
|
|
|
|
def reset_availability(self, provider: Optional[ModelProvider] = None):
|
|
"""Reset model availability status.
|
|
|
|
Args:
|
|
provider: Specific provider or None for all
|
|
"""
|
|
if provider:
|
|
if provider in self.models:
|
|
self.models[provider].is_available = True
|
|
self.models[provider].last_error = None
|
|
logger.info(f"Reset availability for {provider.value}")
|
|
else:
|
|
for config in self.models.values():
|
|
config.is_available = True
|
|
config.last_error = None
|
|
logger.info("Reset availability for all models") |