trax/tests/test_domain_enhancement.py

465 lines
21 KiB
Python

"""Test domain-specific enhancement pipeline.
Tests the specialized enhancement workflows for different domains,
including technical terminology enhancement, medical vocabulary optimization,
academic citation handling, and domain-specific quality metrics.
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from typing import Dict, Any
from src.services.domain_enhancement import (
DomainEnhancementPipeline,
DomainEnhancementConfig,
DomainType,
EnhancementResult
)
class TestDomainEnhancementPipeline:
"""Test the domain-specific enhancement pipeline."""
@pytest.fixture
def mock_enhancement_service(self):
"""Create a mock enhancement service."""
service = Mock()
service.enhance_transcript = AsyncMock()
return service
@pytest.fixture
def pipeline(self, mock_enhancement_service):
"""Create a DomainEnhancementPipeline instance."""
return DomainEnhancementPipeline(enhancement_service=mock_enhancement_service)
@pytest.fixture
def sample_texts(self):
"""Sample texts for different domains."""
return {
DomainType.TECHNICAL: "The algorithm implements a singleton pattern for thread safety in the software system",
DomainType.MEDICAL: "Patient presents with symptoms of hypertension and requires treatment for myocardial infarction",
DomainType.ACADEMIC: "Research study analysis shows hypothesis testing methodology with literature review",
DomainType.LEGAL: "Contract agreement compliance with law regulation and legal jurisdiction",
DomainType.GENERAL: "This is a general conversation about various topics and interests"
}
def test_initialization(self, pipeline):
"""Test pipeline initialization."""
assert pipeline.enhancement_service is not None
assert pipeline.domain_detector is not None
assert len(pipeline.strategies) == 5 # All domain types
assert len(pipeline.quality_metrics) == 5 # All domain types
def test_domain_type_enum(self):
"""Test domain type enumeration."""
assert DomainType.GENERAL.value == "general"
assert DomainType.TECHNICAL.value == "technical"
assert DomainType.MEDICAL.value == "medical"
assert DomainType.ACADEMIC.value == "academic"
assert DomainType.LEGAL.value == "legal"
def test_domain_enhancement_config(self):
"""Test domain enhancement configuration."""
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
assert config.domain == DomainType.TECHNICAL
assert config.enable_terminology_enhancement is True
assert config.enable_citation_handling is True
assert config.enable_formatting_optimization is True
assert config.quality_threshold == 0.8
assert config.max_enhancement_iterations == 2
assert config.technical_jargon_threshold == 0.7
assert config.medical_terminology_threshold == 0.8
assert config.academic_citation_threshold == 0.75
assert config.legal_precision_threshold == 0.85
@pytest.mark.asyncio
async def test_enhance_content_with_specified_domain(self, pipeline, sample_texts):
"""Test content enhancement with specified domain."""
text = sample_texts[DomainType.TECHNICAL]
# Mock the enhancement service response
pipeline.enhancement_service.enhance_transcript.return_value = {
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety in the **software system**"
}
result = await pipeline.enhance_content(text, domain=DomainType.TECHNICAL)
assert isinstance(result, EnhancementResult)
assert result.original_text == text
assert result.domain == DomainType.TECHNICAL
assert result.confidence_score > 0
assert len(result.improvements) > 0
assert len(result.quality_metrics) > 0
assert result.processing_time > 0
@pytest.mark.asyncio
async def test_enhance_content_auto_detect_domain(self, pipeline, sample_texts):
"""Test content enhancement with automatic domain detection."""
text = sample_texts[DomainType.MEDICAL]
# Mock the enhancement service response
pipeline.enhancement_service.enhance_transcript.return_value = {
"enhanced_text": "**Patient** presents with symptoms of **hypertension** and requires treatment for **myocardial infarction**"
}
result = await pipeline.enhance_content(text)
assert isinstance(result, EnhancementResult)
assert result.domain in [DomainType.MEDICAL, DomainType.GENERAL] # May fall back to general
assert result.confidence_score > 0
@pytest.mark.asyncio
async def test_enhance_technical_content(self, pipeline, sample_texts):
"""Test technical content enhancement."""
text = sample_texts[DomainType.TECHNICAL]
# Mock the enhancement service response
pipeline.enhancement_service.enhance_transcript.return_value = {
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety in the **software system**"
}
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
assert enhanced_text != text
assert len(improvements) > 0
assert "Applied technical formatting standards" in improvements
@pytest.mark.asyncio
async def test_enhance_medical_content(self, pipeline, sample_texts):
"""Test medical content enhancement."""
text = sample_texts[DomainType.MEDICAL]
# Mock the enhancement service response
pipeline.enhancement_service.enhance_transcript.return_value = {
"enhanced_text": "**Patient** presents with symptoms of **hypertension** and requires treatment for **myocardial infarction**"
}
config = DomainEnhancementConfig(domain=DomainType.MEDICAL)
enhanced_text, improvements, corrections = await pipeline._enhance_medical_content(text, config)
assert enhanced_text != text
assert len(improvements) > 0
assert "Applied medical documentation standards" in improvements
@pytest.mark.asyncio
async def test_enhance_academic_content(self, pipeline, sample_texts):
"""Test academic content enhancement."""
text = sample_texts[DomainType.ACADEMIC]
# Mock the enhancement service responses
pipeline.enhancement_service.enhance_transcript.side_effect = [
{"enhanced_text": "Research study analysis shows hypothesis testing methodology with literature review"},
{"enhanced_text": "**Research** **study** **analysis** shows **hypothesis** testing **methodology** with **literature** review"}
]
config = DomainEnhancementConfig(domain=DomainType.ACADEMIC)
enhanced_text, improvements, corrections = await pipeline._enhance_academic_content(text, config)
assert enhanced_text != text
assert len(improvements) > 0
assert "Applied academic formatting standards" in improvements
@pytest.mark.asyncio
async def test_enhance_legal_content(self, pipeline, sample_texts):
"""Test legal content enhancement."""
text = sample_texts[DomainType.LEGAL]
# Mock the enhancement service response
pipeline.enhancement_service.enhance_transcript.return_value = {
"enhanced_text": "**Contract** **agreement** compliance with **law** **regulation** and **legal** **jurisdiction**"
}
config = DomainEnhancementConfig(domain=DomainType.LEGAL)
enhanced_text, improvements, corrections = await pipeline._enhance_legal_content(text, config)
assert enhanced_text != text
assert len(improvements) > 0
assert "Applied legal precision standards" in improvements
def test_optimize_technical_formatting(self, pipeline):
"""Test technical formatting optimization."""
text = "The code function method class uses file path C:\\temp\\file.txt and version v1.2.3"
enhanced = pipeline._optimize_technical_formatting(text)
# Check that technical terms are formatted
assert "`code`" in enhanced
assert "`function`" in enhanced
assert "`method`" in enhanced
assert "`class`" in enhanced
assert "`C:\\temp\\file.txt`" in enhanced
assert "**v1.2.3**" in enhanced
def test_apply_medical_formatting(self, pipeline):
"""Test medical formatting application."""
text = "Patient takes aspirin and ibuprofen with blood pressure 120/80 mmHg and heart rate 72 bpm"
enhanced = pipeline._apply_medical_formatting(text)
# Check that medical terms are formatted
assert "**aspirin**" in enhanced
assert "**ibuprofen**" in enhanced
assert "`120/80 mmHg`" in enhanced
assert "`72 bpm`" in enhanced
def test_apply_academic_formatting(self, pipeline):
"""Test academic formatting application."""
text = "Research shows et al. findings ibid. and op. cit. references with Figure 1 and Table 2"
enhanced = pipeline._apply_academic_formatting(text)
# Check that academic terms are formatted
assert "*et al.*" in enhanced
assert "*ibid.*" in enhanced
assert "*op. cit.*" in enhanced
assert "**Figure 1**" in enhanced
assert "**Table 2**" in enhanced
def test_optimize_legal_precision(self, pipeline):
"""Test legal precision optimization."""
text = "The contract shall must may hereby whereas therefore be executed"
enhanced = pipeline._optimize_legal_precision(text)
# Check that legal terms are emphasized
assert "**shall**" in enhanced
assert "**must**" in enhanced
assert "**may**" in enhanced
assert "**hereby**" in enhanced
assert "**whereas**" in enhanced
assert "**therefore**" in enhanced
def test_identify_technical_corrections(self, pipeline):
"""Test technical terminology correction identification."""
original = "The python free code uses my sequel database"
enhanced = "The Python 3 code uses MySQL database"
corrections = pipeline._identify_technical_corrections(original, enhanced)
assert len(corrections) > 0
assert any("python free" in corr and "Python 3" in corr for corr in corrections)
assert any("my sequel" in corr and "MySQL" in corr for corr in corrections)
def test_identify_medical_corrections(self, pipeline):
"""Test medical terminology correction identification."""
original = "Patient has hippa compliance issues and takes prozack"
enhanced = "Patient has HIPAA compliance issues and takes Prozac"
corrections = pipeline._identify_medical_corrections(original, enhanced)
assert len(corrections) > 0
assert any("hippa" in corr and "HIPAA" in corr for corr in corrections)
assert any("prozack" in corr and "Prozac" in corr for corr in corrections)
def test_identify_academic_corrections(self, pipeline):
"""Test academic terminology correction identification."""
original = "The research methodology hypothesis and literature review"
enhanced = "The **research** **methodology** **hypothesis** and **literature** review"
corrections = pipeline._identify_academic_corrections(original, enhanced)
# Note: This test may not find corrections if the original text already contains correct terms
# The identification depends on the specific correction patterns
assert isinstance(corrections, list)
def test_identify_legal_corrections(self, pipeline):
"""Test legal terminology correction identification."""
original = "The contract jurisdiction statute and compliance requirements"
enhanced = "The **contract** **jurisdiction** **statute** and **compliance** requirements"
corrections = pipeline._identify_legal_corrections(original, enhanced)
# Note: This test may not find corrections if the original text already contains correct terms
assert isinstance(corrections, list)
def test_calculate_technical_quality(self, pipeline):
"""Test technical content quality calculation."""
enhanced_text = "The `algorithm` implements a **v1.2.3** system with `code` and `function`"
original_text = "The algorithm implements a v1.2.3 system with code and function"
metrics = pipeline._calculate_technical_quality(enhanced_text, original_text)
assert 'technical_term_density' in metrics
assert 'code_reference_accuracy' in metrics
assert 'technical_precision' in metrics
assert all(0 <= value <= 1 for value in metrics.values())
def test_calculate_medical_quality(self, pipeline):
"""Test medical content quality calculation."""
enhanced_text = "**Patient** has **diagnosis** with `120/80 mmHg` and **treatment**"
original_text = "Patient has diagnosis with 120/80 mmHg and treatment"
metrics = pipeline._calculate_medical_quality(enhanced_text, original_text)
assert 'medical_terminology_accuracy' in metrics
assert 'formatting_compliance' in metrics
assert 'medical_precision' in metrics
assert all(0 <= value <= 1 for value in metrics.values())
def test_calculate_academic_quality(self, pipeline):
"""Test academic content quality calculation."""
enhanced_text = "**Research** *et al.* shows **hypothesis** and **Figure 1**"
original_text = "Research et al. shows hypothesis and Figure 1"
metrics = pipeline._calculate_academic_quality(enhanced_text, original_text)
assert 'citation_handling' in metrics
assert 'academic_terminology' in metrics
assert 'academic_quality' in metrics
assert all(0 <= value <= 1 for value in metrics.values())
def test_calculate_legal_quality(self, pipeline):
"""Test legal content quality calculation."""
enhanced_text = "**Contract** **agreement** with `reference` and **legal** terms"
original_text = "Contract agreement with reference and legal terms"
metrics = pipeline._calculate_legal_quality(enhanced_text, original_text)
assert 'legal_terminology_precision' in metrics
assert 'legal_formatting' in metrics
assert 'legal_quality' in metrics
assert all(0 <= value <= 1 for value in metrics.values())
def test_calculate_general_quality(self, pipeline):
"""Test general content quality calculation."""
enhanced_text = "This is a general conversation. It has proper punctuation!"
original_text = "This is a general conversation It has proper punctuation"
metrics = pipeline._calculate_general_quality(enhanced_text, original_text)
assert 'length_ratio' in metrics
assert 'punctuation_improvement' in metrics
assert 'general_quality' in metrics
assert all(0 <= value <= 1 for value in metrics.values())
def test_calculate_confidence_score(self, pipeline):
"""Test confidence score calculation."""
quality_metrics = {
'technical_precision': 0.8,
'medical_precision': 0.9,
'academic_quality': 0.7,
'legal_quality': 0.85,
'general_quality': 0.75
}
confidence = pipeline._calculate_confidence_score(quality_metrics)
assert 0 <= confidence <= 1
assert confidence > 0.7 # Should be high with good metrics
def test_calculate_confidence_score_empty_metrics(self, pipeline):
"""Test confidence score calculation with empty metrics."""
confidence = pipeline._calculate_confidence_score({})
assert confidence == 0.0
@pytest.mark.asyncio
async def test_enhancement_service_failure_handling(self, pipeline, sample_texts):
"""Test handling of enhancement service failures."""
text = sample_texts[DomainType.TECHNICAL]
# Mock enhancement service to raise an exception
pipeline.enhancement_service.enhance_transcript.side_effect = Exception("Service unavailable")
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
# Should fall back to original text for terminology enhancement
# But formatting optimization may still be applied
assert len(corrections) == 0 # No terminology corrections
# Note: Formatting may still be applied even if enhancement service fails
@pytest.mark.asyncio
async def test_domain_specific_configuration(self, pipeline, sample_texts):
"""Test domain-specific configuration options."""
text = sample_texts[DomainType.TECHNICAL]
# Create config with disabled terminology enhancement
config = DomainEnhancementConfig(
domain=DomainType.TECHNICAL,
enable_terminology_enhancement=False,
enable_formatting_optimization=True
)
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
# Should skip terminology enhancement but apply formatting
assert "Applied technical formatting standards" in improvements
assert len(corrections) == 0 # No terminology corrections
def test_enhancement_result_structure(self):
"""Test EnhancementResult data structure."""
result = EnhancementResult(
original_text="Original text",
enhanced_text="Enhanced text",
domain=DomainType.TECHNICAL,
confidence_score=0.85,
improvements=["Improved formatting"],
terminology_corrections=["Corrected term"],
quality_metrics={"technical_precision": 0.8},
processing_time=1.5
)
assert result.original_text == "Original text"
assert result.enhanced_text == "Enhanced text"
assert result.domain == DomainType.TECHNICAL
assert result.confidence_score == 0.85
assert len(result.improvements) == 1
assert len(result.terminology_corrections) == 1
assert len(result.quality_metrics) == 1
assert result.processing_time == 1.5
class TestDomainEnhancementIntegration:
"""Test integration of domain enhancement with the pipeline."""
@pytest.mark.asyncio
async def test_end_to_end_technical_enhancement(self):
"""Test end-to-end technical content enhancement."""
from src.services.domain_enhancement import DomainEnhancementPipeline
# Create pipeline with mock service
mock_service = Mock()
mock_service.enhance_transcript = AsyncMock(return_value={
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety"
})
pipeline = DomainEnhancementPipeline(enhancement_service=mock_service)
text = "The algorithm implements a singleton pattern for thread safety"
result = await pipeline.enhance_content(text, domain=DomainType.TECHNICAL)
assert result.domain == DomainType.TECHNICAL
assert result.confidence_score > 0
assert len(result.improvements) > 0
assert "Applied technical formatting standards" in result.improvements
@pytest.mark.asyncio
async def test_domain_switching(self):
"""Test switching between different domains."""
from src.services.domain_enhancement import DomainEnhancementPipeline
mock_service = Mock()
mock_service.enhance_transcript = AsyncMock(return_value={
"enhanced_text": "Enhanced content"
})
pipeline = DomainEnhancementPipeline(enhancement_service=mock_service)
# Test different domains
domains = [DomainType.TECHNICAL, DomainType.MEDICAL, DomainType.ACADEMIC]
for domain in domains:
result = await pipeline.enhance_content("Test content", domain=domain)
assert result.domain == domain
# Confidence score may be 0 if no domain-specific terms are detected
# This is expected behavior for generic content
assert result.confidence_score >= 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])