""" Integration tests for WebSocket real-time features. """ import pytest import time import json from unittest.mock import Mock, patch, MagicMock from threading import Thread from src.api import create_app from src.api.websocket_enhanced import ProcessingStage, JobManager class TestWebSocketRealTimeUpdates: """Test real-time WebSocket updates during processing.""" def test_full_processing_workflow(self, socketio_client): """Test complete processing workflow with WebSocket updates.""" job_id = 'test-job-workflow' # Join job room socketio_client.emit('join_job', {'job_id': job_id}) socketio_client.get_received() # Clear initial messages # Simulate processing stages stages = [ ('initializing', 5), ('loading', 10), ('transcription', 30), ('detection', 60), ('censoring', 80), ('finalizing', 95) ] for stage, progress in stages: socketio_client.emit('job_progress', { 'job_id': job_id, 'stage': stage, 'overall_progress': progress, 'message': f'Processing: {stage}' }) time.sleep(0.1) # Small delay between stages # Complete the job socketio_client.emit('job_completed', { 'job_id': job_id, 'output_file': 'output.mp3', 'summary': { 'words_detected': 10, 'words_censored': 8, 'duration': 30.5 } }) # Get all received messages received = socketio_client.get_received() # Verify we received progress updates progress_messages = [ msg for msg in received if msg['name'] == 'job_progress' ] assert len(progress_messages) >= len(stages) # Verify completion message completion_messages = [ msg for msg in received if msg['name'] == 'job_completed' ] assert len(completion_messages) == 1 assert completion_messages[0]['args'][0]['job_id'] == job_id def test_concurrent_job_updates(self, socketio_client): """Test handling concurrent job updates.""" job_ids = ['job1', 'job2', 'job3'] # Join multiple job rooms for job_id in job_ids: socketio_client.emit('join_job', {'job_id': job_id}) socketio_client.get_received() # Clear initial messages # Send updates for all jobs for i, job_id in enumerate(job_ids): socketio_client.emit('job_progress', { 'job_id': job_id, 'overall_progress': (i + 1) * 25, 'stage': 'transcription' }) received = socketio_client.get_received() # Should receive updates for all jobs progress_updates = [ msg for msg in received if msg['name'] == 'job_progress' ] assert len(progress_updates) == len(job_ids) # Verify each job's update received_job_ids = [ msg['args'][0]['job_id'] for msg in progress_updates ] assert set(received_job_ids) == set(job_ids) def test_error_handling_in_websocket(self, socketio_client): """Test error handling through WebSocket.""" job_id = 'error-job' # Join job room socketio_client.emit('join_job', {'job_id': job_id}) socketio_client.get_received() # Send error socketio_client.emit('job_error', { 'job_id': job_id, 'error_type': 'transcription_failed', 'error_message': 'Failed to transcribe audio', 'recoverable': False }) received = socketio_client.get_received() # Should receive error message error_messages = [ msg for msg in received if msg['name'] == 'job_error' ] assert len(error_messages) == 1 assert error_messages[0]['args'][0]['error_type'] == 'transcription_failed' def test_batch_processing_updates(self, socketio_client): """Test batch processing with WebSocket updates.""" batch_id = 'batch-123' job_ids = ['batch-job1', 'batch-job2', 'batch-job3'] # Join batch room socketio_client.emit('join_batch', {'batch_id': batch_id}) socketio_client.get_received() # Process each file in batch for i, job_id in enumerate(job_ids): # File start socketio_client.emit('batch_file_start', { 'batch_id': batch_id, 'job_id': job_id, 'file_index': i, 'total_files': len(job_ids), 'filename': f'file{i+1}.mp3' }) # File progress for progress in [25, 50, 75, 100]: socketio_client.emit('batch_file_progress', { 'batch_id': batch_id, 'job_id': job_id, 'file_progress': progress, 'overall_progress': (i * 100 + progress) / len(job_ids) }) time.sleep(0.05) # File complete socketio_client.emit('batch_file_complete', { 'batch_id': batch_id, 'job_id': job_id, 'file_index': i, 'results': { 'words_detected': 5, 'words_censored': 4 } }) # Batch complete socketio_client.emit('batch_complete', { 'batch_id': batch_id, 'total_files': len(job_ids), 'successful': len(job_ids), 'failed': 0 }) received = socketio_client.get_received() # Verify batch messages batch_complete = [ msg for msg in received if msg['name'] == 'batch_complete' ] assert len(batch_complete) == 1 assert batch_complete[0]['args'][0]['successful'] == 3 class TestWebSocketJobManagement: """Test WebSocket job management features.""" def test_get_active_jobs(self, socketio_client): """Test getting list of active jobs.""" # Request active jobs socketio_client.emit('get_active_jobs', {}) received = socketio_client.get_received() # Should receive active jobs list active_jobs_msgs = [ msg for msg in received if msg['name'] == 'active_jobs' ] assert len(active_jobs_msgs) >= 1 if active_jobs_msgs: jobs = active_jobs_msgs[0]['args'][0]['jobs'] assert isinstance(jobs, list) def test_cancel_job_via_websocket(self, socketio_client): """Test canceling a job through WebSocket.""" job_id = 'cancel-test-job' # Start a job socketio_client.emit('join_job', {'job_id': job_id}) # Send some progress socketio_client.emit('job_progress', { 'job_id': job_id, 'overall_progress': 30, 'stage': 'transcription' }) # Cancel the job socketio_client.emit('cancel_job', {'job_id': job_id}) received = socketio_client.get_received() # Should receive cancellation confirmation cancelled_msgs = [ msg for msg in received if msg['name'] == 'job_cancelled' ] assert len(cancelled_msgs) >= 1 assert cancelled_msgs[0]['args'][0]['job_id'] == job_id def test_reconnection_handling(self, socketio_client): """Test handling of client reconnection.""" job_id = 'reconnect-test' # Join job room socketio_client.emit('join_job', {'job_id': job_id}) # Simulate disconnect and reconnect socketio_client.disconnect() time.sleep(0.1) socketio_client.connect() # Rejoin job room after reconnection socketio_client.emit('join_job', {'job_id': job_id}) # Should be able to receive updates socketio_client.emit('job_progress', { 'job_id': job_id, 'overall_progress': 50, 'stage': 'detection' }) received = socketio_client.get_received() # Verify we can still receive updates after reconnection progress_msgs = [ msg for msg in received if msg['name'] == 'job_progress' ] assert len(progress_msgs) >= 1 class TestWebSocketPerformance: """Test WebSocket performance and throttling.""" def test_progress_throttling(self, socketio_client): """Test that progress updates are throttled.""" job_id = 'throttle-test' socketio_client.emit('join_job', {'job_id': job_id}) socketio_client.get_received() # Clear # Send many rapid updates for i in range(100): socketio_client.emit('job_progress', { 'job_id': job_id, 'overall_progress': i, 'stage': 'transcription' }) received = socketio_client.get_received() # Should receive fewer messages due to throttling progress_msgs = [ msg for msg in received if msg['name'] == 'job_progress' ] # Exact number depends on throttle settings # But should be significantly less than 100 assert len(progress_msgs) < 50 def test_multiple_clients_same_job(self, app): """Test multiple clients monitoring same job.""" socketio = app.socketio # Create multiple test clients clients = [ socketio.test_client(app) for _ in range(5) ] job_id = 'multi-client-job' # All clients join same job for client in clients: client.emit('join_job', {'job_id': job_id}) # Clear initial messages for client in clients: client.get_received() # Send progress update clients[0].emit('job_progress', { 'job_id': job_id, 'overall_progress': 75, 'stage': 'censoring' }) # All clients should receive the update for client in clients: received = client.get_received() progress_msgs = [ msg for msg in received if msg['name'] == 'job_progress' ] assert len(progress_msgs) >= 1 # Cleanup for client in clients: client.disconnect() class TestWebSocketSecurity: """Test WebSocket security features.""" def test_unauthorized_room_access(self, socketio_client): """Test that clients can't join unauthorized rooms.""" # Try to join a room without proper authorization socketio_client.emit('join_job', { 'job_id': 'unauthorized-job', 'token': 'invalid-token' }) received = socketio_client.get_received() # Should receive error or no confirmation # Depends on security implementation join_confirmations = [ msg for msg in received if msg['name'] == 'joined_job' ] # If security is implemented, should not join # This test depends on actual security implementation assert len(join_confirmations) >= 0 def test_rate_limiting_websocket(self, socketio_client): """Test WebSocket message rate limiting.""" # Send many messages rapidly for i in range(1000): socketio_client.emit('get_active_jobs', {}) received = socketio_client.get_received() # Should not process all messages if rate limiting is active # Exact behavior depends on rate limiting implementation assert len(received) <= 1000 def test_message_validation(self, socketio_client): """Test that invalid messages are rejected.""" # Send invalid message format socketio_client.emit('job_progress', { 'invalid_field': 'test', # Missing required fields }) received = socketio_client.get_received() # Should receive error or no processing error_msgs = [ msg for msg in received if msg['name'] == 'error' ] # Depends on validation implementation assert len(error_msgs) >= 0 class TestWebSocketMetrics: """Test WebSocket metrics and monitoring.""" def test_connection_metrics(self, app): """Test tracking of WebSocket connections.""" socketio = app.socketio # Create multiple connections clients = [] for i in range(10): client = socketio.test_client(app) clients.append(client) # Get metrics (if implemented) # This depends on actual metrics implementation # Cleanup for client in clients: client.disconnect() def test_message_metrics(self, socketio_client): """Test tracking of message metrics.""" # Send various message types message_types = [ ('join_job', {'job_id': 'metrics-test'}), ('get_active_jobs', {}), ('job_progress', { 'job_id': 'metrics-test', 'overall_progress': 50 }) ] for msg_type, data in message_types: socketio_client.emit(msg_type, data) # Metrics should be tracked (implementation dependent) received = socketio_client.get_received() assert len(received) >= 0 if __name__ == '__main__': pytest.main([__file__, '-v'])