330 lines
12 KiB
Python
330 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Parallel Chunk Processing for M3 Transcription Optimization.
|
|
|
|
Following TDD principles - tests written BEFORE implementation.
|
|
These tests define the expected behavior of the parallel processing system.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
import time
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Dict
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
|
|
# Import the classes we will implement
|
|
from src.services.parallel_transcription import (
|
|
ParallelTranscriber,
|
|
TranscriptionResult,
|
|
ChunkResult
|
|
)
|
|
|
|
|
|
class TestParallelProcessing:
|
|
"""Test suite for parallel chunk processing - 2-4x speed improvement."""
|
|
|
|
@pytest.fixture
|
|
def sample_audio_30s(self):
|
|
"""Real 30-second audio file for testing."""
|
|
return Path("tests/fixtures/audio/sample_30s.wav")
|
|
|
|
@pytest.fixture
|
|
def sample_audio_2m(self):
|
|
"""Real 2-minute audio file for testing."""
|
|
return Path("tests/fixtures/audio/sample_2m.wav")
|
|
|
|
@pytest.fixture
|
|
def sample_audio_5m(self):
|
|
"""Real 5-minute audio file for testing."""
|
|
return Path("tests/fixtures/audio/sample_5m.wav")
|
|
|
|
@pytest.fixture
|
|
def mock_whisper_model(self):
|
|
"""Mock Whisper model for testing without actual ML inference."""
|
|
model = MagicMock()
|
|
model.transcribe = MagicMock(return_value={"text": "Test transcription"})
|
|
return model
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parallel_faster_than_sequential(self, sample_audio_2m):
|
|
"""Test that parallel processing is 2-4x faster than sequential."""
|
|
transcriber = ParallelTranscriber(max_workers=4, chunk_size_seconds=30)
|
|
|
|
# Measure sequential processing time
|
|
start = time.time()
|
|
seq_result = await transcriber.transcribe_sequential(sample_audio_2m)
|
|
sequential_time = time.time() - start
|
|
|
|
# Measure parallel processing time
|
|
start = time.time()
|
|
par_result = await transcriber.transcribe_parallel(sample_audio_2m)
|
|
parallel_time = time.time() - start
|
|
|
|
# Assertions
|
|
assert seq_result.text == par_result.text # Same output
|
|
assert parallel_time < sequential_time * 0.5 # At least 2x faster
|
|
assert len(par_result.chunks) >= 4 # Used multiple chunks
|
|
assert par_result.speedup_factor >= 2.0 # Documented speedup
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chunk_splitting_logic(self):
|
|
"""Test audio is correctly split into overlapping chunks."""
|
|
transcriber = ParallelTranscriber(
|
|
max_workers=4,
|
|
chunk_size_seconds=30,
|
|
overlap_seconds=2
|
|
)
|
|
|
|
# Create synthetic 2-minute audio (120 seconds)
|
|
sample_rate = 16000
|
|
duration = 120
|
|
audio_array = np.random.randn(sample_rate * duration).astype(np.float32)
|
|
|
|
chunks = await transcriber._split_audio(audio_array, sample_rate)
|
|
|
|
# Verify chunk properties
|
|
assert len(chunks) > 1 # Multiple chunks created
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
assert "audio" in chunk
|
|
assert "start_time" in chunk
|
|
assert "end_time" in chunk
|
|
assert "chunk_id" in chunk
|
|
|
|
# Check chunk duration (except last chunk)
|
|
if i < len(chunks) - 1:
|
|
duration = chunk["end_time"] - chunk["start_time"]
|
|
assert 28 <= duration <= 30 # Approximately chunk_size_seconds
|
|
|
|
# Check overlap with next chunk
|
|
if i < len(chunks) - 1:
|
|
next_chunk = chunks[i + 1]
|
|
overlap = chunk["end_time"] - next_chunk["start_time"]
|
|
assert 1.5 <= overlap <= 2.5 # Approximately overlap_seconds
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chunk_merging_handles_overlaps(self):
|
|
"""Test that overlapping transcriptions are merged correctly."""
|
|
transcriber = ParallelTranscriber()
|
|
|
|
# Create overlapping chunk results
|
|
chunks = [
|
|
ChunkResult(
|
|
text="This is the first chunk of text.",
|
|
start_time=0.0,
|
|
end_time=10.0,
|
|
chunk_id=0
|
|
),
|
|
ChunkResult(
|
|
text="chunk of text. This is the second",
|
|
start_time=8.0,
|
|
end_time=18.0,
|
|
chunk_id=1
|
|
),
|
|
ChunkResult(
|
|
text="the second chunk with more content.",
|
|
start_time=16.0,
|
|
end_time=26.0,
|
|
chunk_id=2
|
|
)
|
|
]
|
|
|
|
merged_text = await transcriber._merge_transcriptions(chunks)
|
|
|
|
# Should intelligently merge overlapping text
|
|
expected = "This is the first chunk of text. This is the second chunk with more content."
|
|
assert merged_text == expected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semaphore_limits_concurrent_workers(self):
|
|
"""Test that semaphore properly limits concurrent processing."""
|
|
max_workers = 2
|
|
transcriber = ParallelTranscriber(max_workers=max_workers)
|
|
|
|
# Track concurrent executions
|
|
concurrent_count = 0
|
|
max_concurrent = 0
|
|
lock = asyncio.Lock()
|
|
|
|
async def mock_process_chunk(chunk):
|
|
nonlocal concurrent_count, max_concurrent
|
|
async with lock:
|
|
concurrent_count += 1
|
|
max_concurrent = max(max_concurrent, concurrent_count)
|
|
|
|
await asyncio.sleep(0.1) # Simulate processing
|
|
|
|
async with lock:
|
|
concurrent_count -= 1
|
|
|
|
return ChunkResult(
|
|
text=f"Chunk {chunk['chunk_id']}",
|
|
start_time=chunk["start_time"],
|
|
end_time=chunk["end_time"],
|
|
chunk_id=chunk["chunk_id"]
|
|
)
|
|
|
|
# Replace process method with mock
|
|
transcriber._process_chunk = mock_process_chunk
|
|
|
|
# Create multiple chunks
|
|
chunks = [{"chunk_id": i, "start_time": i*10, "end_time": (i+1)*10}
|
|
for i in range(6)]
|
|
|
|
# Process chunks
|
|
await asyncio.gather(*[transcriber._process_chunk(c) for c in chunks])
|
|
|
|
# Verify max concurrent never exceeded limit
|
|
assert max_concurrent <= max_workers
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_usage_under_2gb(self, sample_audio_5m):
|
|
"""Test that memory usage stays under 2GB target."""
|
|
import psutil
|
|
import gc
|
|
|
|
gc.collect()
|
|
process = psutil.Process()
|
|
baseline_memory = process.memory_info().rss / (1024 * 1024) # MB
|
|
|
|
transcriber = ParallelTranscriber(max_workers=4)
|
|
result = await transcriber.transcribe_parallel(sample_audio_5m)
|
|
|
|
peak_memory = process.memory_info().rss / (1024 * 1024) # MB
|
|
memory_used = peak_memory - baseline_memory
|
|
|
|
# Should stay well under 2GB (2048 MB)
|
|
assert memory_used < 2048
|
|
assert result.memory_usage_mb < 2048
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_chunk_failures_gracefully(self):
|
|
"""Test error handling when a chunk fails to process."""
|
|
transcriber = ParallelTranscriber(max_workers=2)
|
|
|
|
# Mock process to fail on specific chunks
|
|
async def mock_process(chunk):
|
|
if chunk["chunk_id"] == 2:
|
|
raise Exception("Processing failed for chunk 2")
|
|
return ChunkResult(
|
|
text=f"Chunk {chunk['chunk_id']}",
|
|
start_time=chunk["start_time"],
|
|
end_time=chunk["end_time"],
|
|
chunk_id=chunk["chunk_id"]
|
|
)
|
|
|
|
transcriber._process_chunk = mock_process
|
|
|
|
chunks = [{"chunk_id": i, "start_time": i*10, "end_time": (i+1)*10}
|
|
for i in range(4)]
|
|
|
|
# Should handle failure and continue with other chunks
|
|
results = await transcriber._process_chunks_parallel(chunks)
|
|
|
|
assert len(results) == 3 # One chunk failed
|
|
assert all(r.chunk_id != 2 for r in results) # Chunk 2 missing
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adaptive_chunk_sizing(self, sample_audio_2m):
|
|
"""Test that chunk size adapts based on audio characteristics."""
|
|
# Short audio should use smaller chunks
|
|
short_transcriber = ParallelTranscriber(adaptive_chunking=True)
|
|
short_chunks = await short_transcriber._determine_chunk_size(
|
|
duration_seconds=30
|
|
)
|
|
assert short_chunks <= 15 # Smaller chunks for short audio
|
|
|
|
# Long audio should use larger chunks
|
|
long_chunks = await short_transcriber._determine_chunk_size(
|
|
duration_seconds=600 # 10 minutes
|
|
)
|
|
assert long_chunks >= 30 # Larger chunks for long audio
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_performance_metrics_accurate(self, sample_audio_30s):
|
|
"""Test that performance metrics are accurately reported."""
|
|
transcriber = ParallelTranscriber(max_workers=2)
|
|
|
|
start = time.time()
|
|
result = await transcriber.transcribe_parallel(sample_audio_30s)
|
|
actual_time = time.time() - start
|
|
|
|
# Verify metrics
|
|
assert result.processing_time > 0
|
|
assert abs(result.processing_time - actual_time) < 0.1 # Within 100ms
|
|
assert result.chunks_processed >= 1
|
|
assert result.speedup_factor >= 1.0
|
|
assert result.worker_utilization > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_maintains_transcription_quality(self, sample_audio_30s):
|
|
"""Test that parallel processing maintains transcription accuracy."""
|
|
transcriber = ParallelTranscriber(max_workers=4)
|
|
|
|
# Get sequential result as baseline
|
|
seq_result = await transcriber.transcribe_sequential(sample_audio_30s)
|
|
|
|
# Get parallel result
|
|
par_result = await transcriber.transcribe_parallel(sample_audio_30s)
|
|
|
|
# Calculate similarity (should be very high)
|
|
from difflib import SequenceMatcher
|
|
similarity = SequenceMatcher(None, seq_result.text, par_result.text).ratio()
|
|
|
|
assert similarity > 0.95 # At least 95% similar
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cli_integration(self, sample_audio_2m):
|
|
"""Test that parallel processing integrates with CLI properly."""
|
|
from src.cli.main import transcribe_command
|
|
|
|
# Mock the CLI context
|
|
with patch("src.cli.main.get_transcriber") as mock_get:
|
|
transcriber = ParallelTranscriber(max_workers=4)
|
|
mock_get.return_value = transcriber
|
|
|
|
# Run CLI command with parallel flag
|
|
result = await transcribe_command(
|
|
audio_path=str(sample_audio_2m),
|
|
parallel=True,
|
|
chunks=4,
|
|
show_progress=True
|
|
)
|
|
|
|
assert result.success
|
|
assert "Speedup" in result.message
|
|
assert result.speedup_factor >= 2.0
|
|
|
|
|
|
class TestPerformanceBenchmarks:
|
|
"""Performance benchmarks to validate 2-4x speed improvement."""
|
|
|
|
@pytest.mark.benchmark
|
|
@pytest.mark.asyncio
|
|
async def test_benchmark_30s_audio(self, benchmark, sample_audio_30s):
|
|
"""Benchmark 30-second audio processing."""
|
|
transcriber = ParallelTranscriber(max_workers=4)
|
|
|
|
result = await benchmark(
|
|
transcriber.transcribe_parallel,
|
|
sample_audio_30s
|
|
)
|
|
|
|
assert result.processing_time < 15 # Should process in <15s
|
|
|
|
@pytest.mark.benchmark
|
|
@pytest.mark.asyncio
|
|
async def test_benchmark_5m_audio(self, benchmark, sample_audio_5m):
|
|
"""Benchmark 5-minute audio - should meet <30s target."""
|
|
transcriber = ParallelTranscriber(max_workers=4)
|
|
|
|
result = await benchmark(
|
|
transcriber.transcribe_parallel,
|
|
sample_audio_5m
|
|
)
|
|
|
|
# Must meet v1 target: 5-minute audio in <30 seconds
|
|
assert result.processing_time < 30
|
|
assert result.speedup_factor >= 2.0 |