"""Test domain detection integration with transcription pipeline. Tests the integration of domain detection into the transcription pipeline, including the new methods added to DomainDetector. """ import pytest from pathlib import Path from unittest.mock import Mock, patch from src.services.domain_adaptation import DomainDetector from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline class TestDomainDetectionIntegration: """Test domain detection integration with transcription pipeline.""" @pytest.fixture def domain_detector(self): """Create a DomainDetector instance for testing.""" return DomainDetector() @pytest.fixture def mock_model_manager(self): """Create a mock ModelManager for testing.""" mock_manager = Mock() mock_manager.load_model.return_value = Mock() mock_manager.get_current_model.return_value = Mock() return mock_manager @pytest.fixture def mock_domain_adaptation_manager(self): """Create a mock DomainAdaptationManager for testing.""" mock_manager = Mock() mock_manager.domain_detector = DomainDetector() mock_manager.domain_adapter.domain_adapters = { "medical": Mock(), "technical": Mock(), "academic": Mock() } return mock_manager def test_detect_domain_from_text_medical(self, domain_detector): """Test domain detection from medical text.""" medical_text = "The patient shows symptoms of acute myocardial infarction" detected_domain = domain_detector.detect_domain_from_text(medical_text) assert detected_domain == "medical" def test_detect_domain_from_text_technical(self, domain_detector): """Test domain detection from technical text.""" technical_text = "The algorithm implements a singleton pattern for thread safety in the software system" detected_domain = domain_detector.detect_domain_from_text(technical_text) assert detected_domain == "technical" def test_detect_domain_from_text_academic(self, domain_detector): """Test domain detection from academic text.""" academic_text = "The research methodology follows a quantitative approach" detected_domain = domain_detector.detect_domain_from_text(academic_text) assert detected_domain == "academic" def test_detect_domain_from_text_general(self, domain_detector): """Test domain detection from general text.""" general_text = "Hello world, how are you today?" detected_domain = domain_detector.detect_domain_from_text(general_text) assert detected_domain == "general" def test_detect_domain_from_path_medical(self, domain_detector): """Test domain detection from medical audio path.""" medical_path = Path("data/media/medical_interview_patient_123.wav") detected_domain = domain_detector.detect_domain_from_path(medical_path) assert detected_domain == "medical" def test_detect_domain_from_path_technical(self, domain_detector): """Test domain detection from technical audio path.""" technical_path = Path("data/media/tech_tutorial_python_programming.mp3") detected_domain = domain_detector.detect_domain_from_path(technical_path) assert detected_domain == "technical" def test_detect_domain_from_path_academic(self, domain_detector): """Test domain detection from academic audio path.""" academic_path = Path("data/media/research_presentation_university_lecture.wav") detected_domain = domain_detector.detect_domain_from_path(academic_path) assert detected_domain == "academic" def test_detect_domain_from_path_no_indicators(self, domain_detector): """Test domain detection from path with no domain indicators.""" general_path = Path("data/media/recording_001.wav") detected_domain = domain_detector.detect_domain_from_path(general_path) assert detected_domain is None def test_rule_based_detection_fallback(self, domain_detector): """Test rule-based detection fallback when ML model not trained.""" # DomainDetector starts untrained, so should use rule-based detection medical_text = "The patient requires immediate medical attention" detected_domain = domain_detector.detect_domain(medical_text) assert detected_domain == "medical" def test_domain_probabilities_fallback(self, domain_detector): """Test domain probabilities fallback when ML model not trained.""" medical_text = "Patient shows symptoms of hypertension" probabilities = domain_detector.get_domain_probabilities(medical_text) assert "medical" in probabilities assert "general" in probabilities assert "technical" in probabilities assert "academic" in probabilities # Medical domain should have highest probability assert probabilities["medical"] > probabilities["general"] def test_pipeline_domain_detection_integration(self, mock_model_manager, mock_domain_adaptation_manager): """Test domain detection integration in the transcription pipeline.""" pipeline = MultiPassTranscriptionPipeline( model_manager=mock_model_manager, domain_adapter=mock_domain_adaptation_manager, auto_detect_domain=True ) # Test that domain detector is properly initialized assert pipeline.domain_detector is not None assert pipeline.auto_detect_domain is True def test_pipeline_domain_detection_disabled(self, mock_model_manager): """Test pipeline behavior when domain detection is disabled.""" pipeline = MultiPassTranscriptionPipeline( model_manager=mock_model_manager, auto_detect_domain=False ) # Test that domain detector is not initialized when disabled assert pipeline.domain_detector is None assert pipeline.auto_detect_domain is False def test_domain_detection_confidence_scoring(self, domain_detector): """Test domain detection confidence scoring.""" # Test with clear medical text medical_text = "The patient exhibits symptoms of diabetes mellitus" probabilities = domain_detector.get_domain_probabilities(medical_text) # Medical domain should have highest probability medical_prob = probabilities.get("medical", 0.0) assert medical_prob > 0.5 # Should have high confidence # Test with ambiguous text ambiguous_text = "This is a general conversation about various topics" ambiguous_probs = domain_detector.get_domain_probabilities(ambiguous_text) # General domain should have highest probability for ambiguous text general_prob = ambiguous_probs.get("general", 0.0) assert general_prob > 0.3 # Should have reasonable confidence def test_domain_detection_edge_cases(self, domain_detector): """Test domain detection with edge cases.""" # Empty text empty_result = domain_detector.detect_domain_from_text("") assert empty_result == "general" # Very short text short_result = domain_detector.detect_domain_from_text("Hi") assert short_result == "general" # Text with only punctuation punct_result = domain_detector.detect_domain_from_text("...!?") assert punct_result == "general" # Mixed domain text (should pick the strongest signal) mixed_text = "The patient needs to implement the algorithm for diagnosis" mixed_result = domain_detector.detect_domain_from_text(mixed_text) # Should detect either medical or technical, not general assert mixed_result in ["medical", "technical"] if __name__ == "__main__": pytest.main([__file__])