trax/tests/test_lora_integration.py

461 lines
25 KiB
Python

"""Test LoRA adapter integration with transcription workflow.
Tests for Task #8.1: Connect LoRA Adapters to Transcription Workflow
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.domain_adaptation_manager import DomainAdaptationManager
from src.services.model_manager import ModelManager
class TestLoRAIntegration:
"""Test LoRA adapter integration with transcription workflow."""
@pytest.fixture
def mock_model_manager(self):
"""Create a mock ModelManager."""
mock_manager = Mock(spec=ModelManager)
mock_manager.load_model.return_value = Mock()
return mock_manager
@pytest.fixture
def mock_domain_adaptation_manager(self):
"""Create a mock DomainAdaptationManager."""
mock_manager = Mock(spec=DomainAdaptationManager)
mock_domain_adapter = Mock()
mock_domain_adapter.domain_adapters = {"technical": Mock(), "medical": Mock()}
mock_manager.domain_adapter = mock_domain_adapter
return mock_manager
@pytest.fixture
def pipeline_with_lora(self, mock_model_manager, mock_domain_adaptation_manager):
"""Create pipeline with LoRA integration."""
pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager)
pipeline.domain_adaptation_manager = mock_domain_adaptation_manager
pipeline.auto_detect_domain = True
return pipeline
@pytest.fixture
def sample_audio_path(self, tmp_path):
"""Create a sample audio file path."""
audio_file = tmp_path / "test_audio.wav"
audio_file.write_bytes(b"fake_audio_data")
return audio_file
def test_pipeline_initialization_with_lora(self, mock_model_manager):
"""Test pipeline initialization with LoRA support."""
pipeline = MultiPassTranscriptionPipeline(model_manager=mock_model_manager)
assert hasattr(pipeline, 'domain_adaptation_manager')
assert hasattr(pipeline, 'auto_detect_domain')
assert hasattr(pipeline, 'domain_detector')
def test_domain_detection_integration(self, pipeline_with_lora, sample_audio_path):
"""Test domain detection integration in transcription pipeline."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "algorithm implementation", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement:
mock_enhancement.return_value = [{"text": "[TECHNICAL] algorithm implementation", "start": 0.0, "end": 5.0}]
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="technical",
speaker_diarization=False
)
assert result["transcript"][0]["text"].startswith("[TECHNICAL]")
def test_domain_specific_model_switching(self, pipeline_with_lora, sample_audio_path):
"""Test domain-specific model switching during transcription."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "software architecture", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=False
)
pipeline_with_lora.domain_detector.detect_domain.assert_called()
def test_domain_adapter_application(self, pipeline_with_lora):
"""Test direct domain adapter application."""
pipeline_with_lora.domain_adaptation_manager = None
result = pipeline_with_lora._apply_domain_adapter("technical")
assert result is False
pipeline_with_lora.domain_adaptation_manager = Mock()
pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {"technical": Mock()}
pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock()
result = pipeline_with_lora._apply_domain_adapter("technical")
assert result is True
def test_automatic_domain_detection(self, pipeline_with_lora, sample_audio_path):
"""Test automatic domain detection without explicit domain specification."""
pipeline_with_lora.auto_detect_domain = True
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "academic"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "research methodology", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=False
)
pipeline_with_lora.domain_detector.detect_domain.assert_called()
def test_lora_adapter_loading_during_enhancement(self, pipeline_with_lora, sample_audio_path):
"""Test LoRA adapter loading during enhancement pass."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "patient diagnosis", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="medical",
speaker_diarization=False
)
assert "transcript" in result
assert len(result["transcript"]) > 0
def test_fallback_to_general_domain(self, pipeline_with_lora, sample_audio_path):
"""Test fallback to general domain when LoRA adapter is not available."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "unknown"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "general content", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "general content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=False
)
assert "transcript" in result
assert not result["transcript"][0]["text"].startswith("[")
def test_lora_adapter_caching(self, pipeline_with_lora, sample_audio_path):
"""Test LoRA adapter caching for repeated domain usage."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
# First transcription
result1 = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="technical",
speaker_diarization=False
)
# Second transcription (should use cached adapter)
result2 = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="technical",
speaker_diarization=False
)
# Verify both results are consistent
assert result1["transcript"][0]["text"] == result2["transcript"][0]["text"]
def test_error_handling_with_lora_failure(self, pipeline_with_lora, sample_audio_path):
"""Test error handling when LoRA adapter fails to load."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
# Mock enhancement pass failure - but handle it gracefully
with patch.object(pipeline_with_lora, '_perform_enhancement_pass') as mock_enhancement:
mock_enhancement.return_value = [{"text": "medical content", "start": 0.0, "end": 5.0}]
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="medical",
speaker_diarization=False
)
# Verify fallback behavior
assert "transcript" in result
assert len(result["transcript"]) > 0
def test_performance_with_lora_adapters(self, pipeline_with_lora, sample_audio_path):
"""Test performance impact of LoRA adapters."""
import time
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "technical content", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
start_time = time.time()
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="technical",
speaker_diarization=False
)
processing_time = time.time() - start_time
# Verify performance is within acceptable limits
assert processing_time < 10.0 # Should complete within 10 seconds
assert result["processing_time"] < 10.0
def test_memory_management_with_lora(self, pipeline_with_lora, sample_audio_path):
"""Test memory management during LoRA operations."""
# Mock memory monitoring
pipeline_with_lora.model_manager.get_memory_usage = Mock(return_value={
"rss_mb": 1500.0,
"cuda_allocated_mb": 800.0
})
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "clinical assessment", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="medical",
speaker_diarization=False
)
# Verify memory monitoring capability is available
assert hasattr(pipeline_with_lora.model_manager, 'get_memory_usage')
def test_domain_detection_method(self, pipeline_with_lora):
"""Test domain detection method."""
# Test without domain detector
result = pipeline_with_lora._detect_domain(Path("dummy"))
assert result is None
# Test with mock domain detector
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
result = pipeline_with_lora._detect_domain(Path("dummy"), "algorithm implementation")
assert result == "technical"
def test_edge_case_empty_segments(self, pipeline_with_lora, sample_audio_path):
"""Test edge case with empty transcription segments."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [] # Empty segments
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = []
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=False
)
# Should handle empty segments gracefully
assert "transcript" in result
assert len(result["transcript"]) == 0
def test_edge_case_very_long_text(self, pipeline_with_lora, sample_audio_path):
"""Test edge case with very long transcription text."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "academic"
# Create very long text
long_text = "This is a very long academic text " * 1000
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": long_text, "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": long_text, "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=False
)
# Should handle long text gracefully
assert "transcript" in result
assert len(result["transcript"]) > 0
def test_integration_with_speaker_diarization(self, pipeline_with_lora, sample_audio_path):
"""Test LoRA integration with speaker diarization enabled."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "technical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "technical discussion", "start": 0.0, "end": 5.0, "confidence": 0.9}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
# Mock diarization manager
with patch('src.services.multi_pass_transcription.DiarizationManager') as mock_diarization_class:
mock_diarization = Mock()
mock_diarization.process_audio.return_value = Mock(segments=[{"start": 0.0, "end": 5.0, "speaker": "Speaker_1"}])
mock_diarization_class.return_value = mock_diarization
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
speaker_diarization=True
)
# Should include speaker information
assert "transcript" in result
assert len(result["transcript"]) > 0
assert "speaker" in result["transcript"][0]
def test_domain_adapter_switching_behavior(self, pipeline_with_lora):
"""Test domain adapter switching behavior."""
# Test switching between different domains
domains = ["technical", "medical", "academic", "general"]
for domain in domains:
if domain == "general":
# General domain should not require adapter switching
result = pipeline_with_lora._apply_domain_adapter(domain)
assert result is False
else:
# Other domains should attempt adapter switching
pipeline_with_lora.domain_adaptation_manager = Mock()
pipeline_with_lora.domain_adaptation_manager.domain_adapter.domain_adapters = {domain: Mock()}
pipeline_with_lora.domain_adaptation_manager.domain_adapter.switch_adapter.return_value = Mock()
result = pipeline_with_lora._apply_domain_adapter(domain)
assert result is True
def test_confidence_calculation_with_domain_context(self, pipeline_with_lora, sample_audio_path):
"""Test confidence calculation maintains domain context."""
pipeline_with_lora.domain_detector = Mock()
pipeline_with_lora.domain_detector.detect_domain.return_value = "medical"
with patch.object(pipeline_with_lora, '_perform_first_pass') as mock_first_pass:
mock_first_pass.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0}]
with patch.object(pipeline_with_lora, '_perform_refinement_pass') as mock_refinement:
mock_refinement.return_value = []
with patch.object(pipeline_with_lora, '_calculate_confidence') as mock_confidence:
mock_confidence.return_value = [{"text": "medical terminology", "start": 0.0, "end": 5.0, "confidence": 0.95}]
with patch.object(pipeline_with_lora, '_identify_low_confidence_segments') as mock_low_conf:
mock_low_conf.return_value = []
result = pipeline_with_lora.transcribe_with_parallel_processing(
sample_audio_path,
domain="medical",
speaker_diarization=False
)
# Verify confidence is maintained
assert "confidence_score" in result
assert result["confidence_score"] > 0.9