175 lines
7.9 KiB
Python
175 lines
7.9 KiB
Python
"""Test domain detection integration with transcription pipeline.
|
|
|
|
Tests the integration of domain detection into the transcription pipeline,
|
|
including the new methods added to DomainDetector.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
from src.services.domain_adaptation import DomainDetector
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
|
|
|
|
class TestDomainDetectionIntegration:
|
|
"""Test domain detection integration with transcription pipeline."""
|
|
|
|
@pytest.fixture
|
|
def domain_detector(self):
|
|
"""Create a DomainDetector instance for testing."""
|
|
return DomainDetector()
|
|
|
|
@pytest.fixture
|
|
def mock_model_manager(self):
|
|
"""Create a mock ModelManager for testing."""
|
|
mock_manager = Mock()
|
|
mock_manager.load_model.return_value = Mock()
|
|
mock_manager.get_current_model.return_value = Mock()
|
|
return mock_manager
|
|
|
|
@pytest.fixture
|
|
def mock_domain_adaptation_manager(self):
|
|
"""Create a mock DomainAdaptationManager for testing."""
|
|
mock_manager = Mock()
|
|
mock_manager.domain_detector = DomainDetector()
|
|
mock_manager.domain_adapter.domain_adapters = {
|
|
"medical": Mock(),
|
|
"technical": Mock(),
|
|
"academic": Mock()
|
|
}
|
|
return mock_manager
|
|
|
|
def test_detect_domain_from_text_medical(self, domain_detector):
|
|
"""Test domain detection from medical text."""
|
|
medical_text = "The patient shows symptoms of acute myocardial infarction"
|
|
detected_domain = domain_detector.detect_domain_from_text(medical_text)
|
|
assert detected_domain == "medical"
|
|
|
|
def test_detect_domain_from_text_technical(self, domain_detector):
|
|
"""Test domain detection from technical text."""
|
|
technical_text = "The algorithm implements a singleton pattern for thread safety in the software system"
|
|
detected_domain = domain_detector.detect_domain_from_text(technical_text)
|
|
assert detected_domain == "technical"
|
|
|
|
def test_detect_domain_from_text_academic(self, domain_detector):
|
|
"""Test domain detection from academic text."""
|
|
academic_text = "The research methodology follows a quantitative approach"
|
|
detected_domain = domain_detector.detect_domain_from_text(academic_text)
|
|
assert detected_domain == "academic"
|
|
|
|
def test_detect_domain_from_text_general(self, domain_detector):
|
|
"""Test domain detection from general text."""
|
|
general_text = "Hello world, how are you today?"
|
|
detected_domain = domain_detector.detect_domain_from_text(general_text)
|
|
assert detected_domain == "general"
|
|
|
|
def test_detect_domain_from_path_medical(self, domain_detector):
|
|
"""Test domain detection from medical audio path."""
|
|
medical_path = Path("data/media/medical_interview_patient_123.wav")
|
|
detected_domain = domain_detector.detect_domain_from_path(medical_path)
|
|
assert detected_domain == "medical"
|
|
|
|
def test_detect_domain_from_path_technical(self, domain_detector):
|
|
"""Test domain detection from technical audio path."""
|
|
technical_path = Path("data/media/tech_tutorial_python_programming.mp3")
|
|
detected_domain = domain_detector.detect_domain_from_path(technical_path)
|
|
assert detected_domain == "technical"
|
|
|
|
def test_detect_domain_from_path_academic(self, domain_detector):
|
|
"""Test domain detection from academic audio path."""
|
|
academic_path = Path("data/media/research_presentation_university_lecture.wav")
|
|
detected_domain = domain_detector.detect_domain_from_path(academic_path)
|
|
assert detected_domain == "academic"
|
|
|
|
def test_detect_domain_from_path_no_indicators(self, domain_detector):
|
|
"""Test domain detection from path with no domain indicators."""
|
|
general_path = Path("data/media/recording_001.wav")
|
|
detected_domain = domain_detector.detect_domain_from_path(general_path)
|
|
assert detected_domain is None
|
|
|
|
def test_rule_based_detection_fallback(self, domain_detector):
|
|
"""Test rule-based detection fallback when ML model not trained."""
|
|
# DomainDetector starts untrained, so should use rule-based detection
|
|
medical_text = "The patient requires immediate medical attention"
|
|
detected_domain = domain_detector.detect_domain(medical_text)
|
|
assert detected_domain == "medical"
|
|
|
|
def test_domain_probabilities_fallback(self, domain_detector):
|
|
"""Test domain probabilities fallback when ML model not trained."""
|
|
medical_text = "Patient shows symptoms of hypertension"
|
|
probabilities = domain_detector.get_domain_probabilities(medical_text)
|
|
|
|
assert "medical" in probabilities
|
|
assert "general" in probabilities
|
|
assert "technical" in probabilities
|
|
assert "academic" in probabilities
|
|
|
|
# Medical domain should have highest probability
|
|
assert probabilities["medical"] > probabilities["general"]
|
|
|
|
def test_pipeline_domain_detection_integration(self, mock_model_manager, mock_domain_adaptation_manager):
|
|
"""Test domain detection integration in the transcription pipeline."""
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
model_manager=mock_model_manager,
|
|
domain_adapter=mock_domain_adaptation_manager,
|
|
auto_detect_domain=True
|
|
)
|
|
|
|
# Test that domain detector is properly initialized
|
|
assert pipeline.domain_detector is not None
|
|
assert pipeline.auto_detect_domain is True
|
|
|
|
def test_pipeline_domain_detection_disabled(self, mock_model_manager):
|
|
"""Test pipeline behavior when domain detection is disabled."""
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
model_manager=mock_model_manager,
|
|
auto_detect_domain=False
|
|
)
|
|
|
|
# Test that domain detector is not initialized when disabled
|
|
assert pipeline.domain_detector is None
|
|
assert pipeline.auto_detect_domain is False
|
|
|
|
def test_domain_detection_confidence_scoring(self, domain_detector):
|
|
"""Test domain detection confidence scoring."""
|
|
# Test with clear medical text
|
|
medical_text = "The patient exhibits symptoms of diabetes mellitus"
|
|
probabilities = domain_detector.get_domain_probabilities(medical_text)
|
|
|
|
# Medical domain should have highest probability
|
|
medical_prob = probabilities.get("medical", 0.0)
|
|
assert medical_prob > 0.5 # Should have high confidence
|
|
|
|
# Test with ambiguous text
|
|
ambiguous_text = "This is a general conversation about various topics"
|
|
ambiguous_probs = domain_detector.get_domain_probabilities(ambiguous_text)
|
|
|
|
# General domain should have highest probability for ambiguous text
|
|
general_prob = ambiguous_probs.get("general", 0.0)
|
|
assert general_prob > 0.3 # Should have reasonable confidence
|
|
|
|
def test_domain_detection_edge_cases(self, domain_detector):
|
|
"""Test domain detection with edge cases."""
|
|
# Empty text
|
|
empty_result = domain_detector.detect_domain_from_text("")
|
|
assert empty_result == "general"
|
|
|
|
# Very short text
|
|
short_result = domain_detector.detect_domain_from_text("Hi")
|
|
assert short_result == "general"
|
|
|
|
# Text with only punctuation
|
|
punct_result = domain_detector.detect_domain_from_text("...!?")
|
|
assert punct_result == "general"
|
|
|
|
# Mixed domain text (should pick the strongest signal)
|
|
mixed_text = "The patient needs to implement the algorithm for diagnosis"
|
|
mixed_result = domain_detector.detect_domain_from_text(mixed_text)
|
|
# Should detect either medical or technical, not general
|
|
assert mixed_result in ["medical", "technical"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|