439 lines
18 KiB
Python
439 lines
18 KiB
Python
"""Google Gemini summarization service with 2M token context support."""
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
import httpx
|
|
import re
|
|
|
|
from .ai_service import AIService, SummaryRequest, SummaryResult, SummaryLength
|
|
from ..core.exceptions import AIServiceError, ErrorCode
|
|
|
|
|
|
class GeminiSummarizer(AIService):
|
|
"""Google Gemini-based summarization service with large context support."""
|
|
|
|
def __init__(self, api_key: str, model: str = "gemini-1.5-pro"):
|
|
"""Initialize Gemini summarizer.
|
|
|
|
Args:
|
|
api_key: Google AI API key
|
|
model: Model to use (gemini-1.5-pro for 2M context, gemini-1.5-flash for speed)
|
|
"""
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
|
|
# Context window sizes
|
|
self.max_tokens_input = 2000000 if "1.5-pro" in model else 1000000 # 2M for Pro, 1M for Flash
|
|
self.max_tokens_output = 8192 # Standard output limit
|
|
|
|
# Cost per 1M tokens (Gemini 1.5 Pro pricing - very competitive)
|
|
if "1.5-pro" in model:
|
|
self.input_cost_per_1k = 0.007 # $7 per 1M input tokens
|
|
self.output_cost_per_1k = 0.021 # $21 per 1M output tokens
|
|
else: # Flash model
|
|
self.input_cost_per_1k = 0.00015 # $0.15 per 1M input tokens
|
|
self.output_cost_per_1k = 0.0006 # $0.60 per 1M output tokens
|
|
|
|
# HTTP client for API calls
|
|
self.client = httpx.AsyncClient(timeout=300.0) # 5 minute timeout for long context
|
|
|
|
async def generate_summary(self, request: SummaryRequest) -> SummaryResult:
|
|
"""Generate structured summary using Google Gemini with large context."""
|
|
|
|
# With 2M token context, we can handle very long transcripts without chunking!
|
|
estimated_tokens = self.get_token_count(request.transcript)
|
|
if estimated_tokens > 1800000: # Leave room for prompt and response
|
|
# Only chunk if absolutely necessary (very rare with 2M context)
|
|
return await self._generate_chunked_summary(request)
|
|
|
|
prompt = self._build_summary_prompt(request)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Make API request to Gemini
|
|
url = f"{self.base_url}/models/{self.model}:generateContent"
|
|
|
|
payload = {
|
|
"contents": [
|
|
{
|
|
"parts": [
|
|
{"text": prompt}
|
|
]
|
|
}
|
|
],
|
|
"generationConfig": {
|
|
"temperature": 0.3, # Lower temperature for consistent summaries
|
|
"maxOutputTokens": self._get_max_tokens(request.length),
|
|
"topP": 0.8,
|
|
"topK": 10
|
|
},
|
|
"safetySettings": [
|
|
{
|
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
"threshold": "BLOCK_NONE"
|
|
},
|
|
{
|
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
|
"threshold": "BLOCK_NONE"
|
|
},
|
|
{
|
|
"category": "HARM_CATEGORY_HARASSMENT",
|
|
"threshold": "BLOCK_NONE"
|
|
},
|
|
{
|
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
"threshold": "BLOCK_NONE"
|
|
}
|
|
]
|
|
}
|
|
|
|
response = await self.client.post(
|
|
url,
|
|
params={"key": self.api_key},
|
|
json=payload,
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
# Extract response text
|
|
if "candidates" not in result or not result["candidates"]:
|
|
raise AIServiceError(
|
|
message="No response candidates from Gemini",
|
|
error_code=ErrorCode.AI_SERVICE_ERROR
|
|
)
|
|
|
|
content = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
|
|
# Parse JSON from response
|
|
try:
|
|
result_data = json.loads(content)
|
|
except json.JSONDecodeError:
|
|
# Fallback to structured parsing
|
|
result_data = self._extract_structured_data(content)
|
|
|
|
# Calculate token usage and costs
|
|
input_tokens = estimated_tokens
|
|
output_tokens = self.get_token_count(content)
|
|
|
|
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
|
|
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
|
|
total_cost = input_cost + output_cost
|
|
|
|
# Check for usage info if available
|
|
if "usageMetadata" in result:
|
|
usage = result["usageMetadata"]
|
|
input_tokens = usage.get("promptTokenCount", input_tokens)
|
|
output_tokens = usage.get("candidatesTokenCount", output_tokens)
|
|
|
|
# Recalculate costs with actual usage
|
|
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
|
|
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
|
|
total_cost = input_cost + output_cost
|
|
|
|
return SummaryResult(
|
|
summary=result_data.get("summary", ""),
|
|
key_points=result_data.get("key_points", []),
|
|
main_themes=result_data.get("main_themes", []),
|
|
actionable_insights=result_data.get("actionable_insights", []),
|
|
confidence_score=result_data.get("confidence_score", 0.9),
|
|
processing_metadata={
|
|
"model": self.model,
|
|
"processing_time_seconds": processing_time,
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"total_tokens": input_tokens + output_tokens,
|
|
"chunks_processed": 1,
|
|
"context_window_used": f"{input_tokens}/{self.max_tokens_input}",
|
|
"large_context_advantage": "Single pass processing - no chunking needed"
|
|
},
|
|
cost_data={
|
|
"input_cost_usd": input_cost,
|
|
"output_cost_usd": output_cost,
|
|
"total_cost_usd": total_cost,
|
|
"cost_per_summary": total_cost,
|
|
"model_efficiency": "Large context eliminates chunking overhead"
|
|
}
|
|
)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code == 429:
|
|
raise AIServiceError(
|
|
message="Gemini API rate limit exceeded",
|
|
error_code=ErrorCode.RATE_LIMIT_ERROR,
|
|
recoverable=True
|
|
)
|
|
elif e.response.status_code == 400:
|
|
error_detail = ""
|
|
try:
|
|
error_data = e.response.json()
|
|
error_detail = error_data.get("error", {}).get("message", "")
|
|
except:
|
|
pass
|
|
|
|
raise AIServiceError(
|
|
message=f"Gemini API request error: {error_detail}",
|
|
error_code=ErrorCode.AI_SERVICE_ERROR,
|
|
recoverable=False
|
|
)
|
|
else:
|
|
raise AIServiceError(
|
|
message=f"Gemini API error: {e.response.status_code}",
|
|
error_code=ErrorCode.AI_SERVICE_ERROR,
|
|
recoverable=True
|
|
)
|
|
|
|
except Exception as e:
|
|
raise AIServiceError(
|
|
message=f"Gemini summarization failed: {str(e)}",
|
|
error_code=ErrorCode.AI_SERVICE_ERROR,
|
|
details={
|
|
"model": self.model,
|
|
"transcript_length": len(request.transcript),
|
|
"error_type": type(e).__name__
|
|
}
|
|
)
|
|
|
|
def _build_summary_prompt(self, request: SummaryRequest) -> str:
|
|
"""Build optimized prompt for Gemini summary generation."""
|
|
length_instructions = {
|
|
SummaryLength.BRIEF: "Generate a concise summary in 100-200 words",
|
|
SummaryLength.STANDARD: "Generate a comprehensive summary in 300-500 words",
|
|
SummaryLength.DETAILED: "Generate a detailed summary in 500-800 words"
|
|
}
|
|
|
|
focus_instruction = ""
|
|
if request.focus_areas:
|
|
focus_instruction = f"\nPay special attention to these areas: {', '.join(request.focus_areas)}"
|
|
|
|
return f"""
|
|
Analyze this YouTube video transcript and provide a structured summary. With your large context window, you can process the entire transcript at once for maximum coherence.
|
|
|
|
{length_instructions[request.length]}.
|
|
|
|
Please respond with a valid JSON object in this exact format:
|
|
{{
|
|
"summary": "Main summary text here",
|
|
"key_points": ["Point 1", "Point 2", "Point 3"],
|
|
"main_themes": ["Theme 1", "Theme 2", "Theme 3"],
|
|
"actionable_insights": ["Insight 1", "Insight 2"],
|
|
"confidence_score": 0.95
|
|
}}
|
|
|
|
Guidelines:
|
|
- Extract 5-8 key points that capture the most important information
|
|
- Identify 3-5 main themes or topics discussed
|
|
- Provide 3-6 actionable insights that viewers can apply
|
|
- Assign a confidence score (0.0-1.0) based on transcript quality and coherence
|
|
- Use clear, engaging language that's accessible to a general audience
|
|
- Focus on value and practical takeaways
|
|
- Maintain narrative flow since you can see the entire transcript{focus_instruction}
|
|
|
|
Transcript:
|
|
{request.transcript}
|
|
"""
|
|
|
|
def _extract_structured_data(self, response_text: str) -> dict:
|
|
"""Extract structured data when JSON parsing fails."""
|
|
try:
|
|
# Look for JSON block in the response
|
|
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
|
if json_match:
|
|
return json.loads(json_match.group())
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Fallback: parse structured text
|
|
lines = response_text.split('\n')
|
|
|
|
summary = ""
|
|
key_points = []
|
|
main_themes = []
|
|
actionable_insights = []
|
|
confidence_score = 0.9
|
|
|
|
current_section = None
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
# Detect sections
|
|
if "summary" in line.lower() and (":" in line or line.endswith("summary")):
|
|
current_section = "summary"
|
|
if ":" in line:
|
|
summary = line.split(":", 1)[1].strip()
|
|
continue
|
|
elif "key points" in line.lower() or "key_points" in line.lower():
|
|
current_section = "key_points"
|
|
continue
|
|
elif "main themes" in line.lower() or "themes" in line.lower():
|
|
current_section = "main_themes"
|
|
continue
|
|
elif "actionable" in line.lower() or "insights" in line.lower():
|
|
current_section = "actionable_insights"
|
|
continue
|
|
elif "confidence" in line.lower():
|
|
numbers = re.findall(r'0?\\.\\d+|\\d+', line)
|
|
if numbers:
|
|
try:
|
|
confidence_score = float(numbers[0])
|
|
except ValueError:
|
|
pass
|
|
continue
|
|
|
|
# Add content to appropriate section
|
|
if current_section == "summary" and not summary:
|
|
summary = line
|
|
elif current_section == "key_points" and (line.startswith(('-', '•', '*')) or line[0].isdigit()):
|
|
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
|
|
if cleaned_line:
|
|
key_points.append(cleaned_line)
|
|
elif current_section == "main_themes" and (line.startswith(('-', '•', '*')) or line[0].isdigit()):
|
|
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
|
|
if cleaned_line:
|
|
main_themes.append(cleaned_line)
|
|
elif current_section == "actionable_insights" and (line.startswith(('-', '•', '*')) or line[0].isdigit()):
|
|
cleaned_line = re.sub(r'^[-•*0-9.)\s]+', '', line).strip()
|
|
if cleaned_line:
|
|
actionable_insights.append(cleaned_line)
|
|
|
|
return {
|
|
"summary": summary or response_text[:500] + "...",
|
|
"key_points": key_points[:8],
|
|
"main_themes": main_themes[:5],
|
|
"actionable_insights": actionable_insights[:6],
|
|
"confidence_score": confidence_score
|
|
}
|
|
|
|
async def _generate_chunked_summary(self, request: SummaryRequest) -> SummaryResult:
|
|
"""Handle extremely long transcripts (rare with 2M context) using hierarchical approach."""
|
|
|
|
# Split transcript into large chunks (1.5M tokens each)
|
|
chunks = self._split_transcript_intelligently(request.transcript, max_tokens=1500000)
|
|
|
|
# Generate summary for each chunk
|
|
chunk_summaries = []
|
|
total_cost = 0.0
|
|
total_tokens = 0
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_request = SummaryRequest(
|
|
transcript=chunk,
|
|
length=SummaryLength.STANDARD, # Standard summaries for chunks
|
|
focus_areas=request.focus_areas,
|
|
language=request.language
|
|
)
|
|
|
|
chunk_result = await self.generate_summary(chunk_request)
|
|
chunk_summaries.append(chunk_result.summary)
|
|
total_cost += chunk_result.cost_data["total_cost_usd"]
|
|
total_tokens += chunk_result.processing_metadata["total_tokens"]
|
|
|
|
# Add delay to respect rate limits
|
|
await asyncio.sleep(1.0)
|
|
|
|
# Combine chunk summaries into final summary using hierarchical approach
|
|
combined_text = "\n\n".join([
|
|
f"Part {i+1}: {summary}"
|
|
for i, summary in enumerate(chunk_summaries)
|
|
])
|
|
|
|
final_request = SummaryRequest(
|
|
transcript=combined_text,
|
|
length=request.length,
|
|
focus_areas=request.focus_areas,
|
|
language=request.language
|
|
)
|
|
|
|
final_result = await self.generate_summary(final_request)
|
|
|
|
# Update metadata to reflect chunked processing
|
|
final_result.processing_metadata.update({
|
|
"chunks_processed": len(chunks),
|
|
"total_tokens": total_tokens + final_result.processing_metadata["total_tokens"],
|
|
"chunking_strategy": "hierarchical_large_chunks",
|
|
"chunk_size": "1.5M tokens per chunk"
|
|
})
|
|
|
|
final_result.cost_data["total_cost_usd"] = total_cost + final_result.cost_data["total_cost_usd"]
|
|
|
|
return final_result
|
|
|
|
def _split_transcript_intelligently(self, transcript: str, max_tokens: int = 1500000) -> List[str]:
|
|
"""Split transcript at natural boundaries while respecting large token limits."""
|
|
|
|
# With Gemini's large context, we can use very large chunks
|
|
paragraphs = transcript.split('\n\n')
|
|
chunks = []
|
|
current_chunk = []
|
|
current_tokens = 0
|
|
|
|
for paragraph in paragraphs:
|
|
paragraph_tokens = self.get_token_count(paragraph)
|
|
|
|
# If single paragraph exceeds limit, split by sentences
|
|
if paragraph_tokens > max_tokens:
|
|
sentences = paragraph.split('. ')
|
|
for sentence in sentences:
|
|
sentence_tokens = self.get_token_count(sentence)
|
|
|
|
if current_tokens + sentence_tokens > max_tokens and current_chunk:
|
|
chunks.append(' '.join(current_chunk))
|
|
current_chunk = [sentence]
|
|
current_tokens = sentence_tokens
|
|
else:
|
|
current_chunk.append(sentence)
|
|
current_tokens += sentence_tokens
|
|
else:
|
|
if current_tokens + paragraph_tokens > max_tokens and current_chunk:
|
|
chunks.append('\n\n'.join(current_chunk))
|
|
current_chunk = [paragraph]
|
|
current_tokens = paragraph_tokens
|
|
else:
|
|
current_chunk.append(paragraph)
|
|
current_tokens += paragraph_tokens
|
|
|
|
# Add final chunk
|
|
if current_chunk:
|
|
chunks.append('\n\n'.join(current_chunk))
|
|
|
|
return chunks
|
|
|
|
def _get_max_tokens(self, length: SummaryLength) -> int:
|
|
"""Get max output tokens based on summary length."""
|
|
return {
|
|
SummaryLength.BRIEF: 400,
|
|
SummaryLength.STANDARD: 800,
|
|
SummaryLength.DETAILED: 1500
|
|
}[length]
|
|
|
|
def estimate_cost(self, transcript: str, length: SummaryLength) -> float:
|
|
"""Estimate cost for summarizing transcript."""
|
|
input_tokens = self.get_token_count(transcript)
|
|
output_tokens = self._get_max_tokens(length)
|
|
|
|
input_cost = (input_tokens / 1000) * self.input_cost_per_1k
|
|
output_cost = (output_tokens / 1000) * self.output_cost_per_1k
|
|
|
|
return input_cost + output_cost
|
|
|
|
def get_token_count(self, text: str) -> int:
|
|
"""Estimate token count for Gemini model (roughly 4 chars per token)."""
|
|
# Gemini uses a similar tokenization to other models
|
|
return len(text) // 4
|
|
|
|
async def __aenter__(self):
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""Async context manager exit - cleanup resources."""
|
|
await self.client.aclose() |