581 lines
22 KiB
Python
581 lines
22 KiB
Python
"""
|
|
Unit tests for batch processing system.
|
|
|
|
Tests cover:
|
|
- Async worker pool functionality
|
|
- Queue management and priority handling
|
|
- Progress tracking and reporting
|
|
- Error recovery and retry logic
|
|
- Resource monitoring
|
|
- Task processing for different types
|
|
- Pause/resume functionality
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import tempfile
|
|
import os
|
|
|
|
from src.services.batch_processor import (
|
|
BatchProcessor,
|
|
BatchTask,
|
|
BatchProgress,
|
|
BatchResult,
|
|
TaskType,
|
|
create_batch_processor
|
|
)
|
|
from src.services.transcription_service import TranscriptionConfig
|
|
|
|
|
|
class TestBatchTask:
|
|
"""Test BatchTask dataclass functionality."""
|
|
|
|
def test_batch_task_creation(self):
|
|
"""Test creating a batch task with all fields."""
|
|
task = BatchTask(
|
|
id="test_task_1",
|
|
task_type=TaskType.TRANSCRIBE,
|
|
data={"file_path": "/test/file.mp3"},
|
|
priority=1,
|
|
max_retries=5
|
|
)
|
|
|
|
assert task.id == "test_task_1"
|
|
assert task.task_type == TaskType.TRANSCRIBE
|
|
assert task.data["file_path"] == "/test/file.mp3"
|
|
assert task.priority == 1
|
|
assert task.max_retries == 5
|
|
assert task.retry_count == 0
|
|
assert task.created_at is not None
|
|
assert task.started_at is None
|
|
assert task.completed_at is None
|
|
assert task.error is None
|
|
assert task.result is None
|
|
|
|
def test_batch_task_defaults(self):
|
|
"""Test batch task creation with default values."""
|
|
task = BatchTask(
|
|
id="test_task_2",
|
|
task_type=TaskType.ENHANCE,
|
|
data={"transcript_id": "123"}
|
|
)
|
|
|
|
assert task.priority == 0
|
|
assert task.max_retries == 3
|
|
assert task.retry_count == 0
|
|
|
|
|
|
class TestBatchProgress:
|
|
"""Test BatchProgress dataclass functionality."""
|
|
|
|
def test_batch_progress_creation(self):
|
|
"""Test creating batch progress with initial values."""
|
|
progress = BatchProgress(total_tasks=10)
|
|
|
|
assert progress.total_tasks == 10
|
|
assert progress.completed_tasks == 0
|
|
assert progress.failed_tasks == 0
|
|
assert progress.in_progress_tasks == 0
|
|
assert progress.queued_tasks == 0
|
|
assert progress.start_time is None
|
|
assert progress.estimated_completion is None
|
|
assert progress.current_worker_count == 0
|
|
assert progress.memory_usage_mb == 0.0
|
|
assert progress.cpu_usage_percent == 0.0
|
|
|
|
def test_success_rate_calculation(self):
|
|
"""Test success rate calculation."""
|
|
progress = BatchProgress(total_tasks=10)
|
|
progress.completed_tasks = 7
|
|
progress.failed_tasks = 2
|
|
|
|
assert progress.success_rate == 70.0
|
|
|
|
def test_success_rate_zero_total(self):
|
|
"""Test success rate with zero total tasks."""
|
|
progress = BatchProgress(total_tasks=0)
|
|
assert progress.success_rate == 0.0
|
|
|
|
def test_failure_rate_calculation(self):
|
|
"""Test failure rate calculation."""
|
|
progress = BatchProgress(total_tasks=10)
|
|
progress.failed_tasks = 3
|
|
|
|
assert progress.failure_rate == 30.0
|
|
|
|
def test_elapsed_time_calculation(self):
|
|
"""Test elapsed time calculation."""
|
|
start_time = datetime.now(timezone.utc)
|
|
progress = BatchProgress(total_tasks=5)
|
|
progress.start_time = start_time
|
|
|
|
# Should be close to 0 since we just set it
|
|
elapsed = progress.elapsed_time
|
|
assert elapsed is not None
|
|
assert elapsed >= 0.0
|
|
assert elapsed < 1.0 # Should be very small
|
|
|
|
def test_elapsed_time_no_start(self):
|
|
"""Test elapsed time when start_time is None."""
|
|
progress = BatchProgress(total_tasks=5)
|
|
assert progress.elapsed_time is None
|
|
|
|
|
|
class TestBatchResult:
|
|
"""Test BatchResult dataclass functionality."""
|
|
|
|
def test_batch_result_creation(self):
|
|
"""Test creating batch result with all fields."""
|
|
result = BatchResult(
|
|
success_count=8,
|
|
failure_count=2,
|
|
total_count=10,
|
|
results=[{"status": "completed"}],
|
|
failures=[{"task_id": "1", "error": "test error"}],
|
|
processing_time=120.5,
|
|
memory_peak_mb=512.0,
|
|
cpu_peak_percent=75.0,
|
|
quality_metrics={"avg_accuracy": 95.5}
|
|
)
|
|
|
|
assert result.success_count == 8
|
|
assert result.failure_count == 2
|
|
assert result.total_count == 10
|
|
assert len(result.results) == 1
|
|
assert len(result.failures) == 1
|
|
assert result.processing_time == 120.5
|
|
assert result.memory_peak_mb == 512.0
|
|
assert result.cpu_peak_percent == 75.0
|
|
assert result.quality_metrics["avg_accuracy"] == 95.5
|
|
|
|
def test_success_rate_calculation(self):
|
|
"""Test success rate calculation in batch result."""
|
|
result = BatchResult(
|
|
success_count=9,
|
|
failure_count=1,
|
|
total_count=10,
|
|
results=[],
|
|
failures=[],
|
|
processing_time=0.0,
|
|
memory_peak_mb=0.0,
|
|
cpu_peak_percent=0.0,
|
|
quality_metrics={}
|
|
)
|
|
|
|
assert result.success_rate == 90.0
|
|
|
|
def test_success_rate_zero_total(self):
|
|
"""Test success rate with zero total count."""
|
|
result = BatchResult(
|
|
success_count=0,
|
|
failure_count=0,
|
|
total_count=0,
|
|
results=[],
|
|
failures=[],
|
|
processing_time=0.0,
|
|
memory_peak_mb=0.0,
|
|
cpu_peak_percent=0.0,
|
|
quality_metrics={}
|
|
)
|
|
|
|
assert result.success_rate == 0.0
|
|
|
|
|
|
class TestBatchProcessor:
|
|
"""Test BatchProcessor functionality."""
|
|
|
|
@pytest.fixture
|
|
def batch_processor(self):
|
|
"""Create a batch processor for testing."""
|
|
return BatchProcessor(max_workers=2, progress_interval=0.1)
|
|
|
|
@pytest.fixture
|
|
def mock_services(self):
|
|
"""Mock all required services."""
|
|
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
|
|
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
|
|
patch('src.services.batch_processor.create_media_service') as mock_media, \
|
|
patch('src.services.batch_processor.create_media_repository') as mock_repo:
|
|
|
|
mock_trans.return_value = AsyncMock()
|
|
mock_enhance.return_value = AsyncMock()
|
|
mock_media.return_value = AsyncMock()
|
|
mock_repo.return_value = AsyncMock()
|
|
|
|
yield {
|
|
'transcription': mock_trans.return_value,
|
|
'enhancement': mock_enhance.return_value,
|
|
'media': mock_media.return_value,
|
|
'repository': mock_repo.return_value
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_processor_initialization(self, batch_processor):
|
|
"""Test batch processor initialization."""
|
|
assert batch_processor.max_workers == 2
|
|
assert batch_processor.progress_interval == 0.1
|
|
assert not batch_processor.running
|
|
assert not batch_processor.paused
|
|
assert not batch_processor.stopped
|
|
assert batch_processor.progress.total_tasks == 0
|
|
assert len(batch_processor.workers) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_task(self, batch_processor):
|
|
"""Test adding tasks to the queue."""
|
|
task_id = await batch_processor.add_task(
|
|
TaskType.TRANSCRIBE,
|
|
{"file_path": "/test/file.mp3"},
|
|
priority=1
|
|
)
|
|
|
|
assert task_id.startswith("task_1_transcribe")
|
|
assert batch_processor.progress.total_tasks == 1
|
|
assert batch_processor.progress.queued_tasks == 1
|
|
assert not batch_processor.task_queue.empty()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_multiple_tasks(self, batch_processor):
|
|
"""Test adding multiple tasks with different priorities."""
|
|
# Add tasks with different priorities
|
|
await batch_processor.add_task(TaskType.TRANSCRIBE, {"file": "1.mp3"}, priority=2)
|
|
await batch_processor.add_task(TaskType.ENHANCE, {"id": "123"}, priority=1)
|
|
await batch_processor.add_task(TaskType.YOUTUBE, {"url": "test.com"}, priority=0)
|
|
|
|
assert batch_processor.progress.total_tasks == 3
|
|
assert batch_processor.progress.queued_tasks == 3
|
|
|
|
# Check that tasks are ordered by priority (lower = higher priority)
|
|
tasks = []
|
|
while not batch_processor.task_queue.empty():
|
|
priority, task = await batch_processor.task_queue.get()
|
|
tasks.append((priority, task.task_type))
|
|
|
|
# Should be ordered by priority (0, 1, 2)
|
|
assert tasks[0][0] == 0 # YouTube task
|
|
assert tasks[1][0] == 1 # Enhance task
|
|
assert tasks[2][0] == 2 # Transcribe task
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_services(self, batch_processor, mock_services):
|
|
"""Test service initialization."""
|
|
await batch_processor._initialize_services()
|
|
|
|
assert batch_processor.transcription_service is not None
|
|
assert batch_processor.enhancement_service is not None
|
|
assert batch_processor.media_service is not None
|
|
|
|
# Verify services were initialized
|
|
mock_services['transcription'].initialize.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_transcription_task(self, batch_processor, mock_services):
|
|
"""Test processing a transcription task."""
|
|
await batch_processor._initialize_services()
|
|
|
|
task = BatchTask(
|
|
id="test_task",
|
|
task_type=TaskType.TRANSCRIBE,
|
|
data={
|
|
"file_path": "/test/file.mp3",
|
|
"config": {"model": "whisper-1"}
|
|
}
|
|
)
|
|
|
|
# Mock transcription result
|
|
mock_result = MagicMock()
|
|
mock_result.text_content = "Test transcript"
|
|
mock_result.segments = [{"text": "Test", "start": 0, "end": 1}]
|
|
mock_result.accuracy = 95.5
|
|
mock_result.processing_time = 10.0
|
|
mock_result.quality_warnings = []
|
|
|
|
mock_services['transcription'].transcribe_file.return_value = mock_result
|
|
|
|
result = await batch_processor._process_transcription(task)
|
|
|
|
assert result["status"] == "completed"
|
|
assert result["file_path"] == "/test/file.mp3"
|
|
assert result["transcript"] == "Test transcript"
|
|
assert result["accuracy"] == 95.5
|
|
assert result["processing_time"] == 10.0
|
|
|
|
mock_services['transcription'].transcribe_file.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_enhancement_task(self, batch_processor, mock_services):
|
|
"""Test processing an enhancement task."""
|
|
await batch_processor._initialize_services()
|
|
|
|
task = BatchTask(
|
|
id="test_task",
|
|
task_type=TaskType.ENHANCE,
|
|
data={"transcript_id": "123"}
|
|
)
|
|
|
|
# Mock enhancement result
|
|
mock_result = MagicMock()
|
|
mock_result.enhanced_content = "Enhanced transcript"
|
|
mock_result.accuracy_improvement = 2.5
|
|
mock_result.processing_time = 5.0
|
|
|
|
mock_services['enhancement'].enhance_transcript.return_value = mock_result
|
|
|
|
result = await batch_processor._process_enhancement(task)
|
|
|
|
assert result["status"] == "completed"
|
|
assert result["transcript_id"] == "123"
|
|
assert result["enhanced_content"] == "Enhanced transcript"
|
|
assert result["accuracy_improvement"] == 2.5
|
|
|
|
mock_services['enhancement'].enhance_transcript.assert_called_once_with("123")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_task_retry_on_failure(self, batch_processor, mock_services):
|
|
"""Test task retry mechanism on failure."""
|
|
await batch_processor._initialize_services()
|
|
|
|
task = BatchTask(
|
|
id="test_task",
|
|
task_type=TaskType.TRANSCRIBE,
|
|
data={"file_path": "/test/file.mp3"},
|
|
max_retries=2
|
|
)
|
|
|
|
# Mock service to fail twice, then succeed
|
|
mock_services['transcription'].transcribe_file.side_effect = [
|
|
Exception("First failure"),
|
|
Exception("Second failure"),
|
|
MagicMock(text_content="Success", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[])
|
|
]
|
|
|
|
# First attempt should fail and retry
|
|
result1 = await batch_processor._process_transcription(task)
|
|
assert result1["status"] == "retrying"
|
|
assert result1["retry_count"] == 1
|
|
|
|
# Second attempt should fail and retry
|
|
result2 = await batch_processor._process_transcription(task)
|
|
assert result2["status"] == "retrying"
|
|
assert result2["retry_count"] == 2
|
|
|
|
# Third attempt should succeed
|
|
result3 = await batch_processor._process_transcription(task)
|
|
assert result3["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_task_permanent_failure(self, batch_processor, mock_services):
|
|
"""Test task permanent failure after max retries."""
|
|
await batch_processor._initialize_services()
|
|
|
|
task = BatchTask(
|
|
id="test_task",
|
|
task_type=TaskType.TRANSCRIBE,
|
|
data={"file_path": "/test/file.mp3"},
|
|
max_retries=1
|
|
)
|
|
|
|
# Mock service to always fail
|
|
mock_services['transcription'].transcribe_file.side_effect = Exception("Permanent failure")
|
|
|
|
# First attempt should retry
|
|
result1 = await batch_processor._process_transcription(task)
|
|
assert result1["status"] == "retrying"
|
|
|
|
# Second attempt should fail permanently
|
|
result2 = await batch_processor._process_transcription(task)
|
|
assert result2["status"] == "failed"
|
|
assert "Permanent failure" in result2["error"]
|
|
|
|
# Task should be in failed tasks list
|
|
assert len(batch_processor.failed_tasks) == 1
|
|
assert batch_processor.failed_tasks[0].id == "test_task"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pause_resume_functionality(self, batch_processor):
|
|
"""Test pause and resume functionality."""
|
|
assert not batch_processor.paused
|
|
|
|
# Pause when not running should do nothing
|
|
await batch_processor.pause()
|
|
assert not batch_processor.paused
|
|
|
|
# Start the processor
|
|
batch_processor.running = True
|
|
|
|
# Pause
|
|
await batch_processor.pause()
|
|
assert batch_processor.paused
|
|
|
|
# Resume
|
|
await batch_processor.resume()
|
|
assert not batch_processor.paused
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_functionality(self, batch_processor):
|
|
"""Test stop functionality."""
|
|
assert not batch_processor.stopped
|
|
|
|
await batch_processor.stop()
|
|
assert batch_processor.stopped
|
|
assert not batch_processor.running
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_progress(self, batch_processor):
|
|
"""Test getting current progress."""
|
|
progress = batch_processor.get_progress()
|
|
|
|
assert isinstance(progress, BatchProgress)
|
|
assert progress.total_tasks == 0
|
|
assert progress.completed_tasks == 0
|
|
assert progress.failed_tasks == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_batch_processing(self, batch_processor, mock_services):
|
|
"""Test simple batch processing with one task."""
|
|
await batch_processor._initialize_services()
|
|
|
|
# Add a task
|
|
await batch_processor.add_task(
|
|
TaskType.TRANSCRIBE,
|
|
{"file_path": "/test/file.mp3"}
|
|
)
|
|
|
|
# Mock successful transcription
|
|
mock_result = MagicMock()
|
|
mock_result.text_content = "Test transcript"
|
|
mock_result.segments = []
|
|
mock_result.accuracy = 95.0
|
|
mock_result.processing_time = 10.0
|
|
mock_result.quality_warnings = []
|
|
|
|
mock_services['transcription'].transcribe_file.return_value = mock_result
|
|
|
|
# Start processing
|
|
result = await batch_processor.start()
|
|
|
|
assert result.success_count == 1
|
|
assert result.failure_count == 0
|
|
assert result.total_count == 1
|
|
assert result.success_rate == 100.0
|
|
assert len(result.results) == 1
|
|
assert len(result.failures) == 0
|
|
|
|
|
|
class TestCreateBatchProcessor:
|
|
"""Test batch processor factory function."""
|
|
|
|
def test_create_batch_processor_defaults(self):
|
|
"""Test creating batch processor with default parameters."""
|
|
processor = create_batch_processor()
|
|
|
|
assert processor.max_workers == 8
|
|
assert processor.queue_size == 1000
|
|
assert processor.progress_interval == 5.0
|
|
assert processor.memory_limit_mb == 2048.0
|
|
assert processor.cpu_limit_percent == 90.0
|
|
|
|
def test_create_batch_processor_custom(self):
|
|
"""Test creating batch processor with custom parameters."""
|
|
processor = create_batch_processor(
|
|
max_workers=4,
|
|
queue_size=500,
|
|
progress_interval=2.0,
|
|
memory_limit_mb=1024.0,
|
|
cpu_limit_percent=80.0
|
|
)
|
|
|
|
assert processor.max_workers == 4
|
|
assert processor.queue_size == 500
|
|
assert processor.progress_interval == 2.0
|
|
assert processor.memory_limit_mb == 1024.0
|
|
assert processor.cpu_limit_percent == 80.0
|
|
|
|
|
|
class TestBatchProcessorIntegration:
|
|
"""Integration tests for batch processor."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_task_types(self):
|
|
"""Test processing multiple different task types."""
|
|
processor = BatchProcessor(max_workers=2, progress_interval=0.1)
|
|
|
|
# Mock services
|
|
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
|
|
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
|
|
patch('src.services.batch_processor.create_media_service') as mock_media, \
|
|
patch('src.services.batch_processor.create_media_repository') as mock_repo:
|
|
|
|
mock_trans.return_value = AsyncMock()
|
|
mock_enhance.return_value = AsyncMock()
|
|
mock_media.return_value = AsyncMock()
|
|
mock_repo.return_value = AsyncMock()
|
|
|
|
# Mock results
|
|
mock_trans.return_value.transcribe_file.return_value = MagicMock(
|
|
text_content="Transcript", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[]
|
|
)
|
|
mock_enhance.return_value.enhance_transcript.return_value = MagicMock(
|
|
enhanced_content="Enhanced", accuracy_improvement=2.0, processing_time=5.0
|
|
)
|
|
mock_media.return_value.download_media.return_value = MagicMock(
|
|
file_path=Path("/test/file.mp3"), file_size=1024, duration=60.0
|
|
)
|
|
|
|
# Add different types of tasks
|
|
await processor.add_task(TaskType.TRANSCRIBE, {"file_path": "/test1.mp3"})
|
|
await processor.add_task(TaskType.ENHANCE, {"transcript_id": "123"})
|
|
await processor.add_task(TaskType.DOWNLOAD, {"url": "https://test.com"})
|
|
|
|
# Process all tasks
|
|
result = await processor.start()
|
|
|
|
assert result.success_count == 3
|
|
assert result.failure_count == 0
|
|
assert result.total_count == 3
|
|
assert result.success_rate == 100.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_callback(self):
|
|
"""Test progress callback functionality."""
|
|
processor = BatchProcessor(max_workers=1, progress_interval=0.1)
|
|
|
|
progress_updates = []
|
|
|
|
def progress_callback(progress: BatchProgress):
|
|
progress_updates.append(progress)
|
|
|
|
# Mock services
|
|
with patch('src.services.batch_processor.create_transcription_service') as mock_trans, \
|
|
patch('src.services.batch_processor.create_enhancement_service') as mock_enhance, \
|
|
patch('src.services.batch_processor.create_media_service') as mock_media, \
|
|
patch('src.services.batch_processor.create_media_repository') as mock_repo:
|
|
|
|
mock_trans.return_value = AsyncMock()
|
|
mock_enhance.return_value = AsyncMock()
|
|
mock_media.return_value = AsyncMock()
|
|
mock_repo.return_value = AsyncMock()
|
|
|
|
mock_trans.return_value.transcribe_file.return_value = MagicMock(
|
|
text_content="Test", segments=[], accuracy=95.0, processing_time=10.0, quality_warnings=[]
|
|
)
|
|
|
|
# Add a task
|
|
await processor.add_task(TaskType.TRANSCRIBE, {"file_path": "/test.mp3"})
|
|
|
|
# Process with callback
|
|
result = await processor.start(progress_callback=progress_callback)
|
|
|
|
# Should have received progress updates
|
|
assert len(progress_updates) > 0
|
|
|
|
# Check final progress
|
|
final_progress = progress_updates[-1]
|
|
assert final_progress.total_tasks == 1
|
|
assert final_progress.completed_tasks == 1
|
|
assert final_progress.failed_tasks == 0
|
|
assert final_progress.success_rate == 100.0
|