426 lines
17 KiB
Python
426 lines
17 KiB
Python
"""Enhanced WebSocket manager for real-time progress updates with connection recovery."""
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Any, Optional, Set
|
|
from fastapi import WebSocket
|
|
from datetime import datetime
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProcessingStage(Enum):
|
|
"""Processing stages for video summarization."""
|
|
INITIALIZED = "initialized"
|
|
VALIDATING_URL = "validating_url"
|
|
EXTRACTING_METADATA = "extracting_metadata"
|
|
EXTRACTING_TRANSCRIPT = "extracting_transcript"
|
|
ANALYZING_CONTENT = "analyzing_content"
|
|
GENERATING_SUMMARY = "generating_summary"
|
|
VALIDATING_QUALITY = "validating_quality"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
@dataclass
|
|
class ProgressData:
|
|
"""Progress data structure for processing updates."""
|
|
job_id: str
|
|
stage: ProcessingStage
|
|
percentage: float
|
|
message: str
|
|
time_elapsed: float
|
|
estimated_remaining: Optional[float] = None
|
|
sub_progress: Optional[Dict[str, Any]] = None
|
|
details: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections for real-time updates."""
|
|
|
|
def __init__(self):
|
|
# Active connections by job_id
|
|
self.active_connections: Dict[str, List[WebSocket]] = {}
|
|
# All connected websockets for broadcast
|
|
self.all_connections: Set[WebSocket] = set()
|
|
# Connection metadata
|
|
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
|
# Message queue for disconnected clients
|
|
self.message_queue: Dict[str, List[Dict[str, Any]]] = {}
|
|
# Job progress tracking
|
|
self.job_progress: Dict[str, ProgressData] = {}
|
|
# Job start times for time estimation
|
|
self.job_start_times: Dict[str, datetime] = {}
|
|
# Historical processing times for estimation
|
|
self.processing_history: List[Dict[str, float]] = []
|
|
|
|
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
|
|
"""Accept and manage a new WebSocket connection with recovery support."""
|
|
await websocket.accept()
|
|
|
|
# Add to all connections
|
|
self.all_connections.add(websocket)
|
|
|
|
# Add connection metadata
|
|
self.connection_metadata[websocket] = {
|
|
"connected_at": datetime.utcnow(),
|
|
"job_id": job_id,
|
|
"last_ping": datetime.utcnow()
|
|
}
|
|
|
|
# Add to job-specific connections if job_id provided
|
|
if job_id:
|
|
if job_id not in self.active_connections:
|
|
self.active_connections[job_id] = []
|
|
self.active_connections[job_id].append(websocket)
|
|
|
|
# Send queued messages if reconnecting
|
|
if job_id in self.message_queue:
|
|
logger.info(f"Sending {len(self.message_queue[job_id])} queued messages for job {job_id}")
|
|
for message in self.message_queue[job_id]:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Failed to send queued message: {e}")
|
|
break
|
|
else:
|
|
# Clear queue if all messages sent successfully
|
|
del self.message_queue[job_id]
|
|
|
|
# Send current progress if available
|
|
if job_id in self.job_progress:
|
|
await self.send_current_progress(websocket, job_id)
|
|
|
|
logger.info(f"WebSocket connected. Job ID: {job_id}, Total connections: {len(self.all_connections)}")
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Remove a WebSocket connection."""
|
|
# Remove from all connections
|
|
self.all_connections.discard(websocket)
|
|
|
|
# Get job_id from metadata before removal
|
|
metadata = self.connection_metadata.get(websocket, {})
|
|
job_id = metadata.get("job_id")
|
|
|
|
# Remove from job-specific connections
|
|
if job_id and job_id in self.active_connections:
|
|
if websocket in self.active_connections[job_id]:
|
|
self.active_connections[job_id].remove(websocket)
|
|
|
|
# Clean up empty job connection lists
|
|
if not self.active_connections[job_id]:
|
|
del self.active_connections[job_id]
|
|
|
|
# Remove metadata
|
|
self.connection_metadata.pop(websocket, None)
|
|
|
|
print(f"WebSocket disconnected. Job ID: {job_id}, Remaining connections: {len(self.all_connections)}")
|
|
|
|
async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket):
|
|
"""Send a message to a specific WebSocket connection."""
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending personal message: {e}")
|
|
# Connection might be closed, remove it
|
|
self.disconnect(websocket)
|
|
|
|
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
|
|
"""Send progress update to all connections listening to a specific job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": progress_data
|
|
}
|
|
|
|
# Send to all connections for this job
|
|
connections = self.active_connections[job_id].copy() # Copy to avoid modification during iteration
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending progress update to {job_id}: {e}")
|
|
# Remove broken connection
|
|
self.disconnect(websocket)
|
|
|
|
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
|
|
"""Send completion notification for a job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "completion_notification",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": result_data
|
|
}
|
|
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending completion notification to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
|
|
"""Send error notification for a job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "error_notification",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": error_data
|
|
}
|
|
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending error notification to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def broadcast_system_message(self, message_data: Dict[str, Any]):
|
|
"""Broadcast a system message to all connected clients."""
|
|
message = {
|
|
"type": "system_message",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": message_data
|
|
}
|
|
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error broadcasting system message: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_heartbeat(self):
|
|
"""Send heartbeat to all connections to keep them alive."""
|
|
message = {
|
|
"type": "heartbeat",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending heartbeat: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
def get_connection_stats(self) -> Dict[str, Any]:
|
|
"""Get connection statistics."""
|
|
job_connection_counts = {
|
|
job_id: len(connections)
|
|
for job_id, connections in self.active_connections.items()
|
|
}
|
|
|
|
return {
|
|
"total_connections": len(self.all_connections),
|
|
"job_connections": job_connection_counts,
|
|
"active_jobs": list(self.active_connections.keys())
|
|
}
|
|
|
|
async def cleanup_stale_connections(self):
|
|
"""Clean up stale connections by sending a ping."""
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.ping()
|
|
except Exception:
|
|
# Connection is stale, remove it
|
|
self.disconnect(websocket)
|
|
|
|
async def send_current_progress(self, websocket: WebSocket, job_id: str):
|
|
"""Send current progress state to a reconnecting client."""
|
|
if job_id in self.job_progress:
|
|
progress = self.job_progress[job_id]
|
|
message = {
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"stage": progress.stage.value,
|
|
"percentage": progress.percentage,
|
|
"message": progress.message,
|
|
"time_elapsed": progress.time_elapsed,
|
|
"estimated_remaining": progress.estimated_remaining,
|
|
"sub_progress": progress.sub_progress,
|
|
"details": progress.details
|
|
}
|
|
}
|
|
await self.send_personal_message(message, websocket)
|
|
|
|
def update_job_progress(self, job_id: str, progress_data: ProgressData):
|
|
"""Update job progress tracking."""
|
|
self.job_progress[job_id] = progress_data
|
|
|
|
# Track start time if not already tracked
|
|
if job_id not in self.job_start_times:
|
|
self.job_start_times[job_id] = datetime.utcnow()
|
|
|
|
# Store in message queue if no active connections
|
|
if job_id not in self.active_connections or not self.active_connections[job_id]:
|
|
if job_id not in self.message_queue:
|
|
self.message_queue[job_id] = []
|
|
|
|
# Limit queue size to prevent memory issues
|
|
if len(self.message_queue[job_id]) < 100:
|
|
self.message_queue[job_id].append({
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"stage": progress_data.stage.value,
|
|
"percentage": progress_data.percentage,
|
|
"message": progress_data.message,
|
|
"time_elapsed": progress_data.time_elapsed,
|
|
"estimated_remaining": progress_data.estimated_remaining,
|
|
"sub_progress": progress_data.sub_progress,
|
|
"details": progress_data.details
|
|
}
|
|
})
|
|
|
|
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
|
|
"""Estimate remaining processing time based on history."""
|
|
if job_id not in self.job_start_times or current_percentage <= 0:
|
|
return None
|
|
|
|
elapsed = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
|
|
|
|
if current_percentage >= 100:
|
|
return 0
|
|
|
|
# Estimate based on current progress rate
|
|
rate = elapsed / current_percentage
|
|
remaining_percentage = 100 - current_percentage
|
|
estimated_remaining = rate * remaining_percentage
|
|
|
|
# Adjust based on historical data if available
|
|
if self.processing_history:
|
|
avg_total_time = sum(h.get('total_time', 0) for h in self.processing_history[-10:]) / min(len(self.processing_history), 10)
|
|
if avg_total_time > 0:
|
|
# Weighted average of current estimate and historical average
|
|
historical_remaining = avg_total_time - elapsed
|
|
if historical_remaining > 0:
|
|
estimated_remaining = (estimated_remaining * 0.7 + historical_remaining * 0.3)
|
|
|
|
return max(0, estimated_remaining)
|
|
|
|
def record_job_completion(self, job_id: str):
|
|
"""Record job completion time for future estimations."""
|
|
if job_id in self.job_start_times:
|
|
total_time = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
|
|
self.processing_history.append({
|
|
"job_id": job_id,
|
|
"total_time": total_time,
|
|
"completed_at": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
# Keep only last 100 records
|
|
if len(self.processing_history) > 100:
|
|
self.processing_history = self.processing_history[-100:]
|
|
|
|
# Clean up tracking
|
|
del self.job_start_times[job_id]
|
|
if job_id in self.job_progress:
|
|
del self.job_progress[job_id]
|
|
if job_id in self.message_queue:
|
|
del self.message_queue[job_id]
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Main WebSocket manager with singleton pattern."""
|
|
|
|
_instance = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(WebSocketManager, cls).__new__(cls)
|
|
cls._instance.connection_manager = ConnectionManager()
|
|
cls._instance._heartbeat_task = None
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if not hasattr(self, 'connection_manager'):
|
|
self.connection_manager = ConnectionManager()
|
|
self._heartbeat_task = None
|
|
|
|
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
|
|
"""Connect a WebSocket for job updates."""
|
|
await self.connection_manager.connect(websocket, job_id)
|
|
|
|
# Start heartbeat task if not running
|
|
if self._heartbeat_task is None or self._heartbeat_task.done():
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Disconnect a WebSocket."""
|
|
self.connection_manager.disconnect(websocket)
|
|
|
|
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
|
|
"""Send progress update for a job."""
|
|
await self.connection_manager.send_progress_update(job_id, progress_data)
|
|
|
|
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
|
|
"""Send completion notification for a job."""
|
|
await self.connection_manager.send_completion_notification(job_id, result_data)
|
|
|
|
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
|
|
"""Send error notification for a job."""
|
|
await self.connection_manager.send_error_notification(job_id, error_data)
|
|
|
|
async def broadcast_system_message(self, message_data: Dict[str, Any]):
|
|
"""Broadcast system message to all connections."""
|
|
await self.connection_manager.broadcast_system_message(message_data)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get WebSocket connection statistics."""
|
|
return self.connection_manager.get_connection_stats()
|
|
|
|
def update_job_progress(self, job_id: str, progress_data: ProgressData):
|
|
"""Update and track job progress."""
|
|
self.connection_manager.update_job_progress(job_id, progress_data)
|
|
|
|
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
|
|
"""Estimate remaining processing time."""
|
|
return self.connection_manager.estimate_remaining_time(job_id, current_percentage)
|
|
|
|
def record_job_completion(self, job_id: str):
|
|
"""Record job completion for time estimation."""
|
|
self.connection_manager.record_job_completion(job_id)
|
|
|
|
async def _heartbeat_loop(self):
|
|
"""Background task to send periodic heartbeats."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(30) # Send heartbeat every 30 seconds
|
|
await self.connection_manager.send_heartbeat()
|
|
await self.connection_manager.cleanup_stale_connections()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
print(f"Error in heartbeat loop: {e}")
|
|
|
|
|
|
# Global WebSocket manager instance
|
|
websocket_manager = WebSocketManager() |