527 lines
20 KiB
Python
527 lines
20 KiB
Python
"""
|
|
Batch processing service for handling multiple video summarizations
|
|
"""
|
|
import asyncio
|
|
import re
|
|
import json
|
|
import zipfile
|
|
import tempfile
|
|
import os
|
|
from typing import List, Dict, Optional, Any
|
|
from datetime import datetime, timedelta
|
|
import uuid
|
|
from sqlalchemy.orm import Session
|
|
import logging
|
|
|
|
from backend.models.batch_job import BatchJob, BatchJobItem
|
|
from backend.models.summary import Summary
|
|
from backend.services.summary_pipeline import SummaryPipeline
|
|
from backend.services.notification_service import NotificationService
|
|
from backend.core.websocket_manager import websocket_manager
|
|
from backend.models.pipeline import PipelineConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchProcessingService:
|
|
"""Service for processing multiple YouTube videos in batch"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_session: Session,
|
|
summary_pipeline: Optional[SummaryPipeline] = None,
|
|
notification_service: Optional[NotificationService] = None
|
|
):
|
|
self.db = db_session
|
|
self.pipeline = summary_pipeline
|
|
self.notifications = notification_service
|
|
self.active_jobs: Dict[str, asyncio.Task] = {}
|
|
|
|
def _validate_youtube_url(self, url: str) -> bool:
|
|
"""Validate if URL is a valid YouTube URL"""
|
|
youtube_regex = r'(https?://)?(www\.)?(youtube\.com/(watch\?v=|embed/|v/)|youtu\.be/|m\.youtube\.com/watch\?v=)[\w\-]+'
|
|
return bool(re.match(youtube_regex, url))
|
|
|
|
def _extract_video_id(self, url: str) -> Optional[str]:
|
|
"""Extract video ID from YouTube URL"""
|
|
patterns = [
|
|
r'(?:v=|\/)([0-9A-Za-z_-]{11}).*',
|
|
r'(?:embed\/)([0-9A-Za-z_-]{11})',
|
|
r'(?:watch\?v=)([0-9A-Za-z_-]{11})'
|
|
]
|
|
|
|
for pattern in patterns:
|
|
match = re.search(pattern, url)
|
|
if match:
|
|
return match.group(1)
|
|
return None
|
|
|
|
async def create_batch_job(
|
|
self,
|
|
user_id: str,
|
|
urls: List[str],
|
|
name: Optional[str] = None,
|
|
model: str = "deepseek",
|
|
summary_length: str = "standard",
|
|
options: Optional[Dict] = None
|
|
) -> BatchJob:
|
|
"""Create a new batch processing job"""
|
|
|
|
# Validate and deduplicate URLs
|
|
validated_urls = []
|
|
seen_ids = set()
|
|
|
|
for url in urls:
|
|
if self._validate_youtube_url(url):
|
|
video_id = self._extract_video_id(url)
|
|
if video_id and video_id not in seen_ids:
|
|
validated_urls.append(url)
|
|
seen_ids.add(video_id)
|
|
|
|
if not validated_urls:
|
|
raise ValueError("No valid YouTube URLs provided")
|
|
|
|
# Create batch job
|
|
batch_job = BatchJob(
|
|
user_id=user_id,
|
|
name=name or f"Batch {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
|
urls=validated_urls,
|
|
total_videos=len(validated_urls),
|
|
model=model,
|
|
summary_length=summary_length,
|
|
options=options or {},
|
|
status="pending"
|
|
)
|
|
|
|
self.db.add(batch_job)
|
|
self.db.flush() # Get the ID
|
|
|
|
# Create job items
|
|
for idx, url in enumerate(validated_urls):
|
|
item = BatchJobItem(
|
|
batch_job_id=batch_job.id,
|
|
url=url,
|
|
position=idx,
|
|
video_id=self._extract_video_id(url)
|
|
)
|
|
self.db.add(item)
|
|
|
|
self.db.commit()
|
|
|
|
# Start processing in background
|
|
task = asyncio.create_task(self._process_batch(batch_job.id))
|
|
self.active_jobs[batch_job.id] = task
|
|
|
|
logger.info(f"Created batch job {batch_job.id} with {len(validated_urls)} videos")
|
|
return batch_job
|
|
|
|
async def _process_batch(self, batch_job_id: str):
|
|
"""Process all videos in a batch sequentially"""
|
|
|
|
try:
|
|
# Get batch job
|
|
batch_job = self.db.query(BatchJob).filter_by(id=batch_job_id).first()
|
|
if not batch_job:
|
|
logger.error(f"Batch job {batch_job_id} not found")
|
|
return
|
|
|
|
# Update status to processing
|
|
batch_job.status = "processing"
|
|
batch_job.started_at = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
# Send initial progress update
|
|
await self._send_progress_update(batch_job)
|
|
|
|
# Get all items to process
|
|
items = self.db.query(BatchJobItem).filter_by(
|
|
batch_job_id=batch_job_id
|
|
).order_by(BatchJobItem.position).all()
|
|
|
|
# Process each item
|
|
for item in items:
|
|
if batch_job.status == "cancelled":
|
|
logger.info(f"Batch job {batch_job_id} cancelled")
|
|
break
|
|
|
|
await self._process_single_item(item, batch_job)
|
|
|
|
# Update progress
|
|
await self._send_progress_update(batch_job)
|
|
|
|
# Small delay between videos to avoid rate limiting
|
|
await asyncio.sleep(2)
|
|
|
|
# Finalize batch
|
|
if batch_job.status != "cancelled":
|
|
batch_job.status = "completed"
|
|
|
|
batch_job.completed_at = datetime.utcnow()
|
|
|
|
# Calculate total processing time
|
|
if batch_job.started_at:
|
|
batch_job.total_processing_time = (
|
|
batch_job.completed_at - batch_job.started_at
|
|
).total_seconds()
|
|
|
|
# Generate export file
|
|
try:
|
|
export_url = await self._generate_export(batch_job_id)
|
|
batch_job.export_url = export_url
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate export for batch {batch_job_id}: {e}")
|
|
|
|
self.db.commit()
|
|
|
|
# Send completion notification
|
|
await self._send_completion_notification(batch_job)
|
|
|
|
# Final progress update
|
|
await self._send_progress_update(batch_job)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing batch {batch_job_id}: {e}")
|
|
batch_job.status = "failed"
|
|
self.db.commit()
|
|
|
|
finally:
|
|
# Clean up active job
|
|
if batch_job_id in self.active_jobs:
|
|
del self.active_jobs[batch_job_id]
|
|
|
|
async def _process_single_item(self, item: BatchJobItem, batch_job: BatchJob):
|
|
"""Process a single video item in the batch"""
|
|
|
|
try:
|
|
# Update item status
|
|
item.status = "processing"
|
|
item.started_at = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
# Create pipeline config
|
|
config = PipelineConfig(
|
|
model=batch_job.model,
|
|
summary_length=batch_job.summary_length,
|
|
**batch_job.options
|
|
)
|
|
|
|
# Process video using the pipeline
|
|
if self.pipeline:
|
|
# Start pipeline processing
|
|
pipeline_job_id = await self.pipeline.process_video(
|
|
video_url=item.url,
|
|
config=config
|
|
)
|
|
|
|
# Wait for completion (with timeout)
|
|
result = await self._wait_for_pipeline_completion(
|
|
pipeline_job_id,
|
|
timeout=600 # 10 minutes max per video
|
|
)
|
|
|
|
if result and result.status == "completed":
|
|
# Create summary record
|
|
summary = Summary(
|
|
user_id=batch_job.user_id,
|
|
video_url=item.url,
|
|
video_id=item.video_id,
|
|
video_title=result.video_metadata.get("title") if result.video_metadata else None,
|
|
channel_name=result.video_metadata.get("channel") if result.video_metadata else None,
|
|
duration_seconds=result.video_metadata.get("duration") if result.video_metadata else None,
|
|
summary_text=result.summary,
|
|
key_points=result.key_points,
|
|
model_used=batch_job.model,
|
|
confidence_score=result.confidence_score,
|
|
quality_score=result.quality_score,
|
|
processing_time=result.processing_time,
|
|
cost_data=result.cost_data
|
|
)
|
|
self.db.add(summary)
|
|
self.db.flush()
|
|
|
|
# Update item with success
|
|
item.status = "completed"
|
|
item.summary_id = summary.id
|
|
item.video_title = summary.video_title
|
|
item.channel_name = summary.channel_name
|
|
item.duration_seconds = summary.duration_seconds
|
|
item.cost_usd = result.cost_data.get("total_cost_usd", 0) if result.cost_data else 0
|
|
|
|
# Update batch counters
|
|
batch_job.completed_videos += 1
|
|
batch_job.total_cost_usd += item.cost_usd
|
|
|
|
else:
|
|
# Processing failed
|
|
error_msg = result.error if result else "Pipeline timeout"
|
|
await self._handle_item_failure(item, batch_job, error_msg, "processing_error")
|
|
|
|
else:
|
|
# No pipeline available (shouldn't happen in production)
|
|
await self._handle_item_failure(item, batch_job, "Pipeline not available", "system_error")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing item {item.id}: {e}")
|
|
await self._handle_item_failure(item, batch_job, str(e), "exception")
|
|
|
|
finally:
|
|
# Update item completion time
|
|
item.completed_at = datetime.utcnow()
|
|
if item.started_at:
|
|
item.processing_time_seconds = (
|
|
item.completed_at - item.started_at
|
|
).total_seconds()
|
|
self.db.commit()
|
|
|
|
async def _handle_item_failure(
|
|
self,
|
|
item: BatchJobItem,
|
|
batch_job: BatchJob,
|
|
error_message: str,
|
|
error_type: str
|
|
):
|
|
"""Handle a failed item with retry logic"""
|
|
|
|
item.retry_count += 1
|
|
|
|
if item.retry_count < item.max_retries:
|
|
# Will retry later
|
|
item.status = "pending"
|
|
logger.info(f"Item {item.id} failed, will retry ({item.retry_count}/{item.max_retries})")
|
|
else:
|
|
# Max retries reached
|
|
item.status = "failed"
|
|
item.error_message = error_message
|
|
item.error_type = error_type
|
|
batch_job.failed_videos += 1
|
|
logger.error(f"Item {item.id} failed after {item.retry_count} retries: {error_message}")
|
|
|
|
async def _wait_for_pipeline_completion(
|
|
self,
|
|
pipeline_job_id: str,
|
|
timeout: int = 600
|
|
) -> Optional[Any]:
|
|
"""Wait for pipeline job to complete with timeout"""
|
|
|
|
start_time = datetime.utcnow()
|
|
|
|
while (datetime.utcnow() - start_time).total_seconds() < timeout:
|
|
if self.pipeline:
|
|
result = await self.pipeline.get_pipeline_result(pipeline_job_id)
|
|
|
|
if result and result.status in ["completed", "failed"]:
|
|
return result
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
logger.warning(f"Pipeline job {pipeline_job_id} timed out after {timeout} seconds")
|
|
return None
|
|
|
|
async def _send_progress_update(self, batch_job: BatchJob):
|
|
"""Send progress update via WebSocket"""
|
|
|
|
# Get current processing item
|
|
current_item = self.db.query(BatchJobItem).filter_by(
|
|
batch_job_id=batch_job.id,
|
|
status="processing"
|
|
).first()
|
|
|
|
progress_data = {
|
|
"batch_job_id": batch_job.id,
|
|
"status": batch_job.status,
|
|
"name": batch_job.name,
|
|
"progress": {
|
|
"total": batch_job.total_videos,
|
|
"completed": batch_job.completed_videos,
|
|
"failed": batch_job.failed_videos,
|
|
"percentage": batch_job.get_progress_percentage()
|
|
},
|
|
"current_item": {
|
|
"url": current_item.url,
|
|
"position": current_item.position + 1,
|
|
"video_title": current_item.video_title
|
|
} if current_item else None,
|
|
"estimated_completion": self._estimate_completion_time(batch_job),
|
|
"export_url": batch_job.export_url
|
|
}
|
|
|
|
# Send via WebSocket to subscribers
|
|
await websocket_manager.broadcast_to_job(
|
|
f"batch_{batch_job.id}",
|
|
{
|
|
"type": "batch_progress",
|
|
"data": progress_data
|
|
}
|
|
)
|
|
|
|
def _estimate_completion_time(self, batch_job: BatchJob) -> Optional[str]:
|
|
"""Estimate completion time based on average processing time"""
|
|
|
|
if batch_job.completed_videos == 0:
|
|
return None
|
|
|
|
# Calculate average time per video
|
|
elapsed = (datetime.utcnow() - batch_job.started_at).total_seconds()
|
|
avg_time_per_video = elapsed / batch_job.completed_videos
|
|
|
|
# Estimate remaining time
|
|
remaining_videos = batch_job.total_videos - batch_job.completed_videos - batch_job.failed_videos
|
|
estimated_seconds = remaining_videos * avg_time_per_video
|
|
|
|
estimated_completion = datetime.utcnow() + timedelta(seconds=estimated_seconds)
|
|
return estimated_completion.isoformat()
|
|
|
|
async def _send_completion_notification(self, batch_job: BatchJob):
|
|
"""Send completion notification"""
|
|
|
|
if self.notifications:
|
|
await self.notifications.send_notification(
|
|
user_id=batch_job.user_id,
|
|
type="batch_complete",
|
|
title=f"Batch Processing Complete: {batch_job.name}",
|
|
message=f"Processed {batch_job.completed_videos} videos successfully, {batch_job.failed_videos} failed.",
|
|
data={
|
|
"batch_job_id": batch_job.id,
|
|
"export_url": batch_job.export_url
|
|
}
|
|
)
|
|
|
|
async def _generate_export(self, batch_job_id: str) -> str:
|
|
"""Generate ZIP export of all summaries in the batch"""
|
|
|
|
batch_job = self.db.query(BatchJob).filter_by(id=batch_job_id).first()
|
|
if not batch_job:
|
|
return ""
|
|
|
|
# Get all completed items with summaries
|
|
items = self.db.query(BatchJobItem).filter_by(
|
|
batch_job_id=batch_job_id,
|
|
status="completed"
|
|
).all()
|
|
|
|
if not items:
|
|
return ""
|
|
|
|
# Create temporary ZIP file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
|
|
with zipfile.ZipFile(tmp_file.name, 'w') as zip_file:
|
|
|
|
# Add metadata file
|
|
metadata = {
|
|
"batch_name": batch_job.name,
|
|
"total_videos": batch_job.total_videos,
|
|
"completed": batch_job.completed_videos,
|
|
"failed": batch_job.failed_videos,
|
|
"created_at": batch_job.created_at.isoformat() if batch_job.created_at else None,
|
|
"completed_at": batch_job.completed_at.isoformat() if batch_job.completed_at else None,
|
|
"total_cost_usd": batch_job.total_cost_usd
|
|
}
|
|
zip_file.writestr("batch_metadata.json", json.dumps(metadata, indent=2))
|
|
|
|
# Add each summary
|
|
for item in items:
|
|
if item.summary_id:
|
|
summary = self.db.query(Summary).filter_by(id=item.summary_id).first()
|
|
if summary:
|
|
# Create filename from video title or ID
|
|
safe_title = re.sub(r'[^\w\s-]', '', summary.video_title or f"video_{item.position}")
|
|
safe_title = re.sub(r'[-\s]+', '-', safe_title)
|
|
|
|
# Export as JSON
|
|
summary_data = {
|
|
"video_url": summary.video_url,
|
|
"video_title": summary.video_title,
|
|
"channel_name": summary.channel_name,
|
|
"summary": summary.summary_text,
|
|
"key_points": summary.key_points,
|
|
"created_at": summary.created_at.isoformat() if summary.created_at else None
|
|
}
|
|
|
|
zip_file.writestr(
|
|
f"summaries/{safe_title}.json",
|
|
json.dumps(summary_data, indent=2)
|
|
)
|
|
|
|
# Also export as markdown
|
|
markdown_content = f"""# {summary.video_title}
|
|
|
|
**URL**: {summary.video_url}
|
|
**Channel**: {summary.channel_name}
|
|
**Date**: {summary.created_at.strftime('%Y-%m-%d') if summary.created_at else 'N/A'}
|
|
|
|
## Summary
|
|
|
|
{summary.summary_text}
|
|
|
|
## Key Points
|
|
|
|
{chr(10).join([f"- {point}" for point in (summary.key_points or [])])}
|
|
"""
|
|
zip_file.writestr(
|
|
f"summaries/{safe_title}.md",
|
|
markdown_content
|
|
)
|
|
|
|
# Move to permanent location (in real app, upload to S3 or similar)
|
|
export_path = f"/tmp/batch_exports/{batch_job_id}.zip"
|
|
os.makedirs(os.path.dirname(export_path), exist_ok=True)
|
|
os.rename(tmp_file.name, export_path)
|
|
|
|
# Return URL (in real app, return S3 URL)
|
|
return f"/api/batch/{batch_job_id}/download"
|
|
|
|
async def cancel_batch_job(self, batch_job_id: str, user_id: str) -> bool:
|
|
"""Cancel a running batch job"""
|
|
|
|
batch_job = self.db.query(BatchJob).filter_by(
|
|
id=batch_job_id,
|
|
user_id=user_id,
|
|
status="processing"
|
|
).first()
|
|
|
|
if not batch_job:
|
|
return False
|
|
|
|
batch_job.status = "cancelled"
|
|
self.db.commit()
|
|
|
|
# Cancel the async task if it exists
|
|
if batch_job_id in self.active_jobs:
|
|
self.active_jobs[batch_job_id].cancel()
|
|
|
|
logger.info(f"Cancelled batch job {batch_job_id}")
|
|
return True
|
|
|
|
async def get_batch_status(self, batch_job_id: str, user_id: str) -> Optional[Dict]:
|
|
"""Get detailed status of a batch job"""
|
|
|
|
batch_job = self.db.query(BatchJob).filter_by(
|
|
id=batch_job_id,
|
|
user_id=user_id
|
|
).first()
|
|
|
|
if not batch_job:
|
|
return None
|
|
|
|
items = self.db.query(BatchJobItem).filter_by(
|
|
batch_job_id=batch_job_id
|
|
).order_by(BatchJobItem.position).all()
|
|
|
|
return {
|
|
"id": batch_job.id,
|
|
"name": batch_job.name,
|
|
"status": batch_job.status,
|
|
"progress": {
|
|
"total": batch_job.total_videos,
|
|
"completed": batch_job.completed_videos,
|
|
"failed": batch_job.failed_videos,
|
|
"percentage": batch_job.get_progress_percentage()
|
|
},
|
|
"items": [item.to_dict() for item in items],
|
|
"created_at": batch_job.created_at.isoformat() if batch_job.created_at else None,
|
|
"started_at": batch_job.started_at.isoformat() if batch_job.started_at else None,
|
|
"completed_at": batch_job.completed_at.isoformat() if batch_job.completed_at else None,
|
|
"export_url": batch_job.export_url,
|
|
"total_cost_usd": batch_job.total_cost_usd,
|
|
"estimated_completion": self._estimate_completion_time(batch_job)
|
|
} |