youtube-summarizer/backend/core/websocket_manager.py

667 lines
28 KiB
Python

"""Enhanced WebSocket manager for real-time progress updates with connection recovery."""
import json
import asyncio
import logging
from typing import Dict, List, Any, Optional, Set
from fastapi import WebSocket
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class ProcessingStage(Enum):
"""Processing stages for video summarization."""
INITIALIZED = "initialized"
VALIDATING_URL = "validating_url"
EXTRACTING_METADATA = "extracting_metadata"
EXTRACTING_TRANSCRIPT = "extracting_transcript"
ANALYZING_CONTENT = "analyzing_content"
GENERATING_SUMMARY = "generating_summary"
VALIDATING_QUALITY = "validating_quality"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class ProgressData:
"""Progress data structure for processing updates."""
job_id: str
stage: ProcessingStage
percentage: float
message: str
time_elapsed: float
estimated_remaining: Optional[float] = None
sub_progress: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None
# Enhanced context for user-friendly display
video_context: Optional[Dict[str, Any]] = None # Contains video_id, title, display_name
class ConnectionManager:
"""Manages WebSocket connections for real-time updates."""
def __init__(self):
# Active connections by job_id
self.active_connections: Dict[str, List[WebSocket]] = {}
# Chat connections by session_id (for Story 4.6 RAG Chat)
self.chat_connections: Dict[str, List[WebSocket]] = {}
# All connected websockets for broadcast
self.all_connections: Set[WebSocket] = set()
# Connection metadata
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
# Message queue for disconnected clients
self.message_queue: Dict[str, List[Dict[str, Any]]] = {}
# Job progress tracking
self.job_progress: Dict[str, ProgressData] = {}
# Job start times for time estimation
self.job_start_times: Dict[str, datetime] = {}
# Historical processing times for estimation
self.processing_history: List[Dict[str, float]] = []
# Chat typing indicators
self.chat_typing: Dict[str, Set[str]] = {} # session_id -> set of user_ids typing
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
"""Accept and manage a new WebSocket connection with recovery support."""
await websocket.accept()
# Add to all connections
self.all_connections.add(websocket)
# Add connection metadata
self.connection_metadata[websocket] = {
"connected_at": datetime.utcnow(),
"job_id": job_id,
"last_ping": datetime.utcnow()
}
# Add to job-specific connections if job_id provided
if job_id:
if job_id not in self.active_connections:
self.active_connections[job_id] = []
self.active_connections[job_id].append(websocket)
# Send queued messages if reconnecting
if job_id in self.message_queue:
logger.info(f"Sending {len(self.message_queue[job_id])} queued messages for job {job_id}")
for message in self.message_queue[job_id]:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send queued message: {e}")
break
else:
# Clear queue if all messages sent successfully
del self.message_queue[job_id]
# Send current progress if available
if job_id in self.job_progress:
await self.send_current_progress(websocket, job_id)
logger.info(f"WebSocket connected. Job ID: {job_id}, Total connections: {len(self.all_connections)}")
async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None):
"""Connect a WebSocket for chat functionality (Story 4.6)."""
await websocket.accept()
# Add to all connections
self.all_connections.add(websocket)
# Add connection metadata for chat
self.connection_metadata[websocket] = {
"connected_at": datetime.utcnow(),
"session_id": session_id,
"user_id": user_id,
"connection_type": "chat",
"last_ping": datetime.utcnow()
}
# Add to chat-specific connections
if session_id not in self.chat_connections:
self.chat_connections[session_id] = []
self.chat_connections[session_id].append(websocket)
logger.info(f"Chat WebSocket connected. Session ID: {session_id}, User ID: {user_id}, Total connections: {len(self.all_connections)}")
def disconnect(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
# Remove from all connections
self.all_connections.discard(websocket)
# Get connection info from metadata before removal
metadata = self.connection_metadata.get(websocket, {})
job_id = metadata.get("job_id")
session_id = metadata.get("session_id")
connection_type = metadata.get("connection_type")
# Remove from job-specific connections
if job_id and job_id in self.active_connections:
if websocket in self.active_connections[job_id]:
self.active_connections[job_id].remove(websocket)
# Clean up empty job connection lists
if not self.active_connections[job_id]:
del self.active_connections[job_id]
# Remove from chat-specific connections
if session_id and session_id in self.chat_connections:
if websocket in self.chat_connections[session_id]:
self.chat_connections[session_id].remove(websocket)
# Clean up empty chat connection lists
if not self.chat_connections[session_id]:
del self.chat_connections[session_id]
# Remove metadata
self.connection_metadata.pop(websocket, None)
print(f"WebSocket disconnected. Job ID: {job_id}, Remaining connections: {len(self.all_connections)}")
async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket):
"""Send a message to a specific WebSocket connection."""
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error sending personal message: {e}")
# Connection might be closed, remove it
self.disconnect(websocket)
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
"""Send progress update to all connections listening to a specific job."""
if job_id not in self.active_connections:
return
# Extract video context from progress_data if available
video_context = progress_data.get('video_context', {})
message = {
"type": "progress_update",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": progress_data,
"video_title": video_context.get('title'),
"video_id": video_context.get('video_id'),
"display_name": video_context.get('display_name')
}
# Send to all connections for this job
connections = self.active_connections[job_id].copy() # Copy to avoid modification during iteration
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error sending progress update to {job_id}: {e}")
# Remove broken connection
self.disconnect(websocket)
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
"""Send completion notification for a job."""
if job_id not in self.active_connections:
return
# Extract video context from result_data if available
video_metadata = result_data.get('video_metadata', {})
message = {
"type": "completion_notification",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": result_data,
"video_title": video_metadata.get('title'),
"video_id": result_data.get('video_id'),
"display_name": result_data.get('display_name')
}
connections = self.active_connections[job_id].copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error sending completion notification to {job_id}: {e}")
self.disconnect(websocket)
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
"""Send error notification for a job."""
if job_id not in self.active_connections:
return
# Extract video context from error_data if available
video_context = error_data.get('video_context', {})
message = {
"type": "error_notification",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": error_data,
"video_title": video_context.get('title'),
"video_id": video_context.get('video_id'),
"display_name": video_context.get('display_name')
}
connections = self.active_connections[job_id].copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error sending error notification to {job_id}: {e}")
self.disconnect(websocket)
async def broadcast_system_message(self, message_data: Dict[str, Any]):
"""Broadcast a system message to all connected clients."""
message = {
"type": "system_message",
"timestamp": datetime.utcnow().isoformat(),
"data": message_data
}
connections = self.all_connections.copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error broadcasting system message: {e}")
self.disconnect(websocket)
async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]):
"""Send a chat message to all connections in a chat session (Story 4.6)."""
if session_id not in self.chat_connections:
return
message = {
"type": "message",
"session_id": session_id,
"timestamp": datetime.utcnow().isoformat(),
"data": message_data
}
# Send to all connections for this chat session
connections = self.chat_connections[session_id].copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Error sending chat message to {session_id}: {e}")
self.disconnect(websocket)
async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool):
"""Send typing indicator to chat session (Story 4.6)."""
if session_id not in self.chat_connections:
return
# Update typing state
if session_id not in self.chat_typing:
self.chat_typing[session_id] = set()
if is_typing:
self.chat_typing[session_id].add(user_id)
else:
self.chat_typing[session_id].discard(user_id)
message = {
"type": "typing_start" if is_typing else "typing_end",
"session_id": session_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
}
# Send to all connections in the chat session except the typer
connections = self.chat_connections[session_id].copy()
for websocket in connections:
try:
# Don't send typing indicator back to the person typing
ws_metadata = self.connection_metadata.get(websocket, {})
if ws_metadata.get("user_id") != user_id:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Error sending typing indicator to {session_id}: {e}")
self.disconnect(websocket)
async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]):
"""Send chat status update to session connections (Story 4.6)."""
if session_id not in self.chat_connections:
return
message = {
"type": "connection_status",
"session_id": session_id,
"timestamp": datetime.utcnow().isoformat(),
"data": status_data
}
connections = self.chat_connections[session_id].copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Error sending chat status to {session_id}: {e}")
self.disconnect(websocket)
async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]):
"""Send live transcript chunk to job connections (Task 14.3)."""
if job_id not in self.active_connections:
return
message = {
"type": "transcript_chunk",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": chunk_data
}
# Send to all connections for this job
connections = self.active_connections[job_id].copy()
for websocket in connections:
try:
# Check if this connection has transcript streaming enabled
ws_metadata = self.connection_metadata.get(websocket, {})
if ws_metadata.get("transcript_streaming", False):
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Error sending transcript chunk to {job_id}: {e}")
self.disconnect(websocket)
async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]):
"""Send complete transcript data to job connections."""
if job_id not in self.active_connections:
return
message = {
"type": "transcript_complete",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": transcript_data
}
connections = self.active_connections[job_id].copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Error sending complete transcript to {job_id}: {e}")
self.disconnect(websocket)
def enable_transcript_streaming(self, websocket: WebSocket, job_id: str):
"""Enable transcript streaming for a specific connection."""
if websocket in self.connection_metadata:
self.connection_metadata[websocket]["transcript_streaming"] = True
logger.info(f"Enabled transcript streaming for job {job_id}")
def disable_transcript_streaming(self, websocket: WebSocket, job_id: str):
"""Disable transcript streaming for a specific connection."""
if websocket in self.connection_metadata:
self.connection_metadata[websocket]["transcript_streaming"] = False
logger.info(f"Disabled transcript streaming for job {job_id}")
async def send_heartbeat(self):
"""Send heartbeat to all connections to keep them alive."""
message = {
"type": "heartbeat",
"timestamp": datetime.utcnow().isoformat()
}
connections = self.all_connections.copy()
for websocket in connections:
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
print(f"Error sending heartbeat: {e}")
self.disconnect(websocket)
def get_connection_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
job_connection_counts = {
job_id: len(connections)
for job_id, connections in self.active_connections.items()
}
chat_connection_counts = {
session_id: len(connections)
for session_id, connections in self.chat_connections.items()
}
return {
"total_connections": len(self.all_connections),
"job_connections": job_connection_counts,
"chat_connections": chat_connection_counts,
"active_jobs": list(self.active_connections.keys()),
"active_chat_sessions": list(self.chat_connections.keys()),
"typing_sessions": {
session_id: list(typing_users)
for session_id, typing_users in self.chat_typing.items()
if typing_users
}
}
async def cleanup_stale_connections(self):
"""Clean up stale connections by sending a ping."""
connections = self.all_connections.copy()
for websocket in connections:
try:
await websocket.ping()
except Exception:
# Connection is stale, remove it
self.disconnect(websocket)
async def send_current_progress(self, websocket: WebSocket, job_id: str):
"""Send current progress state to a reconnecting client."""
if job_id in self.job_progress:
progress = self.job_progress[job_id]
message = {
"type": "progress_update",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": {
"stage": progress.stage.value,
"percentage": progress.percentage,
"message": progress.message,
"time_elapsed": progress.time_elapsed,
"estimated_remaining": progress.estimated_remaining,
"sub_progress": progress.sub_progress,
"details": progress.details
}
}
await self.send_personal_message(message, websocket)
def update_job_progress(self, job_id: str, progress_data: ProgressData):
"""Update job progress tracking."""
self.job_progress[job_id] = progress_data
# Track start time if not already tracked
if job_id not in self.job_start_times:
self.job_start_times[job_id] = datetime.utcnow()
# Store in message queue if no active connections
if job_id not in self.active_connections or not self.active_connections[job_id]:
if job_id not in self.message_queue:
self.message_queue[job_id] = []
# Limit queue size to prevent memory issues
if len(self.message_queue[job_id]) < 100:
self.message_queue[job_id].append({
"type": "progress_update",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat(),
"data": {
"stage": progress_data.stage.value,
"percentage": progress_data.percentage,
"message": progress_data.message,
"time_elapsed": progress_data.time_elapsed,
"estimated_remaining": progress_data.estimated_remaining,
"sub_progress": progress_data.sub_progress,
"details": progress_data.details
}
})
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
"""Estimate remaining processing time based on history."""
if job_id not in self.job_start_times or current_percentage <= 0:
return None
elapsed = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
if current_percentage >= 100:
return 0
# Estimate based on current progress rate
rate = elapsed / current_percentage
remaining_percentage = 100 - current_percentage
estimated_remaining = rate * remaining_percentage
# Adjust based on historical data if available
if self.processing_history:
avg_total_time = sum(h.get('total_time', 0) for h in self.processing_history[-10:]) / min(len(self.processing_history), 10)
if avg_total_time > 0:
# Weighted average of current estimate and historical average
historical_remaining = avg_total_time - elapsed
if historical_remaining > 0:
estimated_remaining = (estimated_remaining * 0.7 + historical_remaining * 0.3)
return max(0, estimated_remaining)
def record_job_completion(self, job_id: str):
"""Record job completion time for future estimations."""
if job_id in self.job_start_times:
total_time = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
self.processing_history.append({
"job_id": job_id,
"total_time": total_time,
"completed_at": datetime.utcnow().isoformat()
})
# Keep only last 100 records
if len(self.processing_history) > 100:
self.processing_history = self.processing_history[-100:]
# Clean up tracking
del self.job_start_times[job_id]
if job_id in self.job_progress:
del self.job_progress[job_id]
if job_id in self.message_queue:
del self.message_queue[job_id]
class WebSocketManager:
"""Main WebSocket manager with singleton pattern."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(WebSocketManager, cls).__new__(cls)
cls._instance.connection_manager = ConnectionManager()
cls._instance._heartbeat_task = None
return cls._instance
def __init__(self):
if not hasattr(self, 'connection_manager'):
self.connection_manager = ConnectionManager()
self._heartbeat_task = None
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
"""Connect a WebSocket for job updates."""
await self.connection_manager.connect(websocket, job_id)
# Start heartbeat task if not running
if self._heartbeat_task is None or self._heartbeat_task.done():
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None):
"""Connect a WebSocket for chat functionality (Story 4.6)."""
await self.connection_manager.connect_chat(websocket, session_id, user_id)
# Start heartbeat task if not running
if self._heartbeat_task is None or self._heartbeat_task.done():
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
def disconnect(self, websocket: WebSocket):
"""Disconnect a WebSocket."""
self.connection_manager.disconnect(websocket)
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
"""Send progress update for a job."""
await self.connection_manager.send_progress_update(job_id, progress_data)
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
"""Send completion notification for a job."""
await self.connection_manager.send_completion_notification(job_id, result_data)
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
"""Send error notification for a job."""
await self.connection_manager.send_error_notification(job_id, error_data)
async def broadcast_system_message(self, message_data: Dict[str, Any]):
"""Broadcast system message to all connections."""
await self.connection_manager.broadcast_system_message(message_data)
async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]):
"""Send chat message to all connections in a session (Story 4.6)."""
await self.connection_manager.send_chat_message(session_id, message_data)
async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool):
"""Send typing indicator to chat session (Story 4.6)."""
await self.connection_manager.send_typing_indicator(session_id, user_id, is_typing)
async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]):
"""Send status update to chat session (Story 4.6)."""
await self.connection_manager.send_chat_status(session_id, status_data)
async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]):
"""Send live transcript chunk to job connections (Task 14.3)."""
await self.connection_manager.send_transcript_chunk(job_id, chunk_data)
async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]):
"""Send complete transcript data to job connections."""
await self.connection_manager.send_transcript_complete(job_id, transcript_data)
def enable_transcript_streaming(self, websocket: WebSocket, job_id: str):
"""Enable transcript streaming for a connection."""
self.connection_manager.enable_transcript_streaming(websocket, job_id)
def disable_transcript_streaming(self, websocket: WebSocket, job_id: str):
"""Disable transcript streaming for a connection."""
self.connection_manager.disable_transcript_streaming(websocket, job_id)
def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket connection statistics."""
return self.connection_manager.get_connection_stats()
def update_job_progress(self, job_id: str, progress_data: ProgressData):
"""Update and track job progress."""
self.connection_manager.update_job_progress(job_id, progress_data)
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
"""Estimate remaining processing time."""
return self.connection_manager.estimate_remaining_time(job_id, current_percentage)
def record_job_completion(self, job_id: str):
"""Record job completion for time estimation."""
self.connection_manager.record_job_completion(job_id)
async def _heartbeat_loop(self):
"""Background task to send periodic heartbeats."""
while True:
try:
await asyncio.sleep(30) # Send heartbeat every 30 seconds
await self.connection_manager.send_heartbeat()
await self.connection_manager.cleanup_stale_connections()
except asyncio.CancelledError:
break
except Exception as e:
print(f"Error in heartbeat loop: {e}")
# Global WebSocket manager instance
websocket_manager = WebSocketManager()