323 lines
14 KiB
Python
323 lines
14 KiB
Python
"""Unit tests for Domain Adaptation System with LoRA Adapters.
|
|
|
|
Tests the domain adaptation system including LoRA adapters, domain detection,
|
|
and integration with the ModelManager.
|
|
"""
|
|
|
|
import pytest
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from typing import Dict, List, Any
|
|
|
|
import torch
|
|
import numpy as np
|
|
from transformers import WhisperForConditionalGeneration
|
|
|
|
from src.services.domain_adaptation import (
|
|
DomainAdapter,
|
|
DomainDetector
|
|
)
|
|
from src.services.domain_adaptation_manager import DomainAdaptationManager
|
|
|
|
|
|
class TestDomainAdapter:
|
|
"""Test cases for LoRA adapter architecture."""
|
|
|
|
@pytest.fixture
|
|
def mock_base_model(self):
|
|
"""Create a mock base model for testing."""
|
|
model = Mock(spec=WhisperForConditionalGeneration)
|
|
model.config = Mock()
|
|
model.config.hidden_size = 768
|
|
return model
|
|
|
|
@pytest.fixture
|
|
def domain_adapter(self, mock_base_model):
|
|
"""Create a DomainAdapter instance for testing."""
|
|
with patch('src.services.domain_adaptation.WhisperForConditionalGeneration.from_pretrained', return_value=mock_base_model):
|
|
return DomainAdapter(base_model_id="openai/whisper-large-v2")
|
|
|
|
def test_domain_adapter_initialization(self, domain_adapter):
|
|
"""Test DomainAdapter initialization."""
|
|
assert domain_adapter.base_model is not None
|
|
assert isinstance(domain_adapter.domain_adapters, dict)
|
|
assert len(domain_adapter.domain_adapters) == 0
|
|
|
|
def test_create_adapter(self, domain_adapter):
|
|
"""Test creating a new LoRA adapter."""
|
|
with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft:
|
|
mock_adapter = Mock()
|
|
mock_get_peft.return_value = mock_adapter
|
|
|
|
adapter = domain_adapter.create_adapter("technical")
|
|
|
|
assert adapter == mock_adapter
|
|
assert "technical" in domain_adapter.domain_adapters
|
|
mock_get_peft.assert_called_once()
|
|
|
|
def test_load_adapter(self, domain_adapter):
|
|
"""Test loading a pre-trained adapter."""
|
|
with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft:
|
|
mock_adapter = Mock()
|
|
mock_get_peft.return_value = mock_adapter
|
|
|
|
# Test loading non-existent adapter - should raise FileNotFoundError
|
|
with pytest.raises(FileNotFoundError, match="Adapter path not found: /path/to/adapter"):
|
|
domain_adapter.load_adapter("medical", "/path/to/adapter")
|
|
|
|
def test_switch_adapter_existing(self, domain_adapter):
|
|
"""Test switching to an existing adapter."""
|
|
mock_adapter = Mock()
|
|
domain_adapter.domain_adapters["technical"] = mock_adapter
|
|
|
|
result = domain_adapter.switch_adapter("technical")
|
|
assert result == mock_adapter
|
|
|
|
def test_switch_adapter_not_found(self, domain_adapter):
|
|
"""Test switching to non-existent adapter raises error."""
|
|
with pytest.raises(ValueError, match="Domain adapter 'unknown' not found"):
|
|
domain_adapter.switch_adapter("unknown")
|
|
|
|
|
|
class TestDomainDetector:
|
|
"""Test cases for domain detection system."""
|
|
|
|
@pytest.fixture
|
|
def domain_detector(self):
|
|
"""Create a DomainDetector instance for testing."""
|
|
return DomainDetector()
|
|
|
|
@pytest.fixture
|
|
def sample_training_data(self):
|
|
"""Create sample training data for domain detection."""
|
|
texts = [
|
|
"The API endpoint returns a JSON response with status code 200",
|
|
"Patient shows symptoms of acute myocardial infarction",
|
|
"The research methodology follows a quantitative approach",
|
|
"Hello world, how are you today?",
|
|
"Implement the singleton pattern for thread safety",
|
|
"Administer 500mg of acetaminophen every 6 hours",
|
|
"The study population consisted of 100 participants",
|
|
"This is a general conversation about the weather"
|
|
]
|
|
labels = ["technical", "medical", "academic", "general",
|
|
"technical", "medical", "academic", "general"]
|
|
return texts, labels
|
|
|
|
def test_domain_detector_initialization(self, domain_detector):
|
|
"""Test DomainDetector initialization."""
|
|
assert domain_detector.vectorizer is not None
|
|
assert domain_detector.classifier is not None
|
|
assert "general" in domain_detector.domains
|
|
assert "technical" in domain_detector.domains
|
|
assert "medical" in domain_detector.domains
|
|
assert "academic" in domain_detector.domains
|
|
|
|
def test_train_domain_detector(self, domain_detector, sample_training_data):
|
|
"""Test training the domain detector."""
|
|
texts, labels = sample_training_data
|
|
|
|
# Should not raise any exceptions
|
|
domain_detector.train(texts, labels)
|
|
|
|
# Verify vectorizer was fitted
|
|
assert hasattr(domain_detector.vectorizer, 'vocabulary_')
|
|
|
|
def test_detect_domain_high_confidence(self, domain_detector, sample_training_data):
|
|
"""Test domain detection with high confidence."""
|
|
texts, labels = sample_training_data
|
|
domain_detector.train(texts, labels)
|
|
|
|
# Test technical domain
|
|
result = domain_detector.detect_domain("API endpoint configuration", threshold=0.6)
|
|
assert result in domain_detector.domains
|
|
|
|
def test_detect_domain_low_confidence(self, domain_detector, sample_training_data):
|
|
"""Test domain detection with low confidence returns general."""
|
|
texts, labels = sample_training_data
|
|
domain_detector.train(texts, labels)
|
|
|
|
# Test with ambiguous text
|
|
result = domain_detector.detect_domain("random ambiguous text", threshold=0.9)
|
|
assert result == "general"
|
|
|
|
def test_detect_domain_empty_text(self, domain_detector, sample_training_data):
|
|
"""Test domain detection with empty text."""
|
|
texts, labels = sample_training_data
|
|
domain_detector.train(texts, labels)
|
|
|
|
result = domain_detector.detect_domain("", threshold=0.6)
|
|
assert result == "general"
|
|
|
|
|
|
class TestDomainAdaptationManager:
|
|
"""Test cases for DomainAdaptationManager integration."""
|
|
|
|
@pytest.fixture
|
|
def mock_model_manager(self):
|
|
"""Create a mock ModelManager."""
|
|
manager = Mock()
|
|
manager.get_base_model.return_value = Mock()
|
|
return manager
|
|
|
|
@pytest.fixture
|
|
def domain_adaptation_manager(self, mock_model_manager):
|
|
"""Create a DomainAdaptationManager instance for testing."""
|
|
with patch('src.services.domain_adaptation_manager.ModelManager', return_value=mock_model_manager):
|
|
return DomainAdaptationManager()
|
|
|
|
def test_domain_adaptation_manager_initialization(self, domain_adaptation_manager):
|
|
"""Test DomainAdaptationManager initialization."""
|
|
assert domain_adaptation_manager.model_manager is not None
|
|
assert domain_adaptation_manager.domain_adapter is not None
|
|
assert domain_adaptation_manager.domain_detector is not None
|
|
|
|
def test_load_default_adapters(self, domain_adaptation_manager):
|
|
"""Test loading default domain adapters."""
|
|
# Since the default adapters don't exist, this should just log info messages
|
|
# The method should not raise any exceptions
|
|
domain_adaptation_manager._load_default_adapters()
|
|
# Test passes if no exception is raised
|
|
|
|
def test_transcribe_with_domain_adaptation_auto_detect(self, domain_adaptation_manager):
|
|
"""Test transcription with automatic domain detection."""
|
|
mock_audio = Mock()
|
|
mock_transcription = "API endpoint configuration for microservices"
|
|
|
|
# Mock the model manager transcription
|
|
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
|
|
|
|
# Mock domain detection
|
|
# Add the technical adapter to the domain_adapters dict so switch_adapter gets called
|
|
domain_adaptation_manager.domain_adapter.domain_adapters["technical"] = Mock()
|
|
|
|
with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="technical"):
|
|
# Mock adapter switching
|
|
with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch:
|
|
mock_adapter = Mock()
|
|
mock_switch.return_value = mock_adapter
|
|
|
|
result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio)
|
|
|
|
# Should return enhanced transcription with domain prefix
|
|
assert result == "[TECHNICAL] API endpoint configuration for microservices"
|
|
domain_adaptation_manager.model_manager.transcribe.assert_called_once_with(mock_audio)
|
|
mock_switch.assert_called_once_with("technical")
|
|
|
|
def test_transcribe_with_domain_adaptation_specified_domain(self, domain_adaptation_manager):
|
|
"""Test transcription with specified domain."""
|
|
mock_audio = Mock()
|
|
mock_transcription = "Medical transcription"
|
|
|
|
# Mock the model manager transcription
|
|
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
|
|
|
|
# Add the medical adapter to the domain_adapters dict so switch_adapter gets called
|
|
domain_adaptation_manager.domain_adapter.domain_adapters["medical"] = Mock()
|
|
|
|
# Mock adapter switching
|
|
with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch:
|
|
mock_adapter = Mock()
|
|
mock_switch.return_value = mock_adapter
|
|
|
|
result = domain_adaptation_manager.transcribe_with_domain_adaptation(
|
|
mock_audio, auto_detect=False, domain="medical"
|
|
)
|
|
|
|
assert result == "[MEDICAL] Medical transcription"
|
|
mock_switch.assert_called_once_with("medical")
|
|
|
|
def test_transcribe_with_domain_adaptation_general_domain(self, domain_adaptation_manager):
|
|
"""Test transcription with general domain (no adaptation)."""
|
|
mock_audio = Mock()
|
|
mock_transcription = "General conversation"
|
|
|
|
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
|
|
|
|
with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="general"):
|
|
result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio)
|
|
|
|
assert result == mock_transcription
|
|
# Should not call switch_adapter for general domain
|
|
|
|
def test_train_custom_domain(self, domain_adaptation_manager):
|
|
"""Test training a custom domain adapter."""
|
|
training_data = Mock()
|
|
|
|
# Create the adapter first so it exists in the domain_adapters dict
|
|
mock_adapter = Mock()
|
|
domain_adaptation_manager.domain_adapter.domain_adapters["legal"] = mock_adapter
|
|
|
|
with patch.object(domain_adaptation_manager, '_setup_trainer') as mock_setup:
|
|
mock_trainer = Mock()
|
|
mock_setup.return_value = mock_trainer
|
|
|
|
domain_adaptation_manager.train_custom_domain("legal", training_data)
|
|
|
|
mock_setup.assert_called_once()
|
|
mock_trainer.train.assert_called_once_with(training_data)
|
|
|
|
def test_setup_trainer(self, domain_adaptation_manager):
|
|
"""Test trainer setup for adapter fine-tuning."""
|
|
mock_model = Mock()
|
|
|
|
with patch('src.services.domain_adaptation_manager.Seq2SeqTrainer') as mock_trainer_class:
|
|
with patch('src.services.domain_adaptation_manager.Seq2SeqTrainingArguments') as mock_args_class:
|
|
mock_args = Mock()
|
|
mock_args_class.return_value = mock_args
|
|
mock_trainer = Mock()
|
|
mock_trainer_class.return_value = mock_trainer
|
|
|
|
result = domain_adaptation_manager._setup_trainer(mock_model, "test_output_dir")
|
|
|
|
assert result == mock_trainer
|
|
mock_args_class.assert_called_once()
|
|
mock_trainer_class.assert_called_once_with(
|
|
model=mock_model,
|
|
args=mock_args
|
|
)
|
|
|
|
|
|
class TestDomainAdaptationIntegration:
|
|
"""Integration tests for domain adaptation system."""
|
|
|
|
@pytest.fixture
|
|
def temp_adapters_dir(self):
|
|
"""Create temporary directory for adapter storage."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield Path(temp_dir)
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_end_to_end_domain_adaptation(self, temp_adapters_dir):
|
|
"""Test end-to-end domain adaptation workflow."""
|
|
# This test would require actual model loading and training
|
|
# For now, we'll test the integration points
|
|
with patch('src.services.domain_adaptation_manager.ModelManager') as mock_model_manager_class:
|
|
mock_model_manager = Mock()
|
|
mock_model_manager_class.return_value = mock_model_manager
|
|
|
|
manager = DomainAdaptationManager()
|
|
|
|
# Verify all components are properly initialized
|
|
assert manager.model_manager is not None
|
|
assert manager.domain_adapter is not None
|
|
assert manager.domain_detector is not None
|
|
|
|
def test_memory_optimization_integration(self, temp_adapters_dir):
|
|
"""Test memory optimization features."""
|
|
# This would test adapter swapping and memory management
|
|
# Implementation depends on the memory optimization features
|
|
pass
|
|
|
|
def test_performance_optimization_integration(self, temp_adapters_dir):
|
|
"""Test performance optimization features."""
|
|
# This would test caching and batched inference
|
|
# Implementation depends on the performance optimization features
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|