youtube-summarizer/backend/api/batch.py

369 lines
11 KiB
Python

"""
Batch processing API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field, validator
from datetime import datetime
import os
from sqlalchemy.orm import Session
from backend.models.user import User
from backend.models.batch_job import BatchJob, BatchJobItem
from backend.services.batch_processing_service import BatchProcessingService
from backend.services.summary_pipeline import SummaryPipeline
from backend.services.notification_service import NotificationService
from backend.api.auth import get_current_user
from backend.core.database import get_db
from backend.api.pipeline import get_summary_pipeline, get_notification_service
router = APIRouter(prefix="/api/batch", tags=["batch"])
class BatchJobRequest(BaseModel):
"""Request model for creating a batch job"""
name: Optional[str] = Field(None, max_length=255, description="Name for the batch job")
urls: List[str] = Field(..., min_items=1, max_items=100, description="List of YouTube URLs to process")
model: str = Field("deepseek", description="AI model to use for summarization")
summary_length: str = Field("standard", description="Length of summaries (brief, standard, detailed)")
options: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional processing options")
@validator('urls')
def validate_urls(cls, urls):
"""Ensure URLs are strings and not empty"""
cleaned = []
for url in urls:
if isinstance(url, str) and url.strip():
cleaned.append(url.strip())
if not cleaned:
raise ValueError("At least one valid URL is required")
return cleaned
@validator('model')
def validate_model(cls, model):
"""Validate model selection"""
valid_models = ["deepseek", "openai", "anthropic"] # DeepSeek preferred
if model not in valid_models:
raise ValueError(f"Model must be one of: {', '.join(valid_models)}")
return model
@validator('summary_length')
def validate_summary_length(cls, length):
"""Validate summary length"""
valid_lengths = ["brief", "standard", "detailed"]
if length not in valid_lengths:
raise ValueError(f"Summary length must be one of: {', '.join(valid_lengths)}")
return length
class BatchJobResponse(BaseModel):
"""Response model for batch job creation"""
id: str
name: str
status: str
total_videos: int
created_at: datetime
message: str = "Batch job created successfully"
class BatchJobStatusResponse(BaseModel):
"""Response model for batch job status"""
id: str
name: str
status: str
progress: Dict[str, Any]
items: List[Dict[str, Any]]
created_at: Optional[datetime]
started_at: Optional[datetime]
completed_at: Optional[datetime]
export_url: Optional[str]
total_cost_usd: float
estimated_completion: Optional[str]
class BatchJobListResponse(BaseModel):
"""Response model for listing batch jobs"""
batch_jobs: List[Dict[str, Any]]
total: int
page: int
page_size: int
@router.post("/create", response_model=BatchJobResponse)
async def create_batch_job(
request: BatchJobRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
pipeline: SummaryPipeline = Depends(get_summary_pipeline),
notifications: NotificationService = Depends(get_notification_service)
):
"""
Create a new batch processing job
This endpoint accepts a list of YouTube URLs and processes them sequentially.
Progress updates are available via WebSocket or polling the status endpoint.
"""
# Create batch processing service
batch_service = BatchProcessingService(
db_session=db,
summary_pipeline=pipeline,
notification_service=notifications
)
try:
# Create the batch job
batch_job = await batch_service.create_batch_job(
user_id=current_user.id,
urls=request.urls,
name=request.name,
model=request.model,
summary_length=request.summary_length,
options=request.options
)
return BatchJobResponse(
id=batch_job.id,
name=batch_job.name,
status=batch_job.status,
total_videos=batch_job.total_videos,
created_at=batch_job.created_at
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create batch job: {str(e)}")
@router.get("/{job_id}", response_model=BatchJobStatusResponse)
async def get_batch_status(
job_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Get the current status of a batch job
Returns detailed information about the batch job including progress,
individual item statuses, and export URL when complete.
"""
batch_service = BatchProcessingService(db_session=db)
status = await batch_service.get_batch_status(job_id, current_user.id)
if not status:
raise HTTPException(status_code=404, detail="Batch job not found")
return BatchJobStatusResponse(**status)
@router.get("/", response_model=BatchJobListResponse)
async def list_batch_jobs(
page: int = 1,
page_size: int = 20,
status: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
List all batch jobs for the current user
Supports pagination and optional filtering by status.
"""
query = db.query(BatchJob).filter(BatchJob.user_id == current_user.id)
if status:
query = query.filter(BatchJob.status == status)
# Get total count
total = query.count()
# Apply pagination
offset = (page - 1) * page_size
batch_jobs = query.order_by(BatchJob.created_at.desc()).offset(offset).limit(page_size).all()
return BatchJobListResponse(
batch_jobs=[job.to_dict() for job in batch_jobs],
total=total,
page=page,
page_size=page_size
)
@router.post("/{job_id}/cancel")
async def cancel_batch_job(
job_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Cancel a running batch job
Only jobs with status 'processing' can be cancelled.
"""
batch_service = BatchProcessingService(db_session=db)
success = await batch_service.cancel_batch_job(job_id, current_user.id)
if not success:
raise HTTPException(
status_code=404,
detail="Batch job not found or not in processing state"
)
return {"message": "Batch job cancelled successfully", "job_id": job_id}
@router.post("/{job_id}/retry")
async def retry_failed_items(
job_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
pipeline: SummaryPipeline = Depends(get_summary_pipeline),
notifications: NotificationService = Depends(get_notification_service)
):
"""
Retry failed items in a batch job
Creates a new batch job with only the failed items from the original job.
"""
# Get original batch job
original_job = db.query(BatchJob).filter_by(
id=job_id,
user_id=current_user.id
).first()
if not original_job:
raise HTTPException(status_code=404, detail="Batch job not found")
# Get failed items
failed_items = db.query(BatchJobItem).filter_by(
batch_job_id=job_id,
status="failed"
).all()
if not failed_items:
return {"message": "No failed items to retry"}
# Create new batch job with failed URLs
failed_urls = [item.url for item in failed_items]
batch_service = BatchProcessingService(
db_session=db,
summary_pipeline=pipeline,
notification_service=notifications
)
new_job = await batch_service.create_batch_job(
user_id=current_user.id,
urls=failed_urls,
name=f"{original_job.name} (Retry)",
model=original_job.model,
summary_length=original_job.summary_length,
options=original_job.options
)
return {
"message": f"Created retry batch job with {len(failed_urls)} items",
"new_job_id": new_job.id,
"original_job_id": job_id
}
@router.get("/{job_id}/download")
async def download_batch_export(
job_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Download the export ZIP file for a completed batch job
Returns a ZIP file containing all summaries in JSON and Markdown formats.
"""
# Get batch job
batch_job = db.query(BatchJob).filter_by(
id=job_id,
user_id=current_user.id
).first()
if not batch_job:
raise HTTPException(status_code=404, detail="Batch job not found")
if batch_job.status != "completed":
raise HTTPException(status_code=400, detail="Batch job not completed yet")
# Check if export file exists
export_path = f"/tmp/batch_exports/{job_id}.zip"
if not os.path.exists(export_path):
# Try to regenerate export
batch_service = BatchProcessingService(db_session=db)
export_url = await batch_service._generate_export(job_id)
if not export_url or not os.path.exists(export_path):
raise HTTPException(status_code=404, detail="Export file not found")
return FileResponse(
export_path,
media_type="application/zip",
filename=f"{batch_job.name.replace(' ', '_')}_summaries.zip"
)
@router.delete("/{job_id}")
async def delete_batch_job(
job_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Delete a batch job and all associated data
This will also delete any summaries created by the batch job.
"""
# Get batch job
batch_job = db.query(BatchJob).filter_by(
id=job_id,
user_id=current_user.id
).first()
if not batch_job:
raise HTTPException(status_code=404, detail="Batch job not found")
# Don't allow deletion of running jobs
if batch_job.status == "processing":
raise HTTPException(
status_code=400,
detail="Cannot delete a running batch job. Cancel it first."
)
# Delete associated summaries
items = db.query(BatchJobItem).filter_by(batch_job_id=job_id).all()
for item in items:
if item.summary_id:
from backend.models.summary import Summary
summary = db.query(Summary).filter_by(id=item.summary_id).first()
if summary:
db.delete(summary)
# Delete batch job (cascade will delete items)
db.delete(batch_job)
db.commit()
# Delete export file if exists
export_path = f"/tmp/batch_exports/{job_id}.zip"
if os.path.exists(export_path):
os.remove(export_path)
return {"message": "Batch job deleted successfully", "job_id": job_id}