"""Unit tests for Domain Adaptation System with LoRA Adapters. Tests the domain adaptation system including LoRA adapters, domain detection, and integration with the ModelManager. """ import pytest import tempfile import shutil from pathlib import Path from unittest.mock import Mock, patch, MagicMock from typing import Dict, List, Any import torch import numpy as np from transformers import WhisperForConditionalGeneration from src.services.domain_adaptation import ( DomainAdapter, DomainDetector ) from src.services.domain_adaptation_manager import DomainAdaptationManager class TestDomainAdapter: """Test cases for LoRA adapter architecture.""" @pytest.fixture def mock_base_model(self): """Create a mock base model for testing.""" model = Mock(spec=WhisperForConditionalGeneration) model.config = Mock() model.config.hidden_size = 768 return model @pytest.fixture def domain_adapter(self, mock_base_model): """Create a DomainAdapter instance for testing.""" with patch('src.services.domain_adaptation.WhisperForConditionalGeneration.from_pretrained', return_value=mock_base_model): return DomainAdapter(base_model_id="openai/whisper-large-v2") def test_domain_adapter_initialization(self, domain_adapter): """Test DomainAdapter initialization.""" assert domain_adapter.base_model is not None assert isinstance(domain_adapter.domain_adapters, dict) assert len(domain_adapter.domain_adapters) == 0 def test_create_adapter(self, domain_adapter): """Test creating a new LoRA adapter.""" with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft: mock_adapter = Mock() mock_get_peft.return_value = mock_adapter adapter = domain_adapter.create_adapter("technical") assert adapter == mock_adapter assert "technical" in domain_adapter.domain_adapters mock_get_peft.assert_called_once() def test_load_adapter(self, domain_adapter): """Test loading a pre-trained adapter.""" with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft: mock_adapter = Mock() mock_get_peft.return_value = mock_adapter # Test loading non-existent adapter - should raise FileNotFoundError with pytest.raises(FileNotFoundError, match="Adapter path not found: /path/to/adapter"): domain_adapter.load_adapter("medical", "/path/to/adapter") def test_switch_adapter_existing(self, domain_adapter): """Test switching to an existing adapter.""" mock_adapter = Mock() domain_adapter.domain_adapters["technical"] = mock_adapter result = domain_adapter.switch_adapter("technical") assert result == mock_adapter def test_switch_adapter_not_found(self, domain_adapter): """Test switching to non-existent adapter raises error.""" with pytest.raises(ValueError, match="Domain adapter 'unknown' not found"): domain_adapter.switch_adapter("unknown") class TestDomainDetector: """Test cases for domain detection system.""" @pytest.fixture def domain_detector(self): """Create a DomainDetector instance for testing.""" return DomainDetector() @pytest.fixture def sample_training_data(self): """Create sample training data for domain detection.""" texts = [ "The API endpoint returns a JSON response with status code 200", "Patient shows symptoms of acute myocardial infarction", "The research methodology follows a quantitative approach", "Hello world, how are you today?", "Implement the singleton pattern for thread safety", "Administer 500mg of acetaminophen every 6 hours", "The study population consisted of 100 participants", "This is a general conversation about the weather" ] labels = ["technical", "medical", "academic", "general", "technical", "medical", "academic", "general"] return texts, labels def test_domain_detector_initialization(self, domain_detector): """Test DomainDetector initialization.""" assert domain_detector.vectorizer is not None assert domain_detector.classifier is not None assert "general" in domain_detector.domains assert "technical" in domain_detector.domains assert "medical" in domain_detector.domains assert "academic" in domain_detector.domains def test_train_domain_detector(self, domain_detector, sample_training_data): """Test training the domain detector.""" texts, labels = sample_training_data # Should not raise any exceptions domain_detector.train(texts, labels) # Verify vectorizer was fitted assert hasattr(domain_detector.vectorizer, 'vocabulary_') def test_detect_domain_high_confidence(self, domain_detector, sample_training_data): """Test domain detection with high confidence.""" texts, labels = sample_training_data domain_detector.train(texts, labels) # Test technical domain result = domain_detector.detect_domain("API endpoint configuration", threshold=0.6) assert result in domain_detector.domains def test_detect_domain_low_confidence(self, domain_detector, sample_training_data): """Test domain detection with low confidence returns general.""" texts, labels = sample_training_data domain_detector.train(texts, labels) # Test with ambiguous text result = domain_detector.detect_domain("random ambiguous text", threshold=0.9) assert result == "general" def test_detect_domain_empty_text(self, domain_detector, sample_training_data): """Test domain detection with empty text.""" texts, labels = sample_training_data domain_detector.train(texts, labels) result = domain_detector.detect_domain("", threshold=0.6) assert result == "general" class TestDomainAdaptationManager: """Test cases for DomainAdaptationManager integration.""" @pytest.fixture def mock_model_manager(self): """Create a mock ModelManager.""" manager = Mock() manager.get_base_model.return_value = Mock() return manager @pytest.fixture def domain_adaptation_manager(self, mock_model_manager): """Create a DomainAdaptationManager instance for testing.""" with patch('src.services.domain_adaptation_manager.ModelManager', return_value=mock_model_manager): return DomainAdaptationManager() def test_domain_adaptation_manager_initialization(self, domain_adaptation_manager): """Test DomainAdaptationManager initialization.""" assert domain_adaptation_manager.model_manager is not None assert domain_adaptation_manager.domain_adapter is not None assert domain_adaptation_manager.domain_detector is not None def test_load_default_adapters(self, domain_adaptation_manager): """Test loading default domain adapters.""" # Since the default adapters don't exist, this should just log info messages # The method should not raise any exceptions domain_adaptation_manager._load_default_adapters() # Test passes if no exception is raised def test_transcribe_with_domain_adaptation_auto_detect(self, domain_adaptation_manager): """Test transcription with automatic domain detection.""" mock_audio = Mock() mock_transcription = "API endpoint configuration for microservices" # Mock the model manager transcription domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription # Mock domain detection # Add the technical adapter to the domain_adapters dict so switch_adapter gets called domain_adaptation_manager.domain_adapter.domain_adapters["technical"] = Mock() with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="technical"): # Mock adapter switching with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch: mock_adapter = Mock() mock_switch.return_value = mock_adapter result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio) # Should return enhanced transcription with domain prefix assert result == "[TECHNICAL] API endpoint configuration for microservices" domain_adaptation_manager.model_manager.transcribe.assert_called_once_with(mock_audio) mock_switch.assert_called_once_with("technical") def test_transcribe_with_domain_adaptation_specified_domain(self, domain_adaptation_manager): """Test transcription with specified domain.""" mock_audio = Mock() mock_transcription = "Medical transcription" # Mock the model manager transcription domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription # Add the medical adapter to the domain_adapters dict so switch_adapter gets called domain_adaptation_manager.domain_adapter.domain_adapters["medical"] = Mock() # Mock adapter switching with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch: mock_adapter = Mock() mock_switch.return_value = mock_adapter result = domain_adaptation_manager.transcribe_with_domain_adaptation( mock_audio, auto_detect=False, domain="medical" ) assert result == "[MEDICAL] Medical transcription" mock_switch.assert_called_once_with("medical") def test_transcribe_with_domain_adaptation_general_domain(self, domain_adaptation_manager): """Test transcription with general domain (no adaptation).""" mock_audio = Mock() mock_transcription = "General conversation" domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="general"): result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio) assert result == mock_transcription # Should not call switch_adapter for general domain def test_train_custom_domain(self, domain_adaptation_manager): """Test training a custom domain adapter.""" training_data = Mock() # Create the adapter first so it exists in the domain_adapters dict mock_adapter = Mock() domain_adaptation_manager.domain_adapter.domain_adapters["legal"] = mock_adapter with patch.object(domain_adaptation_manager, '_setup_trainer') as mock_setup: mock_trainer = Mock() mock_setup.return_value = mock_trainer domain_adaptation_manager.train_custom_domain("legal", training_data) mock_setup.assert_called_once() mock_trainer.train.assert_called_once_with(training_data) def test_setup_trainer(self, domain_adaptation_manager): """Test trainer setup for adapter fine-tuning.""" mock_model = Mock() with patch('src.services.domain_adaptation_manager.Seq2SeqTrainer') as mock_trainer_class: with patch('src.services.domain_adaptation_manager.Seq2SeqTrainingArguments') as mock_args_class: mock_args = Mock() mock_args_class.return_value = mock_args mock_trainer = Mock() mock_trainer_class.return_value = mock_trainer result = domain_adaptation_manager._setup_trainer(mock_model, "test_output_dir") assert result == mock_trainer mock_args_class.assert_called_once() mock_trainer_class.assert_called_once_with( model=mock_model, args=mock_args ) class TestDomainAdaptationIntegration: """Integration tests for domain adaptation system.""" @pytest.fixture def temp_adapters_dir(self): """Create temporary directory for adapter storage.""" temp_dir = tempfile.mkdtemp() yield Path(temp_dir) shutil.rmtree(temp_dir) def test_end_to_end_domain_adaptation(self, temp_adapters_dir): """Test end-to-end domain adaptation workflow.""" # This test would require actual model loading and training # For now, we'll test the integration points with patch('src.services.domain_adaptation_manager.ModelManager') as mock_model_manager_class: mock_model_manager = Mock() mock_model_manager_class.return_value = mock_model_manager manager = DomainAdaptationManager() # Verify all components are properly initialized assert manager.model_manager is not None assert manager.domain_adapter is not None assert manager.domain_detector is not None def test_memory_optimization_integration(self, temp_adapters_dir): """Test memory optimization features.""" # This would test adapter swapping and memory management # Implementation depends on the memory optimization features pass def test_performance_optimization_integration(self, temp_adapters_dir): """Test performance optimization features.""" # This would test caching and batched inference # Implementation depends on the performance optimization features pass if __name__ == "__main__": pytest.main([__file__])