"""AI Model Registry for managing multiple AI providers.""" import logging from enum import Enum from typing import Dict, List, Optional, Any, Type from dataclasses import dataclass, field from datetime import datetime, timedelta import asyncio from ..services.ai_service import AIService, SummaryRequest, SummaryResult logger = logging.getLogger(__name__) class ModelProvider(Enum): """Supported AI model providers.""" OPENAI = "openai" ANTHROPIC = "anthropic" DEEPSEEK = "deepseek" GOOGLE = "google" class ModelCapability(Enum): """Model capabilities for matching.""" SHORT_FORM = "short_form" # < 5 min videos MEDIUM_FORM = "medium_form" # 5-30 min videos LONG_FORM = "long_form" # 30+ min videos TECHNICAL = "technical" # Code, tutorials EDUCATIONAL = "educational" # Lectures, courses CONVERSATIONAL = "conversational" # Interviews, podcasts NEWS = "news" # News, current events CREATIVE = "creative" # Music, art, entertainment @dataclass class ModelConfig: """Configuration for an AI model.""" provider: ModelProvider model_name: str display_name: str max_tokens: int context_window: int # Cost per 1K tokens in USD input_cost_per_1k: float output_cost_per_1k: float # Performance characteristics average_latency_ms: float = 1000.0 reliability_score: float = 0.95 # 0-1 scale quality_score: float = 0.90 # 0-1 scale # Capabilities capabilities: List[ModelCapability] = field(default_factory=list) supported_languages: List[str] = field(default_factory=lambda: ["en"]) # Rate limits requests_per_minute: int = 60 tokens_per_minute: int = 90000 # Status is_available: bool = True last_error: Optional[str] = None last_error_time: Optional[datetime] = None def get_total_cost(self, input_tokens: int, output_tokens: int) -> float: """Calculate total cost for token usage.""" input_cost = (input_tokens / 1000) * self.input_cost_per_1k output_cost = (output_tokens / 1000) * self.output_cost_per_1k return input_cost + output_cost @dataclass class ModelMetrics: """Performance metrics for a model.""" total_requests: int = 0 successful_requests: int = 0 failed_requests: int = 0 total_input_tokens: int = 0 total_output_tokens: int = 0 total_cost: float = 0.0 total_latency_ms: float = 0.0 last_used: Optional[datetime] = None @property def success_rate(self) -> float: """Calculate success rate.""" if self.total_requests == 0: return 1.0 return self.successful_requests / self.total_requests @property def average_latency(self) -> float: """Calculate average latency.""" if self.successful_requests == 0: return 0.0 return self.total_latency_ms / self.successful_requests class ModelSelectionStrategy(Enum): """Strategy for selecting models.""" COST_OPTIMIZED = "cost_optimized" # Minimize cost QUALITY_OPTIMIZED = "quality_optimized" # Maximize quality SPEED_OPTIMIZED = "speed_optimized" # Minimize latency BALANCED = "balanced" # Balance all factors @dataclass class ModelSelectionContext: """Context for model selection.""" content_length: int # Characters content_type: Optional[ModelCapability] = None language: str = "en" strategy: ModelSelectionStrategy = ModelSelectionStrategy.BALANCED max_cost: Optional[float] = None # Maximum cost in USD max_latency_ms: Optional[float] = None required_quality: float = 0.8 # Minimum quality score user_preference: Optional[ModelProvider] = None class AIModelRegistry: """Registry for managing multiple AI model providers.""" def __init__(self): """Initialize the model registry.""" self.models: Dict[ModelProvider, ModelConfig] = {} self.services: Dict[ModelProvider, AIService] = {} self.metrics: Dict[ModelProvider, ModelMetrics] = {} self.fallback_chain: List[ModelProvider] = [] # Initialize default model configurations self._initialize_default_models() def _initialize_default_models(self): """Initialize default model configurations.""" # OpenAI GPT-4o-mini self.register_model(ModelConfig( provider=ModelProvider.OPENAI, model_name="gpt-4o-mini", display_name="GPT-4 Omni Mini", max_tokens=16384, context_window=128000, input_cost_per_1k=0.00015, output_cost_per_1k=0.0006, average_latency_ms=800, reliability_score=0.95, quality_score=0.88, capabilities=[ ModelCapability.SHORT_FORM, ModelCapability.MEDIUM_FORM, ModelCapability.TECHNICAL, ModelCapability.EDUCATIONAL, ModelCapability.CONVERSATIONAL ], supported_languages=["en", "es", "fr", "de", "zh", "ja", "ko"], requests_per_minute=500, tokens_per_minute=200000 )) # Anthropic Claude 3.5 Haiku self.register_model(ModelConfig( provider=ModelProvider.ANTHROPIC, model_name="claude-3-5-haiku-20241022", display_name="Claude 3.5 Haiku", max_tokens=8192, context_window=200000, input_cost_per_1k=0.00025, output_cost_per_1k=0.00125, average_latency_ms=500, reliability_score=0.98, quality_score=0.92, capabilities=[ ModelCapability.SHORT_FORM, ModelCapability.MEDIUM_FORM, ModelCapability.LONG_FORM, ModelCapability.TECHNICAL, ModelCapability.EDUCATIONAL, ModelCapability.CREATIVE ], supported_languages=["en", "es", "fr", "de", "pt", "it", "nl"], requests_per_minute=100, tokens_per_minute=100000 )) # DeepSeek V2 self.register_model(ModelConfig( provider=ModelProvider.DEEPSEEK, model_name="deepseek-chat", display_name="DeepSeek V2", max_tokens=4096, context_window=32000, input_cost_per_1k=0.00014, output_cost_per_1k=0.00028, average_latency_ms=1200, reliability_score=0.90, quality_score=0.85, capabilities=[ ModelCapability.SHORT_FORM, ModelCapability.MEDIUM_FORM, ModelCapability.TECHNICAL, ModelCapability.EDUCATIONAL ], supported_languages=["en", "zh"], requests_per_minute=60, tokens_per_minute=90000 )) # Google Gemini 1.5 Pro - MASSIVE CONTEXT WINDOW (2M tokens!) self.register_model(ModelConfig( provider=ModelProvider.GOOGLE, model_name="gemini-1.5-pro", display_name="Gemini 1.5 Pro (2M Context)", max_tokens=8192, context_window=2000000, # 2 MILLION token context! input_cost_per_1k=0.007, # $7 per 1M tokens - competitive for massive context output_cost_per_1k=0.021, # $21 per 1M tokens average_latency_ms=2000, # Slightly higher due to large context processing reliability_score=0.96, quality_score=0.94, # Excellent quality with full context capabilities=[ ModelCapability.SHORT_FORM, ModelCapability.MEDIUM_FORM, ModelCapability.LONG_FORM, # EXCELS at long-form content ModelCapability.TECHNICAL, ModelCapability.EDUCATIONAL, ModelCapability.CONVERSATIONAL, ModelCapability.NEWS, ModelCapability.CREATIVE ], supported_languages=["en", "es", "fr", "de", "pt", "it", "nl", "ja", "ko", "zh", "hi"], requests_per_minute=60, tokens_per_minute=32000 # Large context means fewer but higher-quality requests )) # Set default fallback chain - Gemini FIRST for long content due to massive context self.fallback_chain = [ ModelProvider.GOOGLE, # Best for long-form content ModelProvider.ANTHROPIC, # Great quality fallback ModelProvider.OPENAI, # Reliable alternative ModelProvider.DEEPSEEK # Cost-effective option ] def register_model(self, config: ModelConfig): """Register a model configuration.""" self.models[config.provider] = config self.metrics[config.provider] = ModelMetrics() logger.info(f"Registered model: {config.display_name} ({config.provider.value})") def register_service(self, provider: ModelProvider, service: AIService): """Register an AI service implementation.""" if provider not in self.models: raise ValueError(f"Model {provider} not registered") self.services[provider] = service logger.info(f"Registered service for {provider.value}") def get_model_config(self, provider: ModelProvider) -> Optional[ModelConfig]: """Get model configuration.""" return self.models.get(provider) def get_service(self, provider: ModelProvider) -> Optional[AIService]: """Get AI service for a provider.""" return self.services.get(provider) def select_model(self, context: ModelSelectionContext) -> Optional[ModelProvider]: """Select the best model based on context. Args: context: Selection context with requirements Returns: Selected model provider or None if no suitable model """ available_models = self._get_available_models(context) if not available_models: logger.warning("No available models match the context requirements") return None # Apply user preference if specified if context.user_preference and context.user_preference in available_models: return context.user_preference # Score and rank models scored_models = [] for provider in available_models: score = self._score_model(provider, context) scored_models.append((provider, score)) # Sort by score (higher is better) scored_models.sort(key=lambda x: x[1], reverse=True) if scored_models: selected = scored_models[0][0] logger.info(f"Selected model: {selected.value} (score: {scored_models[0][1]:.2f})") return selected return None def _get_available_models(self, context: ModelSelectionContext) -> List[ModelProvider]: """Get available models that meet context requirements.""" available = [] for provider, config in self.models.items(): # Check availability if not config.is_available: continue # Check language support if context.language not in config.supported_languages: continue # Check quality requirement if config.quality_score < context.required_quality: continue # Check cost constraint if context.max_cost: estimated_tokens = context.content_length / 4 # Rough estimate estimated_cost = config.get_total_cost(estimated_tokens, estimated_tokens / 2) if estimated_cost > context.max_cost: continue # Check latency constraint if context.max_latency_ms and config.average_latency_ms > context.max_latency_ms: continue # Check capabilities match if context.content_type and context.content_type not in config.capabilities: continue available.append(provider) return available def _score_model(self, provider: ModelProvider, context: ModelSelectionContext) -> float: """Score a model based on selection strategy. Args: provider: Model provider to score context: Selection context Returns: Score from 0-100 """ config = self.models[provider] metrics = self.metrics[provider] # Base scores (0-1 scale) cost_score = 1.0 - (config.input_cost_per_1k / 0.001) # Normalize to $0.001 baseline quality_score = config.quality_score speed_score = 1.0 - (config.average_latency_ms / 5000) # Normalize to 5s baseline reliability_score = config.reliability_score * metrics.success_rate # Apply strategy weights if context.strategy == ModelSelectionStrategy.COST_OPTIMIZED: weights = {"cost": 0.6, "quality": 0.2, "speed": 0.1, "reliability": 0.1} elif context.strategy == ModelSelectionStrategy.QUALITY_OPTIMIZED: weights = {"cost": 0.1, "quality": 0.6, "speed": 0.1, "reliability": 0.2} elif context.strategy == ModelSelectionStrategy.SPEED_OPTIMIZED: weights = {"cost": 0.1, "quality": 0.2, "speed": 0.5, "reliability": 0.2} else: # BALANCED weights = {"cost": 0.25, "quality": 0.35, "speed": 0.2, "reliability": 0.2} # Calculate weighted score score = ( cost_score * weights["cost"] + quality_score * weights["quality"] + speed_score * weights["speed"] + reliability_score * weights["reliability"] ) * 100 # Boost score if model has specific capability if context.content_type in config.capabilities: score += 10 return min(score, 100) # Cap at 100 async def execute_with_fallback( self, request: SummaryRequest, context: Optional[ModelSelectionContext] = None, max_retries: int = 3 ) -> tuple[SummaryResult, ModelProvider]: """Execute request with automatic fallback. Args: request: Summary request context: Selection context max_retries: Maximum retry attempts Returns: Tuple of (result, provider used) """ if not context: # Create default context from request context = ModelSelectionContext( content_length=len(request.transcript), strategy=ModelSelectionStrategy.BALANCED ) # Get fallback chain primary = self.select_model(context) if not primary: raise ValueError("No suitable model available") # Build fallback list fallback_providers = [primary] for provider in self.fallback_chain: if provider != primary and provider in self.services: fallback_providers.append(provider) last_error = None for provider in fallback_providers: service = self.services.get(provider) if not service: continue config = self.models[provider] if not config.is_available: continue for attempt in range(max_retries): try: # Execute request start_time = datetime.utcnow() result = await service.generate_summary(request) latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000 # Update metrics await self._update_metrics( provider, success=True, latency_ms=latency_ms, input_tokens=result.usage.input_tokens, output_tokens=result.usage.output_tokens ) # Mark as available config.is_available = True logger.info(f"Successfully used {provider.value} (attempt {attempt + 1})") return result, provider except Exception as e: last_error = e logger.warning(f"Failed with {provider.value} (attempt {attempt + 1}): {e}") # Update failure metrics await self._update_metrics(provider, success=False) # Mark as unavailable after multiple failures if attempt == max_retries - 1: config.is_available = False config.last_error = str(e) config.last_error_time = datetime.utcnow() # Wait before retry if attempt < max_retries - 1: await asyncio.sleep(2 ** attempt) # Exponential backoff # All attempts failed raise Exception(f"All models failed. Last error: {last_error}") async def _update_metrics( self, provider: ModelProvider, success: bool, latency_ms: float = 0, input_tokens: int = 0, output_tokens: int = 0 ): """Update model metrics. Args: provider: Model provider success: Whether request succeeded latency_ms: Request latency input_tokens: Input token count output_tokens: Output token count """ metrics = self.metrics[provider] config = self.models[provider] metrics.total_requests += 1 if success: metrics.successful_requests += 1 metrics.total_latency_ms += latency_ms metrics.total_input_tokens += input_tokens metrics.total_output_tokens += output_tokens # Calculate cost cost = config.get_total_cost(input_tokens, output_tokens) metrics.total_cost += cost else: metrics.failed_requests += 1 metrics.last_used = datetime.utcnow() def get_metrics(self, provider: Optional[ModelProvider] = None) -> Dict[str, Any]: """Get metrics for models. Args: provider: Specific provider or None for all Returns: Metrics dictionary """ if provider: metrics = self.metrics.get(provider) if not metrics: return {} return { "provider": provider.value, "total_requests": metrics.total_requests, "success_rate": metrics.success_rate, "average_latency_ms": metrics.average_latency, "total_cost_usd": metrics.total_cost, "total_tokens": metrics.total_input_tokens + metrics.total_output_tokens } # Return metrics for all providers all_metrics = {} for prov, metrics in self.metrics.items(): all_metrics[prov.value] = { "total_requests": metrics.total_requests, "success_rate": metrics.success_rate, "average_latency_ms": metrics.average_latency, "total_cost_usd": metrics.total_cost, "total_tokens": metrics.total_input_tokens + metrics.total_output_tokens } return all_metrics def get_cost_comparison(self, token_count: int) -> Dict[str, float]: """Get cost comparison across models. Args: token_count: Estimated token count Returns: Cost comparison dictionary """ comparison = {} for provider, config in self.models.items(): # Estimate 1:1 input/output ratio cost = config.get_total_cost(token_count, token_count) comparison[provider.value] = { "cost_usd": cost, "model": config.display_name, "quality_score": config.quality_score, "latency_ms": config.average_latency_ms } return comparison def reset_availability(self, provider: Optional[ModelProvider] = None): """Reset model availability status. Args: provider: Specific provider or None for all """ if provider: if provider in self.models: self.models[provider].is_available = True self.models[provider].last_error = None logger.info(f"Reset availability for {provider.value}") else: for config in self.models.values(): config.is_available = True config.last_error = None logger.info("Reset availability for all models")