trax/tests/test_domain_adaptation.py

323 lines
14 KiB
Python

"""Unit tests for Domain Adaptation System with LoRA Adapters.
Tests the domain adaptation system including LoRA adapters, domain detection,
and integration with the ModelManager.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from typing import Dict, List, Any
import torch
import numpy as np
from transformers import WhisperForConditionalGeneration
from src.services.domain_adaptation import (
DomainAdapter,
DomainDetector
)
from src.services.domain_adaptation_manager import DomainAdaptationManager
class TestDomainAdapter:
"""Test cases for LoRA adapter architecture."""
@pytest.fixture
def mock_base_model(self):
"""Create a mock base model for testing."""
model = Mock(spec=WhisperForConditionalGeneration)
model.config = Mock()
model.config.hidden_size = 768
return model
@pytest.fixture
def domain_adapter(self, mock_base_model):
"""Create a DomainAdapter instance for testing."""
with patch('src.services.domain_adaptation.WhisperForConditionalGeneration.from_pretrained', return_value=mock_base_model):
return DomainAdapter(base_model_id="openai/whisper-large-v2")
def test_domain_adapter_initialization(self, domain_adapter):
"""Test DomainAdapter initialization."""
assert domain_adapter.base_model is not None
assert isinstance(domain_adapter.domain_adapters, dict)
assert len(domain_adapter.domain_adapters) == 0
def test_create_adapter(self, domain_adapter):
"""Test creating a new LoRA adapter."""
with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft:
mock_adapter = Mock()
mock_get_peft.return_value = mock_adapter
adapter = domain_adapter.create_adapter("technical")
assert adapter == mock_adapter
assert "technical" in domain_adapter.domain_adapters
mock_get_peft.assert_called_once()
def test_load_adapter(self, domain_adapter):
"""Test loading a pre-trained adapter."""
with patch('src.services.domain_adaptation.get_peft_model') as mock_get_peft:
mock_adapter = Mock()
mock_get_peft.return_value = mock_adapter
# Test loading non-existent adapter - should raise FileNotFoundError
with pytest.raises(FileNotFoundError, match="Adapter path not found: /path/to/adapter"):
domain_adapter.load_adapter("medical", "/path/to/adapter")
def test_switch_adapter_existing(self, domain_adapter):
"""Test switching to an existing adapter."""
mock_adapter = Mock()
domain_adapter.domain_adapters["technical"] = mock_adapter
result = domain_adapter.switch_adapter("technical")
assert result == mock_adapter
def test_switch_adapter_not_found(self, domain_adapter):
"""Test switching to non-existent adapter raises error."""
with pytest.raises(ValueError, match="Domain adapter 'unknown' not found"):
domain_adapter.switch_adapter("unknown")
class TestDomainDetector:
"""Test cases for domain detection system."""
@pytest.fixture
def domain_detector(self):
"""Create a DomainDetector instance for testing."""
return DomainDetector()
@pytest.fixture
def sample_training_data(self):
"""Create sample training data for domain detection."""
texts = [
"The API endpoint returns a JSON response with status code 200",
"Patient shows symptoms of acute myocardial infarction",
"The research methodology follows a quantitative approach",
"Hello world, how are you today?",
"Implement the singleton pattern for thread safety",
"Administer 500mg of acetaminophen every 6 hours",
"The study population consisted of 100 participants",
"This is a general conversation about the weather"
]
labels = ["technical", "medical", "academic", "general",
"technical", "medical", "academic", "general"]
return texts, labels
def test_domain_detector_initialization(self, domain_detector):
"""Test DomainDetector initialization."""
assert domain_detector.vectorizer is not None
assert domain_detector.classifier is not None
assert "general" in domain_detector.domains
assert "technical" in domain_detector.domains
assert "medical" in domain_detector.domains
assert "academic" in domain_detector.domains
def test_train_domain_detector(self, domain_detector, sample_training_data):
"""Test training the domain detector."""
texts, labels = sample_training_data
# Should not raise any exceptions
domain_detector.train(texts, labels)
# Verify vectorizer was fitted
assert hasattr(domain_detector.vectorizer, 'vocabulary_')
def test_detect_domain_high_confidence(self, domain_detector, sample_training_data):
"""Test domain detection with high confidence."""
texts, labels = sample_training_data
domain_detector.train(texts, labels)
# Test technical domain
result = domain_detector.detect_domain("API endpoint configuration", threshold=0.6)
assert result in domain_detector.domains
def test_detect_domain_low_confidence(self, domain_detector, sample_training_data):
"""Test domain detection with low confidence returns general."""
texts, labels = sample_training_data
domain_detector.train(texts, labels)
# Test with ambiguous text
result = domain_detector.detect_domain("random ambiguous text", threshold=0.9)
assert result == "general"
def test_detect_domain_empty_text(self, domain_detector, sample_training_data):
"""Test domain detection with empty text."""
texts, labels = sample_training_data
domain_detector.train(texts, labels)
result = domain_detector.detect_domain("", threshold=0.6)
assert result == "general"
class TestDomainAdaptationManager:
"""Test cases for DomainAdaptationManager integration."""
@pytest.fixture
def mock_model_manager(self):
"""Create a mock ModelManager."""
manager = Mock()
manager.get_base_model.return_value = Mock()
return manager
@pytest.fixture
def domain_adaptation_manager(self, mock_model_manager):
"""Create a DomainAdaptationManager instance for testing."""
with patch('src.services.domain_adaptation_manager.ModelManager', return_value=mock_model_manager):
return DomainAdaptationManager()
def test_domain_adaptation_manager_initialization(self, domain_adaptation_manager):
"""Test DomainAdaptationManager initialization."""
assert domain_adaptation_manager.model_manager is not None
assert domain_adaptation_manager.domain_adapter is not None
assert domain_adaptation_manager.domain_detector is not None
def test_load_default_adapters(self, domain_adaptation_manager):
"""Test loading default domain adapters."""
# Since the default adapters don't exist, this should just log info messages
# The method should not raise any exceptions
domain_adaptation_manager._load_default_adapters()
# Test passes if no exception is raised
def test_transcribe_with_domain_adaptation_auto_detect(self, domain_adaptation_manager):
"""Test transcription with automatic domain detection."""
mock_audio = Mock()
mock_transcription = "API endpoint configuration for microservices"
# Mock the model manager transcription
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
# Mock domain detection
# Add the technical adapter to the domain_adapters dict so switch_adapter gets called
domain_adaptation_manager.domain_adapter.domain_adapters["technical"] = Mock()
with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="technical"):
# Mock adapter switching
with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch:
mock_adapter = Mock()
mock_switch.return_value = mock_adapter
result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio)
# Should return enhanced transcription with domain prefix
assert result == "[TECHNICAL] API endpoint configuration for microservices"
domain_adaptation_manager.model_manager.transcribe.assert_called_once_with(mock_audio)
mock_switch.assert_called_once_with("technical")
def test_transcribe_with_domain_adaptation_specified_domain(self, domain_adaptation_manager):
"""Test transcription with specified domain."""
mock_audio = Mock()
mock_transcription = "Medical transcription"
# Mock the model manager transcription
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
# Add the medical adapter to the domain_adapters dict so switch_adapter gets called
domain_adaptation_manager.domain_adapter.domain_adapters["medical"] = Mock()
# Mock adapter switching
with patch.object(domain_adaptation_manager.domain_adapter, 'switch_adapter') as mock_switch:
mock_adapter = Mock()
mock_switch.return_value = mock_adapter
result = domain_adaptation_manager.transcribe_with_domain_adaptation(
mock_audio, auto_detect=False, domain="medical"
)
assert result == "[MEDICAL] Medical transcription"
mock_switch.assert_called_once_with("medical")
def test_transcribe_with_domain_adaptation_general_domain(self, domain_adaptation_manager):
"""Test transcription with general domain (no adaptation)."""
mock_audio = Mock()
mock_transcription = "General conversation"
domain_adaptation_manager.model_manager.transcribe.return_value = mock_transcription
with patch.object(domain_adaptation_manager.domain_detector, 'detect_domain', return_value="general"):
result = domain_adaptation_manager.transcribe_with_domain_adaptation(mock_audio)
assert result == mock_transcription
# Should not call switch_adapter for general domain
def test_train_custom_domain(self, domain_adaptation_manager):
"""Test training a custom domain adapter."""
training_data = Mock()
# Create the adapter first so it exists in the domain_adapters dict
mock_adapter = Mock()
domain_adaptation_manager.domain_adapter.domain_adapters["legal"] = mock_adapter
with patch.object(domain_adaptation_manager, '_setup_trainer') as mock_setup:
mock_trainer = Mock()
mock_setup.return_value = mock_trainer
domain_adaptation_manager.train_custom_domain("legal", training_data)
mock_setup.assert_called_once()
mock_trainer.train.assert_called_once_with(training_data)
def test_setup_trainer(self, domain_adaptation_manager):
"""Test trainer setup for adapter fine-tuning."""
mock_model = Mock()
with patch('src.services.domain_adaptation_manager.Seq2SeqTrainer') as mock_trainer_class:
with patch('src.services.domain_adaptation_manager.Seq2SeqTrainingArguments') as mock_args_class:
mock_args = Mock()
mock_args_class.return_value = mock_args
mock_trainer = Mock()
mock_trainer_class.return_value = mock_trainer
result = domain_adaptation_manager._setup_trainer(mock_model, "test_output_dir")
assert result == mock_trainer
mock_args_class.assert_called_once()
mock_trainer_class.assert_called_once_with(
model=mock_model,
args=mock_args
)
class TestDomainAdaptationIntegration:
"""Integration tests for domain adaptation system."""
@pytest.fixture
def temp_adapters_dir(self):
"""Create temporary directory for adapter storage."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir)
def test_end_to_end_domain_adaptation(self, temp_adapters_dir):
"""Test end-to-end domain adaptation workflow."""
# This test would require actual model loading and training
# For now, we'll test the integration points
with patch('src.services.domain_adaptation_manager.ModelManager') as mock_model_manager_class:
mock_model_manager = Mock()
mock_model_manager_class.return_value = mock_model_manager
manager = DomainAdaptationManager()
# Verify all components are properly initialized
assert manager.model_manager is not None
assert manager.domain_adapter is not None
assert manager.domain_detector is not None
def test_memory_optimization_integration(self, temp_adapters_dir):
"""Test memory optimization features."""
# This would test adapter swapping and memory management
# Implementation depends on the memory optimization features
pass
def test_performance_optimization_integration(self, temp_adapters_dir):
"""Test performance optimization features."""
# This would test caching and batched inference
# Implementation depends on the performance optimization features
pass
if __name__ == "__main__":
pytest.main([__file__])