"""Integration tests for MultiPassTranscriptionPipeline with DomainEnhancementPipeline. Tests the integration of the domain-specific enhancement pipeline with the multi-pass transcription pipeline for Task 8.3. """ import pytest import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline from src.services.domain_enhancement import DomainEnhancementPipeline, DomainEnhancementConfig, EnhancementResult class TestMultiPassDomainEnhancementIntegration: """Test integration between MultiPassTranscriptionPipeline and DomainEnhancementPipeline.""" @pytest.fixture def mock_model_manager(self): """Mock model manager for testing.""" mock = MagicMock() mock.load_model.return_value = MagicMock() mock.get_current_model.return_value = MagicMock() return mock @pytest.fixture def mock_domain_adaptation_manager(self): """Mock domain adaptation manager for testing.""" mock = MagicMock() mock.domain_detector = MagicMock() mock.domain_detector.detect_domain_from_text.return_value = "technical" mock.domain_detector.detect_domain_from_path.return_value = "technical" # Mock the domain adapter to return True (success) mock.domain_adapter = MagicMock() mock.domain_adapter.domain_adapters = {"technical": MagicMock()} mock.domain_adapter.switch_adapter.return_value = MagicMock() return mock @pytest.fixture def mock_domain_enhancement_pipeline(self): """Mock domain enhancement pipeline for testing.""" mock = AsyncMock(spec=DomainEnhancementPipeline) # Mock enhancement result mock_result = EnhancementResult( original_text="Test technical content", enhanced_text="Enhanced technical content with proper terminology", domain="technical", confidence_score=0.85, improvements=["Terminology enhancement", "Formatting optimization"], terminology_corrections=["tech -> technical", "algo -> algorithm"], quality_metrics={"technical_term_density": 0.8, "formatting_consistency": 0.9}, processing_time=1.5 ) mock.enhance_content.return_value = mock_result return mock @pytest.fixture def pipeline(self, mock_model_manager, mock_domain_adaptation_manager): """Create MultiPassTranscriptionPipeline instance for testing.""" config = DomainEnhancementConfig( domain="technical", enable_terminology_enhancement=True, enable_citation_handling=True, enable_formatting_optimization=True, quality_threshold=0.7, max_enhancement_iterations=3 ) return MultiPassTranscriptionPipeline( model_manager=mock_model_manager, domain_adapter=mock_domain_adaptation_manager, auto_detect_domain=True, domain_enhancement_config=config ) @pytest.mark.asyncio async def test_domain_enhancement_pipeline_initialization(self, pipeline): """Test that domain enhancement pipeline is properly initialized.""" # The pipeline should be None initially assert pipeline.domain_enhancement_pipeline is None # Mock the domain enhancement pipeline with patch('src.services.domain_enhancement.DomainEnhancementPipeline', autospec=True) as mock_class: mock_instance = MagicMock() mock_class.return_value = mock_instance # Mock the enhance_content method to return a valid result mock_instance.enhance_content = AsyncMock(return_value=MagicMock( enhanced_text="Enhanced content", confidence_score=0.8, improvements=["Test improvement"], terminology_corrections=["Test correction"], quality_metrics={"test": 0.8} )) # Call the enhancement method to trigger initialization segments = [{"text": "Test technical content", "start": 0.0, "end": 1.0}] result = await pipeline._perform_enhancement_pass(segments, domain="technical") # Verify the pipeline was initialized assert pipeline.domain_enhancement_pipeline is not None mock_class.assert_called_once() @pytest.mark.asyncio async def test_domain_enhancement_with_technical_content(self, pipeline, mock_domain_enhancement_pipeline): """Test domain enhancement with technical content.""" # Set up the mock pipeline pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline # Test segments segments = [ {"text": "Test technical content", "start": 0.0, "end": 1.0}, {"text": "with algorithms", "start": 1.0, "end": 2.0} ] # Perform enhancement result = await pipeline._perform_enhancement_pass(segments, domain="technical") # Verify enhancement was called mock_domain_enhancement_pipeline.enhance_content.assert_called_once() # Verify result structure assert len(result) == 2 assert result[0]["domain"] == "technical" assert "enhancement_confidence" in result[0] assert "enhancement_improvements" in result[0] assert "enhancement_terminology_corrections" in result[0] assert "enhancement_quality_metrics" in result[0] @pytest.mark.asyncio async def test_domain_enhancement_fallback_on_failure(self, pipeline): """Test that enhancement falls back gracefully on failure.""" # Mock a failing enhancement pipeline mock_failing_pipeline = AsyncMock(spec=DomainEnhancementPipeline) mock_failing_pipeline.enhance_content.side_effect = Exception("Enhancement failed") pipeline.domain_enhancement_pipeline = mock_failing_pipeline # Test segments segments = [ {"text": "Test content", "start": 0.0, "end": 1.0} ] # Perform enhancement - should fall back to basic enhancement result = await pipeline._perform_enhancement_pass(segments, domain="technical") # Verify fallback behavior assert len(result) == 1 assert result[0]["domain"] == "technical" assert result[0]["text"].startswith("[TECHNICAL]") @pytest.mark.asyncio async def test_domain_enhancement_with_general_domain(self, pipeline): """Test that general domain content doesn't trigger enhancement.""" # Test segments with general domain segments = [ {"text": "General content", "start": 0.0, "end": 1.0} ] # Perform enhancement with general domain result = await pipeline._perform_enhancement_pass(segments, domain="general") # Verify no enhancement was applied assert len(result) == 1 assert result[0]["domain"] == "general" assert result[0]["text"] == "General content" @pytest.mark.asyncio async def test_domain_enhancement_config_integration(self, pipeline): """Test that domain enhancement configuration is properly used.""" # Verify configuration is set assert pipeline.domain_enhancement_config is not None assert isinstance(pipeline.domain_enhancement_config, DomainEnhancementConfig) assert pipeline.domain_enhancement_config.enable_terminology_enhancement is True assert pipeline.domain_enhancement_config.quality_threshold == 0.7 @pytest.mark.asyncio async def test_async_transcription_with_enhancement(self, pipeline, mock_domain_enhancement_pipeline): """Test that the main transcription method works with async enhancement.""" # Set up mocks pipeline.domain_enhancement_pipeline = mock_domain_enhancement_pipeline # Mock the transcription methods with patch.object(pipeline, '_perform_first_pass_with_domain_awareness') as mock_first_pass, \ patch.object(pipeline, '_calculate_confidence') as mock_calc_conf, \ patch.object(pipeline, '_identify_low_confidence_segments') as mock_identify, \ patch.object(pipeline, '_perform_refinement_pass') as mock_refine, \ patch.object(pipeline, '_merge_transcription_results') as mock_merge, \ patch.object(pipeline, '_merge_with_diarization') as mock_merge_diar: # Set up mock return values mock_first_pass.return_value = ( [{"text": "Test", "start": 0.0, "end": 1.0}], "technical", 0.8 ) mock_calc_conf.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "confidence": 0.9}] mock_identify.return_value = [] mock_merge.return_value = [{"text": "Test", "start": 0.0, "end": 1.0}] mock_merge_diar.return_value = [{"text": "Test", "start": 0.0, "end": 1.0, "speaker": "UNKNOWN"}] # Mock audio file audio_path = Path("test_audio.wav") # Test transcription result = await pipeline.transcribe_with_parallel_processing( audio_path, speaker_diarization=False, domain="technical" ) # Verify result structure assert "transcript" in result assert "processing_time" in result assert "confidence_score" in result assert len(result["transcript"]) == 1 if __name__ == "__main__": pytest.main([__file__, "-v"])