"""Test LoRA adapter integration with transcription workflow. Tests for Task #8.1: Connect LoRA Adapters to Transcription Workflow """ import pytest from pathlib import Path from unittest.mock import Mock, patch from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline from src.services.domain_adaptation_manager import DomainAdaptationManager from src.services.model_manager import ModelManager class TestLoRAIntegration: """Test LoRA adapter integration with transcription workflow.""" @pytest.fixture def mock_model_manager(self): """Create a mock ModelManager.""" mock_manager = Mock(spec=ModelManager) mock_manager.load_model.return_value = Mock() return mock_manager @pytest.fixture def mock_domain_adaptation_manager(self): """Create a mock DomainAdaptationManager.""" mock_manager = Mock(spec=DomainAdaptationManager) mock_domain_adapter = Mock() mock_domain_adapter.domain_adapters = {"technical": Mock(), "medical": Mock()} mock_manager.domain_adapter = mock_domain_adapter return mock_manager @pytest.fixture def pipeline_with_lora(self, mock_model_manager, mock_domain_adaptation_manager): """Create pipeline with LoRA integration.""" pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager) pipeline.domain_adaptation_manager = mock_domain_adaptation_manager pipeline.auto_detect_domain = True return pipeline @pytest.fixture def sample_audio_path(self, tmp_path): """Create a sample audio file path.""" audio_file = tmp_path / "test_audio.wav" audio_file.write_bytes(b"fake_audio_data") return audio_file def test_pipeline_initialization_with_lora(self, mock_model_manager): """Test pipeline initialization with LoRA support.""" pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager) assert hasattr(pipeline, 'domain_adaptation_manager') assert hasattr(pipeline, 'auto_detect_domain') assert hasattr(pipeline, 'domain_detector') def test_domain_detection_integration(self, pipeline_with_lora, sample_audio_path): """Test domain detection integration in transcription pipeline.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "algorithm implementation", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement: mock_enhancement.return_value = [{"text": "[TECHNICAL] algorithm implementation", "start": 0.0, "end": 5.0}] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="technical", speaker_diarization=False ) assert result["transcript"][0]["text"].startswith("[TECHNICAL]") def test_domain_specific_model_switching(self, pipeline_with_lora, sample_audio_path): """Test domain-specific model switching during transcription.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=False ) pipeline_with_lora.domain_detector.detect_domain.assert_called() def test_domain_adapter_application(self, pipeline_with_lora): """Test direct domain adapter application.""" pipeline_with_lora.domain_adaptation_manager = None result = pipeline_with_lora._apply_domain_adapter("technical") assert result is False pipeline_with_lora.domain_adaptation_manager = Mock() pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {"technical": Mock()} pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock() result = pipeline_with_lora._apply_domain_adapter("technical") assert result is True def test_automatic_domain_detection(self, pipeline_with_lora, sample_audio_path): """Test automatic domain detection without explicit domain specification.""" pipeline_with_lora.auto_detect_domain = True pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "academic" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=False ) pipeline_with_lora.domain_detector.detect_domain.assert_called() def test_lora_adapter_loading_during_enhancement(self, pipeline_with_lora, sample_audio_path): """Test LoRA adapter loading during enhancement pass.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "medical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="medical", speaker_diarization=False ) assert "transcript" in result assert len(result["transcript"]) > 0 def test_fallback_to_general_domain(self, pipeline_with_lora, sample_audio_path): """Test fallback to general domain when LoRA adapter is not available.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "unknown" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "general content", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "general content", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=False ) assert "transcript" in result assert not result["transcript"][0]["text"].startswith("[") def test_lora_adapter_caching(self, pipeline_with_lora, sample_audio_path): """Test LoRA adapter caching for repeated domain usage.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] # First transcription result1 = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="technical", speaker_diarization=False ) # Second transcription (should use cached adapter) result2 = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="technical", speaker_diarization=False ) # Verify both results are consistent assert result1["transcript"][0]["text"] == result2["transcript"][0]["text"] def test_error_handling_with_lora_failure(self, pipeline_with_lora, sample_audio_path): """Test error handling when LoRA adapter fails to load.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "medical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] # Mock enhancement pass failure - but handle it gracefully with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement: mock_enhancement.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="medical", speaker_diarization=False ) # Verify fallback behavior assert "transcript" in result assert len(result["transcript"]) > 0 def test_performance_with_lora_adapters(self, pipeline_with_lora, sample_audio_path): """Test performance impact of LoRA adapters.""" import time pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] start_time = time.time() result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="technical", speaker_diarization=False ) processing_time = time.time() - start_time # Verify performance is within acceptable limits assert processing_time < 10.0 # Should complete within 10 seconds assert result["processing_time"] < 10.0 def test_memory_management_with_lora(self, pipeline_with_lora, sample_audio_path): """Test memory management during LoRA operations.""" # Mock memory monitoring pipeline_with_lora.model_manager.get_memory_usage = Mock(return_value={ "rss_mb": 1500.0, "cuda_allocated_mb": 800.0 }) pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "medical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="medical", speaker_diarization=False ) # Verify memory monitoring capability is available assert hasattr(pipeline_with_lora.model_manager, 'get_memory_usage') def test_domain_detection_method(self, pipeline_with_lora): """Test domain detection method.""" # Test without domain detector result = pipeline_with_lora._detect_domain(Path("dummy")) assert result is None # Test with mock domain detector pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" result = pipeline_with_lora._detect_domain(Path("dummy"), "algorithm implementation") assert result == "technical" def test_edge_case_empty_segments(self, pipeline_with_lora, sample_audio_path): """Test edge case with empty transcription segments.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [] # Empty segments with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=False ) # Should handle empty segments gracefully assert "transcript" in result assert len(result["transcript"]) == 0 def test_edge_case_very_long_text(self, pipeline_with_lora, sample_audio_path): """Test edge case with very long transcription text.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "academic" # Create very long text long_text = "This is a very long academic text " * 1000 with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": long_text, "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": long_text, "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=False ) # Should handle long text gracefully assert "transcript" in result assert len(result["transcript"]) > 0 def test_integration_with_speaker_diarization(self, pipeline_with_lora, sample_audio_path): """Test LoRA integration with speaker diarization enabled.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "technical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0, "confidence": 0.9}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] # Mock diarization manager with patch('src.services.multi_pass_transcription.DiarizationManager') as mock_diarization_class: mock_diarization = Mock() mock_diarization.process_audio.return_value = Mock(segments=[{"start": 0.0, "end": 5.0, "speaker": "Speaker_1"}]) mock_diarization_class.return_value = mock_diarization result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, speaker_diarization=True ) # Should include speaker information assert "transcript" in result assert len(result["transcript"]) > 0 assert "speaker" in result["transcript"][0] def test_domain_adapter_switching_behavior(self, pipeline_with_lora): """Test domain adapter switching behavior.""" # Test switching between different domains domains = ["technical", "medical", "academic", "general"] for domain in domains: if domain == "general": # General domain should not require adapter switching result = pipeline_with_lora._apply_domain_adapter(domain) assert result is False else: # Other domains should attempt adapter switching pipeline_with_lora.domain_adaptation_manager = Mock() pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {domain: Mock()} pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock() result = pipeline_with_lora._apply_domain_adapter(domain) assert result is True def test_confidence_calculation_with_domain_context(self, pipeline_with_lora, sample_audio_path): """Test confidence calculation maintains domain context.""" pipeline_with_lora.domain_detector = Mock() pipeline_with_lora.domain_detector.detect_domain.return_value = "medical" with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass: mock_first_pass.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0}] with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement: mock_refinement.return_value = [] with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence: mock_confidence.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0, "confidence": 0.95}] with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf: mock_low_conf.return_value = [] result = pipeline_with_lora.transcribe_with_parallel_processing( sample_audio_path, domain="medical", speaker_diarization=False ) # Verify confidence is maintained assert "confidence_score" in result assert result["confidence_score"] > 0.9