trax/tests/test_batch_processor.py

581 lines
22 KiB
Python

"""
Unit tests for batch processing system.
Tests cover:
- Async worker pool functionality
- Queue management and priority handling
- Progress tracking and reporting
- Error recovery and retry logic
- Resource monitoring
- Task processing for different types
- Pause/resume functionality
"""
import asyncio
import pytest
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import tempfile
import os
from src.services.batch_processor import (
BatchProcessor,
BatchTask,
BatchProgress,
BatchResult,
TaskType,
create_batch_processor
)
from src.services.transcription_service import TranscriptionConfig
class TestBatchTask:
"""Test BatchTask dataclass functionality."""
def test_batch_task_creation(self):
"""Test creating a batch task with all fields."""
task = BatchTask(
id="test_task_1",
task_type=TaskType.TRANSCRIBE,
data={"file_path": "/test/file.mp3"},
priority=1,
max_retries=5
)
assert task.id == "test_task_1"
assert task.task_type == TaskType.TRANSCRIBE
assert task.data["file_path"] == "/test/file.mp3"
assert task.priority == 1
assert task.max_retries == 5
assert task.retry_count == 0
assert task.created_at is not None
assert task.started_at is None
assert task.completed_at is None
assert task.error is None
assert task.result is None
def test_batch_task_defaults(self):
"""Test batch task creation with default values."""
task = BatchTask(
id="test_task_2",
task_type=TaskType.ENHANCE,
data={"transcript_id": "123"}
)
assert task.priority == 0
assert task.max_retries == 3
assert task.retry_count == 0
class TestBatchProgress:
"""Test BatchProgress dataclass functionality."""
def test_batch_progress_creation(self):
"""Test creating batch progress with initial values."""
progress = BatchProgress(total_tasks=10)
assert progress.total_tasks == 10
assert progress.completed_tasks == 0
assert progress.failed_tasks == 0
assert progress.in_progress_tasks == 0
assert progress.queued_tasks == 0
assert progress.start_time is None
assert progress.estimated_completion is None
assert progress.current_worker_count == 0
assert progress.memory_usage_mb == 0.0
assert progress.cpu_usage_percent == 0.0
def test_success_rate_calculation(self):
"""Test success rate calculation."""
progress = BatchProgress(total_tasks=10)
progress.completed_tasks = 7
progress.failed_tasks = 2
assert progress.success_rate == 70.0
def test_success_rate_zero_total(self):
"""Test success rate with zero total tasks."""
progress = BatchProgress(total_tasks=0)
assert progress.success_rate == 0.0
def test_failure_rate_calculation(self):
"""Test failure rate calculation."""
progress = BatchProgress(total_tasks=10)
progress.failed_tasks = 3
assert progress.failure_rate == 30.0
def test_elapsed_time_calculation(self):
"""Test elapsed time calculation."""
start_time = datetime.now(timezone.utc)
progress = BatchProgress(total_tasks=5)
progress.start_time = start_time
# Should be close to 0 since we just set it
elapsed = progress.elapsed_time
assert elapsed is not None
assert elapsed >= 0.0
assert elapsed < 1.0 # Should be very small
def test_elapsed_time_no_start(self):
"""Test elapsed time when start_time is None."""
progress = BatchProgress(total_tasks=5)
assert progress.elapsed_time is None
class TestBatchResult:
"""Test BatchResult dataclass functionality."""
def test_batch_result_creation(self):
"""Test creating batch result with all fields."""
result = BatchResult(
success_count=8,
failure_count=2,
total_count=10,
results=[{"status": "completed"}],
failures=[{"task_id": "1", "error": "test error"}],
processing_time=120.5,
memory_peak_mb=512.0,
cpu_peak_percent=75.0,
quality_metrics={"avg_accuracy": 95.5}
)
assert result.success_count == 8
assert result.failure_count == 2
assert result.total_count == 10
assert len(result.results) == 1
assert len(result.failures) == 1
assert result.processing_time == 120.5
assert result.memory_peak_mb == 512.0
assert result.cpu_peak_percent == 75.0
assert result.quality_metrics["avg_accuracy"] == 95.5
def test_success_rate_calculation(self):
"""Test success rate calculation in batch result."""
result = BatchResult(
success_count=9,
failure_count=1,
total_count=10,
results=[],
failures=[],
processing_time=0.0,
memory_peak_mb=0.0,
cpu_peak_percent=0.0,
quality_metrics={}
)
assert result.success_rate == 90.0
def test_success_rate_zero_total(self):
"""Test success rate with zero total count."""
result = BatchResult(
success_count=0,
failure_count=0,
total_count=0,
results=[],
failures=[],
processing_time=0.0,
memory_peak_mb=0.0,
cpu_peak_percent=0.0,
quality_metrics={}
)
assert result.success_rate == 0.0
class TestBatchProcessor:
"""Test BatchProcessor functionality."""
@pytest.fixture
def batch_processor(self):
"""Create a batch processor for testing."""
return BatchProcessor(max_workers=2, progress_interval=0.1)
@pytest.fixture
def mock_services(self):
"""Mock all required services."""
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
patch('src.services.batch_processor.create_media_service') as mock_media, \
patch('src.services.batch_processor.create_media_repository') as mock_repo:
mock_trans.return_value = AsyncMock()
mock_enhance.return_value = AsyncMock()
mock_media.return_value = AsyncMock()
mock_repo.return_value = AsyncMock()
yield {
'transcription': mock_trans.return_value,
'enhancement': mock_enhance.return_value,
'media': mock_media.return_value,
'repository': mock_repo.return_value
}
@pytest.mark.asyncio
async def test_batch_processor_initialization(self, batch_processor):
"""Test batch processor initialization."""
assert batch_processor.max_workers == 2
assert batch_processor.progress_interval == 0.1
assert not batch_processor.running
assert not batch_processor.paused
assert not batch_processor.stopped
assert batch_processor.progress.total_tasks == 0
assert len(batch_processor.workers) == 0
@pytest.mark.asyncio
async def test_add_task(self, batch_processor):
"""Test adding tasks to the queue."""
task_id = await batch_processor.add_task(
TaskType.TRANSCRIBE,
{"file_path": "/test/file.mp3"},
priority=1
)
assert task_id.startswith("task_1_transcribe")
assert batch_processor.progress.total_tasks == 1
assert batch_processor.progress.queued_tasks == 1
assert not batch_processor.task_queue.empty()
@pytest.mark.asyncio
async def test_add_multiple_tasks(self, batch_processor):
"""Test adding multiple tasks with different priorities."""
# Add tasks with different priorities
await batch_processor.add_task(TaskType.TRANSCRIBE, {"file": "1.mp3"}, priority=2)
await batch_processor.add_task(TaskType.ENHANCE, {"id": "123"}, priority=1)
await batch_processor.add_task(TaskType.YOUTUBE, {"url": "test.com"}, priority=0)
assert batch_processor.progress.total_tasks == 3
assert batch_processor.progress.queued_tasks == 3
# Check that tasks are ordered by priority (lower = higher priority)
tasks = []
while not batch_processor.task_queue.empty():
priority, task = await batch_processor.task_queue.get()
tasks.append((priority, task.task_type))
# Should be ordered by priority (0, 1, 2)
assert tasks[0][0] == 0 # YouTube task
assert tasks[1][0] == 1 # Enhance task
assert tasks[2][0] == 2 # Transcribe task
@pytest.mark.asyncio
async def test_initialize_services(self, batch_processor, mock_services):
"""Test service initialization."""
await batch_processor._initialize_services()
assert batch_processor.transcription_service is not None
assert batch_processor.enhancement_service is not None
assert batch_processor.media_service is not None
# Verify services were initialized
mock_services['transcription'].initialize.assert_called_once()
@pytest.mark.asyncio
async def test_process_transcription_task(self, batch_processor, mock_services):
"""Test processing a transcription task."""
await batch_processor._initialize_services()
task = BatchTask(
id="test_task",
task_type=TaskType.TRANSCRIBE,
data={
"file_path": "/test/file.mp3",
"config": {"model": "whisper-1"}
}
)
# Mock transcription result
mock_result = MagicMock()
mock_result.text_content = "Test transcript"
mock_result.segments = [{"text": "Test", "start": 0, "end": 1}]
mock_result.accuracy = 95.5
mock_result.processing_time = 10.0
mock_result.quality_warnings = []
mock_services['transcription'].transcribe_file.return_value = mock_result
result = await batch_processor._process_transcription(task)
assert result["status"] == "completed"
assert result["file_path"] == "/test/file.mp3"
assert result["transcript"] == "Test transcript"
assert result["accuracy"] == 95.5
assert result["processing_time"] == 10.0
mock_services['transcription'].transcribe_file.assert_called_once()
@pytest.mark.asyncio
async def test_process_enhancement_task(self, batch_processor, mock_services):
"""Test processing an enhancement task."""
await batch_processor._initialize_services()
task = BatchTask(
id="test_task",
task_type=TaskType.ENHANCE,
data={"transcript_id": "123"}
)
# Mock enhancement result
mock_result = MagicMock()
mock_result.enhanced_content = "Enhanced transcript"
mock_result.accuracy_improvement = 2.5
mock_result.processing_time = 5.0
mock_services['enhancement'].enhance_transcript.return_value = mock_result
result = await batch_processor._process_enhancement(task)
assert result["status"] == "completed"
assert result["transcript_id"] == "123"
assert result["enhanced_content"] == "Enhanced transcript"
assert result["accuracy_improvement"] == 2.5
mock_services['enhancement'].enhance_transcript.assert_called_once_with("123")
@pytest.mark.asyncio
async def test_task_retry_on_failure(self, batch_processor, mock_services):
"""Test task retry mechanism on failure."""
await batch_processor._initialize_services()
task = BatchTask(
id="test_task",
task_type=TaskType.TRANSCRIBE,
data={"file_path": "/test/file.mp3"},
max_retries=2
)
# Mock service to fail twice, then succeed
mock_services['transcription'].transcribe_file.side_effect = [
Exception("First failure"),
Exception("Second failure"),
MagicMock(text_content="Success", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[])
]
# First attempt should fail and retry
result1 = await batch_processor._process_transcription(task)
assert result1["status"] == "retrying"
assert result1["retry_count"] == 1
# Second attempt should fail and retry
result2 = await batch_processor._process_transcription(task)
assert result2["status"] == "retrying"
assert result2["retry_count"] == 2
# Third attempt should succeed
result3 = await batch_processor._process_transcription(task)
assert result3["status"] == "completed"
@pytest.mark.asyncio
async def test_task_permanent_failure(self, batch_processor, mock_services):
"""Test task permanent failure after max retries."""
await batch_processor._initialize_services()
task = BatchTask(
id="test_task",
task_type=TaskType.TRANSCRIBE,
data={"file_path": "/test/file.mp3"},
max_retries=1
)
# Mock service to always fail
mock_services['transcription'].transcribe_file.side_effect = Exception("Permanent failure")
# First attempt should retry
result1 = await batch_processor._process_transcription(task)
assert result1["status"] == "retrying"
# Second attempt should fail permanently
result2 = await batch_processor._process_transcription(task)
assert result2["status"] == "failed"
assert "Permanent failure" in result2["error"]
# Task should be in failed tasks list
assert len(batch_processor.failed_tasks) == 1
assert batch_processor.failed_tasks[0].id == "test_task"
@pytest.mark.asyncio
async def test_pause_resume_functionality(self, batch_processor):
"""Test pause and resume functionality."""
assert not batch_processor.paused
# Pause when not running should do nothing
await batch_processor.pause()
assert not batch_processor.paused
# Start the processor
batch_processor.running = True
# Pause
await batch_processor.pause()
assert batch_processor.paused
# Resume
await batch_processor.resume()
assert not batch_processor.paused
@pytest.mark.asyncio
async def test_stop_functionality(self, batch_processor):
"""Test stop functionality."""
assert not batch_processor.stopped
await batch_processor.stop()
assert batch_processor.stopped
assert not batch_processor.running
@pytest.mark.asyncio
async def test_get_progress(self, batch_processor):
"""Test getting current progress."""
progress = batch_processor.get_progress()
assert isinstance(progress, BatchProgress)
assert progress.total_tasks == 0
assert progress.completed_tasks == 0
assert progress.failed_tasks == 0
@pytest.mark.asyncio
async def test_simple_batch_processing(self, batch_processor, mock_services):
"""Test simple batch processing with one task."""
await batch_processor._initialize_services()
# Add a task
await batch_processor.add_task(
TaskType.TRANSCRIBE,
{"file_path": "/test/file.mp3"}
)
# Mock successful transcription
mock_result = MagicMock()
mock_result.text_content = "Test transcript"
mock_result.segments = []
mock_result.accuracy = 95.0
mock_result.processing_time = 10.0
mock_result.quality_warnings = []
mock_services['transcription'].transcribe_file.return_value = mock_result
# Start processing
result = await batch_processor.start()
assert result.success_count == 1
assert result.failure_count == 0
assert result.total_count == 1
assert result.success_rate == 100.0
assert len(result.results) == 1
assert len(result.failures) == 0
class TestCreateBatchProcessor:
"""Test batch processor factory function."""
def test_create_batch_processor_defaults(self):
"""Test creating batch processor with default parameters."""
processor = create_batch_processor()
assert processor.max_workers == 8
assert processor.queue_size == 1000
assert processor.progress_interval == 5.0
assert processor.memory_limit_mb == 2048.0
assert processor.cpu_limit_percent == 90.0
def test_create_batch_processor_custom(self):
"""Test creating batch processor with custom parameters."""
processor = create_batch_processor(
max_workers=4,
queue_size=500,
progress_interval=2.0,
memory_limit_mb=1024.0,
cpu_limit_percent=80.0
)
assert processor.max_workers == 4
assert processor.queue_size == 500
assert processor.progress_interval == 2.0
assert processor.memory_limit_mb == 1024.0
assert processor.cpu_limit_percent == 80.0
class TestBatchProcessorIntegration:
"""Integration tests for batch processor."""
@pytest.mark.asyncio
async def test_multiple_task_types(self):
"""Test processing multiple different task types."""
processor = BatchProcessor(max_workers=2, progress_interval=0.1)
# Mock services
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
patch('src.services.batch_processor.create_media_service') as mock_media, \
patch('src.services.batch_processor.create_media_repository') as mock_repo:
mock_trans.return_value = AsyncMock()
mock_enhance.return_value = AsyncMock()
mock_media.return_value = AsyncMock()
mock_repo.return_value = AsyncMock()
# Mock results
mock_trans.return_value.transcribe_file.return_value = MagicMock(
text_content="Transcript", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[]
)
mock_enhance.return_value.enhance_transcript.return_value = MagicMock(
enhanced_content="Enhanced", accuracy_improvement=2.0, processing_time=5.0
)
mock_media.return_value.download_media.return_value = MagicMock(
file_path=Path("/test/file.mp3"), file_size=1024, duration=60.0
)
# Add different types of tasks
await processor.add_task(TaskType.TRANSCRIBE, {"file_path": "/test1.mp3"})
await processor.add_task(TaskType.ENHANCE, {"transcript_id": "123"})
await processor.add_task(TaskType.DOWNLOAD, {"url": "https://test.com"})
# Process all tasks
result = await processor.start()
assert result.success_count == 3
assert result.failure_count == 0
assert result.total_count == 3
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_progress_callback(self):
"""Test progress callback functionality."""
processor = BatchProcessor(max_workers=1, progress_interval=0.1)
progress_updates = []
def progress_callback(progress: BatchProgress):
progress_updates.append(progress)
# Mock services
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
patch('src.services.batch_processor.create_media_service') as mock_media, \
patch('src.services.batch_processor.create_media_repository') as mock_repo:
mock_trans.return_value = AsyncMock()
mock_enhance.return_value = AsyncMock()
mock_media.return_value = AsyncMock()
mock_repo.return_value = AsyncMock()
mock_trans.return_value.transcribe_file.return_value = MagicMock(
text_content="Test", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[]
)
# Add a task
await processor.add_task(TaskType.TRANSCRIBE, {"file_path": "/test.mp3"})
# Process with callback
result = await processor.start(progress_callback=progress_callback)
# Should have received progress updates
assert len(progress_updates) > 0
# Check final progress
final_progress = progress_updates[-1]
assert final_progress.total_tasks == 1
assert final_progress.completed_tasks == 1
assert final_progress.failed_tasks == 0
assert final_progress.success_rate == 100.0