227 lines
9.7 KiB
Python
227 lines
9.7 KiB
Python
"""Integration tests for MultiPassTranscriptionPipeline with DomainEnhancementPipeline.
|
|
|
|
Tests the integration of the domain-specific enhancement pipeline with the
|
|
multi-pass transcription pipeline for Task 8.3.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.domain_enhancement import DomainEnhancementPipeline, DomainEnhancementConfig, EnhancementResult
|
|
|
|
|
|
class TestMultiPassDomainEnhancementIntegration:
|
|
"""Test integration between MultiPassTranscriptionPipeline and DomainEnhancementPipeline."""
|
|
|
|
@pytest.fixture
|
|
def mock_model_manager(self):
|
|
"""Mock model manager for testing."""
|
|
mock = MagicMock()
|
|
mock.load_model.return_value = MagicMock()
|
|
mock.get_current_model.return_value = MagicMock()
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def mock_domain_adaptation_manager(self):
|
|
"""Mock domain adaptation manager for testing."""
|
|
mock = MagicMock()
|
|
mock.domain_detector = MagicMock()
|
|
mock.domain_detector.detect_domain_from_text.return_value = "technical"
|
|
mock.domain_detector.detect_domain_from_path.return_value = "technical"
|
|
|
|
# Mock the domain adapter to return True (success)
|
|
mock.domain_adapter = MagicMock()
|
|
mock.domain_adapter.domain_adapters = {"technical": MagicMock()}
|
|
mock.domain_adapter.switch_adapter.return_value = MagicMock()
|
|
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def mock_domain_enhancement_pipeline(self):
|
|
"""Mock domain enhancement pipeline for testing."""
|
|
mock = AsyncMock(spec=DomainEnhancementPipeline)
|
|
|
|
# Mock enhancement result
|
|
mock_result = EnhancementResult(
|
|
original_text="Test technical content",
|
|
enhanced_text="Enhanced technical content with proper terminology",
|
|
domain="technical",
|
|
confidence_score=0.85,
|
|
improvements=["Terminology enhancement", "Formatting optimization"],
|
|
terminology_corrections=["tech -> technical", "algo -> algorithm"],
|
|
quality_metrics={"technical_term_density": 0.8, "formatting_consistency": 0.9},
|
|
processing_time=1.5
|
|
)
|
|
|
|
mock.enhance_content.return_value = mock_result
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def pipeline(self, mock_model_manager, mock_domain_adaptation_manager):
|
|
"""Create MultiPassTranscriptionPipeline instance for testing."""
|
|
config = DomainEnhancementConfig(
|
|
domain="technical",
|
|
enable_terminology_enhancement=True,
|
|
enable_citation_handling=True,
|
|
enable_formatting_optimization=True,
|
|
quality_threshold=0.7,
|
|
max_enhancement_iterations=3
|
|
)
|
|
|
|
return MultiPassTranscriptionPipeline(
|
|
model_manager=mock_model_manager,
|
|
domain_adapter=mock_domain_adaptation_manager,
|
|
auto_detect_domain=True,
|
|
domain_enhancement_config=config
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_enhancement_pipeline_initialization(self, pipeline):
|
|
"""Test that domain enhancement pipeline is properly initialized."""
|
|
# The pipeline should be None initially
|
|
assert pipeline.domain_enhancement_pipeline is None
|
|
|
|
# Mock the domain enhancement pipeline
|
|
with patch('src.services.domain_enhancement.DomainEnhancementPipeline', autospec=True) as mock_class:
|
|
mock_instance = MagicMock()
|
|
mock_class.return_value = mock_instance
|
|
|
|
# Mock the enhance_content method to return a valid result
|
|
mock_instance.enhance_content = AsyncMock(return_value=MagicMock(
|
|
enhanced_text="Enhanced content",
|
|
confidence_score=0.8,
|
|
improvements=["Test improvement"],
|
|
terminology_corrections=["Test correction"],
|
|
quality_metrics={"test": 0.8}
|
|
))
|
|
|
|
# Call the enhancement method to trigger initialization
|
|
segments = [{"text": "Test technical content", "start": 0.0, "end": 1.0}]
|
|
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
|
|
|
|
|
|
|
|
# Verify the pipeline was initialized
|
|
assert pipeline.domain_enhancement_pipeline is not None
|
|
mock_class.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_enhancement_with_technical_content(self, pipeline, mock_domain_enhancement_pipeline):
|
|
"""Test domain enhancement with technical content."""
|
|
# Set up the mock pipeline
|
|
pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline
|
|
|
|
# Test segments
|
|
segments = [
|
|
{"text": "Test technical content", "start": 0.0, "end": 1.0},
|
|
{"text": "with algorithms", "start": 1.0, "end": 2.0}
|
|
]
|
|
|
|
# Perform enhancement
|
|
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
|
|
|
|
# Verify enhancement was called
|
|
mock_domain_enhancement_pipeline.enhance_content.assert_called_once()
|
|
|
|
# Verify result structure
|
|
assert len(result) == 2
|
|
assert result[0]["domain"] == "technical"
|
|
assert "enhancement_confidence" in result[0]
|
|
assert "enhancement_improvements" in result[0]
|
|
assert "enhancement_terminology_corrections" in result[0]
|
|
assert "enhancement_quality_metrics" in result[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_enhancement_fallback_on_failure(self, pipeline):
|
|
"""Test that enhancement falls back gracefully on failure."""
|
|
# Mock a failing enhancement pipeline
|
|
mock_failing_pipeline = AsyncMock(spec=DomainEnhancementPipeline)
|
|
mock_failing_pipeline.enhance_content.side_effect = Exception("Enhancement failed")
|
|
pipeline.domain_enhancement_pipeline = mock_failing_pipeline
|
|
|
|
# Test segments
|
|
segments = [
|
|
{"text": "Test content", "start": 0.0, "end": 1.0}
|
|
]
|
|
|
|
# Perform enhancement - should fall back to basic enhancement
|
|
result = await pipeline._perform_enhancement_pass(segments, domain="technical")
|
|
|
|
# Verify fallback behavior
|
|
assert len(result) == 1
|
|
assert result[0]["domain"] == "technical"
|
|
assert result[0]["text"].startswith("[TECHNICAL]")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_enhancement_with_general_domain(self, pipeline):
|
|
"""Test that general domain content doesn't trigger enhancement."""
|
|
# Test segments with general domain
|
|
segments = [
|
|
{"text": "General content", "start": 0.0, "end": 1.0}
|
|
]
|
|
|
|
# Perform enhancement with general domain
|
|
result = await pipeline._perform_enhancement_pass(segments, domain="general")
|
|
|
|
# Verify no enhancement was applied
|
|
assert len(result) == 1
|
|
assert result[0]["domain"] == "general"
|
|
assert result[0]["text"] == "General content"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_enhancement_config_integration(self, pipeline):
|
|
"""Test that domain enhancement configuration is properly used."""
|
|
# Verify configuration is set
|
|
assert pipeline.domain_enhancement_config is not None
|
|
assert isinstance(pipeline.domain_enhancement_config, DomainEnhancementConfig)
|
|
assert pipeline.domain_enhancement_config.enable_terminology_enhancement is True
|
|
assert pipeline.domain_enhancement_config.quality_threshold == 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_transcription_with_enhancement(self, pipeline, mock_domain_enhancement_pipeline):
|
|
"""Test that the main transcription method works with async enhancement."""
|
|
# Set up mocks
|
|
pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline
|
|
|
|
# Mock the transcription methods
|
|
with patch.object(pipeline, '_perform_first_pass_with_domain_awareness') as mock_first_pass, \
|
|
patch.object(pipeline, '_calculate_confidence') as mock_calc_conf, \
|
|
patch.object(pipeline, '_identify_low_confidence_segments') as mock_identify, \
|
|
patch.object(pipeline, '_perform_refinement_pass') as mock_refine, \
|
|
patch.object(pipeline, '_merge_transcription_results') as mock_merge, \
|
|
patch.object(pipeline, '_merge_with_diarization') as mock_merge_diar:
|
|
|
|
# Set up mock return values
|
|
mock_first_pass.return_value = (
|
|
[{"text": "Test", "start": 0.0, "end": 1.0}],
|
|
"technical",
|
|
0.8
|
|
)
|
|
mock_calc_conf.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "confidence": 0.9}]
|
|
mock_identify.return_value = []
|
|
mock_merge.return_value = [{"text": "Test", "start": 0.0, "end": 1.0}]
|
|
mock_merge_diar.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "speaker": "UNKNOWN"}]
|
|
|
|
# Mock audio file
|
|
audio_path = Path("test_audio.wav")
|
|
|
|
# Test transcription
|
|
result = await pipeline.transcribe_with_parallel_processing(
|
|
audio_path,
|
|
speaker_diarization=False,
|
|
domain="technical"
|
|
)
|
|
|
|
# Verify result structure
|
|
assert "transcript" in result
|
|
assert "processing_time" in result
|
|
assert "confidence_score" in result
|
|
assert len(result["transcript"]) == 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|