youtube-summarizer/backend/services/ai_model_registry.py

572 lines
21 KiB
Python

"""AI Model Registry for managing multiple AI providers."""
import logging
from enum import Enum
from typing import Dict, List, Optional, Any, Type
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncio
from ..services.ai_service import AIService, SummaryRequest, SummaryResult
logger = logging.getLogger(__name__)
class ModelProvider(Enum):
"""Supported AI model providers."""
OPENAI = "openai"
ANTHROPIC = "anthropic"
DEEPSEEK = "deepseek"
GOOGLE = "google"
class ModelCapability(Enum):
"""Model capabilities for matching."""
SHORT_FORM = "short_form" # < 5 min videos
MEDIUM_FORM = "medium_form" # 5-30 min videos
LONG_FORM = "long_form" # 30+ min videos
TECHNICAL = "technical" # Code, tutorials
EDUCATIONAL = "educational" # Lectures, courses
CONVERSATIONAL = "conversational" # Interviews, podcasts
NEWS = "news" # News, current events
CREATIVE = "creative" # Music, art, entertainment
@dataclass
class ModelConfig:
"""Configuration for an AI model."""
provider: ModelProvider
model_name: str
display_name: str
max_tokens: int
context_window: int
# Cost per 1K tokens in USD
input_cost_per_1k: float
output_cost_per_1k: float
# Performance characteristics
average_latency_ms: float = 1000.0
reliability_score: float = 0.95 # 0-1 scale
quality_score: float = 0.90 # 0-1 scale
# Capabilities
capabilities: List[ModelCapability] = field(default_factory=list)
supported_languages: List[str] = field(default_factory=lambda: ["en"])
# Rate limits
requests_per_minute: int = 60
tokens_per_minute: int = 90000
# Status
is_available: bool = True
last_error: Optional[str] = None
last_error_time: Optional[datetime] = None
def get_total_cost(self, input_tokens: int, output_tokens: int) -> float:
"""Calculate total cost for token usage."""
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
return input_cost + output_cost
@dataclass
class ModelMetrics:
"""Performance metrics for a model."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
total_latency_ms: float = 0.0
last_used: Optional[datetime] = None
@property
def success_rate(self) -> float:
"""Calculate success rate."""
if self.total_requests == 0:
return 1.0
return self.successful_requests / self.total_requests
@property
def average_latency(self) -> float:
"""Calculate average latency."""
if self.successful_requests == 0:
return 0.0
return self.total_latency_ms / self.successful_requests
class ModelSelectionStrategy(Enum):
"""Strategy for selecting models."""
COST_OPTIMIZED = "cost_optimized" # Minimize cost
QUALITY_OPTIMIZED = "quality_optimized" # Maximize quality
SPEED_OPTIMIZED = "speed_optimized" # Minimize latency
BALANCED = "balanced" # Balance all factors
@dataclass
class ModelSelectionContext:
"""Context for model selection."""
content_length: int # Characters
content_type: Optional[ModelCapability] = None
language: str = "en"
strategy: ModelSelectionStrategy = ModelSelectionStrategy.BALANCED
max_cost: Optional[float] = None # Maximum cost in USD
max_latency_ms: Optional[float] = None
required_quality: float = 0.8 # Minimum quality score
user_preference: Optional[ModelProvider] = None
class AIModelRegistry:
"""Registry for managing multiple AI model providers."""
def __init__(self):
"""Initialize the model registry."""
self.models: Dict[ModelProvider, ModelConfig] = {}
self.services: Dict[ModelProvider, AIService] = {}
self.metrics: Dict[ModelProvider, ModelMetrics] = {}
self.fallback_chain: List[ModelProvider] = []
# Initialize default model configurations
self._initialize_default_models()
def _initialize_default_models(self):
"""Initialize default model configurations."""
# OpenAI GPT-4o-mini
self.register_model(ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-4o-mini",
display_name="GPT-4 Omni Mini",
max_tokens=16384,
context_window=128000,
input_cost_per_1k=0.00015,
output_cost_per_1k=0.0006,
average_latency_ms=800,
reliability_score=0.95,
quality_score=0.88,
capabilities=[
ModelCapability.SHORT_FORM,
ModelCapability.MEDIUM_FORM,
ModelCapability.TECHNICAL,
ModelCapability.EDUCATIONAL,
ModelCapability.CONVERSATIONAL
],
supported_languages=["en", "es", "fr", "de", "zh", "ja", "ko"],
requests_per_minute=500,
tokens_per_minute=200000
))
# Anthropic Claude 3.5 Haiku
self.register_model(ModelConfig(
provider=ModelProvider.ANTHROPIC,
model_name="claude-3-5-haiku-20241022",
display_name="Claude 3.5 Haiku",
max_tokens=8192,
context_window=200000,
input_cost_per_1k=0.00025,
output_cost_per_1k=0.00125,
average_latency_ms=500,
reliability_score=0.98,
quality_score=0.92,
capabilities=[
ModelCapability.SHORT_FORM,
ModelCapability.MEDIUM_FORM,
ModelCapability.LONG_FORM,
ModelCapability.TECHNICAL,
ModelCapability.EDUCATIONAL,
ModelCapability.CREATIVE
],
supported_languages=["en", "es", "fr", "de", "pt", "it", "nl"],
requests_per_minute=100,
tokens_per_minute=100000
))
# DeepSeek V2
self.register_model(ModelConfig(
provider=ModelProvider.DEEPSEEK,
model_name="deepseek-chat",
display_name="DeepSeek V2",
max_tokens=4096,
context_window=32000,
input_cost_per_1k=0.00014,
output_cost_per_1k=0.00028,
average_latency_ms=1200,
reliability_score=0.90,
quality_score=0.85,
capabilities=[
ModelCapability.SHORT_FORM,
ModelCapability.MEDIUM_FORM,
ModelCapability.TECHNICAL,
ModelCapability.EDUCATIONAL
],
supported_languages=["en", "zh"],
requests_per_minute=60,
tokens_per_minute=90000
))
# Google Gemini 1.5 Pro - MASSIVE CONTEXT WINDOW (2M tokens!)
self.register_model(ModelConfig(
provider=ModelProvider.GOOGLE,
model_name="gemini-1.5-pro",
display_name="Gemini 1.5 Pro (2M Context)",
max_tokens=8192,
context_window=2000000, # 2 MILLION token context!
input_cost_per_1k=0.007, # $7 per 1M tokens - competitive for massive context
output_cost_per_1k=0.021, # $21 per 1M tokens
average_latency_ms=2000, # Slightly higher due to large context processing
reliability_score=0.96,
quality_score=0.94, # Excellent quality with full context
capabilities=[
ModelCapability.SHORT_FORM,
ModelCapability.MEDIUM_FORM,
ModelCapability.LONG_FORM, # EXCELS at long-form content
ModelCapability.TECHNICAL,
ModelCapability.EDUCATIONAL,
ModelCapability.CONVERSATIONAL,
ModelCapability.NEWS,
ModelCapability.CREATIVE
],
supported_languages=["en", "es", "fr", "de", "pt", "it", "nl", "ja", "ko", "zh", "hi"],
requests_per_minute=60,
tokens_per_minute=32000 # Large context means fewer but higher-quality requests
))
# Set default fallback chain - Gemini FIRST for long content due to massive context
self.fallback_chain = [
ModelProvider.GOOGLE, # Best for long-form content
ModelProvider.ANTHROPIC, # Great quality fallback
ModelProvider.OPENAI, # Reliable alternative
ModelProvider.DEEPSEEK # Cost-effective option
]
def register_model(self, config: ModelConfig):
"""Register a model configuration."""
self.models[config.provider] = config
self.metrics[config.provider] = ModelMetrics()
logger.info(f"Registered model: {config.display_name} ({config.provider.value})")
def register_service(self, provider: ModelProvider, service: AIService):
"""Register an AI service implementation."""
if provider not in self.models:
raise ValueError(f"Model {provider} not registered")
self.services[provider] = service
logger.info(f"Registered service for {provider.value}")
def get_model_config(self, provider: ModelProvider) -> Optional[ModelConfig]:
"""Get model configuration."""
return self.models.get(provider)
def get_service(self, provider: ModelProvider) -> Optional[AIService]:
"""Get AI service for a provider."""
return self.services.get(provider)
def select_model(self, context: ModelSelectionContext) -> Optional[ModelProvider]:
"""Select the best model based on context.
Args:
context: Selection context with requirements
Returns:
Selected model provider or None if no suitable model
"""
available_models = self._get_available_models(context)
if not available_models:
logger.warning("No available models match the context requirements")
return None
# Apply user preference if specified
if context.user_preference and context.user_preference in available_models:
return context.user_preference
# Score and rank models
scored_models = []
for provider in available_models:
score = self._score_model(provider, context)
scored_models.append((provider, score))
# Sort by score (higher is better)
scored_models.sort(key=lambda x: x[1], reverse=True)
if scored_models:
selected = scored_models[0][0]
logger.info(f"Selected model: {selected.value} (score: {scored_models[0][1]:.2f})")
return selected
return None
def _get_available_models(self, context: ModelSelectionContext) -> List[ModelProvider]:
"""Get available models that meet context requirements."""
available = []
for provider, config in self.models.items():
# Check availability
if not config.is_available:
continue
# Check language support
if context.language not in config.supported_languages:
continue
# Check quality requirement
if config.quality_score < context.required_quality:
continue
# Check cost constraint
if context.max_cost:
estimated_tokens = context.content_length / 4 # Rough estimate
estimated_cost = config.get_total_cost(estimated_tokens, estimated_tokens / 2)
if estimated_cost > context.max_cost:
continue
# Check latency constraint
if context.max_latency_ms and config.average_latency_ms > context.max_latency_ms:
continue
# Check capabilities match
if context.content_type and context.content_type not in config.capabilities:
continue
available.append(provider)
return available
def _score_model(self, provider: ModelProvider, context: ModelSelectionContext) -> float:
"""Score a model based on selection strategy.
Args:
provider: Model provider to score
context: Selection context
Returns:
Score from 0-100
"""
config = self.models[provider]
metrics = self.metrics[provider]
# Base scores (0-1 scale)
cost_score = 1.0 - (config.input_cost_per_1k / 0.001) # Normalize to $0.001 baseline
quality_score = config.quality_score
speed_score = 1.0 - (config.average_latency_ms / 5000) # Normalize to 5s baseline
reliability_score = config.reliability_score * metrics.success_rate
# Apply strategy weights
if context.strategy == ModelSelectionStrategy.COST_OPTIMIZED:
weights = {"cost": 0.6, "quality": 0.2, "speed": 0.1, "reliability": 0.1}
elif context.strategy == ModelSelectionStrategy.QUALITY_OPTIMIZED:
weights = {"cost": 0.1, "quality": 0.6, "speed": 0.1, "reliability": 0.2}
elif context.strategy == ModelSelectionStrategy.SPEED_OPTIMIZED:
weights = {"cost": 0.1, "quality": 0.2, "speed": 0.5, "reliability": 0.2}
else: # BALANCED
weights = {"cost": 0.25, "quality": 0.35, "speed": 0.2, "reliability": 0.2}
# Calculate weighted score
score = (
cost_score * weights["cost"] +
quality_score * weights["quality"] +
speed_score * weights["speed"] +
reliability_score * weights["reliability"]
) * 100
# Boost score if model has specific capability
if context.content_type in config.capabilities:
score += 10
return min(score, 100) # Cap at 100
async def execute_with_fallback(
self,
request: SummaryRequest,
context: Optional[ModelSelectionContext] = None,
max_retries: int = 3
) -> tuple[SummaryResult, ModelProvider]:
"""Execute request with automatic fallback.
Args:
request: Summary request
context: Selection context
max_retries: Maximum retry attempts
Returns:
Tuple of (result, provider used)
"""
if not context:
# Create default context from request
context = ModelSelectionContext(
content_length=len(request.transcript),
strategy=ModelSelectionStrategy.BALANCED
)
# Get fallback chain
primary = self.select_model(context)
if not primary:
raise ValueError("No suitable model available")
# Build fallback list
fallback_providers = [primary]
for provider in self.fallback_chain:
if provider != primary and provider in self.services:
fallback_providers.append(provider)
last_error = None
for provider in fallback_providers:
service = self.services.get(provider)
if not service:
continue
config = self.models[provider]
if not config.is_available:
continue
for attempt in range(max_retries):
try:
# Execute request
start_time = datetime.utcnow()
result = await service.generate_summary(request)
latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
# Update metrics
await self._update_metrics(
provider,
success=True,
latency_ms=latency_ms,
input_tokens=result.usage.input_tokens,
output_tokens=result.usage.output_tokens
)
# Mark as available
config.is_available = True
logger.info(f"Successfully used {provider.value} (attempt {attempt + 1})")
return result, provider
except Exception as e:
last_error = e
logger.warning(f"Failed with {provider.value} (attempt {attempt + 1}): {e}")
# Update failure metrics
await self._update_metrics(provider, success=False)
# Mark as unavailable after multiple failures
if attempt == max_retries - 1:
config.is_available = False
config.last_error = str(e)
config.last_error_time = datetime.utcnow()
# Wait before retry
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
# All attempts failed
raise Exception(f"All models failed. Last error: {last_error}")
async def _update_metrics(
self,
provider: ModelProvider,
success: bool,
latency_ms: float = 0,
input_tokens: int = 0,
output_tokens: int = 0
):
"""Update model metrics.
Args:
provider: Model provider
success: Whether request succeeded
latency_ms: Request latency
input_tokens: Input token count
output_tokens: Output token count
"""
metrics = self.metrics[provider]
config = self.models[provider]
metrics.total_requests += 1
if success:
metrics.successful_requests += 1
metrics.total_latency_ms += latency_ms
metrics.total_input_tokens += input_tokens
metrics.total_output_tokens += output_tokens
# Calculate cost
cost = config.get_total_cost(input_tokens, output_tokens)
metrics.total_cost += cost
else:
metrics.failed_requests += 1
metrics.last_used = datetime.utcnow()
def get_metrics(self, provider: Optional[ModelProvider] = None) -> Dict[str, Any]:
"""Get metrics for models.
Args:
provider: Specific provider or None for all
Returns:
Metrics dictionary
"""
if provider:
metrics = self.metrics.get(provider)
if not metrics:
return {}
return {
"provider": provider.value,
"total_requests": metrics.total_requests,
"success_rate": metrics.success_rate,
"average_latency_ms": metrics.average_latency,
"total_cost_usd": metrics.total_cost,
"total_tokens": metrics.total_input_tokens + metrics.total_output_tokens
}
# Return metrics for all providers
all_metrics = {}
for prov, metrics in self.metrics.items():
all_metrics[prov.value] = {
"total_requests": metrics.total_requests,
"success_rate": metrics.success_rate,
"average_latency_ms": metrics.average_latency,
"total_cost_usd": metrics.total_cost,
"total_tokens": metrics.total_input_tokens + metrics.total_output_tokens
}
return all_metrics
def get_cost_comparison(self, token_count: int) -> Dict[str, float]:
"""Get cost comparison across models.
Args:
token_count: Estimated token count
Returns:
Cost comparison dictionary
"""
comparison = {}
for provider, config in self.models.items():
# Estimate 1:1 input/output ratio
cost = config.get_total_cost(token_count, token_count)
comparison[provider.value] = {
"cost_usd": cost,
"model": config.display_name,
"quality_score": config.quality_score,
"latency_ms": config.average_latency_ms
}
return comparison
def reset_availability(self, provider: Optional[ModelProvider] = None):
"""Reset model availability status.
Args:
provider: Specific provider or None for all
"""
if provider:
if provider in self.models:
self.models[provider].is_available = True
self.models[provider].last_error = None
logger.info(f"Reset availability for {provider.value}")
else:
for config in self.models.values():
config.is_available = True
config.last_error = None
logger.info("Reset availability for all models")