310 lines
11 KiB
Python
310 lines
11 KiB
Python
"""Multi-model AI service with intelligent selection and fallback."""
|
|
|
|
import os
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
from enum import Enum
|
|
|
|
from .ai_service import AIService, SummaryRequest, SummaryResult
|
|
from .ai_model_registry import (
|
|
AIModelRegistry,
|
|
ModelProvider,
|
|
ModelSelectionContext,
|
|
ModelSelectionStrategy,
|
|
ModelCapability
|
|
)
|
|
from .openai_summarizer import OpenAISummarizer
|
|
from .anthropic_summarizer import AnthropicSummarizer
|
|
from .deepseek_summarizer import DeepSeekSummarizer
|
|
from .gemini_summarizer import GeminiSummarizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MultiModelService:
|
|
"""Orchestrates multiple AI models with intelligent selection and fallback."""
|
|
|
|
def __init__(
|
|
self,
|
|
openai_api_key: Optional[str] = None,
|
|
anthropic_api_key: Optional[str] = None,
|
|
deepseek_api_key: Optional[str] = None,
|
|
google_api_key: Optional[str] = None,
|
|
default_strategy: ModelSelectionStrategy = ModelSelectionStrategy.BALANCED
|
|
):
|
|
"""Initialize multi-model service.
|
|
|
|
Args:
|
|
openai_api_key: OpenAI API key
|
|
anthropic_api_key: Anthropic API key
|
|
deepseek_api_key: DeepSeek API key
|
|
google_api_key: Google Gemini API key
|
|
default_strategy: Default model selection strategy
|
|
"""
|
|
self.registry = AIModelRegistry()
|
|
self.default_strategy = default_strategy
|
|
|
|
# Initialize available services
|
|
self._initialize_services(openai_api_key, anthropic_api_key, deepseek_api_key, google_api_key)
|
|
|
|
# Track active providers
|
|
self.active_providers = list(self.registry.services.keys())
|
|
|
|
if not self.active_providers:
|
|
raise ValueError("No AI service API keys provided. At least one is required.")
|
|
|
|
logger.info(f"Initialized multi-model service with providers: {[p.value for p in self.active_providers]}")
|
|
|
|
def _initialize_services(
|
|
self,
|
|
openai_api_key: Optional[str],
|
|
anthropic_api_key: Optional[str],
|
|
deepseek_api_key: Optional[str],
|
|
google_api_key: Optional[str]
|
|
):
|
|
"""Initialize AI services based on available API keys."""
|
|
|
|
# Try environment variables if not provided
|
|
openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
|
anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
deepseek_api_key = deepseek_api_key or os.getenv("DEEPSEEK_API_KEY")
|
|
google_api_key = google_api_key or os.getenv("GOOGLE_API_KEY")
|
|
|
|
# Initialize OpenAI
|
|
if openai_api_key:
|
|
try:
|
|
service = OpenAISummarizer(api_key=openai_api_key)
|
|
self.registry.register_service(ModelProvider.OPENAI, service)
|
|
logger.info("Initialized OpenAI service")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize OpenAI service: {e}")
|
|
|
|
# Initialize Anthropic
|
|
if anthropic_api_key:
|
|
try:
|
|
service = AnthropicSummarizer(api_key=anthropic_api_key)
|
|
self.registry.register_service(ModelProvider.ANTHROPIC, service)
|
|
logger.info("Initialized Anthropic service")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Anthropic service: {e}")
|
|
|
|
# Initialize DeepSeek
|
|
if deepseek_api_key:
|
|
try:
|
|
service = DeepSeekSummarizer(api_key=deepseek_api_key)
|
|
self.registry.register_service(ModelProvider.DEEPSEEK, service)
|
|
logger.info("Initialized DeepSeek service")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize DeepSeek service: {e}")
|
|
|
|
# Initialize Google Gemini - BEST for long-form content with 2M context!
|
|
if google_api_key:
|
|
try:
|
|
service = GeminiSummarizer(api_key=google_api_key, model="gemini-1.5-pro")
|
|
self.registry.register_service(ModelProvider.GOOGLE, service)
|
|
logger.info("Initialized Google Gemini service (2M token context)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Google Gemini service: {e}")
|
|
|
|
def _determine_content_type(self, transcript: str) -> Optional[ModelCapability]:
|
|
"""Determine content type from transcript.
|
|
|
|
Args:
|
|
transcript: Video transcript
|
|
|
|
Returns:
|
|
Detected content type or None
|
|
"""
|
|
transcript_lower = transcript.lower()
|
|
|
|
# Check for technical content
|
|
technical_keywords = ["code", "function", "algorithm", "debug", "compile", "api", "database"]
|
|
if sum(1 for k in technical_keywords if k in transcript_lower) >= 3:
|
|
return ModelCapability.TECHNICAL
|
|
|
|
# Check for educational content
|
|
educational_keywords = ["learn", "explain", "understand", "lesson", "tutorial", "course"]
|
|
if sum(1 for k in educational_keywords if k in transcript_lower) >= 3:
|
|
return ModelCapability.EDUCATIONAL
|
|
|
|
# Check for conversational content
|
|
conversational_keywords = ["interview", "discussion", "talk", "conversation", "podcast"]
|
|
if sum(1 for k in conversational_keywords if k in transcript_lower) >= 2:
|
|
return ModelCapability.CONVERSATIONAL
|
|
|
|
# Check for news content
|
|
news_keywords = ["breaking", "news", "report", "update", "announcement"]
|
|
if sum(1 for k in news_keywords if k in transcript_lower) >= 2:
|
|
return ModelCapability.NEWS
|
|
|
|
# Check for creative content
|
|
creative_keywords = ["art", "music", "creative", "design", "performance"]
|
|
if sum(1 for k in creative_keywords if k in transcript_lower) >= 2:
|
|
return ModelCapability.CREATIVE
|
|
|
|
# Determine by length
|
|
word_count = len(transcript.split())
|
|
if word_count < 1000:
|
|
return ModelCapability.SHORT_FORM
|
|
elif word_count < 5000:
|
|
return ModelCapability.MEDIUM_FORM
|
|
else:
|
|
return ModelCapability.LONG_FORM
|
|
|
|
async def generate_summary(
|
|
self,
|
|
request: SummaryRequest,
|
|
strategy: Optional[ModelSelectionStrategy] = None,
|
|
preferred_provider: Optional[ModelProvider] = None,
|
|
max_cost: Optional[float] = None
|
|
) -> tuple[SummaryResult, ModelProvider]:
|
|
"""Generate summary using intelligent model selection.
|
|
|
|
Args:
|
|
request: Summary request
|
|
strategy: Model selection strategy (uses default if None)
|
|
preferred_provider: Preferred model provider
|
|
max_cost: Maximum cost constraint in USD
|
|
|
|
Returns:
|
|
Tuple of (summary result, provider used)
|
|
"""
|
|
# Create selection context
|
|
context = ModelSelectionContext(
|
|
content_length=len(request.transcript),
|
|
content_type=self._determine_content_type(request.transcript),
|
|
language="en", # TODO: Detect language
|
|
strategy=strategy or self.default_strategy,
|
|
max_cost=max_cost,
|
|
user_preference=preferred_provider
|
|
)
|
|
|
|
# Execute with fallback
|
|
result, provider = await self.registry.execute_with_fallback(request, context)
|
|
|
|
# Add provider info to result
|
|
result.processing_metadata["provider"] = provider.value
|
|
result.processing_metadata["strategy"] = context.strategy.value
|
|
|
|
return result, provider
|
|
|
|
async def generate_summary_simple(self, request: SummaryRequest) -> SummaryResult:
|
|
"""Generate summary with default settings (AIService interface).
|
|
|
|
Args:
|
|
request: Summary request
|
|
|
|
Returns:
|
|
Summary result
|
|
"""
|
|
result, _ = await self.generate_summary(request)
|
|
return result
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
"""Get metrics for all models.
|
|
|
|
Returns:
|
|
Metrics dictionary
|
|
"""
|
|
return self.registry.get_metrics()
|
|
|
|
def get_provider_metrics(self, provider: ModelProvider) -> Dict[str, Any]:
|
|
"""Get metrics for specific provider.
|
|
|
|
Args:
|
|
provider: Model provider
|
|
|
|
Returns:
|
|
Provider metrics
|
|
"""
|
|
return self.registry.get_metrics(provider)
|
|
|
|
def estimate_cost(self, transcript_length: int) -> Dict[str, Any]:
|
|
"""Estimate cost across different models.
|
|
|
|
Args:
|
|
transcript_length: Length of transcript in characters
|
|
|
|
Returns:
|
|
Cost comparison across models
|
|
"""
|
|
# Estimate tokens (roughly 1 token per 4 characters)
|
|
estimated_tokens = transcript_length // 4
|
|
|
|
comparison = self.registry.get_cost_comparison(estimated_tokens)
|
|
|
|
# Add recommendations
|
|
recommendations = []
|
|
|
|
# Find cheapest
|
|
cheapest = min(comparison.items(), key=lambda x: x[1]["cost_usd"])
|
|
recommendations.append({
|
|
"type": "cost_optimized",
|
|
"provider": cheapest[0],
|
|
"reason": f"Lowest cost at ${cheapest[1]['cost_usd']:.4f}"
|
|
})
|
|
|
|
# Find highest quality
|
|
highest_quality = max(comparison.items(), key=lambda x: x[1]["quality_score"])
|
|
recommendations.append({
|
|
"type": "quality_optimized",
|
|
"provider": highest_quality[0],
|
|
"reason": f"Highest quality score at {highest_quality[1]['quality_score']:.2f}"
|
|
})
|
|
|
|
# Find fastest
|
|
fastest = min(comparison.items(), key=lambda x: x[1]["latency_ms"])
|
|
recommendations.append({
|
|
"type": "speed_optimized",
|
|
"provider": fastest[0],
|
|
"reason": f"Fastest processing at {fastest[1]['latency_ms']:.0f}ms"
|
|
})
|
|
|
|
return {
|
|
"estimated_tokens": estimated_tokens,
|
|
"comparison": comparison,
|
|
"recommendations": recommendations
|
|
}
|
|
|
|
def reset_model_availability(self, provider: Optional[ModelProvider] = None):
|
|
"""Reset model availability after errors.
|
|
|
|
Args:
|
|
provider: Specific provider or None for all
|
|
"""
|
|
self.registry.reset_availability(provider)
|
|
|
|
def get_available_models(self) -> list[str]:
|
|
"""Get list of available model providers.
|
|
|
|
Returns:
|
|
List of available provider names
|
|
"""
|
|
return [p.value for p in self.active_providers]
|
|
|
|
def set_default_strategy(self, strategy: ModelSelectionStrategy):
|
|
"""Set default model selection strategy.
|
|
|
|
Args:
|
|
strategy: New default strategy
|
|
"""
|
|
self.default_strategy = strategy
|
|
logger.info(f"Set default strategy to: {strategy.value}")
|
|
|
|
|
|
# Factory function for dependency injection
|
|
def get_multi_model_service() -> MultiModelService:
|
|
"""Get or create multi-model service instance.
|
|
|
|
Returns:
|
|
MultiModelService instance
|
|
"""
|
|
from ..core.config import settings
|
|
|
|
# This could be a singleton or created per-request
|
|
return MultiModelService(
|
|
openai_api_key=settings.OPENAI_API_KEY,
|
|
anthropic_api_key=settings.ANTHROPIC_API_KEY,
|
|
deepseek_api_key=settings.DEEPSEEK_API_KEY,
|
|
google_api_key=settings.GOOGLE_API_KEY # 🚀 Gemini with 2M token context!
|
|
) |