import pytest import time import threading from unittest.mock import Mock, patch, MagicMock from src.services.domain_performance_optimizer import ( BackgroundLoader, BatchedInferenceManager, ProgressiveLoader, DomainPerformanceOptimizer, PerformanceStats ) from src.services.domain_adaptation import DomainAdapter, DomainDetector class TestBackgroundLoader: """Test the background loader for domain adapters.""" @pytest.fixture def background_loader(self): """Create a background loader for testing.""" return BackgroundLoader(max_workers=2) def test_background_loader_initialization(self, background_loader): """Test background loader initialization.""" assert background_loader.max_workers == 2 assert len(background_loader.loaded_adapters) == 0 assert len(background_loader.loading_futures) == 0 assert background_loader._worker_thread.is_alive() def test_preload_adapter(self, background_loader): """Test preloading an adapter.""" background_loader.preload_adapter("technical", "/path/to/technical.pt") assert "technical" in background_loader.loading_futures assert "technical" not in background_loader.loaded_adapters def test_get_adapter_success(self, background_loader): """Test getting a successfully loaded adapter.""" background_loader.preload_adapter("technical", "/path/to/technical.pt") # Wait for loading to complete adapter = background_loader.get_adapter("technical", timeout=2.0) assert adapter is not None assert adapter["domain"] == "technical" assert adapter["path"] == "/path/to/technical.pt" assert "technical" in background_loader.loaded_adapters def test_get_adapter_timeout(self, background_loader): """Test getting an adapter with timeout.""" # Don't preload anything adapter = background_loader.get_adapter("nonexistent", timeout=0.1) assert adapter is None def test_preload_duplicate_adapter(self, background_loader): """Test preloading the same adapter twice.""" background_loader.preload_adapter("technical", "/path/to/technical.pt") background_loader.preload_adapter("technical", "/path/to/technical.pt") # Duplicate # Should only have one loading future assert len(background_loader.loading_futures) == 1 def test_shutdown(self, background_loader): """Test background loader shutdown.""" background_loader.shutdown() # Verify executor is shutdown assert background_loader.executor._shutdown class TestBatchedInferenceManager: """Test the batched inference manager.""" @pytest.fixture def batched_inference(self): """Create a batched inference manager for testing.""" return BatchedInferenceManager(batch_size=3, max_wait_time=0.5) def test_batched_inference_initialization(self, batched_inference): """Test batched inference manager initialization.""" assert batched_inference.batch_size == 3 assert batched_inference.max_wait_time == 0.5 assert len(batched_inference.pending_requests) == 0 assert len(batched_inference.results) == 0 def test_add_request(self, batched_inference): """Test adding a request to the batch.""" request_id = batched_inference.add_request("audio1", "technical") assert request_id == 0 assert len(batched_inference.pending_requests) == 1 assert batched_inference.pending_requests[0][0] == 0 assert batched_inference.pending_requests[0][1] == ("audio1", "technical") def test_batch_processing(self, batched_inference): """Test batch processing when batch is full.""" # Add requests to fill the batch request_id1 = batched_inference.add_request("audio1", "technical") request_id2 = batched_inference.add_request("audio2", "medical") request_id3 = batched_inference.add_request("audio3", "academic") # Batch should be processed automatically assert len(batched_inference.pending_requests) == 0 # Get results result1 = batched_inference.get_result(request_id1, timeout=1.0) result2 = batched_inference.get_result(request_id2, timeout=1.0) result3 = batched_inference.get_result(request_id3, timeout=1.0) assert result1 == "[TECHNICAL] Processed audio for technical" assert result2 == "[MEDICAL] Processed audio for medical" assert result3 == "[ACADEMIC] Processed audio for academic" def test_get_result_timeout(self, batched_inference): """Test getting result with timeout.""" result = batched_inference.get_result(999, timeout=0.1) assert result is None def test_multiple_batches(self, batched_inference): """Test processing multiple batches.""" # First batch request_id1 = batched_inference.add_request("audio1", "technical") request_id2 = batched_inference.add_request("audio2", "medical") request_id3 = batched_inference.add_request("audio3", "academic") # Second batch request_id4 = batched_inference.add_request("audio4", "general") request_id5 = batched_inference.add_request("audio5", "technical") request_id6 = batched_inference.add_request("audio6", "medical") # Get all results results = [] for request_id in [request_id1, request_id2, request_id3, request_id4, request_id5, request_id6]: result = batched_inference.get_result(request_id, timeout=1.0) results.append(result) assert len(results) == 6 assert all(result is not None for result in results) class TestProgressiveLoader: """Test the progressive loader for large models.""" @pytest.fixture def progressive_loader(self): """Create a progressive loader for testing.""" return ProgressiveLoader(chunk_size=1024) def test_progressive_loader_initialization(self, progressive_loader): """Test progressive loader initialization.""" assert progressive_loader.chunk_size == 1024 assert len(progressive_loader.loaded_chunks) == 0 def test_load_model_progressively(self, progressive_loader): """Test progressive model loading.""" model = progressive_loader.load_model_progressively("/path/to/model.pt", 3000) assert model["model_path"] == "/path/to/model.pt" assert model["chunks"] == 3 # 3000 / 1024 = 3 chunks # Verify chunks were loaded assert "/path/to/model.pt" in progressive_loader.loaded_chunks assert len(progressive_loader.loaded_chunks["/path/to/model.pt"]) == 3 def test_load_model_smaller_than_chunk(self, progressive_loader): """Test loading a model smaller than chunk size.""" model = progressive_loader.load_model_progressively("/path/to/small_model.pt", 512) assert model["chunks"] == 1 # Should be 1 chunk even though it's smaller def test_load_chunk(self, progressive_loader): """Test loading individual chunks.""" chunk = progressive_loader._load_chunk("/path/to/model.pt", 0) assert chunk["chunk_idx"] == 0 assert chunk["data"] == "chunk_0" def test_combine_chunks(self, progressive_loader): """Test combining chunks into a model.""" # Add some test chunks progressive_loader.loaded_chunks["/path/to/model.pt"] = [ {"chunk_idx": 0, "data": "chunk_0"}, {"chunk_idx": 1, "data": "chunk_1"} ] model = progressive_loader._combine_chunks("/path/to/model.pt") assert model["model_path"] == "/path/to/model.pt" assert model["chunks"] == 2 class TestDomainPerformanceOptimizer: """Test the domain performance optimizer.""" @pytest.fixture def performance_optimizer(self): """Create a performance optimizer for testing.""" return DomainPerformanceOptimizer( cache_size=5, background_workers=2, batch_size=3, enable_progressive_loading=True ) @pytest.fixture def mock_domain_adapter(self): """Create a mock domain adapter.""" return Mock(spec=DomainAdapter) @pytest.fixture def mock_domain_detector(self): """Create a mock domain detector.""" return Mock(spec=DomainDetector) def test_performance_optimizer_initialization(self, performance_optimizer): """Test performance optimizer initialization.""" assert performance_optimizer.cache_size == 5 assert performance_optimizer.background_loader.max_workers == 2 assert performance_optimizer.batched_inference.batch_size == 3 assert performance_optimizer.progressive_loader is not None assert performance_optimizer.memory_optimizer is not None assert len(performance_optimizer.inference_times) == 0 assert performance_optimizer.cache_hits == 0 assert performance_optimizer.cache_misses == 0 def test_optimize_transcription_with_batching(self, performance_optimizer, mock_domain_adapter, mock_domain_detector): """Test transcription optimization with batching.""" audio = "test_audio_data" domain = "technical" result = performance_optimizer.optimize_transcription( audio, domain, mock_domain_adapter, mock_domain_detector, use_batching=True, use_background_loading=False ) assert result is not None assert "[TECHNICAL]" in result assert len(performance_optimizer.inference_times) > 0 assert performance_optimizer.cache_misses == 1 def test_optimize_transcription_cache_hit(self, performance_optimizer, mock_domain_adapter, mock_domain_detector): """Test transcription optimization with cache hit.""" audio = "test_audio_data" domain = "technical" # First call - should miss cache result1 = performance_optimizer.optimize_transcription( audio, domain, mock_domain_adapter, mock_domain_detector, use_batching=False, use_background_loading=False ) # Second call with same audio and domain - should hit cache result2 = performance_optimizer.optimize_transcription( audio, domain, mock_domain_adapter, mock_domain_detector, use_batching=False, use_background_loading=False ) assert result1 == result2 assert performance_optimizer.cache_hits == 1 assert performance_optimizer.cache_misses == 1 def test_preload_domain_adapters(self, performance_optimizer): """Test preloading domain adapters.""" domains = ["technical", "medical", "academic"] adapter_paths = { "technical": "/path/to/technical.pt", "medical": "/path/to/medical.pt", "academic": "/path/to/academic.pt" } performance_optimizer.preload_domain_adapters(domains, adapter_paths) # Verify adapters are being loaded assert len(performance_optimizer.background_loader.loading_futures) == 3 def test_get_performance_stats(self, performance_optimizer, mock_domain_adapter, mock_domain_detector): """Test getting performance statistics.""" # Perform some operations to generate stats audio = "test_audio_data" for i in range(3): performance_optimizer.optimize_transcription( audio, "technical", mock_domain_adapter, mock_domain_detector, use_batching=False, use_background_loading=False ) stats = performance_optimizer.get_performance_stats() assert isinstance(stats, PerformanceStats) assert stats.inference_time_ms > 0 assert stats.memory_usage_mb > 0 assert 0 <= stats.cache_hit_rate <= 1 assert stats.throughput_requests_per_second > 0 def test_cache_eviction(self, performance_optimizer, mock_domain_adapter, mock_domain_detector): """Test cache eviction when cache is full.""" # Fill the cache (size 5) for i in range(6): audio = f"audio_{i}" performance_optimizer.optimize_transcription( audio, "technical", mock_domain_adapter, mock_domain_detector, use_batching=False, use_background_loading=False ) # Cache should have evicted the oldest entry assert len(performance_optimizer._cache) == 5 def test_shutdown(self, performance_optimizer): """Test performance optimizer shutdown.""" performance_optimizer.shutdown() # Verify background loader is shutdown assert performance_optimizer.background_loader.executor._shutdown class TestPerformanceOptimizationIntegration: """Integration tests for performance optimization features.""" @pytest.fixture def performance_optimizer(self): """Create a performance optimizer for integration testing.""" return DomainPerformanceOptimizer( cache_size=10, background_workers=2, batch_size=4, enable_progressive_loading=True ) def test_end_to_end_performance_optimization(self, performance_optimizer): """Test end-to-end performance optimization workflow.""" mock_domain_adapter = Mock(spec=DomainAdapter) mock_domain_detector = Mock(spec=DomainDetector) # Preload adapters domains = ["technical", "medical"] adapter_paths = { "technical": "/path/to/technical.pt", "medical": "/path/to/medical.pt" } performance_optimizer.preload_domain_adapters(domains, adapter_paths) # Perform multiple transcriptions results = [] for i in range(5): audio = f"audio_{i}" domain = "technical" if i % 2 == 0 else "medical" result = performance_optimizer.optimize_transcription( audio, domain, mock_domain_adapter, mock_domain_detector, use_batching=True, use_background_loading=True ) results.append(result) # Verify results assert len(results) == 5 assert all(result is not None for result in results) # Check performance stats stats = performance_optimizer.get_performance_stats() assert stats.inference_time_ms > 0 assert stats.throughput_requests_per_second > 0 def test_concurrent_access(self, performance_optimizer): """Test concurrent access to performance optimizer.""" mock_domain_adapter = Mock(spec=DomainAdapter) mock_domain_detector = Mock(spec=DomainDetector) def transcription_worker(worker_id): for i in range(3): audio = f"audio_worker_{worker_id}_{i}" domain = "technical" if i % 2 == 0 else "medical" result = performance_optimizer.optimize_transcription( audio, domain, mock_domain_adapter, mock_domain_detector, use_batching=True, use_background_loading=False ) assert result is not None # Create multiple threads threads = [] for i in range(3): thread = threading.Thread(target=transcription_worker, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify performance stats stats = performance_optimizer.get_performance_stats() assert stats.throughput_requests_per_second > 0 if __name__ == "__main__": pytest.main([__file__])