"""Enhanced WebSocket manager for real-time progress updates with connection recovery.""" import json import asyncio import logging from typing import Dict, List, Any, Optional, Set from fastapi import WebSocket from datetime import datetime from dataclasses import dataclass from enum import Enum logger = logging.getLogger(__name__) class ProcessingStage(Enum): """Processing stages for video summarization.""" INITIALIZED = "initialized" VALIDATING_URL = "validating_url" EXTRACTING_METADATA = "extracting_metadata" EXTRACTING_TRANSCRIPT = "extracting_transcript" ANALYZING_CONTENT = "analyzing_content" GENERATING_SUMMARY = "generating_summary" VALIDATING_QUALITY = "validating_quality" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" @dataclass class ProgressData: """Progress data structure for processing updates.""" job_id: str stage: ProcessingStage percentage: float message: str time_elapsed: float estimated_remaining: Optional[float] = None sub_progress: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None # Enhanced context for user-friendly display video_context: Optional[Dict[str, Any]] = None # Contains video_id, title, display_name class ConnectionManager: """Manages WebSocket connections for real-time updates.""" def __init__(self): # Active connections by job_id self.active_connections: Dict[str, List[WebSocket]] = {} # Chat connections by session_id (for Story 4.6 RAG Chat) self.chat_connections: Dict[str, List[WebSocket]] = {} # All connected websockets for broadcast self.all_connections: Set[WebSocket] = set() # Connection metadata self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} # Message queue for disconnected clients self.message_queue: Dict[str, List[Dict[str, Any]]] = {} # Job progress tracking self.job_progress: Dict[str, ProgressData] = {} # Job start times for time estimation self.job_start_times: Dict[str, datetime] = {} # Historical processing times for estimation self.processing_history: List[Dict[str, float]] = [] # Chat typing indicators self.chat_typing: Dict[str, Set[str]] = {} # session_id -> set of user_ids typing async def connect(self, websocket: WebSocket, job_id: Optional[str] = None): """Accept and manage a new WebSocket connection with recovery support.""" await websocket.accept() # Add to all connections self.all_connections.add(websocket) # Add connection metadata self.connection_metadata[websocket] = { "connected_at": datetime.utcnow(), "job_id": job_id, "last_ping": datetime.utcnow() } # Add to job-specific connections if job_id provided if job_id: if job_id not in self.active_connections: self.active_connections[job_id] = [] self.active_connections[job_id].append(websocket) # Send queued messages if reconnecting if job_id in self.message_queue: logger.info(f"Sending {len(self.message_queue[job_id])} queued messages for job {job_id}") for message in self.message_queue[job_id]: try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Failed to send queued message: {e}") break else: # Clear queue if all messages sent successfully del self.message_queue[job_id] # Send current progress if available if job_id in self.job_progress: await self.send_current_progress(websocket, job_id) logger.info(f"WebSocket connected. Job ID: {job_id}, Total connections: {len(self.all_connections)}") async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None): """Connect a WebSocket for chat functionality (Story 4.6).""" await websocket.accept() # Add to all connections self.all_connections.add(websocket) # Add connection metadata for chat self.connection_metadata[websocket] = { "connected_at": datetime.utcnow(), "session_id": session_id, "user_id": user_id, "connection_type": "chat", "last_ping": datetime.utcnow() } # Add to chat-specific connections if session_id not in self.chat_connections: self.chat_connections[session_id] = [] self.chat_connections[session_id].append(websocket) logger.info(f"Chat WebSocket connected. Session ID: {session_id}, User ID: {user_id}, Total connections: {len(self.all_connections)}") def disconnect(self, websocket: WebSocket): """Remove a WebSocket connection.""" # Remove from all connections self.all_connections.discard(websocket) # Get connection info from metadata before removal metadata = self.connection_metadata.get(websocket, {}) job_id = metadata.get("job_id") session_id = metadata.get("session_id") connection_type = metadata.get("connection_type") # Remove from job-specific connections if job_id and job_id in self.active_connections: if websocket in self.active_connections[job_id]: self.active_connections[job_id].remove(websocket) # Clean up empty job connection lists if not self.active_connections[job_id]: del self.active_connections[job_id] # Remove from chat-specific connections if session_id and session_id in self.chat_connections: if websocket in self.chat_connections[session_id]: self.chat_connections[session_id].remove(websocket) # Clean up empty chat connection lists if not self.chat_connections[session_id]: del self.chat_connections[session_id] # Remove metadata self.connection_metadata.pop(websocket, None) print(f"WebSocket disconnected. Job ID: {job_id}, Remaining connections: {len(self.all_connections)}") async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket): """Send a message to a specific WebSocket connection.""" try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error sending personal message: {e}") # Connection might be closed, remove it self.disconnect(websocket) async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]): """Send progress update to all connections listening to a specific job.""" if job_id not in self.active_connections: return # Extract video context from progress_data if available video_context = progress_data.get('video_context', {}) message = { "type": "progress_update", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": progress_data, "video_title": video_context.get('title'), "video_id": video_context.get('video_id'), "display_name": video_context.get('display_name') } # Send to all connections for this job connections = self.active_connections[job_id].copy() # Copy to avoid modification during iteration for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error sending progress update to {job_id}: {e}") # Remove broken connection self.disconnect(websocket) async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]): """Send completion notification for a job.""" if job_id not in self.active_connections: return # Extract video context from result_data if available video_metadata = result_data.get('video_metadata', {}) message = { "type": "completion_notification", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": result_data, "video_title": video_metadata.get('title'), "video_id": result_data.get('video_id'), "display_name": result_data.get('display_name') } connections = self.active_connections[job_id].copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error sending completion notification to {job_id}: {e}") self.disconnect(websocket) async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]): """Send error notification for a job.""" if job_id not in self.active_connections: return # Extract video context from error_data if available video_context = error_data.get('video_context', {}) message = { "type": "error_notification", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": error_data, "video_title": video_context.get('title'), "video_id": video_context.get('video_id'), "display_name": video_context.get('display_name') } connections = self.active_connections[job_id].copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error sending error notification to {job_id}: {e}") self.disconnect(websocket) async def broadcast_system_message(self, message_data: Dict[str, Any]): """Broadcast a system message to all connected clients.""" message = { "type": "system_message", "timestamp": datetime.utcnow().isoformat(), "data": message_data } connections = self.all_connections.copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error broadcasting system message: {e}") self.disconnect(websocket) async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]): """Send a chat message to all connections in a chat session (Story 4.6).""" if session_id not in self.chat_connections: return message = { "type": "message", "session_id": session_id, "timestamp": datetime.utcnow().isoformat(), "data": message_data } # Send to all connections for this chat session connections = self.chat_connections[session_id].copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Error sending chat message to {session_id}: {e}") self.disconnect(websocket) async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool): """Send typing indicator to chat session (Story 4.6).""" if session_id not in self.chat_connections: return # Update typing state if session_id not in self.chat_typing: self.chat_typing[session_id] = set() if is_typing: self.chat_typing[session_id].add(user_id) else: self.chat_typing[session_id].discard(user_id) message = { "type": "typing_start" if is_typing else "typing_end", "session_id": session_id, "user_id": user_id, "timestamp": datetime.utcnow().isoformat() } # Send to all connections in the chat session except the typer connections = self.chat_connections[session_id].copy() for websocket in connections: try: # Don't send typing indicator back to the person typing ws_metadata = self.connection_metadata.get(websocket, {}) if ws_metadata.get("user_id") != user_id: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Error sending typing indicator to {session_id}: {e}") self.disconnect(websocket) async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]): """Send chat status update to session connections (Story 4.6).""" if session_id not in self.chat_connections: return message = { "type": "connection_status", "session_id": session_id, "timestamp": datetime.utcnow().isoformat(), "data": status_data } connections = self.chat_connections[session_id].copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Error sending chat status to {session_id}: {e}") self.disconnect(websocket) async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]): """Send live transcript chunk to job connections (Task 14.3).""" if job_id not in self.active_connections: return message = { "type": "transcript_chunk", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": chunk_data } # Send to all connections for this job connections = self.active_connections[job_id].copy() for websocket in connections: try: # Check if this connection has transcript streaming enabled ws_metadata = self.connection_metadata.get(websocket, {}) if ws_metadata.get("transcript_streaming", False): await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Error sending transcript chunk to {job_id}: {e}") self.disconnect(websocket) async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]): """Send complete transcript data to job connections.""" if job_id not in self.active_connections: return message = { "type": "transcript_complete", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": transcript_data } connections = self.active_connections[job_id].copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Error sending complete transcript to {job_id}: {e}") self.disconnect(websocket) def enable_transcript_streaming(self, websocket: WebSocket, job_id: str): """Enable transcript streaming for a specific connection.""" if websocket in self.connection_metadata: self.connection_metadata[websocket]["transcript_streaming"] = True logger.info(f"Enabled transcript streaming for job {job_id}") def disable_transcript_streaming(self, websocket: WebSocket, job_id: str): """Disable transcript streaming for a specific connection.""" if websocket in self.connection_metadata: self.connection_metadata[websocket]["transcript_streaming"] = False logger.info(f"Disabled transcript streaming for job {job_id}") async def send_heartbeat(self): """Send heartbeat to all connections to keep them alive.""" message = { "type": "heartbeat", "timestamp": datetime.utcnow().isoformat() } connections = self.all_connections.copy() for websocket in connections: try: await websocket.send_text(json.dumps(message)) except Exception as e: print(f"Error sending heartbeat: {e}") self.disconnect(websocket) def get_connection_stats(self) -> Dict[str, Any]: """Get connection statistics.""" job_connection_counts = { job_id: len(connections) for job_id, connections in self.active_connections.items() } chat_connection_counts = { session_id: len(connections) for session_id, connections in self.chat_connections.items() } return { "total_connections": len(self.all_connections), "job_connections": job_connection_counts, "chat_connections": chat_connection_counts, "active_jobs": list(self.active_connections.keys()), "active_chat_sessions": list(self.chat_connections.keys()), "typing_sessions": { session_id: list(typing_users) for session_id, typing_users in self.chat_typing.items() if typing_users } } async def cleanup_stale_connections(self): """Clean up stale connections by sending a ping.""" connections = self.all_connections.copy() for websocket in connections: try: await websocket.ping() except Exception: # Connection is stale, remove it self.disconnect(websocket) async def send_current_progress(self, websocket: WebSocket, job_id: str): """Send current progress state to a reconnecting client.""" if job_id in self.job_progress: progress = self.job_progress[job_id] message = { "type": "progress_update", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": { "stage": progress.stage.value, "percentage": progress.percentage, "message": progress.message, "time_elapsed": progress.time_elapsed, "estimated_remaining": progress.estimated_remaining, "sub_progress": progress.sub_progress, "details": progress.details } } await self.send_personal_message(message, websocket) def update_job_progress(self, job_id: str, progress_data: ProgressData): """Update job progress tracking.""" self.job_progress[job_id] = progress_data # Track start time if not already tracked if job_id not in self.job_start_times: self.job_start_times[job_id] = datetime.utcnow() # Store in message queue if no active connections if job_id not in self.active_connections or not self.active_connections[job_id]: if job_id not in self.message_queue: self.message_queue[job_id] = [] # Limit queue size to prevent memory issues if len(self.message_queue[job_id]) < 100: self.message_queue[job_id].append({ "type": "progress_update", "job_id": job_id, "timestamp": datetime.utcnow().isoformat(), "data": { "stage": progress_data.stage.value, "percentage": progress_data.percentage, "message": progress_data.message, "time_elapsed": progress_data.time_elapsed, "estimated_remaining": progress_data.estimated_remaining, "sub_progress": progress_data.sub_progress, "details": progress_data.details } }) def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]: """Estimate remaining processing time based on history.""" if job_id not in self.job_start_times or current_percentage <= 0: return None elapsed = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds() if current_percentage >= 100: return 0 # Estimate based on current progress rate rate = elapsed / current_percentage remaining_percentage = 100 - current_percentage estimated_remaining = rate * remaining_percentage # Adjust based on historical data if available if self.processing_history: avg_total_time = sum(h.get('total_time', 0) for h in self.processing_history[-10:]) / min(len(self.processing_history), 10) if avg_total_time > 0: # Weighted average of current estimate and historical average historical_remaining = avg_total_time - elapsed if historical_remaining > 0: estimated_remaining = (estimated_remaining * 0.7 + historical_remaining * 0.3) return max(0, estimated_remaining) def record_job_completion(self, job_id: str): """Record job completion time for future estimations.""" if job_id in self.job_start_times: total_time = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds() self.processing_history.append({ "job_id": job_id, "total_time": total_time, "completed_at": datetime.utcnow().isoformat() }) # Keep only last 100 records if len(self.processing_history) > 100: self.processing_history = self.processing_history[-100:] # Clean up tracking del self.job_start_times[job_id] if job_id in self.job_progress: del self.job_progress[job_id] if job_id in self.message_queue: del self.message_queue[job_id] class WebSocketManager: """Main WebSocket manager with singleton pattern.""" _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(WebSocketManager, cls).__new__(cls) cls._instance.connection_manager = ConnectionManager() cls._instance._heartbeat_task = None return cls._instance def __init__(self): if not hasattr(self, 'connection_manager'): self.connection_manager = ConnectionManager() self._heartbeat_task = None async def connect(self, websocket: WebSocket, job_id: Optional[str] = None): """Connect a WebSocket for job updates.""" await self.connection_manager.connect(websocket, job_id) # Start heartbeat task if not running if self._heartbeat_task is None or self._heartbeat_task.done(): self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None): """Connect a WebSocket for chat functionality (Story 4.6).""" await self.connection_manager.connect_chat(websocket, session_id, user_id) # Start heartbeat task if not running if self._heartbeat_task is None or self._heartbeat_task.done(): self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) def disconnect(self, websocket: WebSocket): """Disconnect a WebSocket.""" self.connection_manager.disconnect(websocket) async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]): """Send progress update for a job.""" await self.connection_manager.send_progress_update(job_id, progress_data) async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]): """Send completion notification for a job.""" await self.connection_manager.send_completion_notification(job_id, result_data) async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]): """Send error notification for a job.""" await self.connection_manager.send_error_notification(job_id, error_data) async def broadcast_system_message(self, message_data: Dict[str, Any]): """Broadcast system message to all connections.""" await self.connection_manager.broadcast_system_message(message_data) async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]): """Send chat message to all connections in a session (Story 4.6).""" await self.connection_manager.send_chat_message(session_id, message_data) async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool): """Send typing indicator to chat session (Story 4.6).""" await self.connection_manager.send_typing_indicator(session_id, user_id, is_typing) async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]): """Send status update to chat session (Story 4.6).""" await self.connection_manager.send_chat_status(session_id, status_data) async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]): """Send live transcript chunk to job connections (Task 14.3).""" await self.connection_manager.send_transcript_chunk(job_id, chunk_data) async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]): """Send complete transcript data to job connections.""" await self.connection_manager.send_transcript_complete(job_id, transcript_data) def enable_transcript_streaming(self, websocket: WebSocket, job_id: str): """Enable transcript streaming for a connection.""" self.connection_manager.enable_transcript_streaming(websocket, job_id) def disable_transcript_streaming(self, websocket: WebSocket, job_id: str): """Disable transcript streaming for a connection.""" self.connection_manager.disable_transcript_streaming(websocket, job_id) def get_stats(self) -> Dict[str, Any]: """Get WebSocket connection statistics.""" return self.connection_manager.get_connection_stats() def update_job_progress(self, job_id: str, progress_data: ProgressData): """Update and track job progress.""" self.connection_manager.update_job_progress(job_id, progress_data) def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]: """Estimate remaining processing time.""" return self.connection_manager.estimate_remaining_time(job_id, current_percentage) def record_job_completion(self, job_id: str): """Record job completion for time estimation.""" self.connection_manager.record_job_completion(job_id) async def _heartbeat_loop(self): """Background task to send periodic heartbeats.""" while True: try: await asyncio.sleep(30) # Send heartbeat every 30 seconds await self.connection_manager.send_heartbeat() await self.connection_manager.cleanup_stale_connections() except asyncio.CancelledError: break except Exception as e: print(f"Error in heartbeat loop: {e}") # Global WebSocket manager instance websocket_manager = WebSocketManager()