clean-tracks/tests/integration/test_websocket_integration.py

442 lines
14 KiB
Python

"""
Integration tests for WebSocket real-time features.
"""
import pytest
import time
import json
from unittest.mock import Mock, patch, MagicMock
from threading import Thread
from src.api import create_app
from src.api.websocket_enhanced import ProcessingStage, JobManager
class TestWebSocketRealTimeUpdates:
"""Test real-time WebSocket updates during processing."""
def test_full_processing_workflow(self, socketio_client):
"""Test complete processing workflow with WebSocket updates."""
job_id = 'test-job-workflow'
# Join job room
socketio_client.emit('join_job', {'job_id': job_id})
socketio_client.get_received() # Clear initial messages
# Simulate processing stages
stages = [
('initializing', 5),
('loading', 10),
('transcription', 30),
('detection', 60),
('censoring', 80),
('finalizing', 95)
]
for stage, progress in stages:
socketio_client.emit('job_progress', {
'job_id': job_id,
'stage': stage,
'overall_progress': progress,
'message': f'Processing: {stage}'
})
time.sleep(0.1) # Small delay between stages
# Complete the job
socketio_client.emit('job_completed', {
'job_id': job_id,
'output_file': 'output.mp3',
'summary': {
'words_detected': 10,
'words_censored': 8,
'duration': 30.5
}
})
# Get all received messages
received = socketio_client.get_received()
# Verify we received progress updates
progress_messages = [
msg for msg in received
if msg['name'] == 'job_progress'
]
assert len(progress_messages) >= len(stages)
# Verify completion message
completion_messages = [
msg for msg in received
if msg['name'] == 'job_completed'
]
assert len(completion_messages) == 1
assert completion_messages[0]['args'][0]['job_id'] == job_id
def test_concurrent_job_updates(self, socketio_client):
"""Test handling concurrent job updates."""
job_ids = ['job1', 'job2', 'job3']
# Join multiple job rooms
for job_id in job_ids:
socketio_client.emit('join_job', {'job_id': job_id})
socketio_client.get_received() # Clear initial messages
# Send updates for all jobs
for i, job_id in enumerate(job_ids):
socketio_client.emit('job_progress', {
'job_id': job_id,
'overall_progress': (i + 1) * 25,
'stage': 'transcription'
})
received = socketio_client.get_received()
# Should receive updates for all jobs
progress_updates = [
msg for msg in received
if msg['name'] == 'job_progress'
]
assert len(progress_updates) == len(job_ids)
# Verify each job's update
received_job_ids = [
msg['args'][0]['job_id']
for msg in progress_updates
]
assert set(received_job_ids) == set(job_ids)
def test_error_handling_in_websocket(self, socketio_client):
"""Test error handling through WebSocket."""
job_id = 'error-job'
# Join job room
socketio_client.emit('join_job', {'job_id': job_id})
socketio_client.get_received()
# Send error
socketio_client.emit('job_error', {
'job_id': job_id,
'error_type': 'transcription_failed',
'error_message': 'Failed to transcribe audio',
'recoverable': False
})
received = socketio_client.get_received()
# Should receive error message
error_messages = [
msg for msg in received
if msg['name'] == 'job_error'
]
assert len(error_messages) == 1
assert error_messages[0]['args'][0]['error_type'] == 'transcription_failed'
def test_batch_processing_updates(self, socketio_client):
"""Test batch processing with WebSocket updates."""
batch_id = 'batch-123'
job_ids = ['batch-job1', 'batch-job2', 'batch-job3']
# Join batch room
socketio_client.emit('join_batch', {'batch_id': batch_id})
socketio_client.get_received()
# Process each file in batch
for i, job_id in enumerate(job_ids):
# File start
socketio_client.emit('batch_file_start', {
'batch_id': batch_id,
'job_id': job_id,
'file_index': i,
'total_files': len(job_ids),
'filename': f'file{i+1}.mp3'
})
# File progress
for progress in [25, 50, 75, 100]:
socketio_client.emit('batch_file_progress', {
'batch_id': batch_id,
'job_id': job_id,
'file_progress': progress,
'overall_progress': (i * 100 + progress) / len(job_ids)
})
time.sleep(0.05)
# File complete
socketio_client.emit('batch_file_complete', {
'batch_id': batch_id,
'job_id': job_id,
'file_index': i,
'results': {
'words_detected': 5,
'words_censored': 4
}
})
# Batch complete
socketio_client.emit('batch_complete', {
'batch_id': batch_id,
'total_files': len(job_ids),
'successful': len(job_ids),
'failed': 0
})
received = socketio_client.get_received()
# Verify batch messages
batch_complete = [
msg for msg in received
if msg['name'] == 'batch_complete'
]
assert len(batch_complete) == 1
assert batch_complete[0]['args'][0]['successful'] == 3
class TestWebSocketJobManagement:
"""Test WebSocket job management features."""
def test_get_active_jobs(self, socketio_client):
"""Test getting list of active jobs."""
# Request active jobs
socketio_client.emit('get_active_jobs', {})
received = socketio_client.get_received()
# Should receive active jobs list
active_jobs_msgs = [
msg for msg in received
if msg['name'] == 'active_jobs'
]
assert len(active_jobs_msgs) >= 1
if active_jobs_msgs:
jobs = active_jobs_msgs[0]['args'][0]['jobs']
assert isinstance(jobs, list)
def test_cancel_job_via_websocket(self, socketio_client):
"""Test canceling a job through WebSocket."""
job_id = 'cancel-test-job'
# Start a job
socketio_client.emit('join_job', {'job_id': job_id})
# Send some progress
socketio_client.emit('job_progress', {
'job_id': job_id,
'overall_progress': 30,
'stage': 'transcription'
})
# Cancel the job
socketio_client.emit('cancel_job', {'job_id': job_id})
received = socketio_client.get_received()
# Should receive cancellation confirmation
cancelled_msgs = [
msg for msg in received
if msg['name'] == 'job_cancelled'
]
assert len(cancelled_msgs) >= 1
assert cancelled_msgs[0]['args'][0]['job_id'] == job_id
def test_reconnection_handling(self, socketio_client):
"""Test handling of client reconnection."""
job_id = 'reconnect-test'
# Join job room
socketio_client.emit('join_job', {'job_id': job_id})
# Simulate disconnect and reconnect
socketio_client.disconnect()
time.sleep(0.1)
socketio_client.connect()
# Rejoin job room after reconnection
socketio_client.emit('join_job', {'job_id': job_id})
# Should be able to receive updates
socketio_client.emit('job_progress', {
'job_id': job_id,
'overall_progress': 50,
'stage': 'detection'
})
received = socketio_client.get_received()
# Verify we can still receive updates after reconnection
progress_msgs = [
msg for msg in received
if msg['name'] == 'job_progress'
]
assert len(progress_msgs) >= 1
class TestWebSocketPerformance:
"""Test WebSocket performance and throttling."""
def test_progress_throttling(self, socketio_client):
"""Test that progress updates are throttled."""
job_id = 'throttle-test'
socketio_client.emit('join_job', {'job_id': job_id})
socketio_client.get_received() # Clear
# Send many rapid updates
for i in range(100):
socketio_client.emit('job_progress', {
'job_id': job_id,
'overall_progress': i,
'stage': 'transcription'
})
received = socketio_client.get_received()
# Should receive fewer messages due to throttling
progress_msgs = [
msg for msg in received
if msg['name'] == 'job_progress'
]
# Exact number depends on throttle settings
# But should be significantly less than 100
assert len(progress_msgs) < 50
def test_multiple_clients_same_job(self, app):
"""Test multiple clients monitoring same job."""
socketio = app.socketio
# Create multiple test clients
clients = [
socketio.test_client(app)
for _ in range(5)
]
job_id = 'multi-client-job'
# All clients join same job
for client in clients:
client.emit('join_job', {'job_id': job_id})
# Clear initial messages
for client in clients:
client.get_received()
# Send progress update
clients[0].emit('job_progress', {
'job_id': job_id,
'overall_progress': 75,
'stage': 'censoring'
})
# All clients should receive the update
for client in clients:
received = client.get_received()
progress_msgs = [
msg for msg in received
if msg['name'] == 'job_progress'
]
assert len(progress_msgs) >= 1
# Cleanup
for client in clients:
client.disconnect()
class TestWebSocketSecurity:
"""Test WebSocket security features."""
def test_unauthorized_room_access(self, socketio_client):
"""Test that clients can't join unauthorized rooms."""
# Try to join a room without proper authorization
socketio_client.emit('join_job', {
'job_id': 'unauthorized-job',
'token': 'invalid-token'
})
received = socketio_client.get_received()
# Should receive error or no confirmation
# Depends on security implementation
join_confirmations = [
msg for msg in received
if msg['name'] == 'joined_job'
]
# If security is implemented, should not join
# This test depends on actual security implementation
assert len(join_confirmations) >= 0
def test_rate_limiting_websocket(self, socketio_client):
"""Test WebSocket message rate limiting."""
# Send many messages rapidly
for i in range(1000):
socketio_client.emit('get_active_jobs', {})
received = socketio_client.get_received()
# Should not process all messages if rate limiting is active
# Exact behavior depends on rate limiting implementation
assert len(received) <= 1000
def test_message_validation(self, socketio_client):
"""Test that invalid messages are rejected."""
# Send invalid message format
socketio_client.emit('job_progress', {
'invalid_field': 'test',
# Missing required fields
})
received = socketio_client.get_received()
# Should receive error or no processing
error_msgs = [
msg for msg in received
if msg['name'] == 'error'
]
# Depends on validation implementation
assert len(error_msgs) >= 0
class TestWebSocketMetrics:
"""Test WebSocket metrics and monitoring."""
def test_connection_metrics(self, app):
"""Test tracking of WebSocket connections."""
socketio = app.socketio
# Create multiple connections
clients = []
for i in range(10):
client = socketio.test_client(app)
clients.append(client)
# Get metrics (if implemented)
# This depends on actual metrics implementation
# Cleanup
for client in clients:
client.disconnect()
def test_message_metrics(self, socketio_client):
"""Test tracking of message metrics."""
# Send various message types
message_types = [
('join_job', {'job_id': 'metrics-test'}),
('get_active_jobs', {}),
('job_progress', {
'job_id': 'metrics-test',
'overall_progress': 50
})
]
for msg_type, data in message_types:
socketio_client.emit(msg_type, data)
# Metrics should be tracked (implementation dependent)
received = socketio_client.get_received()
assert len(received) >= 0
if __name__ == '__main__':
pytest.main([__file__, '-v'])