369 lines
11 KiB
Python
369 lines
11 KiB
Python
"""
|
|
Batch processing API endpoints
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from fastapi.responses import FileResponse
|
|
from typing import List, Optional, Dict, Any
|
|
from pydantic import BaseModel, Field, validator
|
|
from datetime import datetime
|
|
import os
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.models.user import User
|
|
from backend.models.batch_job import BatchJob, BatchJobItem
|
|
from backend.services.batch_processing_service import BatchProcessingService
|
|
from backend.services.summary_pipeline import SummaryPipeline
|
|
from backend.services.notification_service import NotificationService
|
|
from backend.api.auth import get_current_user
|
|
from backend.core.database import get_db
|
|
from backend.api.pipeline import get_summary_pipeline, get_notification_service
|
|
|
|
router = APIRouter(prefix="/api/batch", tags=["batch"])
|
|
|
|
|
|
class BatchJobRequest(BaseModel):
|
|
"""Request model for creating a batch job"""
|
|
name: Optional[str] = Field(None, max_length=255, description="Name for the batch job")
|
|
urls: List[str] = Field(..., min_items=1, max_items=100, description="List of YouTube URLs to process")
|
|
model: str = Field("deepseek", description="AI model to use for summarization")
|
|
summary_length: str = Field("standard", description="Length of summaries (brief, standard, detailed)")
|
|
options: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional processing options")
|
|
|
|
@validator('urls')
|
|
def validate_urls(cls, urls):
|
|
"""Ensure URLs are strings and not empty"""
|
|
cleaned = []
|
|
for url in urls:
|
|
if isinstance(url, str) and url.strip():
|
|
cleaned.append(url.strip())
|
|
if not cleaned:
|
|
raise ValueError("At least one valid URL is required")
|
|
return cleaned
|
|
|
|
@validator('model')
|
|
def validate_model(cls, model):
|
|
"""Validate model selection"""
|
|
valid_models = ["deepseek", "openai", "anthropic"] # DeepSeek preferred
|
|
if model not in valid_models:
|
|
raise ValueError(f"Model must be one of: {', '.join(valid_models)}")
|
|
return model
|
|
|
|
@validator('summary_length')
|
|
def validate_summary_length(cls, length):
|
|
"""Validate summary length"""
|
|
valid_lengths = ["brief", "standard", "detailed"]
|
|
if length not in valid_lengths:
|
|
raise ValueError(f"Summary length must be one of: {', '.join(valid_lengths)}")
|
|
return length
|
|
|
|
|
|
class BatchJobResponse(BaseModel):
|
|
"""Response model for batch job creation"""
|
|
id: str
|
|
name: str
|
|
status: str
|
|
total_videos: int
|
|
created_at: datetime
|
|
message: str = "Batch job created successfully"
|
|
|
|
|
|
class BatchJobStatusResponse(BaseModel):
|
|
"""Response model for batch job status"""
|
|
id: str
|
|
name: str
|
|
status: str
|
|
progress: Dict[str, Any]
|
|
items: List[Dict[str, Any]]
|
|
created_at: Optional[datetime]
|
|
started_at: Optional[datetime]
|
|
completed_at: Optional[datetime]
|
|
export_url: Optional[str]
|
|
total_cost_usd: float
|
|
estimated_completion: Optional[str]
|
|
|
|
|
|
class BatchJobListResponse(BaseModel):
|
|
"""Response model for listing batch jobs"""
|
|
batch_jobs: List[Dict[str, Any]]
|
|
total: int
|
|
page: int
|
|
page_size: int
|
|
|
|
|
|
@router.post("/create", response_model=BatchJobResponse)
|
|
async def create_batch_job(
|
|
request: BatchJobRequest,
|
|
background_tasks: BackgroundTasks,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
pipeline: SummaryPipeline = Depends(get_summary_pipeline),
|
|
notifications: NotificationService = Depends(get_notification_service)
|
|
):
|
|
"""
|
|
Create a new batch processing job
|
|
|
|
This endpoint accepts a list of YouTube URLs and processes them sequentially.
|
|
Progress updates are available via WebSocket or polling the status endpoint.
|
|
"""
|
|
|
|
# Create batch processing service
|
|
batch_service = BatchProcessingService(
|
|
db_session=db,
|
|
summary_pipeline=pipeline,
|
|
notification_service=notifications
|
|
)
|
|
|
|
try:
|
|
# Create the batch job
|
|
batch_job = await batch_service.create_batch_job(
|
|
user_id=current_user.id,
|
|
urls=request.urls,
|
|
name=request.name,
|
|
model=request.model,
|
|
summary_length=request.summary_length,
|
|
options=request.options
|
|
)
|
|
|
|
return BatchJobResponse(
|
|
id=batch_job.id,
|
|
name=batch_job.name,
|
|
status=batch_job.status,
|
|
total_videos=batch_job.total_videos,
|
|
created_at=batch_job.created_at
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to create batch job: {str(e)}")
|
|
|
|
|
|
@router.get("/{job_id}", response_model=BatchJobStatusResponse)
|
|
async def get_batch_status(
|
|
job_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
Get the current status of a batch job
|
|
|
|
Returns detailed information about the batch job including progress,
|
|
individual item statuses, and export URL when complete.
|
|
"""
|
|
|
|
batch_service = BatchProcessingService(db_session=db)
|
|
|
|
status = await batch_service.get_batch_status(job_id, current_user.id)
|
|
|
|
if not status:
|
|
raise HTTPException(status_code=404, detail="Batch job not found")
|
|
|
|
return BatchJobStatusResponse(**status)
|
|
|
|
|
|
@router.get("/", response_model=BatchJobListResponse)
|
|
async def list_batch_jobs(
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
status: Optional[str] = None,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
List all batch jobs for the current user
|
|
|
|
Supports pagination and optional filtering by status.
|
|
"""
|
|
|
|
query = db.query(BatchJob).filter(BatchJob.user_id == current_user.id)
|
|
|
|
if status:
|
|
query = query.filter(BatchJob.status == status)
|
|
|
|
# Get total count
|
|
total = query.count()
|
|
|
|
# Apply pagination
|
|
offset = (page - 1) * page_size
|
|
batch_jobs = query.order_by(BatchJob.created_at.desc()).offset(offset).limit(page_size).all()
|
|
|
|
return BatchJobListResponse(
|
|
batch_jobs=[job.to_dict() for job in batch_jobs],
|
|
total=total,
|
|
page=page,
|
|
page_size=page_size
|
|
)
|
|
|
|
|
|
@router.post("/{job_id}/cancel")
|
|
async def cancel_batch_job(
|
|
job_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
Cancel a running batch job
|
|
|
|
Only jobs with status 'processing' can be cancelled.
|
|
"""
|
|
|
|
batch_service = BatchProcessingService(db_session=db)
|
|
|
|
success = await batch_service.cancel_batch_job(job_id, current_user.id)
|
|
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Batch job not found or not in processing state"
|
|
)
|
|
|
|
return {"message": "Batch job cancelled successfully", "job_id": job_id}
|
|
|
|
|
|
@router.post("/{job_id}/retry")
|
|
async def retry_failed_items(
|
|
job_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
pipeline: SummaryPipeline = Depends(get_summary_pipeline),
|
|
notifications: NotificationService = Depends(get_notification_service)
|
|
):
|
|
"""
|
|
Retry failed items in a batch job
|
|
|
|
Creates a new batch job with only the failed items from the original job.
|
|
"""
|
|
|
|
# Get original batch job
|
|
original_job = db.query(BatchJob).filter_by(
|
|
id=job_id,
|
|
user_id=current_user.id
|
|
).first()
|
|
|
|
if not original_job:
|
|
raise HTTPException(status_code=404, detail="Batch job not found")
|
|
|
|
# Get failed items
|
|
failed_items = db.query(BatchJobItem).filter_by(
|
|
batch_job_id=job_id,
|
|
status="failed"
|
|
).all()
|
|
|
|
if not failed_items:
|
|
return {"message": "No failed items to retry"}
|
|
|
|
# Create new batch job with failed URLs
|
|
failed_urls = [item.url for item in failed_items]
|
|
|
|
batch_service = BatchProcessingService(
|
|
db_session=db,
|
|
summary_pipeline=pipeline,
|
|
notification_service=notifications
|
|
)
|
|
|
|
new_job = await batch_service.create_batch_job(
|
|
user_id=current_user.id,
|
|
urls=failed_urls,
|
|
name=f"{original_job.name} (Retry)",
|
|
model=original_job.model,
|
|
summary_length=original_job.summary_length,
|
|
options=original_job.options
|
|
)
|
|
|
|
return {
|
|
"message": f"Created retry batch job with {len(failed_urls)} items",
|
|
"new_job_id": new_job.id,
|
|
"original_job_id": job_id
|
|
}
|
|
|
|
|
|
@router.get("/{job_id}/download")
|
|
async def download_batch_export(
|
|
job_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
Download the export ZIP file for a completed batch job
|
|
|
|
Returns a ZIP file containing all summaries in JSON and Markdown formats.
|
|
"""
|
|
|
|
# Get batch job
|
|
batch_job = db.query(BatchJob).filter_by(
|
|
id=job_id,
|
|
user_id=current_user.id
|
|
).first()
|
|
|
|
if not batch_job:
|
|
raise HTTPException(status_code=404, detail="Batch job not found")
|
|
|
|
if batch_job.status != "completed":
|
|
raise HTTPException(status_code=400, detail="Batch job not completed yet")
|
|
|
|
# Check if export file exists
|
|
export_path = f"/tmp/batch_exports/{job_id}.zip"
|
|
|
|
if not os.path.exists(export_path):
|
|
# Try to regenerate export
|
|
batch_service = BatchProcessingService(db_session=db)
|
|
export_url = await batch_service._generate_export(job_id)
|
|
|
|
if not export_url or not os.path.exists(export_path):
|
|
raise HTTPException(status_code=404, detail="Export file not found")
|
|
|
|
return FileResponse(
|
|
export_path,
|
|
media_type="application/zip",
|
|
filename=f"{batch_job.name.replace(' ', '_')}_summaries.zip"
|
|
)
|
|
|
|
|
|
@router.delete("/{job_id}")
|
|
async def delete_batch_job(
|
|
job_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
Delete a batch job and all associated data
|
|
|
|
This will also delete any summaries created by the batch job.
|
|
"""
|
|
|
|
# Get batch job
|
|
batch_job = db.query(BatchJob).filter_by(
|
|
id=job_id,
|
|
user_id=current_user.id
|
|
).first()
|
|
|
|
if not batch_job:
|
|
raise HTTPException(status_code=404, detail="Batch job not found")
|
|
|
|
# Don't allow deletion of running jobs
|
|
if batch_job.status == "processing":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot delete a running batch job. Cancel it first."
|
|
)
|
|
|
|
# Delete associated summaries
|
|
items = db.query(BatchJobItem).filter_by(batch_job_id=job_id).all()
|
|
for item in items:
|
|
if item.summary_id:
|
|
from backend.models.summary import Summary
|
|
summary = db.query(Summary).filter_by(id=item.summary_id).first()
|
|
if summary:
|
|
db.delete(summary)
|
|
|
|
# Delete batch job (cascade will delete items)
|
|
db.delete(batch_job)
|
|
db.commit()
|
|
|
|
# Delete export file if exists
|
|
export_path = f"/tmp/batch_exports/{job_id}.zip"
|
|
if os.path.exists(export_path):
|
|
os.remove(export_path)
|
|
|
|
return {"message": "Batch job deleted successfully", "job_id": job_id} |