453 lines
14 KiB
Python
453 lines
14 KiB
Python
"""
|
|
Unit tests for WebSocket functionality.
|
|
"""
|
|
|
|
import pytest
|
|
import time
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from datetime import datetime
|
|
|
|
from src.api.websocket_enhanced import (
|
|
ProcessingStage,
|
|
JobMetrics,
|
|
JobManager,
|
|
AdvancedProgressTracker,
|
|
BatchProgressTracker,
|
|
create_enhanced_websocket_handlers
|
|
)
|
|
|
|
|
|
class TestProcessingStage:
|
|
"""Test ProcessingStage enum."""
|
|
|
|
def test_stage_properties(self):
|
|
"""Test stage properties."""
|
|
stage = ProcessingStage.TRANSCRIPTION
|
|
|
|
assert stage.stage_name == 'transcription'
|
|
assert stage.start_pct == 10
|
|
assert stage.end_pct == 50
|
|
|
|
def test_calculate_overall_progress(self):
|
|
"""Test overall progress calculation."""
|
|
stage = ProcessingStage.TRANSCRIPTION
|
|
|
|
# 0% of stage
|
|
assert stage.calculate_overall_progress(0) == 10
|
|
|
|
# 50% of stage
|
|
assert stage.calculate_overall_progress(50) == 30
|
|
|
|
# 100% of stage
|
|
assert stage.calculate_overall_progress(100) == 50
|
|
|
|
def test_all_stages_coverage(self):
|
|
"""Test that all stages cover 0-100%."""
|
|
stages = list(ProcessingStage)
|
|
|
|
# Check first stage starts at 0
|
|
assert stages[0].start_pct == 0
|
|
|
|
# Check last stage ends at 100
|
|
assert stages[-1].end_pct == 100
|
|
|
|
# Check stages are continuous
|
|
for i in range(len(stages) - 1):
|
|
assert stages[i].end_pct <= stages[i + 1].start_pct
|
|
|
|
|
|
class TestJobMetrics:
|
|
"""Test JobMetrics dataclass."""
|
|
|
|
def test_job_metrics_creation(self):
|
|
"""Test creating job metrics."""
|
|
metrics = JobMetrics(
|
|
job_id='test-job',
|
|
start_time=time.time(),
|
|
current_stage='transcription',
|
|
overall_progress=25.0,
|
|
stage_progress=50.0
|
|
)
|
|
|
|
assert metrics.job_id == 'test-job'
|
|
assert metrics.current_stage == 'transcription'
|
|
assert metrics.overall_progress == 25.0
|
|
assert metrics.stage_progress == 50.0
|
|
assert metrics.files_processed == 0
|
|
assert metrics.total_files == 1
|
|
|
|
def test_job_metrics_to_dict(self):
|
|
"""Test converting metrics to dictionary."""
|
|
start_time = time.time()
|
|
metrics = JobMetrics(
|
|
job_id='test-job',
|
|
start_time=start_time,
|
|
current_stage='detection',
|
|
overall_progress=60.0,
|
|
stage_progress=75.0,
|
|
words_detected=10,
|
|
words_censored=8
|
|
)
|
|
|
|
data = metrics.to_dict()
|
|
|
|
assert data['job_id'] == 'test-job'
|
|
assert data['current_stage'] == 'detection'
|
|
assert data['overall_progress'] == 60.0
|
|
assert data['words_detected'] == 10
|
|
assert data['words_censored'] == 8
|
|
assert 'elapsed_time' in data
|
|
assert 'timestamp' in data
|
|
|
|
|
|
class TestJobManager:
|
|
"""Test JobManager class."""
|
|
|
|
def test_create_job(self):
|
|
"""Test creating a new job."""
|
|
manager = JobManager()
|
|
|
|
# Create job with auto-generated ID
|
|
job_id = manager.create_job()
|
|
assert job_id is not None
|
|
assert len(manager.jobs) == 1
|
|
|
|
# Create job with specific ID
|
|
custom_id = 'custom-job-id'
|
|
job_id2 = manager.create_job(custom_id)
|
|
assert job_id2 == custom_id
|
|
assert len(manager.jobs) == 2
|
|
|
|
def test_update_job(self):
|
|
"""Test updating job metrics."""
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
# Update job
|
|
updated = manager.update_job(
|
|
job_id,
|
|
current_stage='detection',
|
|
overall_progress=50.0,
|
|
words_detected=5
|
|
)
|
|
|
|
assert updated is not None
|
|
assert updated.current_stage == 'detection'
|
|
assert updated.overall_progress == 50.0
|
|
assert updated.words_detected == 5
|
|
|
|
def test_update_nonexistent_job(self):
|
|
"""Test updating non-existent job."""
|
|
manager = JobManager()
|
|
|
|
result = manager.update_job('nonexistent', overall_progress=50.0)
|
|
assert result is None
|
|
|
|
def test_get_job(self):
|
|
"""Test getting job metrics."""
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job is not None
|
|
assert job.job_id == job_id
|
|
|
|
# Get non-existent job
|
|
assert manager.get_job('nonexistent') is None
|
|
|
|
def test_remove_job(self):
|
|
"""Test removing a job."""
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
# Remove existing job
|
|
assert manager.remove_job(job_id) is True
|
|
assert len(manager.jobs) == 0
|
|
|
|
# Remove non-existent job
|
|
assert manager.remove_job('nonexistent') is False
|
|
|
|
def test_get_active_jobs(self):
|
|
"""Test getting all active jobs."""
|
|
manager = JobManager()
|
|
|
|
# Create multiple jobs
|
|
job1 = manager.create_job()
|
|
job2 = manager.create_job()
|
|
job3 = manager.create_job()
|
|
|
|
active = manager.get_active_jobs()
|
|
assert len(active) == 3
|
|
assert all(isinstance(job, JobMetrics) for job in active)
|
|
|
|
def test_estimated_time_calculation(self):
|
|
"""Test estimated time remaining calculation."""
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
# Simulate some progress
|
|
time.sleep(0.1)
|
|
manager.update_job(job_id, overall_progress=25.0)
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job.estimated_time_remaining > 0
|
|
|
|
|
|
class TestAdvancedProgressTracker:
|
|
"""Test AdvancedProgressTracker class."""
|
|
|
|
def test_tracker_initialization(self):
|
|
"""Test tracker initialization."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(
|
|
socketio=socketio,
|
|
job_manager=manager,
|
|
job_id=job_id,
|
|
debug_mode=True,
|
|
emit_interval=1.0
|
|
)
|
|
|
|
assert tracker.job_id == job_id
|
|
assert tracker.debug_mode is True
|
|
assert tracker.emit_interval == 1.0
|
|
assert tracker.current_stage == ProcessingStage.INITIALIZING
|
|
|
|
def test_change_stage(self):
|
|
"""Test changing processing stage."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(socketio, manager, job_id)
|
|
|
|
# Change to transcription stage
|
|
tracker.change_stage(ProcessingStage.TRANSCRIPTION, "Starting transcription")
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job.current_stage == 'transcription'
|
|
assert job.overall_progress == ProcessingStage.TRANSCRIPTION.start_pct
|
|
|
|
# Verify emit was called
|
|
socketio.emit.assert_called()
|
|
|
|
def test_update_stage_progress(self):
|
|
"""Test updating progress within a stage."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(socketio, manager, job_id)
|
|
tracker.change_stage(ProcessingStage.DETECTION)
|
|
|
|
# Update progress
|
|
tracker.update_stage_progress(
|
|
percent=50.0,
|
|
message="Detecting words...",
|
|
details={'words_detected': 10}
|
|
)
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job.stage_progress == 50.0
|
|
assert job.words_detected == 10
|
|
assert job.overall_progress == ProcessingStage.DETECTION.calculate_overall_progress(50.0)
|
|
|
|
def test_emit_throttling(self):
|
|
"""Test that emit is throttled."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(
|
|
socketio=socketio,
|
|
job_manager=manager,
|
|
job_id=job_id,
|
|
emit_interval=1.0
|
|
)
|
|
|
|
# First update should emit
|
|
tracker.update_stage_progress(10.0)
|
|
assert socketio.emit.call_count == 1
|
|
|
|
# Immediate second update should not emit (throttled)
|
|
tracker.update_stage_progress(20.0)
|
|
assert socketio.emit.call_count == 1
|
|
|
|
# After waiting, should emit again
|
|
time.sleep(1.1)
|
|
tracker.update_stage_progress(30.0)
|
|
assert socketio.emit.call_count == 2
|
|
|
|
def test_emit_completed(self):
|
|
"""Test emitting completion event."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(socketio, manager, job_id)
|
|
|
|
# Emit completion
|
|
tracker.emit_completed(
|
|
output_filename='output.mp3',
|
|
summary={'total_words': 10, 'duration': 30.0}
|
|
)
|
|
|
|
# Verify job was removed
|
|
assert manager.get_job(job_id) is None
|
|
|
|
# Verify emit was called
|
|
socketio.emit.assert_called()
|
|
call_args = socketio.emit.call_args[0]
|
|
assert call_args[0] == 'job_completed'
|
|
|
|
def test_emit_error(self):
|
|
"""Test emitting error event."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = AdvancedProgressTracker(socketio, manager, job_id)
|
|
|
|
# Emit recoverable error
|
|
tracker.emit_error(
|
|
error_type='network_timeout',
|
|
error_message='Connection timeout',
|
|
recoverable=True,
|
|
retry_suggestion='Check network and retry'
|
|
)
|
|
|
|
# Job should still exist (recoverable)
|
|
assert manager.get_job(job_id) is not None
|
|
|
|
# Emit non-recoverable error
|
|
tracker.emit_error(
|
|
error_type='file_corrupted',
|
|
error_message='File is corrupted',
|
|
recoverable=False
|
|
)
|
|
|
|
# Job should be removed (non-recoverable)
|
|
assert manager.get_job(job_id) is None
|
|
|
|
|
|
class TestBatchProgressTracker:
|
|
"""Test BatchProgressTracker class."""
|
|
|
|
def test_batch_tracker_initialization(self):
|
|
"""Test batch tracker initialization."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = BatchProgressTracker(
|
|
socketio=socketio,
|
|
job_manager=manager,
|
|
job_id=job_id,
|
|
total_files=5,
|
|
debug_mode=False
|
|
)
|
|
|
|
assert tracker.total_files == 5
|
|
assert tracker.current_file_index == 0
|
|
assert tracker.file_progress_weight == 0.2 # 1/5
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job.total_files == 5
|
|
|
|
def test_start_file(self):
|
|
"""Test starting a new file in batch."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = BatchProgressTracker(socketio, manager, job_id, total_files=3)
|
|
|
|
# Start first file
|
|
tracker.start_file(0, 'file1.mp3')
|
|
|
|
job = manager.get_job(job_id)
|
|
assert job.files_processed == 0
|
|
|
|
# Check emit was called with correct message
|
|
socketio.emit.assert_called()
|
|
call_args = socketio.emit.call_args[0]
|
|
data = call_args[1]
|
|
assert 'Processing file 1/3' in data['message']
|
|
|
|
def test_batch_progress_calculation(self):
|
|
"""Test batch progress calculation."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
job_id = manager.create_job()
|
|
|
|
tracker = BatchProgressTracker(socketio, manager, job_id, total_files=4)
|
|
|
|
# First file at 50%
|
|
tracker.current_file_index = 0
|
|
progress = tracker._calculate_batch_progress(50.0)
|
|
assert progress == 12.5 # (0 * 100 + 50) / 4
|
|
|
|
# Second file at 75%
|
|
tracker.current_file_index = 1
|
|
progress = tracker._calculate_batch_progress(75.0)
|
|
assert progress == 43.75 # (1 * 100 + 75) / 4
|
|
|
|
# Last file at 100%
|
|
tracker.current_file_index = 3
|
|
progress = tracker._calculate_batch_progress(100.0)
|
|
assert progress == 100.0 # (3 * 100 + 100) / 4
|
|
|
|
|
|
class TestWebSocketHandlers:
|
|
"""Test WebSocket event handlers."""
|
|
|
|
@patch('src.api.websocket_enhanced.request')
|
|
def test_handle_connect(self, mock_request):
|
|
"""Test client connection handler."""
|
|
mock_request.sid = 'test-sid'
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
|
|
handlers = create_enhanced_websocket_handlers(socketio, manager)
|
|
|
|
# Simulate connection
|
|
with patch.object(socketio, 'on') as mock_on:
|
|
# Get the connect handler
|
|
handlers = {}
|
|
for call in socketio.on.call_args_list:
|
|
event_name = call[0][0]
|
|
handler = call[0][1] if len(call[0]) > 1 else None
|
|
if handler:
|
|
handlers[event_name] = handler
|
|
|
|
def test_job_room_management(self):
|
|
"""Test joining and leaving job rooms."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
|
|
# Create handlers
|
|
create_enhanced_websocket_handlers(socketio, manager)
|
|
|
|
# Test would require actual Socket.IO test client
|
|
# This is more of an integration test
|
|
pass
|
|
|
|
def test_get_active_jobs_handler(self):
|
|
"""Test getting active jobs."""
|
|
socketio = Mock()
|
|
manager = JobManager()
|
|
|
|
# Create some jobs
|
|
job1 = manager.create_job()
|
|
job2 = manager.create_job()
|
|
|
|
handlers = create_enhanced_websocket_handlers(socketio, manager)
|
|
|
|
# Test would require actual Socket.IO test client
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v']) |