import pytest import tempfile import shutil from pathlib import Path from unittest.mock import Mock, patch, MagicMock import torch import psutil from src.services.domain_memory_optimizer import AdapterCache, DomainMemoryOptimizer, MemoryStats class TestAdapterCache: """Test the LRU cache for adapters.""" @pytest.fixture def adapter_cache(self): """Create an adapter cache for testing.""" return AdapterCache(max_size=3, max_memory_mb=100) def test_adapter_cache_initialization(self, adapter_cache): """Test adapter cache initialization.""" assert adapter_cache.max_size == 3 assert adapter_cache.max_memory_mb == 100 assert len(adapter_cache.cache) == 0 assert len(adapter_cache.adapter_sizes) == 0 def test_put_and_get_adapter(self, adapter_cache): """Test putting and getting adapters from cache.""" mock_adapter = Mock() adapter_cache.put("technical", mock_adapter, 50) result = adapter_cache.get("technical") assert result == mock_adapter assert "technical" in adapter_cache.cache assert adapter_cache.adapter_sizes["technical"] == 50 def test_cache_eviction_by_size(self, adapter_cache): """Test LRU eviction when cache size limit is reached.""" # Add 4 adapters to a cache with max_size=3 adapters = ["tech1", "tech2", "tech3", "tech4"] for i, name in enumerate(adapters): adapter_cache.put(name, Mock(), 10) # First adapter should be evicted assert "tech1" not in adapter_cache.cache assert "tech4" in adapter_cache.cache assert len(adapter_cache.cache) == 3 def test_cache_eviction_by_memory(self, adapter_cache): """Test eviction when memory limit is exceeded.""" # Add adapters that exceed memory limit adapter_cache.put("large1", Mock(), 60) # 60MB adapter_cache.put("large2", Mock(), 60) # 120MB total, exceeds 100MB limit # First adapter should be evicted assert "large1" not in adapter_cache.cache assert "large2" in adapter_cache.cache def test_get_nonexistent_adapter(self, adapter_cache): """Test getting an adapter that doesn't exist.""" result = adapter_cache.get("nonexistent") assert result is None def test_clear_cache(self, adapter_cache): """Test clearing the cache.""" adapter_cache.put("tech1", Mock(), 10) adapter_cache.put("tech2", Mock(), 10) adapter_cache.clear() assert len(adapter_cache.cache) == 0 assert len(adapter_cache.adapter_sizes) == 0 def test_get_stats(self, adapter_cache): """Test getting cache statistics.""" adapter_cache.put("tech1", Mock(), 30) adapter_cache.put("tech2", Mock(), 40) stats = adapter_cache.get_stats() assert stats["size"] == 2 assert stats["memory_used_mb"] == 70 assert stats["max_size"] == 3 assert stats["max_memory_mb"] == 100 assert "tech1" in stats["domains"] assert "tech2" in stats["domains"] def test_cache_hit_miss_tracking(self, adapter_cache): """Test tracking of cache hits and misses.""" adapter_cache.put("tech1", Mock(), 10) # Hit adapter_cache.get("tech1") # Miss adapter_cache.get("nonexistent") # The cache doesn't track hits/misses, just verify the adapter is still there stats = adapter_cache.get_stats() assert stats["size"] == 1 assert "tech1" in stats["domains"] class TestDomainMemoryOptimizer: """Test the domain memory optimizer.""" @pytest.fixture def temp_swap_dir(self): """Create temporary directory for swap files.""" temp_dir = tempfile.mkdtemp() yield Path(temp_dir) shutil.rmtree(temp_dir) @pytest.fixture def memory_optimizer(self, temp_swap_dir): """Create a memory optimizer for testing.""" with patch('src.services.domain_memory_optimizer.Path') as mock_path: mock_path.return_value = temp_swap_dir return DomainMemoryOptimizer(cache_size=2, max_memory_mb=100) def test_memory_optimizer_initialization(self, memory_optimizer, temp_swap_dir): """Test memory optimizer initialization.""" assert memory_optimizer.cache.max_size == 2 assert memory_optimizer.cache.max_memory_mb == 100 assert memory_optimizer.swap_dir == temp_swap_dir @patch('psutil.Process') def test_get_memory_stats(self, mock_process, memory_optimizer): """Test getting memory statistics.""" mock_process_instance = Mock() mock_memory_info = Mock() mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB mock_memory_info.vms = 4 * 1024 * 1024 * 1024 # 4GB mock_process_instance.memory_info.return_value = mock_memory_info mock_process_instance.memory_percent.return_value = 25.0 mock_process.return_value = mock_process_instance with patch('torch.cuda.is_available', return_value=False): stats = memory_optimizer.get_memory_stats() assert stats.rss_mb == 2048.0 assert stats.vms_mb == 4096.0 assert stats.percent == 25.0 def test_estimate_adapter_size(self, memory_optimizer): """Test adapter size estimation.""" mock_adapter = Mock() mock_param1 = Mock() mock_param1.numel.return_value = 1000 mock_param2 = Mock() mock_param2.numel.return_value = 2000 mock_adapter.parameters.return_value = [mock_param1, mock_param2] size_mb = memory_optimizer.estimate_adapter_size(mock_adapter) # (1000 + 2000) * 2 / (1024 * 1024) ≈ 0.0057 MB assert size_mb >= 0 assert size_mb < 1 # Should be small def test_swap_adapter_to_disk(self, memory_optimizer, temp_swap_dir): """Test swapping adapter to disk.""" mock_adapter = Mock() mock_adapter.state_dict.return_value = {"param1": torch.tensor([1, 2, 3])} expected_swap_path = temp_swap_dir / "test_adapter_swapped.pt" with patch('torch.save') as mock_save: result = memory_optimizer.swap_adapter_to_disk("test_adapter", mock_adapter) assert result == str(expected_swap_path) mock_save.assert_called_once() def test_load_adapter_from_disk(self, memory_optimizer, temp_swap_dir): """Test loading adapter from disk.""" mock_base_model = Mock() swap_path = str(temp_swap_dir / "test_adapter_swapped.pt") with patch('torch.load', return_value={"param1": torch.tensor([1, 2, 3])}) as mock_load: with patch('peft.LoraConfig') as mock_lora_config: with patch('peft.get_peft_model') as mock_get_peft: mock_adapter = Mock() mock_get_peft.return_value = mock_adapter result = memory_optimizer.load_adapter_from_disk("test_adapter", swap_path, mock_base_model) assert result == mock_adapter mock_load.assert_called_once() def test_load_adapter_from_disk_not_found(self, memory_optimizer): """Test loading adapter that doesn't exist on disk.""" mock_base_model = Mock() non_existent_path = "/path/to/nonexistent.pt" with patch('torch.load', side_effect=FileNotFoundError("File not found")): with pytest.raises(FileNotFoundError): memory_optimizer.load_adapter_from_disk("nonexistent", non_existent_path, mock_base_model) def test_optimize_memory_usage(self, memory_optimizer): """Test memory optimization strategy.""" # Mock memory stats to indicate high memory usage with patch.object(memory_optimizer, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( rss_mb=5000.0, # High memory usage vms_mb=8000.0, percent=80.0 ) current_adapters = {"tech1": Mock(), "tech2": Mock()} mock_base_model = Mock() with patch.object(memory_optimizer, 'swap_adapter_to_disk') as mock_swap: mock_swap.return_value = "/path/to/swap.pt" result = memory_optimizer.optimize_memory_usage(current_adapters, mock_base_model) # Should trigger swapping when memory is high assert mock_swap.call_count == 2 def test_cleanup_swap_files(self, memory_optimizer, temp_swap_dir): """Test cleanup of swap files.""" # Create some test swap files test_files = ["adapter1_swapped.pt", "adapter2_swapped.pt", "adapter3_swapped.pt"] for filename in test_files: (temp_swap_dir / filename).touch() # Create a non-swap file (temp_swap_dir / "not_a_swap.txt").touch() memory_optimizer.cleanup_swap_files() # Should only delete *_swapped.pt files assert not (temp_swap_dir / "adapter1_swapped.pt").exists() assert not (temp_swap_dir / "adapter2_swapped.pt").exists() assert not (temp_swap_dir / "adapter3_swapped.pt").exists() assert (temp_swap_dir / "not_a_swap.txt").exists() # Non-swap file should remain def test_get_optimization_stats(self, memory_optimizer): """Test getting optimization statistics.""" stats = memory_optimizer.get_optimization_stats() assert "memory_usage" in stats assert "cache_stats" in stats assert "swap_files" in stats def test_memory_optimization_with_actual_adapters(self, memory_optimizer): """Test memory optimization with realistic adapter scenarios.""" # Add adapters to cache mock_adapter1 = Mock() mock_adapter2 = Mock() mock_adapter3 = Mock() memory_optimizer.cache.put("tech1", mock_adapter1, 30) memory_optimizer.cache.put("tech2", mock_adapter2, 40) memory_optimizer.cache.put("tech3", mock_adapter3, 50) # Should trigger eviction # Verify cache size limit is respected assert len(memory_optimizer.cache.cache) == 2 assert "tech1" not in memory_optimizer.cache.cache # First one evicted assert "tech3" in memory_optimizer.cache.cache # Latest one kept class TestMemoryOptimizationIntegration: """Integration tests for memory optimization features.""" @pytest.fixture def temp_swap_dir(self): """Create temporary directory for swap files.""" temp_dir = tempfile.mkdtemp() yield Path(temp_dir) shutil.rmtree(temp_dir) def test_adapter_swapping_workflow(self, temp_swap_dir): """Test complete adapter swapping workflow.""" with patch('src.services.domain_memory_optimizer.Path') as mock_path: mock_path.return_value = temp_swap_dir optimizer = DomainMemoryOptimizer(cache_size=1, max_memory_mb=50) # Create mock adapters mock_adapter1 = Mock() mock_adapter2 = Mock() # Add first adapter optimizer.cache.put("adapter1", mock_adapter1, 30) # Add second adapter (should trigger eviction of first) optimizer.cache.put("adapter2", mock_adapter2, 40) # Verify cache state assert "adapter2" in optimizer.cache.cache assert "adapter1" not in optimizer.cache.cache def test_memory_pressure_response(self, temp_swap_dir): """Test system response to memory pressure.""" with patch('src.services.domain_memory_optimizer.Path') as mock_path: mock_path.return_value = temp_swap_dir optimizer = DomainMemoryOptimizer(cache_size=3, max_memory_mb=100) # Simulate memory pressure with patch.object(optimizer, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( rss_mb=5000.0, # High memory usage vms_mb=8000.0, percent=95.0 ) current_adapters = {"tech1": Mock(), "tech2": Mock()} mock_base_model = Mock() with patch.object(optimizer, 'swap_adapter_to_disk') as mock_swap: mock_swap.return_value = "/path/to/swap.pt" result = optimizer.optimize_memory_usage(current_adapters, mock_base_model) # Should trigger swapping when memory is high assert mock_swap.call_count == 2 if __name__ == "__main__": pytest.main([__file__])