clean-tracks/tests/unit/test_websocket.py

453 lines
14 KiB
Python

"""
Unit tests for WebSocket functionality.
"""
import pytest
import time
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from src.api.websocket_enhanced import (
ProcessingStage,
JobMetrics,
JobManager,
AdvancedProgressTracker,
BatchProgressTracker,
create_enhanced_websocket_handlers
)
class TestProcessingStage:
"""Test ProcessingStage enum."""
def test_stage_properties(self):
"""Test stage properties."""
stage = ProcessingStage.TRANSCRIPTION
assert stage.stage_name == 'transcription'
assert stage.start_pct == 10
assert stage.end_pct == 50
def test_calculate_overall_progress(self):
"""Test overall progress calculation."""
stage = ProcessingStage.TRANSCRIPTION
# 0% of stage
assert stage.calculate_overall_progress(0) == 10
# 50% of stage
assert stage.calculate_overall_progress(50) == 30
# 100% of stage
assert stage.calculate_overall_progress(100) == 50
def test_all_stages_coverage(self):
"""Test that all stages cover 0-100%."""
stages = list(ProcessingStage)
# Check first stage starts at 0
assert stages[0].start_pct == 0
# Check last stage ends at 100
assert stages[-1].end_pct == 100
# Check stages are continuous
for i in range(len(stages) - 1):
assert stages[i].end_pct <= stages[i + 1].start_pct
class TestJobMetrics:
"""Test JobMetrics dataclass."""
def test_job_metrics_creation(self):
"""Test creating job metrics."""
metrics = JobMetrics(
job_id='test-job',
start_time=time.time(),
current_stage='transcription',
overall_progress=25.0,
stage_progress=50.0
)
assert metrics.job_id == 'test-job'
assert metrics.current_stage == 'transcription'
assert metrics.overall_progress == 25.0
assert metrics.stage_progress == 50.0
assert metrics.files_processed == 0
assert metrics.total_files == 1
def test_job_metrics_to_dict(self):
"""Test converting metrics to dictionary."""
start_time = time.time()
metrics = JobMetrics(
job_id='test-job',
start_time=start_time,
current_stage='detection',
overall_progress=60.0,
stage_progress=75.0,
words_detected=10,
words_censored=8
)
data = metrics.to_dict()
assert data['job_id'] == 'test-job'
assert data['current_stage'] == 'detection'
assert data['overall_progress'] == 60.0
assert data['words_detected'] == 10
assert data['words_censored'] == 8
assert 'elapsed_time' in data
assert 'timestamp' in data
class TestJobManager:
"""Test JobManager class."""
def test_create_job(self):
"""Test creating a new job."""
manager = JobManager()
# Create job with auto-generated ID
job_id = manager.create_job()
assert job_id is not None
assert len(manager.jobs) == 1
# Create job with specific ID
custom_id = 'custom-job-id'
job_id2 = manager.create_job(custom_id)
assert job_id2 == custom_id
assert len(manager.jobs) == 2
def test_update_job(self):
"""Test updating job metrics."""
manager = JobManager()
job_id = manager.create_job()
# Update job
updated = manager.update_job(
job_id,
current_stage='detection',
overall_progress=50.0,
words_detected=5
)
assert updated is not None
assert updated.current_stage == 'detection'
assert updated.overall_progress == 50.0
assert updated.words_detected == 5
def test_update_nonexistent_job(self):
"""Test updating non-existent job."""
manager = JobManager()
result = manager.update_job('nonexistent', overall_progress=50.0)
assert result is None
def test_get_job(self):
"""Test getting job metrics."""
manager = JobManager()
job_id = manager.create_job()
job = manager.get_job(job_id)
assert job is not None
assert job.job_id == job_id
# Get non-existent job
assert manager.get_job('nonexistent') is None
def test_remove_job(self):
"""Test removing a job."""
manager = JobManager()
job_id = manager.create_job()
# Remove existing job
assert manager.remove_job(job_id) is True
assert len(manager.jobs) == 0
# Remove non-existent job
assert manager.remove_job('nonexistent') is False
def test_get_active_jobs(self):
"""Test getting all active jobs."""
manager = JobManager()
# Create multiple jobs
job1 = manager.create_job()
job2 = manager.create_job()
job3 = manager.create_job()
active = manager.get_active_jobs()
assert len(active) == 3
assert all(isinstance(job, JobMetrics) for job in active)
def test_estimated_time_calculation(self):
"""Test estimated time remaining calculation."""
manager = JobManager()
job_id = manager.create_job()
# Simulate some progress
time.sleep(0.1)
manager.update_job(job_id, overall_progress=25.0)
job = manager.get_job(job_id)
assert job.estimated_time_remaining > 0
class TestAdvancedProgressTracker:
"""Test AdvancedProgressTracker class."""
def test_tracker_initialization(self):
"""Test tracker initialization."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(
socketio=socketio,
job_manager=manager,
job_id=job_id,
debug_mode=True,
emit_interval=1.0
)
assert tracker.job_id == job_id
assert tracker.debug_mode is True
assert tracker.emit_interval == 1.0
assert tracker.current_stage == ProcessingStage.INITIALIZING
def test_change_stage(self):
"""Test changing processing stage."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(socketio, manager, job_id)
# Change to transcription stage
tracker.change_stage(ProcessingStage.TRANSCRIPTION, "Starting transcription")
job = manager.get_job(job_id)
assert job.current_stage == 'transcription'
assert job.overall_progress == ProcessingStage.TRANSCRIPTION.start_pct
# Verify emit was called
socketio.emit.assert_called()
def test_update_stage_progress(self):
"""Test updating progress within a stage."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(socketio, manager, job_id)
tracker.change_stage(ProcessingStage.DETECTION)
# Update progress
tracker.update_stage_progress(
percent=50.0,
message="Detecting words...",
details={'words_detected': 10}
)
job = manager.get_job(job_id)
assert job.stage_progress == 50.0
assert job.words_detected == 10
assert job.overall_progress == ProcessingStage.DETECTION.calculate_overall_progress(50.0)
def test_emit_throttling(self):
"""Test that emit is throttled."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(
socketio=socketio,
job_manager=manager,
job_id=job_id,
emit_interval=1.0
)
# First update should emit
tracker.update_stage_progress(10.0)
assert socketio.emit.call_count == 1
# Immediate second update should not emit (throttled)
tracker.update_stage_progress(20.0)
assert socketio.emit.call_count == 1
# After waiting, should emit again
time.sleep(1.1)
tracker.update_stage_progress(30.0)
assert socketio.emit.call_count == 2
def test_emit_completed(self):
"""Test emitting completion event."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(socketio, manager, job_id)
# Emit completion
tracker.emit_completed(
output_filename='output.mp3',
summary={'total_words': 10, 'duration': 30.0}
)
# Verify job was removed
assert manager.get_job(job_id) is None
# Verify emit was called
socketio.emit.assert_called()
call_args = socketio.emit.call_args[0]
assert call_args[0] == 'job_completed'
def test_emit_error(self):
"""Test emitting error event."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = AdvancedProgressTracker(socketio, manager, job_id)
# Emit recoverable error
tracker.emit_error(
error_type='network_timeout',
error_message='Connection timeout',
recoverable=True,
retry_suggestion='Check network and retry'
)
# Job should still exist (recoverable)
assert manager.get_job(job_id) is not None
# Emit non-recoverable error
tracker.emit_error(
error_type='file_corrupted',
error_message='File is corrupted',
recoverable=False
)
# Job should be removed (non-recoverable)
assert manager.get_job(job_id) is None
class TestBatchProgressTracker:
"""Test BatchProgressTracker class."""
def test_batch_tracker_initialization(self):
"""Test batch tracker initialization."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = BatchProgressTracker(
socketio=socketio,
job_manager=manager,
job_id=job_id,
total_files=5,
debug_mode=False
)
assert tracker.total_files == 5
assert tracker.current_file_index == 0
assert tracker.file_progress_weight == 0.2 # 1/5
job = manager.get_job(job_id)
assert job.total_files == 5
def test_start_file(self):
"""Test starting a new file in batch."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = BatchProgressTracker(socketio, manager, job_id, total_files=3)
# Start first file
tracker.start_file(0, 'file1.mp3')
job = manager.get_job(job_id)
assert job.files_processed == 0
# Check emit was called with correct message
socketio.emit.assert_called()
call_args = socketio.emit.call_args[0]
data = call_args[1]
assert 'Processing file 1/3' in data['message']
def test_batch_progress_calculation(self):
"""Test batch progress calculation."""
socketio = Mock()
manager = JobManager()
job_id = manager.create_job()
tracker = BatchProgressTracker(socketio, manager, job_id, total_files=4)
# First file at 50%
tracker.current_file_index = 0
progress = tracker._calculate_batch_progress(50.0)
assert progress == 12.5 # (0 * 100 + 50) / 4
# Second file at 75%
tracker.current_file_index = 1
progress = tracker._calculate_batch_progress(75.0)
assert progress == 43.75 # (1 * 100 + 75) / 4
# Last file at 100%
tracker.current_file_index = 3
progress = tracker._calculate_batch_progress(100.0)
assert progress == 100.0 # (3 * 100 + 100) / 4
class TestWebSocketHandlers:
"""Test WebSocket event handlers."""
@patch('src.api.websocket_enhanced.request')
def test_handle_connect(self, mock_request):
"""Test client connection handler."""
mock_request.sid = 'test-sid'
socketio = Mock()
manager = JobManager()
handlers = create_enhanced_websocket_handlers(socketio, manager)
# Simulate connection
with patch.object(socketio, 'on') as mock_on:
# Get the connect handler
handlers = {}
for call in socketio.on.call_args_list:
event_name = call[0][0]
handler = call[0][1] if len(call[0]) > 1 else None
if handler:
handlers[event_name] = handler
def test_job_room_management(self):
"""Test joining and leaving job rooms."""
socketio = Mock()
manager = JobManager()
# Create handlers
create_enhanced_websocket_handlers(socketio, manager)
# Test would require actual Socket.IO test client
# This is more of an integration test
pass
def test_get_active_jobs_handler(self):
"""Test getting active jobs."""
socketio = Mock()
manager = JobManager()
# Create some jobs
job1 = manager.create_job()
job2 = manager.create_job()
handlers = create_enhanced_websocket_handlers(socketio, manager)
# Test would require actual Socket.IO test client
pass
if __name__ == '__main__':
pytest.main([__file__, '-v'])