465 lines
21 KiB
Python
465 lines
21 KiB
Python
"""Test domain-specific enhancement pipeline.
|
|
|
|
Tests the specialized enhancement workflows for different domains,
|
|
including technical terminology enhancement, medical vocabulary optimization,
|
|
academic citation handling, and domain-specific quality metrics.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from typing import Dict, Any
|
|
|
|
from src.services.domain_enhancement import (
|
|
DomainEnhancementPipeline,
|
|
DomainEnhancementConfig,
|
|
DomainType,
|
|
EnhancementResult
|
|
)
|
|
|
|
|
|
class TestDomainEnhancementPipeline:
|
|
"""Test the domain-specific enhancement pipeline."""
|
|
|
|
@pytest.fixture
|
|
def mock_enhancement_service(self):
|
|
"""Create a mock enhancement service."""
|
|
service = Mock()
|
|
service.enhance_transcript = AsyncMock()
|
|
return service
|
|
|
|
@pytest.fixture
|
|
def pipeline(self, mock_enhancement_service):
|
|
"""Create a DomainEnhancementPipeline instance."""
|
|
return DomainEnhancementPipeline(enhancement_service=mock_enhancement_service)
|
|
|
|
@pytest.fixture
|
|
def sample_texts(self):
|
|
"""Sample texts for different domains."""
|
|
return {
|
|
DomainType.TECHNICAL: "The algorithm implements a singleton pattern for thread safety in the software system",
|
|
DomainType.MEDICAL: "Patient presents with symptoms of hypertension and requires treatment for myocardial infarction",
|
|
DomainType.ACADEMIC: "Research study analysis shows hypothesis testing methodology with literature review",
|
|
DomainType.LEGAL: "Contract agreement compliance with law regulation and legal jurisdiction",
|
|
DomainType.GENERAL: "This is a general conversation about various topics and interests"
|
|
}
|
|
|
|
def test_initialization(self, pipeline):
|
|
"""Test pipeline initialization."""
|
|
assert pipeline.enhancement_service is not None
|
|
assert pipeline.domain_detector is not None
|
|
assert len(pipeline.strategies) == 5 # All domain types
|
|
assert len(pipeline.quality_metrics) == 5 # All domain types
|
|
|
|
def test_domain_type_enum(self):
|
|
"""Test domain type enumeration."""
|
|
assert DomainType.GENERAL.value == "general"
|
|
assert DomainType.TECHNICAL.value == "technical"
|
|
assert DomainType.MEDICAL.value == "medical"
|
|
assert DomainType.ACADEMIC.value == "academic"
|
|
assert DomainType.LEGAL.value == "legal"
|
|
|
|
def test_domain_enhancement_config(self):
|
|
"""Test domain enhancement configuration."""
|
|
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
|
|
|
|
assert config.domain == DomainType.TECHNICAL
|
|
assert config.enable_terminology_enhancement is True
|
|
assert config.enable_citation_handling is True
|
|
assert config.enable_formatting_optimization is True
|
|
assert config.quality_threshold == 0.8
|
|
assert config.max_enhancement_iterations == 2
|
|
assert config.technical_jargon_threshold == 0.7
|
|
assert config.medical_terminology_threshold == 0.8
|
|
assert config.academic_citation_threshold == 0.75
|
|
assert config.legal_precision_threshold == 0.85
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_content_with_specified_domain(self, pipeline, sample_texts):
|
|
"""Test content enhancement with specified domain."""
|
|
text = sample_texts[DomainType.TECHNICAL]
|
|
|
|
# Mock the enhancement service response
|
|
pipeline.enhancement_service.enhance_transcript.return_value = {
|
|
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety in the **software system**"
|
|
}
|
|
|
|
result = await pipeline.enhance_content(text, domain=DomainType.TECHNICAL)
|
|
|
|
assert isinstance(result, EnhancementResult)
|
|
assert result.original_text == text
|
|
assert result.domain == DomainType.TECHNICAL
|
|
assert result.confidence_score > 0
|
|
assert len(result.improvements) > 0
|
|
assert len(result.quality_metrics) > 0
|
|
assert result.processing_time > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_content_auto_detect_domain(self, pipeline, sample_texts):
|
|
"""Test content enhancement with automatic domain detection."""
|
|
text = sample_texts[DomainType.MEDICAL]
|
|
|
|
# Mock the enhancement service response
|
|
pipeline.enhancement_service.enhance_transcript.return_value = {
|
|
"enhanced_text": "**Patient** presents with symptoms of **hypertension** and requires treatment for **myocardial infarction**"
|
|
}
|
|
|
|
result = await pipeline.enhance_content(text)
|
|
|
|
assert isinstance(result, EnhancementResult)
|
|
assert result.domain in [DomainType.MEDICAL, DomainType.GENERAL] # May fall back to general
|
|
assert result.confidence_score > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_technical_content(self, pipeline, sample_texts):
|
|
"""Test technical content enhancement."""
|
|
text = sample_texts[DomainType.TECHNICAL]
|
|
|
|
# Mock the enhancement service response
|
|
pipeline.enhancement_service.enhance_transcript.return_value = {
|
|
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety in the **software system**"
|
|
}
|
|
|
|
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
|
|
|
|
assert enhanced_text != text
|
|
assert len(improvements) > 0
|
|
assert "Applied technical formatting standards" in improvements
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_medical_content(self, pipeline, sample_texts):
|
|
"""Test medical content enhancement."""
|
|
text = sample_texts[DomainType.MEDICAL]
|
|
|
|
# Mock the enhancement service response
|
|
pipeline.enhancement_service.enhance_transcript.return_value = {
|
|
"enhanced_text": "**Patient** presents with symptoms of **hypertension** and requires treatment for **myocardial infarction**"
|
|
}
|
|
|
|
config = DomainEnhancementConfig(domain=DomainType.MEDICAL)
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_medical_content(text, config)
|
|
|
|
assert enhanced_text != text
|
|
assert len(improvements) > 0
|
|
assert "Applied medical documentation standards" in improvements
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_academic_content(self, pipeline, sample_texts):
|
|
"""Test academic content enhancement."""
|
|
text = sample_texts[DomainType.ACADEMIC]
|
|
|
|
# Mock the enhancement service responses
|
|
pipeline.enhancement_service.enhance_transcript.side_effect = [
|
|
{"enhanced_text": "Research study analysis shows hypothesis testing methodology with literature review"},
|
|
{"enhanced_text": "**Research** **study** **analysis** shows **hypothesis** testing **methodology** with **literature** review"}
|
|
]
|
|
|
|
config = DomainEnhancementConfig(domain=DomainType.ACADEMIC)
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_academic_content(text, config)
|
|
|
|
assert enhanced_text != text
|
|
assert len(improvements) > 0
|
|
assert "Applied academic formatting standards" in improvements
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_legal_content(self, pipeline, sample_texts):
|
|
"""Test legal content enhancement."""
|
|
text = sample_texts[DomainType.LEGAL]
|
|
|
|
# Mock the enhancement service response
|
|
pipeline.enhancement_service.enhance_transcript.return_value = {
|
|
"enhanced_text": "**Contract** **agreement** compliance with **law** **regulation** and **legal** **jurisdiction**"
|
|
}
|
|
|
|
config = DomainEnhancementConfig(domain=DomainType.LEGAL)
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_legal_content(text, config)
|
|
|
|
assert enhanced_text != text
|
|
assert len(improvements) > 0
|
|
assert "Applied legal precision standards" in improvements
|
|
|
|
def test_optimize_technical_formatting(self, pipeline):
|
|
"""Test technical formatting optimization."""
|
|
text = "The code function method class uses file path C:\\temp\\file.txt and version v1.2.3"
|
|
|
|
enhanced = pipeline._optimize_technical_formatting(text)
|
|
|
|
# Check that technical terms are formatted
|
|
assert "`code`" in enhanced
|
|
assert "`function`" in enhanced
|
|
assert "`method`" in enhanced
|
|
assert "`class`" in enhanced
|
|
assert "`C:\\temp\\file.txt`" in enhanced
|
|
assert "**v1.2.3**" in enhanced
|
|
|
|
def test_apply_medical_formatting(self, pipeline):
|
|
"""Test medical formatting application."""
|
|
text = "Patient takes aspirin and ibuprofen with blood pressure 120/80 mmHg and heart rate 72 bpm"
|
|
|
|
enhanced = pipeline._apply_medical_formatting(text)
|
|
|
|
# Check that medical terms are formatted
|
|
assert "**aspirin**" in enhanced
|
|
assert "**ibuprofen**" in enhanced
|
|
assert "`120/80 mmHg`" in enhanced
|
|
assert "`72 bpm`" in enhanced
|
|
|
|
def test_apply_academic_formatting(self, pipeline):
|
|
"""Test academic formatting application."""
|
|
text = "Research shows et al. findings ibid. and op. cit. references with Figure 1 and Table 2"
|
|
|
|
enhanced = pipeline._apply_academic_formatting(text)
|
|
|
|
# Check that academic terms are formatted
|
|
assert "*et al.*" in enhanced
|
|
assert "*ibid.*" in enhanced
|
|
assert "*op. cit.*" in enhanced
|
|
assert "**Figure 1**" in enhanced
|
|
assert "**Table 2**" in enhanced
|
|
|
|
def test_optimize_legal_precision(self, pipeline):
|
|
"""Test legal precision optimization."""
|
|
text = "The contract shall must may hereby whereas therefore be executed"
|
|
|
|
enhanced = pipeline._optimize_legal_precision(text)
|
|
|
|
# Check that legal terms are emphasized
|
|
assert "**shall**" in enhanced
|
|
assert "**must**" in enhanced
|
|
assert "**may**" in enhanced
|
|
assert "**hereby**" in enhanced
|
|
assert "**whereas**" in enhanced
|
|
assert "**therefore**" in enhanced
|
|
|
|
def test_identify_technical_corrections(self, pipeline):
|
|
"""Test technical terminology correction identification."""
|
|
original = "The python free code uses my sequel database"
|
|
enhanced = "The Python 3 code uses MySQL database"
|
|
|
|
corrections = pipeline._identify_technical_corrections(original, enhanced)
|
|
|
|
assert len(corrections) > 0
|
|
assert any("python free" in corr and "Python 3" in corr for corr in corrections)
|
|
assert any("my sequel" in corr and "MySQL" in corr for corr in corrections)
|
|
|
|
def test_identify_medical_corrections(self, pipeline):
|
|
"""Test medical terminology correction identification."""
|
|
original = "Patient has hippa compliance issues and takes prozack"
|
|
enhanced = "Patient has HIPAA compliance issues and takes Prozac"
|
|
|
|
corrections = pipeline._identify_medical_corrections(original, enhanced)
|
|
|
|
assert len(corrections) > 0
|
|
assert any("hippa" in corr and "HIPAA" in corr for corr in corrections)
|
|
assert any("prozack" in corr and "Prozac" in corr for corr in corrections)
|
|
|
|
def test_identify_academic_corrections(self, pipeline):
|
|
"""Test academic terminology correction identification."""
|
|
original = "The research methodology hypothesis and literature review"
|
|
enhanced = "The **research** **methodology** **hypothesis** and **literature** review"
|
|
|
|
corrections = pipeline._identify_academic_corrections(original, enhanced)
|
|
|
|
# Note: This test may not find corrections if the original text already contains correct terms
|
|
# The identification depends on the specific correction patterns
|
|
assert isinstance(corrections, list)
|
|
|
|
def test_identify_legal_corrections(self, pipeline):
|
|
"""Test legal terminology correction identification."""
|
|
original = "The contract jurisdiction statute and compliance requirements"
|
|
enhanced = "The **contract** **jurisdiction** **statute** and **compliance** requirements"
|
|
|
|
corrections = pipeline._identify_legal_corrections(original, enhanced)
|
|
|
|
# Note: This test may not find corrections if the original text already contains correct terms
|
|
assert isinstance(corrections, list)
|
|
|
|
def test_calculate_technical_quality(self, pipeline):
|
|
"""Test technical content quality calculation."""
|
|
enhanced_text = "The `algorithm` implements a **v1.2.3** system with `code` and `function`"
|
|
original_text = "The algorithm implements a v1.2.3 system with code and function"
|
|
|
|
metrics = pipeline._calculate_technical_quality(enhanced_text, original_text)
|
|
|
|
assert 'technical_term_density' in metrics
|
|
assert 'code_reference_accuracy' in metrics
|
|
assert 'technical_precision' in metrics
|
|
assert all(0 <= value <= 1 for value in metrics.values())
|
|
|
|
def test_calculate_medical_quality(self, pipeline):
|
|
"""Test medical content quality calculation."""
|
|
enhanced_text = "**Patient** has **diagnosis** with `120/80 mmHg` and **treatment**"
|
|
original_text = "Patient has diagnosis with 120/80 mmHg and treatment"
|
|
|
|
metrics = pipeline._calculate_medical_quality(enhanced_text, original_text)
|
|
|
|
assert 'medical_terminology_accuracy' in metrics
|
|
assert 'formatting_compliance' in metrics
|
|
assert 'medical_precision' in metrics
|
|
assert all(0 <= value <= 1 for value in metrics.values())
|
|
|
|
def test_calculate_academic_quality(self, pipeline):
|
|
"""Test academic content quality calculation."""
|
|
enhanced_text = "**Research** *et al.* shows **hypothesis** and **Figure 1**"
|
|
original_text = "Research et al. shows hypothesis and Figure 1"
|
|
|
|
metrics = pipeline._calculate_academic_quality(enhanced_text, original_text)
|
|
|
|
assert 'citation_handling' in metrics
|
|
assert 'academic_terminology' in metrics
|
|
assert 'academic_quality' in metrics
|
|
assert all(0 <= value <= 1 for value in metrics.values())
|
|
|
|
def test_calculate_legal_quality(self, pipeline):
|
|
"""Test legal content quality calculation."""
|
|
enhanced_text = "**Contract** **agreement** with `reference` and **legal** terms"
|
|
original_text = "Contract agreement with reference and legal terms"
|
|
|
|
metrics = pipeline._calculate_legal_quality(enhanced_text, original_text)
|
|
|
|
assert 'legal_terminology_precision' in metrics
|
|
assert 'legal_formatting' in metrics
|
|
assert 'legal_quality' in metrics
|
|
assert all(0 <= value <= 1 for value in metrics.values())
|
|
|
|
def test_calculate_general_quality(self, pipeline):
|
|
"""Test general content quality calculation."""
|
|
enhanced_text = "This is a general conversation. It has proper punctuation!"
|
|
original_text = "This is a general conversation It has proper punctuation"
|
|
|
|
metrics = pipeline._calculate_general_quality(enhanced_text, original_text)
|
|
|
|
assert 'length_ratio' in metrics
|
|
assert 'punctuation_improvement' in metrics
|
|
assert 'general_quality' in metrics
|
|
assert all(0 <= value <= 1 for value in metrics.values())
|
|
|
|
def test_calculate_confidence_score(self, pipeline):
|
|
"""Test confidence score calculation."""
|
|
quality_metrics = {
|
|
'technical_precision': 0.8,
|
|
'medical_precision': 0.9,
|
|
'academic_quality': 0.7,
|
|
'legal_quality': 0.85,
|
|
'general_quality': 0.75
|
|
}
|
|
|
|
confidence = pipeline._calculate_confidence_score(quality_metrics)
|
|
|
|
assert 0 <= confidence <= 1
|
|
assert confidence > 0.7 # Should be high with good metrics
|
|
|
|
def test_calculate_confidence_score_empty_metrics(self, pipeline):
|
|
"""Test confidence score calculation with empty metrics."""
|
|
confidence = pipeline._calculate_confidence_score({})
|
|
|
|
assert confidence == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhancement_service_failure_handling(self, pipeline, sample_texts):
|
|
"""Test handling of enhancement service failures."""
|
|
text = sample_texts[DomainType.TECHNICAL]
|
|
|
|
# Mock enhancement service to raise an exception
|
|
pipeline.enhancement_service.enhance_transcript.side_effect = Exception("Service unavailable")
|
|
|
|
config = DomainEnhancementConfig(domain=DomainType.TECHNICAL)
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
|
|
|
|
# Should fall back to original text for terminology enhancement
|
|
# But formatting optimization may still be applied
|
|
assert len(corrections) == 0 # No terminology corrections
|
|
# Note: Formatting may still be applied even if enhancement service fails
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_specific_configuration(self, pipeline, sample_texts):
|
|
"""Test domain-specific configuration options."""
|
|
text = sample_texts[DomainType.TECHNICAL]
|
|
|
|
# Create config with disabled terminology enhancement
|
|
config = DomainEnhancementConfig(
|
|
domain=DomainType.TECHNICAL,
|
|
enable_terminology_enhancement=False,
|
|
enable_formatting_optimization=True
|
|
)
|
|
|
|
enhanced_text, improvements, corrections = await pipeline._enhance_technical_content(text, config)
|
|
|
|
# Should skip terminology enhancement but apply formatting
|
|
assert "Applied technical formatting standards" in improvements
|
|
assert len(corrections) == 0 # No terminology corrections
|
|
|
|
def test_enhancement_result_structure(self):
|
|
"""Test EnhancementResult data structure."""
|
|
result = EnhancementResult(
|
|
original_text="Original text",
|
|
enhanced_text="Enhanced text",
|
|
domain=DomainType.TECHNICAL,
|
|
confidence_score=0.85,
|
|
improvements=["Improved formatting"],
|
|
terminology_corrections=["Corrected term"],
|
|
quality_metrics={"technical_precision": 0.8},
|
|
processing_time=1.5
|
|
)
|
|
|
|
assert result.original_text == "Original text"
|
|
assert result.enhanced_text == "Enhanced text"
|
|
assert result.domain == DomainType.TECHNICAL
|
|
assert result.confidence_score == 0.85
|
|
assert len(result.improvements) == 1
|
|
assert len(result.terminology_corrections) == 1
|
|
assert len(result.quality_metrics) == 1
|
|
assert result.processing_time == 1.5
|
|
|
|
|
|
class TestDomainEnhancementIntegration:
|
|
"""Test integration of domain enhancement with the pipeline."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_technical_enhancement(self):
|
|
"""Test end-to-end technical content enhancement."""
|
|
from src.services.domain_enhancement import DomainEnhancementPipeline
|
|
|
|
# Create pipeline with mock service
|
|
mock_service = Mock()
|
|
mock_service.enhance_transcript = AsyncMock(return_value={
|
|
"enhanced_text": "The **algorithm** implements a `singleton pattern` for thread safety"
|
|
})
|
|
|
|
pipeline = DomainEnhancementPipeline(enhancement_service=mock_service)
|
|
|
|
text = "The algorithm implements a singleton pattern for thread safety"
|
|
result = await pipeline.enhance_content(text, domain=DomainType.TECHNICAL)
|
|
|
|
assert result.domain == DomainType.TECHNICAL
|
|
assert result.confidence_score > 0
|
|
assert len(result.improvements) > 0
|
|
assert "Applied technical formatting standards" in result.improvements
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_switching(self):
|
|
"""Test switching between different domains."""
|
|
from src.services.domain_enhancement import DomainEnhancementPipeline
|
|
|
|
mock_service = Mock()
|
|
mock_service.enhance_transcript = AsyncMock(return_value={
|
|
"enhanced_text": "Enhanced content"
|
|
})
|
|
|
|
pipeline = DomainEnhancementPipeline(enhancement_service=mock_service)
|
|
|
|
# Test different domains
|
|
domains = [DomainType.TECHNICAL, DomainType.MEDICAL, DomainType.ACADEMIC]
|
|
|
|
for domain in domains:
|
|
result = await pipeline.enhance_content("Test content", domain=domain)
|
|
assert result.domain == domain
|
|
# Confidence score may be 0 if no domain-specific terms are detected
|
|
# This is expected behavior for generic content
|
|
assert result.confidence_score >= 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|