trax/tests/test_parallel_processing.py

330 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Test Parallel Chunk Processing for M3 Transcription Optimization.
Following TDD principles - tests written BEFORE implementation.
These tests define the expected behavior of the parallel processing system.
"""
import pytest
import asyncio
import time
import numpy as np
from pathlib import Path
from typing import List, Dict
from unittest.mock import MagicMock, AsyncMock, patch
# Import the classes we will implement
from src.services.parallel_transcription import (
ParallelTranscriber,
TranscriptionResult,
ChunkResult
)
class TestParallelProcessing:
"""Test suite for parallel chunk processing - 2-4x speed improvement."""
@pytest.fixture
def sample_audio_30s(self):
"""Real 30-second audio file for testing."""
return Path("tests/fixtures/audio/sample_30s.wav")
@pytest.fixture
def sample_audio_2m(self):
"""Real 2-minute audio file for testing."""
return Path("tests/fixtures/audio/sample_2m.wav")
@pytest.fixture
def sample_audio_5m(self):
"""Real 5-minute audio file for testing."""
return Path("tests/fixtures/audio/sample_5m.wav")
@pytest.fixture
def mock_whisper_model(self):
"""Mock Whisper model for testing without actual ML inference."""
model = MagicMock()
model.transcribe = MagicMock(return_value={"text": "Test transcription"})
return model
@pytest.mark.asyncio
async def test_parallel_faster_than_sequential(self, sample_audio_2m):
"""Test that parallel processing is 2-4x faster than sequential."""
transcriber = ParallelTranscriber(max_workers=4, chunk_size_seconds=30)
# Measure sequential processing time
start = time.time()
seq_result = await transcriber.transcribe_sequential(sample_audio_2m)
sequential_time = time.time() - start
# Measure parallel processing time
start = time.time()
par_result = await transcriber.transcribe_parallel(sample_audio_2m)
parallel_time = time.time() - start
# Assertions
assert seq_result.text == par_result.text # Same output
assert parallel_time < sequential_time * 0.5 # At least 2x faster
assert len(par_result.chunks) >= 4 # Used multiple chunks
assert par_result.speedup_factor >= 2.0 # Documented speedup
@pytest.mark.asyncio
async def test_chunk_splitting_logic(self):
"""Test audio is correctly split into overlapping chunks."""
transcriber = ParallelTranscriber(
max_workers=4,
chunk_size_seconds=30,
overlap_seconds=2
)
# Create synthetic 2-minute audio (120 seconds)
sample_rate = 16000
duration = 120
audio_array = np.random.randn(sample_rate * duration).astype(np.float32)
chunks = await transcriber._split_audio(audio_array, sample_rate)
# Verify chunk properties
assert len(chunks) > 1 # Multiple chunks created
for i, chunk in enumerate(chunks):
assert "audio" in chunk
assert "start_time" in chunk
assert "end_time" in chunk
assert "chunk_id" in chunk
# Check chunk duration (except last chunk)
if i < len(chunks) - 1:
duration = chunk["end_time"] - chunk["start_time"]
assert 28 <= duration <= 30 # Approximately chunk_size_seconds
# Check overlap with next chunk
if i < len(chunks) - 1:
next_chunk = chunks[i + 1]
overlap = chunk["end_time"] - next_chunk["start_time"]
assert 1.5 <= overlap <= 2.5 # Approximately overlap_seconds
@pytest.mark.asyncio
async def test_chunk_merging_handles_overlaps(self):
"""Test that overlapping transcriptions are merged correctly."""
transcriber = ParallelTranscriber()
# Create overlapping chunk results
chunks = [
ChunkResult(
text="This is the first chunk of text.",
start_time=0.0,
end_time=10.0,
chunk_id=0
),
ChunkResult(
text="chunk of text. This is the second",
start_time=8.0,
end_time=18.0,
chunk_id=1
),
ChunkResult(
text="the second chunk with more content.",
start_time=16.0,
end_time=26.0,
chunk_id=2
)
]
merged_text = await transcriber._merge_transcriptions(chunks)
# Should intelligently merge overlapping text
expected = "This is the first chunk of text. This is the second chunk with more content."
assert merged_text == expected
@pytest.mark.asyncio
async def test_semaphore_limits_concurrent_workers(self):
"""Test that semaphore properly limits concurrent processing."""
max_workers = 2
transcriber = ParallelTranscriber(max_workers=max_workers)
# Track concurrent executions
concurrent_count = 0
max_concurrent = 0
lock = asyncio.Lock()
async def mock_process_chunk(chunk):
nonlocal concurrent_count, max_concurrent
async with lock:
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.1) # Simulate processing
async with lock:
concurrent_count -= 1
return ChunkResult(
text=f"Chunk {chunk['chunk_id']}",
start_time=chunk["start_time"],
end_time=chunk["end_time"],
chunk_id=chunk["chunk_id"]
)
# Replace process method with mock
transcriber._process_chunk = mock_process_chunk
# Create multiple chunks
chunks = [{"chunk_id": i, "start_time": i*10, "end_time": (i+1)*10}
for i in range(6)]
# Process chunks
await asyncio.gather(*[transcriber._process_chunk(c) for c in chunks])
# Verify max concurrent never exceeded limit
assert max_concurrent <= max_workers
@pytest.mark.asyncio
async def test_memory_usage_under_2gb(self, sample_audio_5m):
"""Test that memory usage stays under 2GB target."""
import psutil
import gc
gc.collect()
process = psutil.Process()
baseline_memory = process.memory_info().rss / (1024 * 1024) # MB
transcriber = ParallelTranscriber(max_workers=4)
result = await transcriber.transcribe_parallel(sample_audio_5m)
peak_memory = process.memory_info().rss / (1024 * 1024) # MB
memory_used = peak_memory - baseline_memory
# Should stay well under 2GB (2048 MB)
assert memory_used < 2048
assert result.memory_usage_mb < 2048
@pytest.mark.asyncio
async def test_handles_chunk_failures_gracefully(self):
"""Test error handling when a chunk fails to process."""
transcriber = ParallelTranscriber(max_workers=2)
# Mock process to fail on specific chunks
async def mock_process(chunk):
if chunk["chunk_id"] == 2:
raise Exception("Processing failed for chunk 2")
return ChunkResult(
text=f"Chunk {chunk['chunk_id']}",
start_time=chunk["start_time"],
end_time=chunk["end_time"],
chunk_id=chunk["chunk_id"]
)
transcriber._process_chunk = mock_process
chunks = [{"chunk_id": i, "start_time": i*10, "end_time": (i+1)*10}
for i in range(4)]
# Should handle failure and continue with other chunks
results = await transcriber._process_chunks_parallel(chunks)
assert len(results) == 3 # One chunk failed
assert all(r.chunk_id != 2 for r in results) # Chunk 2 missing
@pytest.mark.asyncio
async def test_adaptive_chunk_sizing(self, sample_audio_2m):
"""Test that chunk size adapts based on audio characteristics."""
# Short audio should use smaller chunks
short_transcriber = ParallelTranscriber(adaptive_chunking=True)
short_chunks = await short_transcriber._determine_chunk_size(
duration_seconds=30
)
assert short_chunks <= 15 # Smaller chunks for short audio
# Long audio should use larger chunks
long_chunks = await short_transcriber._determine_chunk_size(
duration_seconds=600 # 10 minutes
)
assert long_chunks >= 30 # Larger chunks for long audio
@pytest.mark.asyncio
async def test_performance_metrics_accurate(self, sample_audio_30s):
"""Test that performance metrics are accurately reported."""
transcriber = ParallelTranscriber(max_workers=2)
start = time.time()
result = await transcriber.transcribe_parallel(sample_audio_30s)
actual_time = time.time() - start
# Verify metrics
assert result.processing_time > 0
assert abs(result.processing_time - actual_time) < 0.1 # Within 100ms
assert result.chunks_processed >= 1
assert result.speedup_factor >= 1.0
assert result.worker_utilization > 0
@pytest.mark.asyncio
async def test_maintains_transcription_quality(self, sample_audio_30s):
"""Test that parallel processing maintains transcription accuracy."""
transcriber = ParallelTranscriber(max_workers=4)
# Get sequential result as baseline
seq_result = await transcriber.transcribe_sequential(sample_audio_30s)
# Get parallel result
par_result = await transcriber.transcribe_parallel(sample_audio_30s)
# Calculate similarity (should be very high)
from difflib import SequenceMatcher
similarity = SequenceMatcher(None, seq_result.text, par_result.text).ratio()
assert similarity > 0.95 # At least 95% similar
@pytest.mark.asyncio
async def test_cli_integration(self, sample_audio_2m):
"""Test that parallel processing integrates with CLI properly."""
from src.cli.main import transcribe_command
# Mock the CLI context
with patch("src.cli.main.get_transcriber") as mock_get:
transcriber = ParallelTranscriber(max_workers=4)
mock_get.return_value = transcriber
# Run CLI command with parallel flag
result = await transcribe_command(
audio_path=str(sample_audio_2m),
parallel=True,
chunks=4,
show_progress=True
)
assert result.success
assert "Speedup" in result.message
assert result.speedup_factor >= 2.0
class TestPerformanceBenchmarks:
"""Performance benchmarks to validate 2-4x speed improvement."""
@pytest.mark.benchmark
@pytest.mark.asyncio
async def test_benchmark_30s_audio(self, benchmark, sample_audio_30s):
"""Benchmark 30-second audio processing."""
transcriber = ParallelTranscriber(max_workers=4)
result = await benchmark(
transcriber.transcribe_parallel,
sample_audio_30s
)
assert result.processing_time < 15 # Should process in <15s
@pytest.mark.benchmark
@pytest.mark.asyncio
async def test_benchmark_5m_audio(self, benchmark, sample_audio_5m):
"""Benchmark 5-minute audio - should meet <30s target."""
transcriber = ParallelTranscriber(max_workers=4)
result = await benchmark(
transcriber.transcribe_parallel,
sample_audio_5m
)
# Must meet v1 target: 5-minute audio in <30 seconds
assert result.processing_time < 30
assert result.speedup_factor >= 2.0