442 lines
14 KiB
Python
442 lines
14 KiB
Python
"""
|
|
Integration tests for WebSocket real-time features.
|
|
"""
|
|
|
|
import pytest
|
|
import time
|
|
import json
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from threading import Thread
|
|
|
|
from src.api import create_app
|
|
from src.api.websocket_enhanced import ProcessingStage, JobManager
|
|
|
|
|
|
class TestWebSocketRealTimeUpdates:
|
|
"""Test real-time WebSocket updates during processing."""
|
|
|
|
def test_full_processing_workflow(self, socketio_client):
|
|
"""Test complete processing workflow with WebSocket updates."""
|
|
job_id = 'test-job-workflow'
|
|
|
|
# Join job room
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
socketio_client.get_received() # Clear initial messages
|
|
|
|
# Simulate processing stages
|
|
stages = [
|
|
('initializing', 5),
|
|
('loading', 10),
|
|
('transcription', 30),
|
|
('detection', 60),
|
|
('censoring', 80),
|
|
('finalizing', 95)
|
|
]
|
|
|
|
for stage, progress in stages:
|
|
socketio_client.emit('job_progress', {
|
|
'job_id': job_id,
|
|
'stage': stage,
|
|
'overall_progress': progress,
|
|
'message': f'Processing: {stage}'
|
|
})
|
|
time.sleep(0.1) # Small delay between stages
|
|
|
|
# Complete the job
|
|
socketio_client.emit('job_completed', {
|
|
'job_id': job_id,
|
|
'output_file': 'output.mp3',
|
|
'summary': {
|
|
'words_detected': 10,
|
|
'words_censored': 8,
|
|
'duration': 30.5
|
|
}
|
|
})
|
|
|
|
# Get all received messages
|
|
received = socketio_client.get_received()
|
|
|
|
# Verify we received progress updates
|
|
progress_messages = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_progress'
|
|
]
|
|
assert len(progress_messages) >= len(stages)
|
|
|
|
# Verify completion message
|
|
completion_messages = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_completed'
|
|
]
|
|
assert len(completion_messages) == 1
|
|
assert completion_messages[0]['args'][0]['job_id'] == job_id
|
|
|
|
def test_concurrent_job_updates(self, socketio_client):
|
|
"""Test handling concurrent job updates."""
|
|
job_ids = ['job1', 'job2', 'job3']
|
|
|
|
# Join multiple job rooms
|
|
for job_id in job_ids:
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
|
|
socketio_client.get_received() # Clear initial messages
|
|
|
|
# Send updates for all jobs
|
|
for i, job_id in enumerate(job_ids):
|
|
socketio_client.emit('job_progress', {
|
|
'job_id': job_id,
|
|
'overall_progress': (i + 1) * 25,
|
|
'stage': 'transcription'
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive updates for all jobs
|
|
progress_updates = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_progress'
|
|
]
|
|
assert len(progress_updates) == len(job_ids)
|
|
|
|
# Verify each job's update
|
|
received_job_ids = [
|
|
msg['args'][0]['job_id']
|
|
for msg in progress_updates
|
|
]
|
|
assert set(received_job_ids) == set(job_ids)
|
|
|
|
def test_error_handling_in_websocket(self, socketio_client):
|
|
"""Test error handling through WebSocket."""
|
|
job_id = 'error-job'
|
|
|
|
# Join job room
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
socketio_client.get_received()
|
|
|
|
# Send error
|
|
socketio_client.emit('job_error', {
|
|
'job_id': job_id,
|
|
'error_type': 'transcription_failed',
|
|
'error_message': 'Failed to transcribe audio',
|
|
'recoverable': False
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive error message
|
|
error_messages = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_error'
|
|
]
|
|
assert len(error_messages) == 1
|
|
assert error_messages[0]['args'][0]['error_type'] == 'transcription_failed'
|
|
|
|
def test_batch_processing_updates(self, socketio_client):
|
|
"""Test batch processing with WebSocket updates."""
|
|
batch_id = 'batch-123'
|
|
job_ids = ['batch-job1', 'batch-job2', 'batch-job3']
|
|
|
|
# Join batch room
|
|
socketio_client.emit('join_batch', {'batch_id': batch_id})
|
|
socketio_client.get_received()
|
|
|
|
# Process each file in batch
|
|
for i, job_id in enumerate(job_ids):
|
|
# File start
|
|
socketio_client.emit('batch_file_start', {
|
|
'batch_id': batch_id,
|
|
'job_id': job_id,
|
|
'file_index': i,
|
|
'total_files': len(job_ids),
|
|
'filename': f'file{i+1}.mp3'
|
|
})
|
|
|
|
# File progress
|
|
for progress in [25, 50, 75, 100]:
|
|
socketio_client.emit('batch_file_progress', {
|
|
'batch_id': batch_id,
|
|
'job_id': job_id,
|
|
'file_progress': progress,
|
|
'overall_progress': (i * 100 + progress) / len(job_ids)
|
|
})
|
|
time.sleep(0.05)
|
|
|
|
# File complete
|
|
socketio_client.emit('batch_file_complete', {
|
|
'batch_id': batch_id,
|
|
'job_id': job_id,
|
|
'file_index': i,
|
|
'results': {
|
|
'words_detected': 5,
|
|
'words_censored': 4
|
|
}
|
|
})
|
|
|
|
# Batch complete
|
|
socketio_client.emit('batch_complete', {
|
|
'batch_id': batch_id,
|
|
'total_files': len(job_ids),
|
|
'successful': len(job_ids),
|
|
'failed': 0
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Verify batch messages
|
|
batch_complete = [
|
|
msg for msg in received
|
|
if msg['name'] == 'batch_complete'
|
|
]
|
|
assert len(batch_complete) == 1
|
|
assert batch_complete[0]['args'][0]['successful'] == 3
|
|
|
|
|
|
class TestWebSocketJobManagement:
|
|
"""Test WebSocket job management features."""
|
|
|
|
def test_get_active_jobs(self, socketio_client):
|
|
"""Test getting list of active jobs."""
|
|
# Request active jobs
|
|
socketio_client.emit('get_active_jobs', {})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive active jobs list
|
|
active_jobs_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'active_jobs'
|
|
]
|
|
assert len(active_jobs_msgs) >= 1
|
|
|
|
if active_jobs_msgs:
|
|
jobs = active_jobs_msgs[0]['args'][0]['jobs']
|
|
assert isinstance(jobs, list)
|
|
|
|
def test_cancel_job_via_websocket(self, socketio_client):
|
|
"""Test canceling a job through WebSocket."""
|
|
job_id = 'cancel-test-job'
|
|
|
|
# Start a job
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
|
|
# Send some progress
|
|
socketio_client.emit('job_progress', {
|
|
'job_id': job_id,
|
|
'overall_progress': 30,
|
|
'stage': 'transcription'
|
|
})
|
|
|
|
# Cancel the job
|
|
socketio_client.emit('cancel_job', {'job_id': job_id})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive cancellation confirmation
|
|
cancelled_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_cancelled'
|
|
]
|
|
assert len(cancelled_msgs) >= 1
|
|
assert cancelled_msgs[0]['args'][0]['job_id'] == job_id
|
|
|
|
def test_reconnection_handling(self, socketio_client):
|
|
"""Test handling of client reconnection."""
|
|
job_id = 'reconnect-test'
|
|
|
|
# Join job room
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
|
|
# Simulate disconnect and reconnect
|
|
socketio_client.disconnect()
|
|
time.sleep(0.1)
|
|
socketio_client.connect()
|
|
|
|
# Rejoin job room after reconnection
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
|
|
# Should be able to receive updates
|
|
socketio_client.emit('job_progress', {
|
|
'job_id': job_id,
|
|
'overall_progress': 50,
|
|
'stage': 'detection'
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Verify we can still receive updates after reconnection
|
|
progress_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_progress'
|
|
]
|
|
assert len(progress_msgs) >= 1
|
|
|
|
|
|
class TestWebSocketPerformance:
|
|
"""Test WebSocket performance and throttling."""
|
|
|
|
def test_progress_throttling(self, socketio_client):
|
|
"""Test that progress updates are throttled."""
|
|
job_id = 'throttle-test'
|
|
|
|
socketio_client.emit('join_job', {'job_id': job_id})
|
|
socketio_client.get_received() # Clear
|
|
|
|
# Send many rapid updates
|
|
for i in range(100):
|
|
socketio_client.emit('job_progress', {
|
|
'job_id': job_id,
|
|
'overall_progress': i,
|
|
'stage': 'transcription'
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive fewer messages due to throttling
|
|
progress_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_progress'
|
|
]
|
|
|
|
# Exact number depends on throttle settings
|
|
# But should be significantly less than 100
|
|
assert len(progress_msgs) < 50
|
|
|
|
def test_multiple_clients_same_job(self, app):
|
|
"""Test multiple clients monitoring same job."""
|
|
socketio = app.socketio
|
|
|
|
# Create multiple test clients
|
|
clients = [
|
|
socketio.test_client(app)
|
|
for _ in range(5)
|
|
]
|
|
|
|
job_id = 'multi-client-job'
|
|
|
|
# All clients join same job
|
|
for client in clients:
|
|
client.emit('join_job', {'job_id': job_id})
|
|
|
|
# Clear initial messages
|
|
for client in clients:
|
|
client.get_received()
|
|
|
|
# Send progress update
|
|
clients[0].emit('job_progress', {
|
|
'job_id': job_id,
|
|
'overall_progress': 75,
|
|
'stage': 'censoring'
|
|
})
|
|
|
|
# All clients should receive the update
|
|
for client in clients:
|
|
received = client.get_received()
|
|
progress_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'job_progress'
|
|
]
|
|
assert len(progress_msgs) >= 1
|
|
|
|
# Cleanup
|
|
for client in clients:
|
|
client.disconnect()
|
|
|
|
|
|
class TestWebSocketSecurity:
|
|
"""Test WebSocket security features."""
|
|
|
|
def test_unauthorized_room_access(self, socketio_client):
|
|
"""Test that clients can't join unauthorized rooms."""
|
|
# Try to join a room without proper authorization
|
|
socketio_client.emit('join_job', {
|
|
'job_id': 'unauthorized-job',
|
|
'token': 'invalid-token'
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive error or no confirmation
|
|
# Depends on security implementation
|
|
join_confirmations = [
|
|
msg for msg in received
|
|
if msg['name'] == 'joined_job'
|
|
]
|
|
|
|
# If security is implemented, should not join
|
|
# This test depends on actual security implementation
|
|
assert len(join_confirmations) >= 0
|
|
|
|
def test_rate_limiting_websocket(self, socketio_client):
|
|
"""Test WebSocket message rate limiting."""
|
|
# Send many messages rapidly
|
|
for i in range(1000):
|
|
socketio_client.emit('get_active_jobs', {})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should not process all messages if rate limiting is active
|
|
# Exact behavior depends on rate limiting implementation
|
|
assert len(received) <= 1000
|
|
|
|
def test_message_validation(self, socketio_client):
|
|
"""Test that invalid messages are rejected."""
|
|
# Send invalid message format
|
|
socketio_client.emit('job_progress', {
|
|
'invalid_field': 'test',
|
|
# Missing required fields
|
|
})
|
|
|
|
received = socketio_client.get_received()
|
|
|
|
# Should receive error or no processing
|
|
error_msgs = [
|
|
msg for msg in received
|
|
if msg['name'] == 'error'
|
|
]
|
|
|
|
# Depends on validation implementation
|
|
assert len(error_msgs) >= 0
|
|
|
|
|
|
class TestWebSocketMetrics:
|
|
"""Test WebSocket metrics and monitoring."""
|
|
|
|
def test_connection_metrics(self, app):
|
|
"""Test tracking of WebSocket connections."""
|
|
socketio = app.socketio
|
|
|
|
# Create multiple connections
|
|
clients = []
|
|
for i in range(10):
|
|
client = socketio.test_client(app)
|
|
clients.append(client)
|
|
|
|
# Get metrics (if implemented)
|
|
# This depends on actual metrics implementation
|
|
|
|
# Cleanup
|
|
for client in clients:
|
|
client.disconnect()
|
|
|
|
def test_message_metrics(self, socketio_client):
|
|
"""Test tracking of message metrics."""
|
|
# Send various message types
|
|
message_types = [
|
|
('join_job', {'job_id': 'metrics-test'}),
|
|
('get_active_jobs', {}),
|
|
('job_progress', {
|
|
'job_id': 'metrics-test',
|
|
'overall_progress': 50
|
|
})
|
|
]
|
|
|
|
for msg_type, data in message_types:
|
|
socketio_client.emit(msg_type, data)
|
|
|
|
# Metrics should be tracked (implementation dependent)
|
|
received = socketio_client.get_received()
|
|
assert len(received) >= 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v']) |