461 lines
25 KiB
Python
461 lines
25 KiB
Python
"""Test LoRA adapter integration with transcription workflow.
|
|
|
|
Tests for Task #8.1: Connect LoRA Adapters to Transcription Workflow
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.domain_adaptation_manager import DomainAdaptationManager
|
|
from src.services.model_manager import ModelManager
|
|
|
|
|
|
class TestLoRAIntegration:
|
|
"""Test LoRA adapter integration with transcription workflow."""
|
|
|
|
@pytest.fixture
|
|
def mock_model_manager(self):
|
|
"""Create a mock ModelManager."""
|
|
mock_manager = Mock(spec=ModelManager)
|
|
mock_manager.load_model.return_value = Mock()
|
|
return mock_manager
|
|
|
|
@pytest.fixture
|
|
def mock_domain_adaptation_manager(self):
|
|
"""Create a mock DomainAdaptationManager."""
|
|
mock_manager = Mock(spec=DomainAdaptationManager)
|
|
mock_domain_adapter = Mock()
|
|
mock_domain_adapter.domain_adapters = {"technical": Mock(), "medical": Mock()}
|
|
mock_manager.domain_adapter = mock_domain_adapter
|
|
return mock_manager
|
|
|
|
@pytest.fixture
|
|
def pipeline_with_lora(self, mock_model_manager, mock_domain_adaptation_manager):
|
|
"""Create pipeline with LoRA integration."""
|
|
pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager)
|
|
pipeline.domain_adaptation_manager = mock_domain_adaptation_manager
|
|
pipeline.auto_detect_domain = True
|
|
return pipeline
|
|
|
|
@pytest.fixture
|
|
def sample_audio_path(self, tmp_path):
|
|
"""Create a sample audio file path."""
|
|
audio_file = tmp_path / "test_audio.wav"
|
|
audio_file.write_bytes(b"fake_audio_data")
|
|
return audio_file
|
|
|
|
def test_pipeline_initialization_with_lora(self, mock_model_manager):
|
|
"""Test pipeline initialization with LoRA support."""
|
|
pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager)
|
|
assert hasattr(pipeline, 'domain_adaptation_manager')
|
|
assert hasattr(pipeline, 'auto_detect_domain')
|
|
assert hasattr(pipeline, 'domain_detector')
|
|
|
|
def test_domain_detection_integration(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test domain detection integration in transcription pipeline."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "algorithm implementation", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement:
|
|
mock_enhancement.return_value = [{"text": "[TECHNICAL] algorithm implementation", "start": 0.0, "end": 5.0}]
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="technical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
assert result["transcript"][0]["text"].startswith("[TECHNICAL]")
|
|
|
|
def test_domain_specific_model_switching(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test domain-specific model switching during transcription."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=False
|
|
)
|
|
|
|
pipeline_with_lora.domain_detector.detect_domain.assert_called()
|
|
|
|
def test_domain_adapter_application(self, pipeline_with_lora):
|
|
"""Test direct domain adapter application."""
|
|
pipeline_with_lora.domain_adaptation_manager = None
|
|
result = pipeline_with_lora._apply_domain_adapter("technical")
|
|
assert result is False
|
|
|
|
pipeline_with_lora.domain_adaptation_manager = Mock()
|
|
pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {"technical": Mock()}
|
|
pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock()
|
|
|
|
result = pipeline_with_lora._apply_domain_adapter("technical")
|
|
assert result is True
|
|
|
|
def test_automatic_domain_detection(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test automatic domain detection without explicit domain specification."""
|
|
pipeline_with_lora.auto_detect_domain = True
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "academic"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=False
|
|
)
|
|
|
|
pipeline_with_lora.domain_detector.detect_domain.assert_called()
|
|
|
|
def test_lora_adapter_loading_during_enhancement(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test LoRA adapter loading during enhancement pass."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="medical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
assert "transcript" in result
|
|
assert len(result["transcript"]) > 0
|
|
|
|
def test_fallback_to_general_domain(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test fallback to general domain when LoRA adapter is not available."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "unknown"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "general content", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "general content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=False
|
|
)
|
|
|
|
assert "transcript" in result
|
|
assert not result["transcript"][0]["text"].startswith("[")
|
|
|
|
def test_lora_adapter_caching(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test LoRA adapter caching for repeated domain usage."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
# First transcription
|
|
result1 = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="technical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Second transcription (should use cached adapter)
|
|
result2 = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="technical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Verify both results are consistent
|
|
assert result1["transcript"][0]["text"] == result2["transcript"][0]["text"]
|
|
|
|
def test_error_handling_with_lora_failure(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test error handling when LoRA adapter fails to load."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
# Mock enhancement pass failure - but handle it gracefully
|
|
with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement:
|
|
mock_enhancement.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}]
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="medical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Verify fallback behavior
|
|
assert "transcript" in result
|
|
assert len(result["transcript"]) > 0
|
|
|
|
def test_performance_with_lora_adapters(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test performance impact of LoRA adapters."""
|
|
import time
|
|
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
start_time = time.time()
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="technical",
|
|
speaker_diarization=False
|
|
)
|
|
processing_time = time.time() - start_time
|
|
|
|
# Verify performance is within acceptable limits
|
|
assert processing_time < 10.0 # Should complete within 10 seconds
|
|
assert result["processing_time"] < 10.0
|
|
|
|
def test_memory_management_with_lora(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test memory management during LoRA operations."""
|
|
# Mock memory monitoring
|
|
pipeline_with_lora.model_manager.get_memory_usage = Mock(return_value={
|
|
"rss_mb": 1500.0,
|
|
"cuda_allocated_mb": 800.0
|
|
})
|
|
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="medical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Verify memory monitoring capability is available
|
|
assert hasattr(pipeline_with_lora.model_manager, 'get_memory_usage')
|
|
|
|
def test_domain_detection_method(self, pipeline_with_lora):
|
|
"""Test domain detection method."""
|
|
# Test without domain detector
|
|
result = pipeline_with_lora._detect_domain(Path("dummy"))
|
|
assert result is None
|
|
|
|
# Test with mock domain detector
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
result = pipeline_with_lora._detect_domain(Path("dummy"), "algorithm implementation")
|
|
assert result == "technical"
|
|
|
|
def test_edge_case_empty_segments(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test edge case with empty transcription segments."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [] # Empty segments
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Should handle empty segments gracefully
|
|
assert "transcript" in result
|
|
assert len(result["transcript"]) == 0
|
|
|
|
def test_edge_case_very_long_text(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test edge case with very long transcription text."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "academic"
|
|
|
|
# Create very long text
|
|
long_text = "This is a very long academic text " * 1000
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": long_text, "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": long_text, "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Should handle long text gracefully
|
|
assert "transcript" in result
|
|
assert len(result["transcript"]) > 0
|
|
|
|
def test_integration_with_speaker_diarization(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test LoRA integration with speaker diarization enabled."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0, "confidence": 0.9}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
# Mock diarization manager
|
|
with patch('src.services.multi_pass_transcription.DiarizationManager') as mock_diarization_class:
|
|
mock_diarization = Mock()
|
|
mock_diarization.process_audio.return_value = Mock(segments=[{"start": 0.0, "end": 5.0, "speaker": "Speaker_1"}])
|
|
mock_diarization_class.return_value = mock_diarization
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
speaker_diarization=True
|
|
)
|
|
|
|
# Should include speaker information
|
|
assert "transcript" in result
|
|
assert len(result["transcript"]) > 0
|
|
assert "speaker" in result["transcript"][0]
|
|
|
|
def test_domain_adapter_switching_behavior(self, pipeline_with_lora):
|
|
"""Test domain adapter switching behavior."""
|
|
# Test switching between different domains
|
|
domains = ["technical", "medical", "academic", "general"]
|
|
|
|
for domain in domains:
|
|
if domain == "general":
|
|
# General domain should not require adapter switching
|
|
result = pipeline_with_lora._apply_domain_adapter(domain)
|
|
assert result is False
|
|
else:
|
|
# Other domains should attempt adapter switching
|
|
pipeline_with_lora.domain_adaptation_manager = Mock()
|
|
pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {domain: Mock()}
|
|
pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock()
|
|
|
|
result = pipeline_with_lora._apply_domain_adapter(domain)
|
|
assert result is True
|
|
|
|
def test_confidence_calculation_with_domain_context(self, pipeline_with_lora, sample_audio_path):
|
|
"""Test confidence calculation maintains domain context."""
|
|
pipeline_with_lora.domain_detector = Mock()
|
|
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
|
|
mock_first_pass.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0}]
|
|
|
|
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
|
|
mock_refinement.return_value = []
|
|
|
|
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
|
|
mock_confidence.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0, "confidence": 0.95}]
|
|
|
|
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
|
|
mock_low_conf.return_value = []
|
|
|
|
result = pipeline_with_lora.transcribe_with_parallel_processing(
|
|
sample_audio_path,
|
|
domain="medical",
|
|
speaker_diarization=False
|
|
)
|
|
|
|
# Verify confidence is maintained
|
|
assert "confidence_score" in result
|
|
assert result["confidence_score"] > 0.9
|