youtube-summarizer/backend/services/gemini_summarizer.py

439 lines
18 KiB
Python

"""Google Gemini summarization service with 2M token context support."""
import asyncio
import json
import time
from typing import Dict, List, Optional
import httpx
import re
from .ai_service import AIService, SummaryRequest, SummaryResult, SummaryLength
from ..core.exceptions import AIServiceError, ErrorCode
class GeminiSummarizer(AIService):
"""Google Gemini-based summarization service with large context support."""
def __init__(self, api_key: str, model: str = "gemini-1.5-pro"):
"""Initialize Gemini summarizer.
Args:
api_key: Google AI API key
model: Model to use (gemini-1.5-pro for 2M context, gemini-1.5-flash for speed)
"""
self.api_key = api_key
self.model = model
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
# Context window sizes
self.max_tokens_input = 2000000 if "1.5-pro" in model else 1000000 # 2M for Pro, 1M for Flash
self.max_tokens_output = 8192 # Standard output limit
# Cost per 1M tokens (Gemini 1.5 Pro pricing - very competitive)
if "1.5-pro" in model:
self.input_cost_per_1k = 0.007 # $7 per 1M input tokens
self.output_cost_per_1k = 0.021 # $21 per 1M output tokens
else: # Flash model
self.input_cost_per_1k = 0.00015 # $0.15 per 1M input tokens
self.output_cost_per_1k = 0.0006 # $0.60 per 1M output tokens
# HTTP client for API calls
self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long context
async def generate_summary(self, request: SummaryRequest) -> SummaryResult:
"""Generate structured summary using Google Gemini with large context."""
# With 2M token context, we can handle very long transcripts without chunking!
estimated_tokens = self.get_token_count(request.transcript)
if estimated_tokens > 1800000: # Leave room for prompt and response
# Only chunk if absolutely necessary (very rare with 2M context)
return await self._generate_chunked_summary(request)
prompt = self._build_summary_prompt(request)
try:
start_time = time.time()
# Make API request to Gemini
url = f"{self.base_url}/models/{self.model}:generateContent"
payload = {
"contents": [
{
"parts": [
{"text": prompt}
]
}
],
"generationConfig": {
"temperature": 0.3, # Lower temperature for consistent summaries
"maxOutputTokens": self._get_max_tokens(request.length),
"topP": 0.8,
"topK": 10
},
"safetySettings": [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
}
]
}
response = await self.client.post(
url,
params={"key": self.api_key},
json=payload,
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
result = response.json()
processing_time = time.time() - start_time
# Extract response text
if "candidates" not in result or not result["candidates"]:
raise AIServiceError(
message="No response candidates from Gemini",
error_code=ErrorCode.AI_SERVICE_ERROR
)
content = result["candidates"][0]["content"]["parts"][0]["text"]
# Parse JSON from response
try:
result_data = json.loads(content)
except json.JSONDecodeError:
# Fallback to structured parsing
result_data = self._extract_structured_data(content)
# Calculate token usage and costs
input_tokens = estimated_tokens
output_tokens = self.get_token_count(content)
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
total_cost = input_cost + output_cost
# Check for usage info if available
if "usageMetadata" in result:
usage = result["usageMetadata"]
input_tokens = usage.get("promptTokenCount", input_tokens)
output_tokens = usage.get("candidatesTokenCount", output_tokens)
# Recalculate costs with actual usage
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
total_cost = input_cost + output_cost
return SummaryResult(
summary=result_data.get("summary", ""),
key_points=result_data.get("key_points", []),
main_themes=result_data.get("main_themes", []),
actionable_insights=result_data.get("actionable_insights", []),
confidence_score=result_data.get("confidence_score", 0.9),
processing_metadata={
"model": self.model,
"processing_time_seconds": processing_time,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"chunks_processed": 1,
"context_window_used": f"{input_tokens}/{self.max_tokens_input}",
"large_context_advantage": "Single pass processing - no chunking needed"
},
cost_data={
"input_cost_usd": input_cost,
"output_cost_usd": output_cost,
"total_cost_usd": total_cost,
"cost_per_summary": total_cost,
"model_efficiency": "Large context eliminates chunking overhead"
}
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise AIServiceError(
message="Gemini API rate limit exceeded",
error_code=ErrorCode.RATE_LIMIT_ERROR,
recoverable=True
)
elif e.response.status_code == 400:
error_detail = ""
try:
error_data = e.response.json()
error_detail = error_data.get("error", {}).get("message", "")
except:
pass
raise AIServiceError(
message=f"Gemini API request error: {error_detail}",
error_code=ErrorCode.AI_SERVICE_ERROR,
recoverable=False
)
else:
raise AIServiceError(
message=f"Gemini API error: {e.response.status_code}",
error_code=ErrorCode.AI_SERVICE_ERROR,
recoverable=True
)
except Exception as e:
raise AIServiceError(
message=f"Gemini summarization failed: {str(e)}",
error_code=ErrorCode.AI_SERVICE_ERROR,
details={
"model": self.model,
"transcript_length": len(request.transcript),
"error_type": type(e).__name__
}
)
def _build_summary_prompt(self, request: SummaryRequest) -> str:
"""Build optimized prompt for Gemini summary generation."""
length_instructions = {
SummaryLength.BRIEF: "Generate a concise summary in 100-200 words",
SummaryLength.STANDARD: "Generate a comprehensive summary in 300-500 words",
SummaryLength.DETAILED: "Generate a detailed summary in 500-800 words"
}
focus_instruction = ""
if request.focus_areas:
focus_instruction = f"\nPay special attention to these areas: {', '.join(request.focus_areas)}"
return f"""
Analyze this YouTube video transcript and provide a structured summary. With your large context window, you can process the entire transcript at once for maximum coherence.
{length_instructions[request.length]}.
Please respond with a valid JSON object in this exact format:
{{
"summary": "Main summary text here",
"key_points": ["Point 1", "Point 2", "Point 3"],
"main_themes": ["Theme 1", "Theme 2", "Theme 3"],
"actionable_insights": ["Insight 1", "Insight 2"],
"confidence_score": 0.95
}}
Guidelines:
- Extract 5-8 key points that capture the most important information
- Identify 3-5 main themes or topics discussed
- Provide 3-6 actionable insights that viewers can apply
- Assign a confidence score (0.0-1.0) based on transcript quality and coherence
- Use clear, engaging language that's accessible to a general audience
- Focus on value and practical takeaways
- Maintain narrative flow since you can see the entire transcript{focus_instruction}
Transcript:
{request.transcript}
"""
def _extract_structured_data(self, response_text: str) -> dict:
"""Extract structured data when JSON parsing fails."""
try:
# Look for JSON block in the response
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# Fallback: parse structured text
lines = response_text.split('\n')
summary = ""
key_points = []
main_themes = []
actionable_insights = []
confidence_score = 0.9
current_section = None
for line in lines:
line = line.strip()
if not line:
continue
# Detect sections
if "summary" in line.lower() and (":" in line or line.endswith("summary")):
current_section = "summary"
if ":" in line:
summary = line.split(":", 1)[1].strip()
continue
elif "key points" in line.lower() or "key_points" in line.lower():
current_section = "key_points"
continue
elif "main themes" in line.lower() or "themes" in line.lower():
current_section = "main_themes"
continue
elif "actionable" in line.lower() or "insights" in line.lower():
current_section = "actionable_insights"
continue
elif "confidence" in line.lower():
numbers = re.findall(r'0?\\.\\d+|\\d+', line)
if numbers:
try:
confidence_score = float(numbers[0])
except ValueError:
pass
continue
# Add content to appropriate section
if current_section == "summary" and not summary:
summary = line
elif current_section == "key_points" and (line.startswith(('-', '', '*')) or line[0].isdigit()):
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
if cleaned_line:
key_points.append(cleaned_line)
elif current_section == "main_themes" and (line.startswith(('-', '', '*')) or line[0].isdigit()):
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
if cleaned_line:
main_themes.append(cleaned_line)
elif current_section == "actionable_insights" and (line.startswith(('-', '', '*')) or line[0].isdigit()):
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
if cleaned_line:
actionable_insights.append(cleaned_line)
return {
"summary": summary or response_text[:500] + "...",
"key_points": key_points[:8],
"main_themes": main_themes[:5],
"actionable_insights": actionable_insights[:6],
"confidence_score": confidence_score
}
async def _generate_chunked_summary(self, request: SummaryRequest) -> SummaryResult:
"""Handle extremely long transcripts (rare with 2M context) using hierarchical approach."""
# Split transcript into large chunks (1.5M tokens each)
chunks = self._split_transcript_intelligently(request.transcript, max_tokens=1500000)
# Generate summary for each chunk
chunk_summaries = []
total_cost = 0.0
total_tokens = 0
for i, chunk in enumerate(chunks):
chunk_request = SummaryRequest(
transcript=chunk,
length=SummaryLength.STANDARD, # Standard summaries for chunks
focus_areas=request.focus_areas,
language=request.language
)
chunk_result = await self.generate_summary(chunk_request)
chunk_summaries.append(chunk_result.summary)
total_cost += chunk_result.cost_data["total_cost_usd"]
total_tokens += chunk_result.processing_metadata["total_tokens"]
# Add delay to respect rate limits
await asyncio.sleep(1.0)
# Combine chunk summaries into final summary using hierarchical approach
combined_text = "\n\n".join([
f"Part {i+1}: {summary}"
for i, summary in enumerate(chunk_summaries)
])
final_request = SummaryRequest(
transcript=combined_text,
length=request.length,
focus_areas=request.focus_areas,
language=request.language
)
final_result = await self.generate_summary(final_request)
# Update metadata to reflect chunked processing
final_result.processing_metadata.update({
"chunks_processed": len(chunks),
"total_tokens": total_tokens + final_result.processing_metadata["total_tokens"],
"chunking_strategy": "hierarchical_large_chunks",
"chunk_size": "1.5M tokens per chunk"
})
final_result.cost_data["total_cost_usd"] = total_cost + final_result.cost_data["total_cost_usd"]
return final_result
def _split_transcript_intelligently(self, transcript: str, max_tokens: int = 1500000) -> List[str]:
"""Split transcript at natural boundaries while respecting large token limits."""
# With Gemini's large context, we can use very large chunks
paragraphs = transcript.split('\n\n')
chunks = []
current_chunk = []
current_tokens = 0
for paragraph in paragraphs:
paragraph_tokens = self.get_token_count(paragraph)
# If single paragraph exceeds limit, split by sentences
if paragraph_tokens > max_tokens:
sentences = paragraph.split('. ')
for sentence in sentences:
sentence_tokens = self.get_token_count(sentence)
if current_tokens + sentence_tokens > max_tokens and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [sentence]
current_tokens = sentence_tokens
else:
current_chunk.append(sentence)
current_tokens += sentence_tokens
else:
if current_tokens + paragraph_tokens > max_tokens and current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = [paragraph]
current_tokens = paragraph_tokens
else:
current_chunk.append(paragraph)
current_tokens += paragraph_tokens
# Add final chunk
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return chunks
def _get_max_tokens(self, length: SummaryLength) -> int:
"""Get max output tokens based on summary length."""
return {
SummaryLength.BRIEF: 400,
SummaryLength.STANDARD: 800,
SummaryLength.DETAILED: 1500
}[length]
def estimate_cost(self, transcript: str, length: SummaryLength) -> float:
"""Estimate cost for summarizing transcript."""
input_tokens = self.get_token_count(transcript)
output_tokens = self._get_max_tokens(length)
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
return input_cost + output_cost
def get_token_count(self, text: str) -> int:
"""Estimate token count for Gemini model (roughly 4 chars per token)."""
# Gemini uses a similar tokenization to other models
return len(text) // 4
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit - cleanup resources."""
await self.client.aclose()