trax/tests/test_domain_performance_opt...

399 lines
16 KiB
Python

import pytest
import time
import threading
from unittest.mock import Mock, patch, MagicMock
from src.services.domain_performance_optimizer import (
BackgroundLoader,
BatchedInferenceManager,
ProgressiveLoader,
DomainPerformanceOptimizer,
PerformanceStats
)
from src.services.domain_adaptation import DomainAdapter, DomainDetector
class TestBackgroundLoader:
"""Test the background loader for domain adapters."""
@pytest.fixture
def background_loader(self):
"""Create a background loader for testing."""
return BackgroundLoader(max_workers=2)
def test_background_loader_initialization(self, background_loader):
"""Test background loader initialization."""
assert background_loader.max_workers == 2
assert len(background_loader.loaded_adapters) == 0
assert len(background_loader.loading_futures) == 0
assert background_loader._worker_thread.is_alive()
def test_preload_adapter(self, background_loader):
"""Test preloading an adapter."""
background_loader.preload_adapter("technical", "/path/to/technical.pt")
assert "technical" in background_loader.loading_futures
assert "technical" not in background_loader.loaded_adapters
def test_get_adapter_success(self, background_loader):
"""Test getting a successfully loaded adapter."""
background_loader.preload_adapter("technical", "/path/to/technical.pt")
# Wait for loading to complete
adapter = background_loader.get_adapter("technical", timeout=2.0)
assert adapter is not None
assert adapter["domain"] == "technical"
assert adapter["path"] == "/path/to/technical.pt"
assert "technical" in background_loader.loaded_adapters
def test_get_adapter_timeout(self, background_loader):
"""Test getting an adapter with timeout."""
# Don't preload anything
adapter = background_loader.get_adapter("nonexistent", timeout=0.1)
assert adapter is None
def test_preload_duplicate_adapter(self, background_loader):
"""Test preloading the same adapter twice."""
background_loader.preload_adapter("technical", "/path/to/technical.pt")
background_loader.preload_adapter("technical", "/path/to/technical.pt") # Duplicate
# Should only have one loading future
assert len(background_loader.loading_futures) == 1
def test_shutdown(self, background_loader):
"""Test background loader shutdown."""
background_loader.shutdown()
# Verify executor is shutdown
assert background_loader.executor._shutdown
class TestBatchedInferenceManager:
"""Test the batched inference manager."""
@pytest.fixture
def batched_inference(self):
"""Create a batched inference manager for testing."""
return BatchedInferenceManager(batch_size=3, max_wait_time=0.5)
def test_batched_inference_initialization(self, batched_inference):
"""Test batched inference manager initialization."""
assert batched_inference.batch_size == 3
assert batched_inference.max_wait_time == 0.5
assert len(batched_inference.pending_requests) == 0
assert len(batched_inference.results) == 0
def test_add_request(self, batched_inference):
"""Test adding a request to the batch."""
request_id = batched_inference.add_request("audio1", "technical")
assert request_id == 0
assert len(batched_inference.pending_requests) == 1
assert batched_inference.pending_requests[0][0] == 0
assert batched_inference.pending_requests[0][1] == ("audio1", "technical")
def test_batch_processing(self, batched_inference):
"""Test batch processing when batch is full."""
# Add requests to fill the batch
request_id1 = batched_inference.add_request("audio1", "technical")
request_id2 = batched_inference.add_request("audio2", "medical")
request_id3 = batched_inference.add_request("audio3", "academic")
# Batch should be processed automatically
assert len(batched_inference.pending_requests) == 0
# Get results
result1 = batched_inference.get_result(request_id1, timeout=1.0)
result2 = batched_inference.get_result(request_id2, timeout=1.0)
result3 = batched_inference.get_result(request_id3, timeout=1.0)
assert result1 == "[TECHNICAL] Processed audio for technical"
assert result2 == "[MEDICAL] Processed audio for medical"
assert result3 == "[ACADEMIC] Processed audio for academic"
def test_get_result_timeout(self, batched_inference):
"""Test getting result with timeout."""
result = batched_inference.get_result(999, timeout=0.1)
assert result is None
def test_multiple_batches(self, batched_inference):
"""Test processing multiple batches."""
# First batch
request_id1 = batched_inference.add_request("audio1", "technical")
request_id2 = batched_inference.add_request("audio2", "medical")
request_id3 = batched_inference.add_request("audio3", "academic")
# Second batch
request_id4 = batched_inference.add_request("audio4", "general")
request_id5 = batched_inference.add_request("audio5", "technical")
request_id6 = batched_inference.add_request("audio6", "medical")
# Get all results
results = []
for request_id in [request_id1, request_id2, request_id3, request_id4, request_id5, request_id6]:
result = batched_inference.get_result(request_id, timeout=1.0)
results.append(result)
assert len(results) == 6
assert all(result is not None for result in results)
class TestProgressiveLoader:
"""Test the progressive loader for large models."""
@pytest.fixture
def progressive_loader(self):
"""Create a progressive loader for testing."""
return ProgressiveLoader(chunk_size=1024)
def test_progressive_loader_initialization(self, progressive_loader):
"""Test progressive loader initialization."""
assert progressive_loader.chunk_size == 1024
assert len(progressive_loader.loaded_chunks) == 0
def test_load_model_progressively(self, progressive_loader):
"""Test progressive model loading."""
model = progressive_loader.load_model_progressively("/path/to/model.pt", 3000)
assert model["model_path"] == "/path/to/model.pt"
assert model["chunks"] == 3 # 3000 / 1024 = 3 chunks
# Verify chunks were loaded
assert "/path/to/model.pt" in progressive_loader.loaded_chunks
assert len(progressive_loader.loaded_chunks["/path/to/model.pt"]) == 3
def test_load_model_smaller_than_chunk(self, progressive_loader):
"""Test loading a model smaller than chunk size."""
model = progressive_loader.load_model_progressively("/path/to/small_model.pt", 512)
assert model["chunks"] == 1 # Should be 1 chunk even though it's smaller
def test_load_chunk(self, progressive_loader):
"""Test loading individual chunks."""
chunk = progressive_loader._load_chunk("/path/to/model.pt", 0)
assert chunk["chunk_idx"] == 0
assert chunk["data"] == "chunk_0"
def test_combine_chunks(self, progressive_loader):
"""Test combining chunks into a model."""
# Add some test chunks
progressive_loader.loaded_chunks["/path/to/model.pt"] = [
{"chunk_idx": 0, "data": "chunk_0"},
{"chunk_idx": 1, "data": "chunk_1"}
]
model = progressive_loader._combine_chunks("/path/to/model.pt")
assert model["model_path"] == "/path/to/model.pt"
assert model["chunks"] == 2
class TestDomainPerformanceOptimizer:
"""Test the domain performance optimizer."""
@pytest.fixture
def performance_optimizer(self):
"""Create a performance optimizer for testing."""
return DomainPerformanceOptimizer(
cache_size=5,
background_workers=2,
batch_size=3,
enable_progressive_loading=True
)
@pytest.fixture
def mock_domain_adapter(self):
"""Create a mock domain adapter."""
return Mock(spec=DomainAdapter)
@pytest.fixture
def mock_domain_detector(self):
"""Create a mock domain detector."""
return Mock(spec=DomainDetector)
def test_performance_optimizer_initialization(self, performance_optimizer):
"""Test performance optimizer initialization."""
assert performance_optimizer.cache_size == 5
assert performance_optimizer.background_loader.max_workers == 2
assert performance_optimizer.batched_inference.batch_size == 3
assert performance_optimizer.progressive_loader is not None
assert performance_optimizer.memory_optimizer is not None
assert len(performance_optimizer.inference_times) == 0
assert performance_optimizer.cache_hits == 0
assert performance_optimizer.cache_misses == 0
def test_optimize_transcription_with_batching(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
"""Test transcription optimization with batching."""
audio = "test_audio_data"
domain = "technical"
result = performance_optimizer.optimize_transcription(
audio, domain, mock_domain_adapter, mock_domain_detector,
use_batching=True, use_background_loading=False
)
assert result is not None
assert "[TECHNICAL]" in result
assert len(performance_optimizer.inference_times) > 0
assert performance_optimizer.cache_misses == 1
def test_optimize_transcription_cache_hit(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
"""Test transcription optimization with cache hit."""
audio = "test_audio_data"
domain = "technical"
# First call - should miss cache
result1 = performance_optimizer.optimize_transcription(
audio, domain, mock_domain_adapter, mock_domain_detector,
use_batching=False, use_background_loading=False
)
# Second call with same audio and domain - should hit cache
result2 = performance_optimizer.optimize_transcription(
audio, domain, mock_domain_adapter, mock_domain_detector,
use_batching=False, use_background_loading=False
)
assert result1 == result2
assert performance_optimizer.cache_hits == 1
assert performance_optimizer.cache_misses == 1
def test_preload_domain_adapters(self, performance_optimizer):
"""Test preloading domain adapters."""
domains = ["technical", "medical", "academic"]
adapter_paths = {
"technical": "/path/to/technical.pt",
"medical": "/path/to/medical.pt",
"academic": "/path/to/academic.pt"
}
performance_optimizer.preload_domain_adapters(domains, adapter_paths)
# Verify adapters are being loaded
assert len(performance_optimizer.background_loader.loading_futures) == 3
def test_get_performance_stats(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
"""Test getting performance statistics."""
# Perform some operations to generate stats
audio = "test_audio_data"
for i in range(3):
performance_optimizer.optimize_transcription(
audio, "technical", mock_domain_adapter, mock_domain_detector,
use_batching=False, use_background_loading=False
)
stats = performance_optimizer.get_performance_stats()
assert isinstance(stats, PerformanceStats)
assert stats.inference_time_ms > 0
assert stats.memory_usage_mb > 0
assert 0 <= stats.cache_hit_rate <= 1
assert stats.throughput_requests_per_second > 0
def test_cache_eviction(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
"""Test cache eviction when cache is full."""
# Fill the cache (size 5)
for i in range(6):
audio = f"audio_{i}"
performance_optimizer.optimize_transcription(
audio, "technical", mock_domain_adapter, mock_domain_detector,
use_batching=False, use_background_loading=False
)
# Cache should have evicted the oldest entry
assert len(performance_optimizer._cache) == 5
def test_shutdown(self, performance_optimizer):
"""Test performance optimizer shutdown."""
performance_optimizer.shutdown()
# Verify background loader is shutdown
assert performance_optimizer.background_loader.executor._shutdown
class TestPerformanceOptimizationIntegration:
"""Integration tests for performance optimization features."""
@pytest.fixture
def performance_optimizer(self):
"""Create a performance optimizer for integration testing."""
return DomainPerformanceOptimizer(
cache_size=10,
background_workers=2,
batch_size=4,
enable_progressive_loading=True
)
def test_end_to_end_performance_optimization(self, performance_optimizer):
"""Test end-to-end performance optimization workflow."""
mock_domain_adapter = Mock(spec=DomainAdapter)
mock_domain_detector = Mock(spec=DomainDetector)
# Preload adapters
domains = ["technical", "medical"]
adapter_paths = {
"technical": "/path/to/technical.pt",
"medical": "/path/to/medical.pt"
}
performance_optimizer.preload_domain_adapters(domains, adapter_paths)
# Perform multiple transcriptions
results = []
for i in range(5):
audio = f"audio_{i}"
domain = "technical" if i % 2 == 0 else "medical"
result = performance_optimizer.optimize_transcription(
audio, domain, mock_domain_adapter, mock_domain_detector,
use_batching=True, use_background_loading=True
)
results.append(result)
# Verify results
assert len(results) == 5
assert all(result is not None for result in results)
# Check performance stats
stats = performance_optimizer.get_performance_stats()
assert stats.inference_time_ms > 0
assert stats.throughput_requests_per_second > 0
def test_concurrent_access(self, performance_optimizer):
"""Test concurrent access to performance optimizer."""
mock_domain_adapter = Mock(spec=DomainAdapter)
mock_domain_detector = Mock(spec=DomainDetector)
def transcription_worker(worker_id):
for i in range(3):
audio = f"audio_worker_{worker_id}_{i}"
domain = "technical" if i % 2 == 0 else "medical"
result = performance_optimizer.optimize_transcription(
audio, domain, mock_domain_adapter, mock_domain_detector,
use_batching=True, use_background_loading=False
)
assert result is not None
# Create multiple threads
threads = []
for i in range(3):
thread = threading.Thread(target=transcription_worker, args=(i,))
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify performance stats
stats = performance_optimizer.get_performance_stats()
assert stats.throughput_requests_per_second > 0
if __name__ == "__main__":
pytest.main([__file__])