320 lines
13 KiB
Python
320 lines
13 KiB
Python
import pytest
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import torch
|
|
import psutil
|
|
|
|
from src.services.domain_memory_optimizer import AdapterCache, DomainMemoryOptimizer, MemoryStats
|
|
|
|
|
|
class TestAdapterCache:
|
|
"""Test the LRU cache for adapters."""
|
|
|
|
@pytest.fixture
|
|
def adapter_cache(self):
|
|
"""Create an adapter cache for testing."""
|
|
return AdapterCache(max_size=3, max_memory_mb=100)
|
|
|
|
def test_adapter_cache_initialization(self, adapter_cache):
|
|
"""Test adapter cache initialization."""
|
|
assert adapter_cache.max_size == 3
|
|
assert adapter_cache.max_memory_mb == 100
|
|
assert len(adapter_cache.cache) == 0
|
|
assert len(adapter_cache.adapter_sizes) == 0
|
|
|
|
def test_put_and_get_adapter(self, adapter_cache):
|
|
"""Test putting and getting adapters from cache."""
|
|
mock_adapter = Mock()
|
|
adapter_cache.put("technical", mock_adapter, 50)
|
|
|
|
result = adapter_cache.get("technical")
|
|
assert result == mock_adapter
|
|
assert "technical" in adapter_cache.cache
|
|
assert adapter_cache.adapter_sizes["technical"] == 50
|
|
|
|
def test_cache_eviction_by_size(self, adapter_cache):
|
|
"""Test LRU eviction when cache size limit is reached."""
|
|
# Add 4 adapters to a cache with max_size=3
|
|
adapters = ["tech1", "tech2", "tech3", "tech4"]
|
|
for i, name in enumerate(adapters):
|
|
adapter_cache.put(name, Mock(), 10)
|
|
|
|
# First adapter should be evicted
|
|
assert "tech1" not in adapter_cache.cache
|
|
assert "tech4" in adapter_cache.cache
|
|
assert len(adapter_cache.cache) == 3
|
|
|
|
def test_cache_eviction_by_memory(self, adapter_cache):
|
|
"""Test eviction when memory limit is exceeded."""
|
|
# Add adapters that exceed memory limit
|
|
adapter_cache.put("large1", Mock(), 60) # 60MB
|
|
adapter_cache.put("large2", Mock(), 60) # 120MB total, exceeds 100MB limit
|
|
|
|
# First adapter should be evicted
|
|
assert "large1" not in adapter_cache.cache
|
|
assert "large2" in adapter_cache.cache
|
|
|
|
def test_get_nonexistent_adapter(self, adapter_cache):
|
|
"""Test getting an adapter that doesn't exist."""
|
|
result = adapter_cache.get("nonexistent")
|
|
assert result is None
|
|
|
|
def test_clear_cache(self, adapter_cache):
|
|
"""Test clearing the cache."""
|
|
adapter_cache.put("tech1", Mock(), 10)
|
|
adapter_cache.put("tech2", Mock(), 10)
|
|
|
|
adapter_cache.clear()
|
|
|
|
assert len(adapter_cache.cache) == 0
|
|
assert len(adapter_cache.adapter_sizes) == 0
|
|
|
|
def test_get_stats(self, adapter_cache):
|
|
"""Test getting cache statistics."""
|
|
adapter_cache.put("tech1", Mock(), 30)
|
|
adapter_cache.put("tech2", Mock(), 40)
|
|
|
|
stats = adapter_cache.get_stats()
|
|
|
|
assert stats["size"] == 2
|
|
assert stats["memory_used_mb"] == 70
|
|
assert stats["max_size"] == 3
|
|
assert stats["max_memory_mb"] == 100
|
|
assert "tech1" in stats["domains"]
|
|
assert "tech2" in stats["domains"]
|
|
|
|
def test_cache_hit_miss_tracking(self, adapter_cache):
|
|
"""Test tracking of cache hits and misses."""
|
|
adapter_cache.put("tech1", Mock(), 10)
|
|
|
|
# Hit
|
|
adapter_cache.get("tech1")
|
|
# Miss
|
|
adapter_cache.get("nonexistent")
|
|
|
|
# The cache doesn't track hits/misses, just verify the adapter is still there
|
|
stats = adapter_cache.get_stats()
|
|
assert stats["size"] == 1
|
|
assert "tech1" in stats["domains"]
|
|
|
|
|
|
class TestDomainMemoryOptimizer:
|
|
"""Test the domain memory optimizer."""
|
|
|
|
@pytest.fixture
|
|
def temp_swap_dir(self):
|
|
"""Create temporary directory for swap files."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield Path(temp_dir)
|
|
shutil.rmtree(temp_dir)
|
|
|
|
@pytest.fixture
|
|
def memory_optimizer(self, temp_swap_dir):
|
|
"""Create a memory optimizer for testing."""
|
|
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
|
|
mock_path.return_value = temp_swap_dir
|
|
return DomainMemoryOptimizer(cache_size=2, max_memory_mb=100)
|
|
|
|
def test_memory_optimizer_initialization(self, memory_optimizer, temp_swap_dir):
|
|
"""Test memory optimizer initialization."""
|
|
assert memory_optimizer.cache.max_size == 2
|
|
assert memory_optimizer.cache.max_memory_mb == 100
|
|
assert memory_optimizer.swap_dir == temp_swap_dir
|
|
|
|
@patch('psutil.Process')
|
|
def test_get_memory_stats(self, mock_process, memory_optimizer):
|
|
"""Test getting memory statistics."""
|
|
mock_process_instance = Mock()
|
|
mock_memory_info = Mock()
|
|
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
|
mock_memory_info.vms = 4 * 1024 * 1024 * 1024 # 4GB
|
|
mock_process_instance.memory_info.return_value = mock_memory_info
|
|
mock_process_instance.memory_percent.return_value = 25.0
|
|
mock_process.return_value = mock_process_instance
|
|
|
|
with patch('torch.cuda.is_available', return_value=False):
|
|
stats = memory_optimizer.get_memory_stats()
|
|
|
|
assert stats.rss_mb == 2048.0
|
|
assert stats.vms_mb == 4096.0
|
|
assert stats.percent == 25.0
|
|
|
|
def test_estimate_adapter_size(self, memory_optimizer):
|
|
"""Test adapter size estimation."""
|
|
mock_adapter = Mock()
|
|
mock_param1 = Mock()
|
|
mock_param1.numel.return_value = 1000
|
|
mock_param2 = Mock()
|
|
mock_param2.numel.return_value = 2000
|
|
mock_adapter.parameters.return_value = [mock_param1, mock_param2]
|
|
|
|
size_mb = memory_optimizer.estimate_adapter_size(mock_adapter)
|
|
|
|
# (1000 + 2000) * 2 / (1024 * 1024) ≈ 0.0057 MB
|
|
assert size_mb >= 0
|
|
assert size_mb < 1 # Should be small
|
|
|
|
def test_swap_adapter_to_disk(self, memory_optimizer, temp_swap_dir):
|
|
"""Test swapping adapter to disk."""
|
|
mock_adapter = Mock()
|
|
mock_adapter.state_dict.return_value = {"param1": torch.tensor([1, 2, 3])}
|
|
expected_swap_path = temp_swap_dir / "test_adapter_swapped.pt"
|
|
|
|
with patch('torch.save') as mock_save:
|
|
result = memory_optimizer.swap_adapter_to_disk("test_adapter", mock_adapter)
|
|
|
|
assert result == str(expected_swap_path)
|
|
mock_save.assert_called_once()
|
|
|
|
def test_load_adapter_from_disk(self, memory_optimizer, temp_swap_dir):
|
|
"""Test loading adapter from disk."""
|
|
mock_base_model = Mock()
|
|
swap_path = str(temp_swap_dir / "test_adapter_swapped.pt")
|
|
|
|
with patch('torch.load', return_value={"param1": torch.tensor([1, 2, 3])}) as mock_load:
|
|
with patch('peft.LoraConfig') as mock_lora_config:
|
|
with patch('peft.get_peft_model') as mock_get_peft:
|
|
mock_adapter = Mock()
|
|
mock_get_peft.return_value = mock_adapter
|
|
|
|
result = memory_optimizer.load_adapter_from_disk("test_adapter", swap_path, mock_base_model)
|
|
|
|
assert result == mock_adapter
|
|
mock_load.assert_called_once()
|
|
|
|
def test_load_adapter_from_disk_not_found(self, memory_optimizer):
|
|
"""Test loading adapter that doesn't exist on disk."""
|
|
mock_base_model = Mock()
|
|
non_existent_path = "/path/to/nonexistent.pt"
|
|
|
|
with patch('torch.load', side_effect=FileNotFoundError("File not found")):
|
|
with pytest.raises(FileNotFoundError):
|
|
memory_optimizer.load_adapter_from_disk("nonexistent", non_existent_path, mock_base_model)
|
|
|
|
def test_optimize_memory_usage(self, memory_optimizer):
|
|
"""Test memory optimization strategy."""
|
|
# Mock memory stats to indicate high memory usage
|
|
with patch.object(memory_optimizer, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
rss_mb=5000.0, # High memory usage
|
|
vms_mb=8000.0,
|
|
percent=80.0
|
|
)
|
|
|
|
current_adapters = {"tech1": Mock(), "tech2": Mock()}
|
|
mock_base_model = Mock()
|
|
|
|
with patch.object(memory_optimizer, 'swap_adapter_to_disk') as mock_swap:
|
|
mock_swap.return_value = "/path/to/swap.pt"
|
|
|
|
result = memory_optimizer.optimize_memory_usage(current_adapters, mock_base_model)
|
|
|
|
# Should trigger swapping when memory is high
|
|
assert mock_swap.call_count == 2
|
|
|
|
def test_cleanup_swap_files(self, memory_optimizer, temp_swap_dir):
|
|
"""Test cleanup of swap files."""
|
|
# Create some test swap files
|
|
test_files = ["adapter1_swapped.pt", "adapter2_swapped.pt", "adapter3_swapped.pt"]
|
|
for filename in test_files:
|
|
(temp_swap_dir / filename).touch()
|
|
|
|
# Create a non-swap file
|
|
(temp_swap_dir / "not_a_swap.txt").touch()
|
|
|
|
memory_optimizer.cleanup_swap_files()
|
|
|
|
# Should only delete *_swapped.pt files
|
|
assert not (temp_swap_dir / "adapter1_swapped.pt").exists()
|
|
assert not (temp_swap_dir / "adapter2_swapped.pt").exists()
|
|
assert not (temp_swap_dir / "adapter3_swapped.pt").exists()
|
|
assert (temp_swap_dir / "not_a_swap.txt").exists() # Non-swap file should remain
|
|
|
|
def test_get_optimization_stats(self, memory_optimizer):
|
|
"""Test getting optimization statistics."""
|
|
stats = memory_optimizer.get_optimization_stats()
|
|
|
|
assert "memory_usage" in stats
|
|
assert "cache_stats" in stats
|
|
assert "swap_files" in stats
|
|
|
|
def test_memory_optimization_with_actual_adapters(self, memory_optimizer):
|
|
"""Test memory optimization with realistic adapter scenarios."""
|
|
# Add adapters to cache
|
|
mock_adapter1 = Mock()
|
|
mock_adapter2 = Mock()
|
|
mock_adapter3 = Mock()
|
|
|
|
memory_optimizer.cache.put("tech1", mock_adapter1, 30)
|
|
memory_optimizer.cache.put("tech2", mock_adapter2, 40)
|
|
memory_optimizer.cache.put("tech3", mock_adapter3, 50) # Should trigger eviction
|
|
|
|
# Verify cache size limit is respected
|
|
assert len(memory_optimizer.cache.cache) == 2
|
|
assert "tech1" not in memory_optimizer.cache.cache # First one evicted
|
|
assert "tech3" in memory_optimizer.cache.cache # Latest one kept
|
|
|
|
|
|
class TestMemoryOptimizationIntegration:
|
|
"""Integration tests for memory optimization features."""
|
|
|
|
@pytest.fixture
|
|
def temp_swap_dir(self):
|
|
"""Create temporary directory for swap files."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield Path(temp_dir)
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_adapter_swapping_workflow(self, temp_swap_dir):
|
|
"""Test complete adapter swapping workflow."""
|
|
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
|
|
mock_path.return_value = temp_swap_dir
|
|
|
|
optimizer = DomainMemoryOptimizer(cache_size=1, max_memory_mb=50)
|
|
|
|
# Create mock adapters
|
|
mock_adapter1 = Mock()
|
|
mock_adapter2 = Mock()
|
|
|
|
# Add first adapter
|
|
optimizer.cache.put("adapter1", mock_adapter1, 30)
|
|
|
|
# Add second adapter (should trigger eviction of first)
|
|
optimizer.cache.put("adapter2", mock_adapter2, 40)
|
|
|
|
# Verify cache state
|
|
assert "adapter2" in optimizer.cache.cache
|
|
assert "adapter1" not in optimizer.cache.cache
|
|
|
|
def test_memory_pressure_response(self, temp_swap_dir):
|
|
"""Test system response to memory pressure."""
|
|
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
|
|
mock_path.return_value = temp_swap_dir
|
|
|
|
optimizer = DomainMemoryOptimizer(cache_size=3, max_memory_mb=100)
|
|
|
|
# Simulate memory pressure
|
|
with patch.object(optimizer, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
rss_mb=5000.0, # High memory usage
|
|
vms_mb=8000.0,
|
|
percent=95.0
|
|
)
|
|
|
|
current_adapters = {"tech1": Mock(), "tech2": Mock()}
|
|
mock_base_model = Mock()
|
|
|
|
with patch.object(optimizer, 'swap_adapter_to_disk') as mock_swap:
|
|
mock_swap.return_value = "/path/to/swap.pt"
|
|
|
|
result = optimizer.optimize_memory_usage(current_adapters, mock_base_model)
|
|
|
|
# Should trigger swapping when memory is high
|
|
assert mock_swap.call_count == 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|