trax/tests/test_domain_memory_optimize...

320 lines
13 KiB
Python

import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
import torch
import psutil
from src.services.domain_memory_optimizer import AdapterCache, DomainMemoryOptimizer, MemoryStats
class TestAdapterCache:
"""Test the LRU cache for adapters."""
@pytest.fixture
def adapter_cache(self):
"""Create an adapter cache for testing."""
return AdapterCache(max_size=3, max_memory_mb=100)
def test_adapter_cache_initialization(self, adapter_cache):
"""Test adapter cache initialization."""
assert adapter_cache.max_size == 3
assert adapter_cache.max_memory_mb == 100
assert len(adapter_cache.cache) == 0
assert len(adapter_cache.adapter_sizes) == 0
def test_put_and_get_adapter(self, adapter_cache):
"""Test putting and getting adapters from cache."""
mock_adapter = Mock()
adapter_cache.put("technical", mock_adapter, 50)
result = adapter_cache.get("technical")
assert result == mock_adapter
assert "technical" in adapter_cache.cache
assert adapter_cache.adapter_sizes["technical"] == 50
def test_cache_eviction_by_size(self, adapter_cache):
"""Test LRU eviction when cache size limit is reached."""
# Add 4 adapters to a cache with max_size=3
adapters = ["tech1", "tech2", "tech3", "tech4"]
for i, name in enumerate(adapters):
adapter_cache.put(name, Mock(), 10)
# First adapter should be evicted
assert "tech1" not in adapter_cache.cache
assert "tech4" in adapter_cache.cache
assert len(adapter_cache.cache) == 3
def test_cache_eviction_by_memory(self, adapter_cache):
"""Test eviction when memory limit is exceeded."""
# Add adapters that exceed memory limit
adapter_cache.put("large1", Mock(), 60) # 60MB
adapter_cache.put("large2", Mock(), 60) # 120MB total, exceeds 100MB limit
# First adapter should be evicted
assert "large1" not in adapter_cache.cache
assert "large2" in adapter_cache.cache
def test_get_nonexistent_adapter(self, adapter_cache):
"""Test getting an adapter that doesn't exist."""
result = adapter_cache.get("nonexistent")
assert result is None
def test_clear_cache(self, adapter_cache):
"""Test clearing the cache."""
adapter_cache.put("tech1", Mock(), 10)
adapter_cache.put("tech2", Mock(), 10)
adapter_cache.clear()
assert len(adapter_cache.cache) == 0
assert len(adapter_cache.adapter_sizes) == 0
def test_get_stats(self, adapter_cache):
"""Test getting cache statistics."""
adapter_cache.put("tech1", Mock(), 30)
adapter_cache.put("tech2", Mock(), 40)
stats = adapter_cache.get_stats()
assert stats["size"] == 2
assert stats["memory_used_mb"] == 70
assert stats["max_size"] == 3
assert stats["max_memory_mb"] == 100
assert "tech1" in stats["domains"]
assert "tech2" in stats["domains"]
def test_cache_hit_miss_tracking(self, adapter_cache):
"""Test tracking of cache hits and misses."""
adapter_cache.put("tech1", Mock(), 10)
# Hit
adapter_cache.get("tech1")
# Miss
adapter_cache.get("nonexistent")
# The cache doesn't track hits/misses, just verify the adapter is still there
stats = adapter_cache.get_stats()
assert stats["size"] == 1
assert "tech1" in stats["domains"]
class TestDomainMemoryOptimizer:
"""Test the domain memory optimizer."""
@pytest.fixture
def temp_swap_dir(self):
"""Create temporary directory for swap files."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir)
@pytest.fixture
def memory_optimizer(self, temp_swap_dir):
"""Create a memory optimizer for testing."""
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
mock_path.return_value = temp_swap_dir
return DomainMemoryOptimizer(cache_size=2, max_memory_mb=100)
def test_memory_optimizer_initialization(self, memory_optimizer, temp_swap_dir):
"""Test memory optimizer initialization."""
assert memory_optimizer.cache.max_size == 2
assert memory_optimizer.cache.max_memory_mb == 100
assert memory_optimizer.swap_dir == temp_swap_dir
@patch('psutil.Process')
def test_get_memory_stats(self, mock_process, memory_optimizer):
"""Test getting memory statistics."""
mock_process_instance = Mock()
mock_memory_info = Mock()
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_memory_info.vms = 4 * 1024 * 1024 * 1024 # 4GB
mock_process_instance.memory_info.return_value = mock_memory_info
mock_process_instance.memory_percent.return_value = 25.0
mock_process.return_value = mock_process_instance
with patch('torch.cuda.is_available', return_value=False):
stats = memory_optimizer.get_memory_stats()
assert stats.rss_mb == 2048.0
assert stats.vms_mb == 4096.0
assert stats.percent == 25.0
def test_estimate_adapter_size(self, memory_optimizer):
"""Test adapter size estimation."""
mock_adapter = Mock()
mock_param1 = Mock()
mock_param1.numel.return_value = 1000
mock_param2 = Mock()
mock_param2.numel.return_value = 2000
mock_adapter.parameters.return_value = [mock_param1, mock_param2]
size_mb = memory_optimizer.estimate_adapter_size(mock_adapter)
# (1000 + 2000) * 2 / (1024 * 1024) ≈ 0.0057 MB
assert size_mb >= 0
assert size_mb < 1 # Should be small
def test_swap_adapter_to_disk(self, memory_optimizer, temp_swap_dir):
"""Test swapping adapter to disk."""
mock_adapter = Mock()
mock_adapter.state_dict.return_value = {"param1": torch.tensor([1, 2, 3])}
expected_swap_path = temp_swap_dir / "test_adapter_swapped.pt"
with patch('torch.save') as mock_save:
result = memory_optimizer.swap_adapter_to_disk("test_adapter", mock_adapter)
assert result == str(expected_swap_path)
mock_save.assert_called_once()
def test_load_adapter_from_disk(self, memory_optimizer, temp_swap_dir):
"""Test loading adapter from disk."""
mock_base_model = Mock()
swap_path = str(temp_swap_dir / "test_adapter_swapped.pt")
with patch('torch.load', return_value={"param1": torch.tensor([1, 2, 3])}) as mock_load:
with patch('peft.LoraConfig') as mock_lora_config:
with patch('peft.get_peft_model') as mock_get_peft:
mock_adapter = Mock()
mock_get_peft.return_value = mock_adapter
result = memory_optimizer.load_adapter_from_disk("test_adapter", swap_path, mock_base_model)
assert result == mock_adapter
mock_load.assert_called_once()
def test_load_adapter_from_disk_not_found(self, memory_optimizer):
"""Test loading adapter that doesn't exist on disk."""
mock_base_model = Mock()
non_existent_path = "/path/to/nonexistent.pt"
with patch('torch.load', side_effect=FileNotFoundError("File not found")):
with pytest.raises(FileNotFoundError):
memory_optimizer.load_adapter_from_disk("nonexistent", non_existent_path, mock_base_model)
def test_optimize_memory_usage(self, memory_optimizer):
"""Test memory optimization strategy."""
# Mock memory stats to indicate high memory usage
with patch.object(memory_optimizer, 'get_memory_stats') as mock_stats:
mock_stats.return_value = MemoryStats(
rss_mb=5000.0, # High memory usage
vms_mb=8000.0,
percent=80.0
)
current_adapters = {"tech1": Mock(), "tech2": Mock()}
mock_base_model = Mock()
with patch.object(memory_optimizer, 'swap_adapter_to_disk') as mock_swap:
mock_swap.return_value = "/path/to/swap.pt"
result = memory_optimizer.optimize_memory_usage(current_adapters, mock_base_model)
# Should trigger swapping when memory is high
assert mock_swap.call_count == 2
def test_cleanup_swap_files(self, memory_optimizer, temp_swap_dir):
"""Test cleanup of swap files."""
# Create some test swap files
test_files = ["adapter1_swapped.pt", "adapter2_swapped.pt", "adapter3_swapped.pt"]
for filename in test_files:
(temp_swap_dir / filename).touch()
# Create a non-swap file
(temp_swap_dir / "not_a_swap.txt").touch()
memory_optimizer.cleanup_swap_files()
# Should only delete *_swapped.pt files
assert not (temp_swap_dir / "adapter1_swapped.pt").exists()
assert not (temp_swap_dir / "adapter2_swapped.pt").exists()
assert not (temp_swap_dir / "adapter3_swapped.pt").exists()
assert (temp_swap_dir / "not_a_swap.txt").exists() # Non-swap file should remain
def test_get_optimization_stats(self, memory_optimizer):
"""Test getting optimization statistics."""
stats = memory_optimizer.get_optimization_stats()
assert "memory_usage" in stats
assert "cache_stats" in stats
assert "swap_files" in stats
def test_memory_optimization_with_actual_adapters(self, memory_optimizer):
"""Test memory optimization with realistic adapter scenarios."""
# Add adapters to cache
mock_adapter1 = Mock()
mock_adapter2 = Mock()
mock_adapter3 = Mock()
memory_optimizer.cache.put("tech1", mock_adapter1, 30)
memory_optimizer.cache.put("tech2", mock_adapter2, 40)
memory_optimizer.cache.put("tech3", mock_adapter3, 50) # Should trigger eviction
# Verify cache size limit is respected
assert len(memory_optimizer.cache.cache) == 2
assert "tech1" not in memory_optimizer.cache.cache # First one evicted
assert "tech3" in memory_optimizer.cache.cache # Latest one kept
class TestMemoryOptimizationIntegration:
"""Integration tests for memory optimization features."""
@pytest.fixture
def temp_swap_dir(self):
"""Create temporary directory for swap files."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir)
def test_adapter_swapping_workflow(self, temp_swap_dir):
"""Test complete adapter swapping workflow."""
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
mock_path.return_value = temp_swap_dir
optimizer = DomainMemoryOptimizer(cache_size=1, max_memory_mb=50)
# Create mock adapters
mock_adapter1 = Mock()
mock_adapter2 = Mock()
# Add first adapter
optimizer.cache.put("adapter1", mock_adapter1, 30)
# Add second adapter (should trigger eviction of first)
optimizer.cache.put("adapter2", mock_adapter2, 40)
# Verify cache state
assert "adapter2" in optimizer.cache.cache
assert "adapter1" not in optimizer.cache.cache
def test_memory_pressure_response(self, temp_swap_dir):
"""Test system response to memory pressure."""
with patch('src.services.domain_memory_optimizer.Path') as mock_path:
mock_path.return_value = temp_swap_dir
optimizer = DomainMemoryOptimizer(cache_size=3, max_memory_mb=100)
# Simulate memory pressure
with patch.object(optimizer, 'get_memory_stats') as mock_stats:
mock_stats.return_value = MemoryStats(
rss_mb=5000.0, # High memory usage
vms_mb=8000.0,
percent=95.0
)
current_adapters = {"tech1": Mock(), "tech2": Mock()}
mock_base_model = Mock()
with patch.object(optimizer, 'swap_adapter_to_disk') as mock_swap:
mock_swap.return_value = "/path/to/swap.pt"
result = optimizer.optimize_memory_usage(current_adapters, mock_base_model)
# Should trigger swapping when memory is high
assert mock_swap.call_count == 2
if __name__ == "__main__":
pytest.main([__file__])