trax/tests/test_domain_detection_integ...

175 lines
7.9 KiB
Python

"""Test domain detection integration with transcription pipeline.
Tests the integration of domain detection into the transcription pipeline,
including the new methods added to DomainDetector.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from src.services.domain_adaptation import DomainDetector
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
class TestDomainDetectionIntegration:
"""Test domain detection integration with transcription pipeline."""
@pytest.fixture
def domain_detector(self):
"""Create a DomainDetector instance for testing."""
return DomainDetector()
@pytest.fixture
def mock_model_manager(self):
"""Create a mock ModelManager for testing."""
mock_manager = Mock()
mock_manager.load_model.return_value = Mock()
mock_manager.get_current_model.return_value = Mock()
return mock_manager
@pytest.fixture
def mock_domain_adaptation_manager(self):
"""Create a mock DomainAdaptationManager for testing."""
mock_manager = Mock()
mock_manager.domain_detector = DomainDetector()
mock_manager.domain_adapter.domain_adapters = {
"medical": Mock(),
"technical": Mock(),
"academic": Mock()
}
return mock_manager
def test_detect_domain_from_text_medical(self, domain_detector):
"""Test domain detection from medical text."""
medical_text = "The patient shows symptoms of acute myocardial infarction"
detected_domain = domain_detector.detect_domain_from_text(medical_text)
assert detected_domain == "medical"
def test_detect_domain_from_text_technical(self, domain_detector):
"""Test domain detection from technical text."""
technical_text = "The algorithm implements a singleton pattern for thread safety in the software system"
detected_domain = domain_detector.detect_domain_from_text(technical_text)
assert detected_domain == "technical"
def test_detect_domain_from_text_academic(self, domain_detector):
"""Test domain detection from academic text."""
academic_text = "The research methodology follows a quantitative approach"
detected_domain = domain_detector.detect_domain_from_text(academic_text)
assert detected_domain == "academic"
def test_detect_domain_from_text_general(self, domain_detector):
"""Test domain detection from general text."""
general_text = "Hello world, how are you today?"
detected_domain = domain_detector.detect_domain_from_text(general_text)
assert detected_domain == "general"
def test_detect_domain_from_path_medical(self, domain_detector):
"""Test domain detection from medical audio path."""
medical_path = Path("data/media/medical_interview_patient_123.wav")
detected_domain = domain_detector.detect_domain_from_path(medical_path)
assert detected_domain == "medical"
def test_detect_domain_from_path_technical(self, domain_detector):
"""Test domain detection from technical audio path."""
technical_path = Path("data/media/tech_tutorial_python_programming.mp3")
detected_domain = domain_detector.detect_domain_from_path(technical_path)
assert detected_domain == "technical"
def test_detect_domain_from_path_academic(self, domain_detector):
"""Test domain detection from academic audio path."""
academic_path = Path("data/media/research_presentation_university_lecture.wav")
detected_domain = domain_detector.detect_domain_from_path(academic_path)
assert detected_domain == "academic"
def test_detect_domain_from_path_no_indicators(self, domain_detector):
"""Test domain detection from path with no domain indicators."""
general_path = Path("data/media/recording_001.wav")
detected_domain = domain_detector.detect_domain_from_path(general_path)
assert detected_domain is None
def test_rule_based_detection_fallback(self, domain_detector):
"""Test rule-based detection fallback when ML model not trained."""
# DomainDetector starts untrained, so should use rule-based detection
medical_text = "The patient requires immediate medical attention"
detected_domain = domain_detector.detect_domain(medical_text)
assert detected_domain == "medical"
def test_domain_probabilities_fallback(self, domain_detector):
"""Test domain probabilities fallback when ML model not trained."""
medical_text = "Patient shows symptoms of hypertension"
probabilities = domain_detector.get_domain_probabilities(medical_text)
assert "medical" in probabilities
assert "general" in probabilities
assert "technical" in probabilities
assert "academic" in probabilities
# Medical domain should have highest probability
assert probabilities["medical"] > probabilities["general"]
def test_pipeline_domain_detection_integration(self, mock_model_manager, mock_domain_adaptation_manager):
"""Test domain detection integration in the transcription pipeline."""
pipeline = MultiPassTranscriptionPipeline(
model_manager=mock_model_manager,
domain_adapter=mock_domain_adaptation_manager,
auto_detect_domain=True
)
# Test that domain detector is properly initialized
assert pipeline.domain_detector is not None
assert pipeline.auto_detect_domain is True
def test_pipeline_domain_detection_disabled(self, mock_model_manager):
"""Test pipeline behavior when domain detection is disabled."""
pipeline = MultiPassTranscriptionPipeline(
model_manager=mock_model_manager,
auto_detect_domain=False
)
# Test that domain detector is not initialized when disabled
assert pipeline.domain_detector is None
assert pipeline.auto_detect_domain is False
def test_domain_detection_confidence_scoring(self, domain_detector):
"""Test domain detection confidence scoring."""
# Test with clear medical text
medical_text = "The patient exhibits symptoms of diabetes mellitus"
probabilities = domain_detector.get_domain_probabilities(medical_text)
# Medical domain should have highest probability
medical_prob = probabilities.get("medical", 0.0)
assert medical_prob > 0.5 # Should have high confidence
# Test with ambiguous text
ambiguous_text = "This is a general conversation about various topics"
ambiguous_probs = domain_detector.get_domain_probabilities(ambiguous_text)
# General domain should have highest probability for ambiguous text
general_prob = ambiguous_probs.get("general", 0.0)
assert general_prob > 0.3 # Should have reasonable confidence
def test_domain_detection_edge_cases(self, domain_detector):
"""Test domain detection with edge cases."""
# Empty text
empty_result = domain_detector.detect_domain_from_text("")
assert empty_result == "general"
# Very short text
short_result = domain_detector.detect_domain_from_text("Hi")
assert short_result == "general"
# Text with only punctuation
punct_result = domain_detector.detect_domain_from_text("...!?")
assert punct_result == "general"
# Mixed domain text (should pick the strongest signal)
mixed_text = "The patient needs to implement the algorithm for diagnosis"
mixed_result = domain_detector.detect_domain_from_text(mixed_text)
# Should detect either medical or technical, not general
assert mixed_result in ["medical", "technical"]
if __name__ == "__main__":
pytest.main([__file__])