""" Unit tests for WebSocket functionality. """ import pytest import time from unittest.mock import Mock, patch, MagicMock from datetime import datetime from src.api.websocket_enhanced import ( ProcessingStage, JobMetrics, JobManager, AdvancedProgressTracker, BatchProgressTracker, create_enhanced_websocket_handlers ) class TestProcessingStage: """Test ProcessingStage enum.""" def test_stage_properties(self): """Test stage properties.""" stage = ProcessingStage.TRANSCRIPTION assert stage.stage_name == 'transcription' assert stage.start_pct == 10 assert stage.end_pct == 50 def test_calculate_overall_progress(self): """Test overall progress calculation.""" stage = ProcessingStage.TRANSCRIPTION # 0% of stage assert stage.calculate_overall_progress(0) == 10 # 50% of stage assert stage.calculate_overall_progress(50) == 30 # 100% of stage assert stage.calculate_overall_progress(100) == 50 def test_all_stages_coverage(self): """Test that all stages cover 0-100%.""" stages = list(ProcessingStage) # Check first stage starts at 0 assert stages[0].start_pct == 0 # Check last stage ends at 100 assert stages[-1].end_pct == 100 # Check stages are continuous for i in range(len(stages) - 1): assert stages[i].end_pct <= stages[i + 1].start_pct class TestJobMetrics: """Test JobMetrics dataclass.""" def test_job_metrics_creation(self): """Test creating job metrics.""" metrics = JobMetrics( job_id='test-job', start_time=time.time(), current_stage='transcription', overall_progress=25.0, stage_progress=50.0 ) assert metrics.job_id == 'test-job' assert metrics.current_stage == 'transcription' assert metrics.overall_progress == 25.0 assert metrics.stage_progress == 50.0 assert metrics.files_processed == 0 assert metrics.total_files == 1 def test_job_metrics_to_dict(self): """Test converting metrics to dictionary.""" start_time = time.time() metrics = JobMetrics( job_id='test-job', start_time=start_time, current_stage='detection', overall_progress=60.0, stage_progress=75.0, words_detected=10, words_censored=8 ) data = metrics.to_dict() assert data['job_id'] == 'test-job' assert data['current_stage'] == 'detection' assert data['overall_progress'] == 60.0 assert data['words_detected'] == 10 assert data['words_censored'] == 8 assert 'elapsed_time' in data assert 'timestamp' in data class TestJobManager: """Test JobManager class.""" def test_create_job(self): """Test creating a new job.""" manager = JobManager() # Create job with auto-generated ID job_id = manager.create_job() assert job_id is not None assert len(manager.jobs) == 1 # Create job with specific ID custom_id = 'custom-job-id' job_id2 = manager.create_job(custom_id) assert job_id2 == custom_id assert len(manager.jobs) == 2 def test_update_job(self): """Test updating job metrics.""" manager = JobManager() job_id = manager.create_job() # Update job updated = manager.update_job( job_id, current_stage='detection', overall_progress=50.0, words_detected=5 ) assert updated is not None assert updated.current_stage == 'detection' assert updated.overall_progress == 50.0 assert updated.words_detected == 5 def test_update_nonexistent_job(self): """Test updating non-existent job.""" manager = JobManager() result = manager.update_job('nonexistent', overall_progress=50.0) assert result is None def test_get_job(self): """Test getting job metrics.""" manager = JobManager() job_id = manager.create_job() job = manager.get_job(job_id) assert job is not None assert job.job_id == job_id # Get non-existent job assert manager.get_job('nonexistent') is None def test_remove_job(self): """Test removing a job.""" manager = JobManager() job_id = manager.create_job() # Remove existing job assert manager.remove_job(job_id) is True assert len(manager.jobs) == 0 # Remove non-existent job assert manager.remove_job('nonexistent') is False def test_get_active_jobs(self): """Test getting all active jobs.""" manager = JobManager() # Create multiple jobs job1 = manager.create_job() job2 = manager.create_job() job3 = manager.create_job() active = manager.get_active_jobs() assert len(active) == 3 assert all(isinstance(job, JobMetrics) for job in active) def test_estimated_time_calculation(self): """Test estimated time remaining calculation.""" manager = JobManager() job_id = manager.create_job() # Simulate some progress time.sleep(0.1) manager.update_job(job_id, overall_progress=25.0) job = manager.get_job(job_id) assert job.estimated_time_remaining > 0 class TestAdvancedProgressTracker: """Test AdvancedProgressTracker class.""" def test_tracker_initialization(self): """Test tracker initialization.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker( socketio=socketio, job_manager=manager, job_id=job_id, debug_mode=True, emit_interval=1.0 ) assert tracker.job_id == job_id assert tracker.debug_mode is True assert tracker.emit_interval == 1.0 assert tracker.current_stage == ProcessingStage.INITIALIZING def test_change_stage(self): """Test changing processing stage.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker(socketio, manager, job_id) # Change to transcription stage tracker.change_stage(ProcessingStage.TRANSCRIPTION, "Starting transcription") job = manager.get_job(job_id) assert job.current_stage == 'transcription' assert job.overall_progress == ProcessingStage.TRANSCRIPTION.start_pct # Verify emit was called socketio.emit.assert_called() def test_update_stage_progress(self): """Test updating progress within a stage.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker(socketio, manager, job_id) tracker.change_stage(ProcessingStage.DETECTION) # Update progress tracker.update_stage_progress( percent=50.0, message="Detecting words...", details={'words_detected': 10} ) job = manager.get_job(job_id) assert job.stage_progress == 50.0 assert job.words_detected == 10 assert job.overall_progress == ProcessingStage.DETECTION.calculate_overall_progress(50.0) def test_emit_throttling(self): """Test that emit is throttled.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker( socketio=socketio, job_manager=manager, job_id=job_id, emit_interval=1.0 ) # First update should emit tracker.update_stage_progress(10.0) assert socketio.emit.call_count == 1 # Immediate second update should not emit (throttled) tracker.update_stage_progress(20.0) assert socketio.emit.call_count == 1 # After waiting, should emit again time.sleep(1.1) tracker.update_stage_progress(30.0) assert socketio.emit.call_count == 2 def test_emit_completed(self): """Test emitting completion event.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker(socketio, manager, job_id) # Emit completion tracker.emit_completed( output_filename='output.mp3', summary={'total_words': 10, 'duration': 30.0} ) # Verify job was removed assert manager.get_job(job_id) is None # Verify emit was called socketio.emit.assert_called() call_args = socketio.emit.call_args[0] assert call_args[0] == 'job_completed' def test_emit_error(self): """Test emitting error event.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = AdvancedProgressTracker(socketio, manager, job_id) # Emit recoverable error tracker.emit_error( error_type='network_timeout', error_message='Connection timeout', recoverable=True, retry_suggestion='Check network and retry' ) # Job should still exist (recoverable) assert manager.get_job(job_id) is not None # Emit non-recoverable error tracker.emit_error( error_type='file_corrupted', error_message='File is corrupted', recoverable=False ) # Job should be removed (non-recoverable) assert manager.get_job(job_id) is None class TestBatchProgressTracker: """Test BatchProgressTracker class.""" def test_batch_tracker_initialization(self): """Test batch tracker initialization.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = BatchProgressTracker( socketio=socketio, job_manager=manager, job_id=job_id, total_files=5, debug_mode=False ) assert tracker.total_files == 5 assert tracker.current_file_index == 0 assert tracker.file_progress_weight == 0.2 # 1/5 job = manager.get_job(job_id) assert job.total_files == 5 def test_start_file(self): """Test starting a new file in batch.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = BatchProgressTracker(socketio, manager, job_id, total_files=3) # Start first file tracker.start_file(0, 'file1.mp3') job = manager.get_job(job_id) assert job.files_processed == 0 # Check emit was called with correct message socketio.emit.assert_called() call_args = socketio.emit.call_args[0] data = call_args[1] assert 'Processing file 1/3' in data['message'] def test_batch_progress_calculation(self): """Test batch progress calculation.""" socketio = Mock() manager = JobManager() job_id = manager.create_job() tracker = BatchProgressTracker(socketio, manager, job_id, total_files=4) # First file at 50% tracker.current_file_index = 0 progress = tracker._calculate_batch_progress(50.0) assert progress == 12.5 # (0 * 100 + 50) / 4 # Second file at 75% tracker.current_file_index = 1 progress = tracker._calculate_batch_progress(75.0) assert progress == 43.75 # (1 * 100 + 75) / 4 # Last file at 100% tracker.current_file_index = 3 progress = tracker._calculate_batch_progress(100.0) assert progress == 100.0 # (3 * 100 + 100) / 4 class TestWebSocketHandlers: """Test WebSocket event handlers.""" @patch('src.api.websocket_enhanced.request') def test_handle_connect(self, mock_request): """Test client connection handler.""" mock_request.sid = 'test-sid' socketio = Mock() manager = JobManager() handlers = create_enhanced_websocket_handlers(socketio, manager) # Simulate connection with patch.object(socketio, 'on') as mock_on: # Get the connect handler handlers = {} for call in socketio.on.call_args_list: event_name = call[0][0] handler = call[0][1] if len(call[0]) > 1 else None if handler: handlers[event_name] = handler def test_job_room_management(self): """Test joining and leaving job rooms.""" socketio = Mock() manager = JobManager() # Create handlers create_enhanced_websocket_handlers(socketio, manager) # Test would require actual Socket.IO test client # This is more of an integration test pass def test_get_active_jobs_handler(self): """Test getting active jobs.""" socketio = Mock() manager = JobManager() # Create some jobs job1 = manager.create_job() job2 = manager.create_job() handlers = create_enhanced_websocket_handlers(socketio, manager) # Test would require actual Socket.IO test client pass if __name__ == '__main__': pytest.main([__file__, '-v'])