trax/tests/test_multi_pass_integration.py

227 lines
9.7 KiB
Python

"""Integration tests for MultiPassTranscriptionPipeline with DomainEnhancementPipeline.
Tests the integration of the domain-specific enhancement pipeline with the
multi-pass transcription pipeline for Task 8.3.
"""
import pytest
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.domain_enhancement import DomainEnhancementPipeline, DomainEnhancementConfig, EnhancementResult
class TestMultiPassDomainEnhancementIntegration:
"""Test integration between MultiPassTranscriptionPipeline and DomainEnhancementPipeline."""
@pytest.fixture
def mock_model_manager(self):
"""Mock model manager for testing."""
mock = MagicMock()
mock.load_model.return_value = MagicMock()
mock.get_current_model.return_value = MagicMock()
return mock
@pytest.fixture
def mock_domain_adaptation_manager(self):
"""Mock domain adaptation manager for testing."""
mock = MagicMock()
mock.domain_detector = MagicMock()
mock.domain_detector.detect_domain_from_text.return_value = "technical"
mock.domain_detector.detect_domain_from_path.return_value = "technical"
# Mock the domain adapter to return True (success)
mock.domain_adapter = MagicMock()
mock.domain_adapter.domain_adapters = {"technical": MagicMock()}
mock.domain_adapter.switch_adapter.return_value = MagicMock()
return mock
@pytest.fixture
def mock_domain_enhancement_pipeline(self):
"""Mock domain enhancement pipeline for testing."""
mock = AsyncMock(spec=DomainEnhancementPipeline)
# Mock enhancement result
mock_result = EnhancementResult(
original_text="Test technical content",
enhanced_text="Enhanced technical content with proper terminology",
domain="technical",
confidence_score=0.85,
improvements=["Terminology enhancement", "Formatting optimization"],
terminology_corrections=["tech -> technical", "algo -> algorithm"],
quality_metrics={"technical_term_density": 0.8, "formatting_consistency": 0.9},
processing_time=1.5
)
mock.enhance_content.return_value = mock_result
return mock
@pytest.fixture
def pipeline(self, mock_model_manager, mock_domain_adaptation_manager):
"""Create MultiPassTranscriptionPipeline instance for testing."""
config = DomainEnhancementConfig(
domain="technical",
enable_terminology_enhancement=True,
enable_citation_handling=True,
enable_formatting_optimization=True,
quality_threshold=0.7,
max_enhancement_iterations=3
)
return MultiPassTranscriptionPipeline(
model_manager=mock_model_manager,
domain_adapter=mock_domain_adaptation_manager,
auto_detect_domain=True,
domain_enhancement_config=config
)
@pytest.mark.asyncio
async def test_domain_enhancement_pipeline_initialization(self, pipeline):
"""Test that domain enhancement pipeline is properly initialized."""
# The pipeline should be None initially
assert pipeline.domain_enhancement_pipeline is None
# Mock the domain enhancement pipeline
with patch('src.services.domain_enhancement.DomainEnhancementPipeline', autospec=True) as mock_class:
mock_instance = MagicMock()
mock_class.return_value = mock_instance
# Mock the enhance_content method to return a valid result
mock_instance.enhance_content = AsyncMock(return_value=MagicMock(
enhanced_text="Enhanced content",
confidence_score=0.8,
improvements=["Test improvement"],
terminology_corrections=["Test correction"],
quality_metrics={"test": 0.8}
))
# Call the enhancement method to trigger initialization
segments = [{"text": "Test technical content", "start": 0.0, "end": 1.0}]
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
# Verify the pipeline was initialized
assert pipeline.domain_enhancement_pipeline is not None
mock_class.assert_called_once()
@pytest.mark.asyncio
async def test_domain_enhancement_with_technical_content(self, pipeline, mock_domain_enhancement_pipeline):
"""Test domain enhancement with technical content."""
# Set up the mock pipeline
pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline
# Test segments
segments = [
{"text": "Test technical content", "start": 0.0, "end": 1.0},
{"text": "with algorithms", "start": 1.0, "end": 2.0}
]
# Perform enhancement
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
# Verify enhancement was called
mock_domain_enhancement_pipeline.enhance_content.assert_called_once()
# Verify result structure
assert len(result) == 2
assert result[0]["domain"] == "technical"
assert "enhancement_confidence" in result[0]
assert "enhancement_improvements" in result[0]
assert "enhancement_terminology_corrections" in result[0]
assert "enhancement_quality_metrics" in result[0]
@pytest.mark.asyncio
async def test_domain_enhancement_fallback_on_failure(self, pipeline):
"""Test that enhancement falls back gracefully on failure."""
# Mock a failing enhancement pipeline
mock_failing_pipeline = AsyncMock(spec=DomainEnhancementPipeline)
mock_failing_pipeline.enhance_content.side_effect = Exception("Enhancement failed")
pipeline.domain_enhancement_pipeline = mock_failing_pipeline
# Test segments
segments = [
{"text": "Test content", "start": 0.0, "end": 1.0}
]
# Perform enhancement - should fall back to basic enhancement
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
# Verify fallback behavior
assert len(result) == 1
assert result[0]["domain"] == "technical"
assert result[0]["text"].startswith("[TECHNICAL]")
@pytest.mark.asyncio
async def test_domain_enhancement_with_general_domain(self, pipeline):
"""Test that general domain content doesn't trigger enhancement."""
# Test segments with general domain
segments = [
{"text": "General content", "start": 0.0, "end": 1.0}
]
# Perform enhancement with general domain
result = await pipeline._perform_enhancement_pass(segments, domain="general")
# Verify no enhancement was applied
assert len(result) == 1
assert result[0]["domain"] == "general"
assert result[0]["text"] == "General content"
@pytest.mark.asyncio
async def test_domain_enhancement_config_integration(self, pipeline):
"""Test that domain enhancement configuration is properly used."""
# Verify configuration is set
assert pipeline.domain_enhancement_config is not None
assert isinstance(pipeline.domain_enhancement_config, DomainEnhancementConfig)
assert pipeline.domain_enhancement_config.enable_terminology_enhancement is True
assert pipeline.domain_enhancement_config.quality_threshold == 0.7
@pytest.mark.asyncio
async def test_async_transcription_with_enhancement(self, pipeline, mock_domain_enhancement_pipeline):
"""Test that the main transcription method works with async enhancement."""
# Set up mocks
pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline
# Mock the transcription methods
with patch.object(pipeline, '_perform_first_pass_with_domain_awareness') as mock_first_pass, \
patch.object(pipeline, '_calculate_confidence') as mock_calc_conf, \
patch.object(pipeline, '_identify_low_confidence_segments') as mock_identify, \
patch.object(pipeline, '_perform_refinement_pass') as mock_refine, \
patch.object(pipeline, '_merge_transcription_results') as mock_merge, \
patch.object(pipeline, '_merge_with_diarization') as mock_merge_diar:
# Set up mock return values
mock_first_pass.return_value = (
[{"text": "Test", "start": 0.0, "end": 1.0}],
"technical",
0.8
)
mock_calc_conf.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "confidence": 0.9}]
mock_identify.return_value = []
mock_merge.return_value = [{"text": "Test", "start": 0.0, "end": 1.0}]
mock_merge_diar.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "speaker": "UNKNOWN"}]
# Mock audio file
audio_path = Path("test_audio.wav")
# Test transcription
result = await pipeline.transcribe_with_parallel_processing(
audio_path,
speaker_diarization=False,
domain="technical"
)
# Verify result structure
assert "transcript" in result
assert "processing_time" in result
assert "confidence_score" in result
assert len(result["transcript"]) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])