667 lines
28 KiB
Python
667 lines
28 KiB
Python
"""Enhanced WebSocket manager for real-time progress updates with connection recovery."""
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Any, Optional, Set
|
|
from fastapi import WebSocket
|
|
from datetime import datetime
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProcessingStage(Enum):
|
|
"""Processing stages for video summarization."""
|
|
INITIALIZED = "initialized"
|
|
VALIDATING_URL = "validating_url"
|
|
EXTRACTING_METADATA = "extracting_metadata"
|
|
EXTRACTING_TRANSCRIPT = "extracting_transcript"
|
|
ANALYZING_CONTENT = "analyzing_content"
|
|
GENERATING_SUMMARY = "generating_summary"
|
|
VALIDATING_QUALITY = "validating_quality"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
@dataclass
|
|
class ProgressData:
|
|
"""Progress data structure for processing updates."""
|
|
job_id: str
|
|
stage: ProcessingStage
|
|
percentage: float
|
|
message: str
|
|
time_elapsed: float
|
|
estimated_remaining: Optional[float] = None
|
|
sub_progress: Optional[Dict[str, Any]] = None
|
|
details: Optional[Dict[str, Any]] = None
|
|
# Enhanced context for user-friendly display
|
|
video_context: Optional[Dict[str, Any]] = None # Contains video_id, title, display_name
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections for real-time updates."""
|
|
|
|
def __init__(self):
|
|
# Active connections by job_id
|
|
self.active_connections: Dict[str, List[WebSocket]] = {}
|
|
# Chat connections by session_id (for Story 4.6 RAG Chat)
|
|
self.chat_connections: Dict[str, List[WebSocket]] = {}
|
|
# All connected websockets for broadcast
|
|
self.all_connections: Set[WebSocket] = set()
|
|
# Connection metadata
|
|
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
|
# Message queue for disconnected clients
|
|
self.message_queue: Dict[str, List[Dict[str, Any]]] = {}
|
|
# Job progress tracking
|
|
self.job_progress: Dict[str, ProgressData] = {}
|
|
# Job start times for time estimation
|
|
self.job_start_times: Dict[str, datetime] = {}
|
|
# Historical processing times for estimation
|
|
self.processing_history: List[Dict[str, float]] = []
|
|
# Chat typing indicators
|
|
self.chat_typing: Dict[str, Set[str]] = {} # session_id -> set of user_ids typing
|
|
|
|
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
|
|
"""Accept and manage a new WebSocket connection with recovery support."""
|
|
await websocket.accept()
|
|
|
|
# Add to all connections
|
|
self.all_connections.add(websocket)
|
|
|
|
# Add connection metadata
|
|
self.connection_metadata[websocket] = {
|
|
"connected_at": datetime.utcnow(),
|
|
"job_id": job_id,
|
|
"last_ping": datetime.utcnow()
|
|
}
|
|
|
|
# Add to job-specific connections if job_id provided
|
|
if job_id:
|
|
if job_id not in self.active_connections:
|
|
self.active_connections[job_id] = []
|
|
self.active_connections[job_id].append(websocket)
|
|
|
|
# Send queued messages if reconnecting
|
|
if job_id in self.message_queue:
|
|
logger.info(f"Sending {len(self.message_queue[job_id])} queued messages for job {job_id}")
|
|
for message in self.message_queue[job_id]:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Failed to send queued message: {e}")
|
|
break
|
|
else:
|
|
# Clear queue if all messages sent successfully
|
|
del self.message_queue[job_id]
|
|
|
|
# Send current progress if available
|
|
if job_id in self.job_progress:
|
|
await self.send_current_progress(websocket, job_id)
|
|
|
|
logger.info(f"WebSocket connected. Job ID: {job_id}, Total connections: {len(self.all_connections)}")
|
|
|
|
async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None):
|
|
"""Connect a WebSocket for chat functionality (Story 4.6)."""
|
|
await websocket.accept()
|
|
|
|
# Add to all connections
|
|
self.all_connections.add(websocket)
|
|
|
|
# Add connection metadata for chat
|
|
self.connection_metadata[websocket] = {
|
|
"connected_at": datetime.utcnow(),
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"connection_type": "chat",
|
|
"last_ping": datetime.utcnow()
|
|
}
|
|
|
|
# Add to chat-specific connections
|
|
if session_id not in self.chat_connections:
|
|
self.chat_connections[session_id] = []
|
|
self.chat_connections[session_id].append(websocket)
|
|
|
|
logger.info(f"Chat WebSocket connected. Session ID: {session_id}, User ID: {user_id}, Total connections: {len(self.all_connections)}")
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Remove a WebSocket connection."""
|
|
# Remove from all connections
|
|
self.all_connections.discard(websocket)
|
|
|
|
# Get connection info from metadata before removal
|
|
metadata = self.connection_metadata.get(websocket, {})
|
|
job_id = metadata.get("job_id")
|
|
session_id = metadata.get("session_id")
|
|
connection_type = metadata.get("connection_type")
|
|
|
|
# Remove from job-specific connections
|
|
if job_id and job_id in self.active_connections:
|
|
if websocket in self.active_connections[job_id]:
|
|
self.active_connections[job_id].remove(websocket)
|
|
|
|
# Clean up empty job connection lists
|
|
if not self.active_connections[job_id]:
|
|
del self.active_connections[job_id]
|
|
|
|
# Remove from chat-specific connections
|
|
if session_id and session_id in self.chat_connections:
|
|
if websocket in self.chat_connections[session_id]:
|
|
self.chat_connections[session_id].remove(websocket)
|
|
|
|
# Clean up empty chat connection lists
|
|
if not self.chat_connections[session_id]:
|
|
del self.chat_connections[session_id]
|
|
|
|
# Remove metadata
|
|
self.connection_metadata.pop(websocket, None)
|
|
|
|
print(f"WebSocket disconnected. Job ID: {job_id}, Remaining connections: {len(self.all_connections)}")
|
|
|
|
async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket):
|
|
"""Send a message to a specific WebSocket connection."""
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending personal message: {e}")
|
|
# Connection might be closed, remove it
|
|
self.disconnect(websocket)
|
|
|
|
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
|
|
"""Send progress update to all connections listening to a specific job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
# Extract video context from progress_data if available
|
|
video_context = progress_data.get('video_context', {})
|
|
|
|
message = {
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": progress_data,
|
|
"video_title": video_context.get('title'),
|
|
"video_id": video_context.get('video_id'),
|
|
"display_name": video_context.get('display_name')
|
|
}
|
|
|
|
# Send to all connections for this job
|
|
connections = self.active_connections[job_id].copy() # Copy to avoid modification during iteration
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending progress update to {job_id}: {e}")
|
|
# Remove broken connection
|
|
self.disconnect(websocket)
|
|
|
|
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
|
|
"""Send completion notification for a job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
# Extract video context from result_data if available
|
|
video_metadata = result_data.get('video_metadata', {})
|
|
|
|
message = {
|
|
"type": "completion_notification",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": result_data,
|
|
"video_title": video_metadata.get('title'),
|
|
"video_id": result_data.get('video_id'),
|
|
"display_name": result_data.get('display_name')
|
|
}
|
|
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending completion notification to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
|
|
"""Send error notification for a job."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
# Extract video context from error_data if available
|
|
video_context = error_data.get('video_context', {})
|
|
|
|
message = {
|
|
"type": "error_notification",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": error_data,
|
|
"video_title": video_context.get('title'),
|
|
"video_id": video_context.get('video_id'),
|
|
"display_name": video_context.get('display_name')
|
|
}
|
|
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending error notification to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def broadcast_system_message(self, message_data: Dict[str, Any]):
|
|
"""Broadcast a system message to all connected clients."""
|
|
message = {
|
|
"type": "system_message",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": message_data
|
|
}
|
|
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error broadcasting system message: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]):
|
|
"""Send a chat message to all connections in a chat session (Story 4.6)."""
|
|
if session_id not in self.chat_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "message",
|
|
"session_id": session_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": message_data
|
|
}
|
|
|
|
# Send to all connections for this chat session
|
|
connections = self.chat_connections[session_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Error sending chat message to {session_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool):
|
|
"""Send typing indicator to chat session (Story 4.6)."""
|
|
if session_id not in self.chat_connections:
|
|
return
|
|
|
|
# Update typing state
|
|
if session_id not in self.chat_typing:
|
|
self.chat_typing[session_id] = set()
|
|
|
|
if is_typing:
|
|
self.chat_typing[session_id].add(user_id)
|
|
else:
|
|
self.chat_typing[session_id].discard(user_id)
|
|
|
|
message = {
|
|
"type": "typing_start" if is_typing else "typing_end",
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Send to all connections in the chat session except the typer
|
|
connections = self.chat_connections[session_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
# Don't send typing indicator back to the person typing
|
|
ws_metadata = self.connection_metadata.get(websocket, {})
|
|
if ws_metadata.get("user_id") != user_id:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Error sending typing indicator to {session_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]):
|
|
"""Send chat status update to session connections (Story 4.6)."""
|
|
if session_id not in self.chat_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "connection_status",
|
|
"session_id": session_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": status_data
|
|
}
|
|
|
|
connections = self.chat_connections[session_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Error sending chat status to {session_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]):
|
|
"""Send live transcript chunk to job connections (Task 14.3)."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "transcript_chunk",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": chunk_data
|
|
}
|
|
|
|
# Send to all connections for this job
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
# Check if this connection has transcript streaming enabled
|
|
ws_metadata = self.connection_metadata.get(websocket, {})
|
|
if ws_metadata.get("transcript_streaming", False):
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Error sending transcript chunk to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]):
|
|
"""Send complete transcript data to job connections."""
|
|
if job_id not in self.active_connections:
|
|
return
|
|
|
|
message = {
|
|
"type": "transcript_complete",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": transcript_data
|
|
}
|
|
|
|
connections = self.active_connections[job_id].copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"Error sending complete transcript to {job_id}: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
def enable_transcript_streaming(self, websocket: WebSocket, job_id: str):
|
|
"""Enable transcript streaming for a specific connection."""
|
|
if websocket in self.connection_metadata:
|
|
self.connection_metadata[websocket]["transcript_streaming"] = True
|
|
logger.info(f"Enabled transcript streaming for job {job_id}")
|
|
|
|
def disable_transcript_streaming(self, websocket: WebSocket, job_id: str):
|
|
"""Disable transcript streaming for a specific connection."""
|
|
if websocket in self.connection_metadata:
|
|
self.connection_metadata[websocket]["transcript_streaming"] = False
|
|
logger.info(f"Disabled transcript streaming for job {job_id}")
|
|
|
|
async def send_heartbeat(self):
|
|
"""Send heartbeat to all connections to keep them alive."""
|
|
message = {
|
|
"type": "heartbeat",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_text(json.dumps(message))
|
|
except Exception as e:
|
|
print(f"Error sending heartbeat: {e}")
|
|
self.disconnect(websocket)
|
|
|
|
def get_connection_stats(self) -> Dict[str, Any]:
|
|
"""Get connection statistics."""
|
|
job_connection_counts = {
|
|
job_id: len(connections)
|
|
for job_id, connections in self.active_connections.items()
|
|
}
|
|
|
|
chat_connection_counts = {
|
|
session_id: len(connections)
|
|
for session_id, connections in self.chat_connections.items()
|
|
}
|
|
|
|
return {
|
|
"total_connections": len(self.all_connections),
|
|
"job_connections": job_connection_counts,
|
|
"chat_connections": chat_connection_counts,
|
|
"active_jobs": list(self.active_connections.keys()),
|
|
"active_chat_sessions": list(self.chat_connections.keys()),
|
|
"typing_sessions": {
|
|
session_id: list(typing_users)
|
|
for session_id, typing_users in self.chat_typing.items()
|
|
if typing_users
|
|
}
|
|
}
|
|
|
|
async def cleanup_stale_connections(self):
|
|
"""Clean up stale connections by sending a ping."""
|
|
connections = self.all_connections.copy()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.ping()
|
|
except Exception:
|
|
# Connection is stale, remove it
|
|
self.disconnect(websocket)
|
|
|
|
async def send_current_progress(self, websocket: WebSocket, job_id: str):
|
|
"""Send current progress state to a reconnecting client."""
|
|
if job_id in self.job_progress:
|
|
progress = self.job_progress[job_id]
|
|
message = {
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"stage": progress.stage.value,
|
|
"percentage": progress.percentage,
|
|
"message": progress.message,
|
|
"time_elapsed": progress.time_elapsed,
|
|
"estimated_remaining": progress.estimated_remaining,
|
|
"sub_progress": progress.sub_progress,
|
|
"details": progress.details
|
|
}
|
|
}
|
|
await self.send_personal_message(message, websocket)
|
|
|
|
def update_job_progress(self, job_id: str, progress_data: ProgressData):
|
|
"""Update job progress tracking."""
|
|
self.job_progress[job_id] = progress_data
|
|
|
|
# Track start time if not already tracked
|
|
if job_id not in self.job_start_times:
|
|
self.job_start_times[job_id] = datetime.utcnow()
|
|
|
|
# Store in message queue if no active connections
|
|
if job_id not in self.active_connections or not self.active_connections[job_id]:
|
|
if job_id not in self.message_queue:
|
|
self.message_queue[job_id] = []
|
|
|
|
# Limit queue size to prevent memory issues
|
|
if len(self.message_queue[job_id]) < 100:
|
|
self.message_queue[job_id].append({
|
|
"type": "progress_update",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"stage": progress_data.stage.value,
|
|
"percentage": progress_data.percentage,
|
|
"message": progress_data.message,
|
|
"time_elapsed": progress_data.time_elapsed,
|
|
"estimated_remaining": progress_data.estimated_remaining,
|
|
"sub_progress": progress_data.sub_progress,
|
|
"details": progress_data.details
|
|
}
|
|
})
|
|
|
|
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
|
|
"""Estimate remaining processing time based on history."""
|
|
if job_id not in self.job_start_times or current_percentage <= 0:
|
|
return None
|
|
|
|
elapsed = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
|
|
|
|
if current_percentage >= 100:
|
|
return 0
|
|
|
|
# Estimate based on current progress rate
|
|
rate = elapsed / current_percentage
|
|
remaining_percentage = 100 - current_percentage
|
|
estimated_remaining = rate * remaining_percentage
|
|
|
|
# Adjust based on historical data if available
|
|
if self.processing_history:
|
|
avg_total_time = sum(h.get('total_time', 0) for h in self.processing_history[-10:]) / min(len(self.processing_history), 10)
|
|
if avg_total_time > 0:
|
|
# Weighted average of current estimate and historical average
|
|
historical_remaining = avg_total_time - elapsed
|
|
if historical_remaining > 0:
|
|
estimated_remaining = (estimated_remaining * 0.7 + historical_remaining * 0.3)
|
|
|
|
return max(0, estimated_remaining)
|
|
|
|
def record_job_completion(self, job_id: str):
|
|
"""Record job completion time for future estimations."""
|
|
if job_id in self.job_start_times:
|
|
total_time = (datetime.utcnow() - self.job_start_times[job_id]).total_seconds()
|
|
self.processing_history.append({
|
|
"job_id": job_id,
|
|
"total_time": total_time,
|
|
"completed_at": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
# Keep only last 100 records
|
|
if len(self.processing_history) > 100:
|
|
self.processing_history = self.processing_history[-100:]
|
|
|
|
# Clean up tracking
|
|
del self.job_start_times[job_id]
|
|
if job_id in self.job_progress:
|
|
del self.job_progress[job_id]
|
|
if job_id in self.message_queue:
|
|
del self.message_queue[job_id]
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Main WebSocket manager with singleton pattern."""
|
|
|
|
_instance = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(WebSocketManager, cls).__new__(cls)
|
|
cls._instance.connection_manager = ConnectionManager()
|
|
cls._instance._heartbeat_task = None
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if not hasattr(self, 'connection_manager'):
|
|
self.connection_manager = ConnectionManager()
|
|
self._heartbeat_task = None
|
|
|
|
async def connect(self, websocket: WebSocket, job_id: Optional[str] = None):
|
|
"""Connect a WebSocket for job updates."""
|
|
await self.connection_manager.connect(websocket, job_id)
|
|
|
|
# Start heartbeat task if not running
|
|
if self._heartbeat_task is None or self._heartbeat_task.done():
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
|
|
async def connect_chat(self, websocket: WebSocket, session_id: str, user_id: Optional[str] = None):
|
|
"""Connect a WebSocket for chat functionality (Story 4.6)."""
|
|
await self.connection_manager.connect_chat(websocket, session_id, user_id)
|
|
|
|
# Start heartbeat task if not running
|
|
if self._heartbeat_task is None or self._heartbeat_task.done():
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Disconnect a WebSocket."""
|
|
self.connection_manager.disconnect(websocket)
|
|
|
|
async def send_progress_update(self, job_id: str, progress_data: Dict[str, Any]):
|
|
"""Send progress update for a job."""
|
|
await self.connection_manager.send_progress_update(job_id, progress_data)
|
|
|
|
async def send_completion_notification(self, job_id: str, result_data: Dict[str, Any]):
|
|
"""Send completion notification for a job."""
|
|
await self.connection_manager.send_completion_notification(job_id, result_data)
|
|
|
|
async def send_error_notification(self, job_id: str, error_data: Dict[str, Any]):
|
|
"""Send error notification for a job."""
|
|
await self.connection_manager.send_error_notification(job_id, error_data)
|
|
|
|
async def broadcast_system_message(self, message_data: Dict[str, Any]):
|
|
"""Broadcast system message to all connections."""
|
|
await self.connection_manager.broadcast_system_message(message_data)
|
|
|
|
async def send_chat_message(self, session_id: str, message_data: Dict[str, Any]):
|
|
"""Send chat message to all connections in a session (Story 4.6)."""
|
|
await self.connection_manager.send_chat_message(session_id, message_data)
|
|
|
|
async def send_typing_indicator(self, session_id: str, user_id: str, is_typing: bool):
|
|
"""Send typing indicator to chat session (Story 4.6)."""
|
|
await self.connection_manager.send_typing_indicator(session_id, user_id, is_typing)
|
|
|
|
async def send_chat_status(self, session_id: str, status_data: Dict[str, Any]):
|
|
"""Send status update to chat session (Story 4.6)."""
|
|
await self.connection_manager.send_chat_status(session_id, status_data)
|
|
|
|
async def send_transcript_chunk(self, job_id: str, chunk_data: Dict[str, Any]):
|
|
"""Send live transcript chunk to job connections (Task 14.3)."""
|
|
await self.connection_manager.send_transcript_chunk(job_id, chunk_data)
|
|
|
|
async def send_transcript_complete(self, job_id: str, transcript_data: Dict[str, Any]):
|
|
"""Send complete transcript data to job connections."""
|
|
await self.connection_manager.send_transcript_complete(job_id, transcript_data)
|
|
|
|
def enable_transcript_streaming(self, websocket: WebSocket, job_id: str):
|
|
"""Enable transcript streaming for a connection."""
|
|
self.connection_manager.enable_transcript_streaming(websocket, job_id)
|
|
|
|
def disable_transcript_streaming(self, websocket: WebSocket, job_id: str):
|
|
"""Disable transcript streaming for a connection."""
|
|
self.connection_manager.disable_transcript_streaming(websocket, job_id)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get WebSocket connection statistics."""
|
|
return self.connection_manager.get_connection_stats()
|
|
|
|
def update_job_progress(self, job_id: str, progress_data: ProgressData):
|
|
"""Update and track job progress."""
|
|
self.connection_manager.update_job_progress(job_id, progress_data)
|
|
|
|
def estimate_remaining_time(self, job_id: str, current_percentage: float) -> Optional[float]:
|
|
"""Estimate remaining processing time."""
|
|
return self.connection_manager.estimate_remaining_time(job_id, current_percentage)
|
|
|
|
def record_job_completion(self, job_id: str):
|
|
"""Record job completion for time estimation."""
|
|
self.connection_manager.record_job_completion(job_id)
|
|
|
|
async def _heartbeat_loop(self):
|
|
"""Background task to send periodic heartbeats."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(30) # Send heartbeat every 30 seconds
|
|
await self.connection_manager.send_heartbeat()
|
|
await self.connection_manager.cleanup_stale_connections()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
print(f"Error in heartbeat loop: {e}")
|
|
|
|
|
|
# Global WebSocket manager instance
|
|
websocket_manager = WebSocketManager() |