399 lines
16 KiB
Python
399 lines
16 KiB
Python
import pytest
|
|
import time
|
|
import threading
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from src.services.domain_performance_optimizer import (
|
|
BackgroundLoader,
|
|
BatchedInferenceManager,
|
|
ProgressiveLoader,
|
|
DomainPerformanceOptimizer,
|
|
PerformanceStats
|
|
)
|
|
from src.services.domain_adaptation import DomainAdapter, DomainDetector
|
|
|
|
|
|
class TestBackgroundLoader:
|
|
"""Test the background loader for domain adapters."""
|
|
|
|
@pytest.fixture
|
|
def background_loader(self):
|
|
"""Create a background loader for testing."""
|
|
return BackgroundLoader(max_workers=2)
|
|
|
|
def test_background_loader_initialization(self, background_loader):
|
|
"""Test background loader initialization."""
|
|
assert background_loader.max_workers == 2
|
|
assert len(background_loader.loaded_adapters) == 0
|
|
assert len(background_loader.loading_futures) == 0
|
|
assert background_loader._worker_thread.is_alive()
|
|
|
|
def test_preload_adapter(self, background_loader):
|
|
"""Test preloading an adapter."""
|
|
background_loader.preload_adapter("technical", "/path/to/technical.pt")
|
|
|
|
assert "technical" in background_loader.loading_futures
|
|
assert "technical" not in background_loader.loaded_adapters
|
|
|
|
def test_get_adapter_success(self, background_loader):
|
|
"""Test getting a successfully loaded adapter."""
|
|
background_loader.preload_adapter("technical", "/path/to/technical.pt")
|
|
|
|
# Wait for loading to complete
|
|
adapter = background_loader.get_adapter("technical", timeout=2.0)
|
|
|
|
assert adapter is not None
|
|
assert adapter["domain"] == "technical"
|
|
assert adapter["path"] == "/path/to/technical.pt"
|
|
assert "technical" in background_loader.loaded_adapters
|
|
|
|
def test_get_adapter_timeout(self, background_loader):
|
|
"""Test getting an adapter with timeout."""
|
|
# Don't preload anything
|
|
adapter = background_loader.get_adapter("nonexistent", timeout=0.1)
|
|
|
|
assert adapter is None
|
|
|
|
def test_preload_duplicate_adapter(self, background_loader):
|
|
"""Test preloading the same adapter twice."""
|
|
background_loader.preload_adapter("technical", "/path/to/technical.pt")
|
|
background_loader.preload_adapter("technical", "/path/to/technical.pt") # Duplicate
|
|
|
|
# Should only have one loading future
|
|
assert len(background_loader.loading_futures) == 1
|
|
|
|
def test_shutdown(self, background_loader):
|
|
"""Test background loader shutdown."""
|
|
background_loader.shutdown()
|
|
|
|
# Verify executor is shutdown
|
|
assert background_loader.executor._shutdown
|
|
|
|
|
|
class TestBatchedInferenceManager:
|
|
"""Test the batched inference manager."""
|
|
|
|
@pytest.fixture
|
|
def batched_inference(self):
|
|
"""Create a batched inference manager for testing."""
|
|
return BatchedInferenceManager(batch_size=3, max_wait_time=0.5)
|
|
|
|
def test_batched_inference_initialization(self, batched_inference):
|
|
"""Test batched inference manager initialization."""
|
|
assert batched_inference.batch_size == 3
|
|
assert batched_inference.max_wait_time == 0.5
|
|
assert len(batched_inference.pending_requests) == 0
|
|
assert len(batched_inference.results) == 0
|
|
|
|
def test_add_request(self, batched_inference):
|
|
"""Test adding a request to the batch."""
|
|
request_id = batched_inference.add_request("audio1", "technical")
|
|
|
|
assert request_id == 0
|
|
assert len(batched_inference.pending_requests) == 1
|
|
assert batched_inference.pending_requests[0][0] == 0
|
|
assert batched_inference.pending_requests[0][1] == ("audio1", "technical")
|
|
|
|
def test_batch_processing(self, batched_inference):
|
|
"""Test batch processing when batch is full."""
|
|
# Add requests to fill the batch
|
|
request_id1 = batched_inference.add_request("audio1", "technical")
|
|
request_id2 = batched_inference.add_request("audio2", "medical")
|
|
request_id3 = batched_inference.add_request("audio3", "academic")
|
|
|
|
# Batch should be processed automatically
|
|
assert len(batched_inference.pending_requests) == 0
|
|
|
|
# Get results
|
|
result1 = batched_inference.get_result(request_id1, timeout=1.0)
|
|
result2 = batched_inference.get_result(request_id2, timeout=1.0)
|
|
result3 = batched_inference.get_result(request_id3, timeout=1.0)
|
|
|
|
assert result1 == "[TECHNICAL] Processed audio for technical"
|
|
assert result2 == "[MEDICAL] Processed audio for medical"
|
|
assert result3 == "[ACADEMIC] Processed audio for academic"
|
|
|
|
def test_get_result_timeout(self, batched_inference):
|
|
"""Test getting result with timeout."""
|
|
result = batched_inference.get_result(999, timeout=0.1)
|
|
|
|
assert result is None
|
|
|
|
def test_multiple_batches(self, batched_inference):
|
|
"""Test processing multiple batches."""
|
|
# First batch
|
|
request_id1 = batched_inference.add_request("audio1", "technical")
|
|
request_id2 = batched_inference.add_request("audio2", "medical")
|
|
request_id3 = batched_inference.add_request("audio3", "academic")
|
|
|
|
# Second batch
|
|
request_id4 = batched_inference.add_request("audio4", "general")
|
|
request_id5 = batched_inference.add_request("audio5", "technical")
|
|
request_id6 = batched_inference.add_request("audio6", "medical")
|
|
|
|
# Get all results
|
|
results = []
|
|
for request_id in [request_id1, request_id2, request_id3, request_id4, request_id5, request_id6]:
|
|
result = batched_inference.get_result(request_id, timeout=1.0)
|
|
results.append(result)
|
|
|
|
assert len(results) == 6
|
|
assert all(result is not None for result in results)
|
|
|
|
|
|
class TestProgressiveLoader:
|
|
"""Test the progressive loader for large models."""
|
|
|
|
@pytest.fixture
|
|
def progressive_loader(self):
|
|
"""Create a progressive loader for testing."""
|
|
return ProgressiveLoader(chunk_size=1024)
|
|
|
|
def test_progressive_loader_initialization(self, progressive_loader):
|
|
"""Test progressive loader initialization."""
|
|
assert progressive_loader.chunk_size == 1024
|
|
assert len(progressive_loader.loaded_chunks) == 0
|
|
|
|
def test_load_model_progressively(self, progressive_loader):
|
|
"""Test progressive model loading."""
|
|
model = progressive_loader.load_model_progressively("/path/to/model.pt", 3000)
|
|
|
|
assert model["model_path"] == "/path/to/model.pt"
|
|
assert model["chunks"] == 3 # 3000 / 1024 = 3 chunks
|
|
|
|
# Verify chunks were loaded
|
|
assert "/path/to/model.pt" in progressive_loader.loaded_chunks
|
|
assert len(progressive_loader.loaded_chunks["/path/to/model.pt"]) == 3
|
|
|
|
def test_load_model_smaller_than_chunk(self, progressive_loader):
|
|
"""Test loading a model smaller than chunk size."""
|
|
model = progressive_loader.load_model_progressively("/path/to/small_model.pt", 512)
|
|
|
|
assert model["chunks"] == 1 # Should be 1 chunk even though it's smaller
|
|
|
|
def test_load_chunk(self, progressive_loader):
|
|
"""Test loading individual chunks."""
|
|
chunk = progressive_loader._load_chunk("/path/to/model.pt", 0)
|
|
|
|
assert chunk["chunk_idx"] == 0
|
|
assert chunk["data"] == "chunk_0"
|
|
|
|
def test_combine_chunks(self, progressive_loader):
|
|
"""Test combining chunks into a model."""
|
|
# Add some test chunks
|
|
progressive_loader.loaded_chunks["/path/to/model.pt"] = [
|
|
{"chunk_idx": 0, "data": "chunk_0"},
|
|
{"chunk_idx": 1, "data": "chunk_1"}
|
|
]
|
|
|
|
model = progressive_loader._combine_chunks("/path/to/model.pt")
|
|
|
|
assert model["model_path"] == "/path/to/model.pt"
|
|
assert model["chunks"] == 2
|
|
|
|
|
|
class TestDomainPerformanceOptimizer:
|
|
"""Test the domain performance optimizer."""
|
|
|
|
@pytest.fixture
|
|
def performance_optimizer(self):
|
|
"""Create a performance optimizer for testing."""
|
|
return DomainPerformanceOptimizer(
|
|
cache_size=5,
|
|
background_workers=2,
|
|
batch_size=3,
|
|
enable_progressive_loading=True
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_domain_adapter(self):
|
|
"""Create a mock domain adapter."""
|
|
return Mock(spec=DomainAdapter)
|
|
|
|
@pytest.fixture
|
|
def mock_domain_detector(self):
|
|
"""Create a mock domain detector."""
|
|
return Mock(spec=DomainDetector)
|
|
|
|
def test_performance_optimizer_initialization(self, performance_optimizer):
|
|
"""Test performance optimizer initialization."""
|
|
assert performance_optimizer.cache_size == 5
|
|
assert performance_optimizer.background_loader.max_workers == 2
|
|
assert performance_optimizer.batched_inference.batch_size == 3
|
|
assert performance_optimizer.progressive_loader is not None
|
|
assert performance_optimizer.memory_optimizer is not None
|
|
assert len(performance_optimizer.inference_times) == 0
|
|
assert performance_optimizer.cache_hits == 0
|
|
assert performance_optimizer.cache_misses == 0
|
|
|
|
def test_optimize_transcription_with_batching(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
|
|
"""Test transcription optimization with batching."""
|
|
audio = "test_audio_data"
|
|
domain = "technical"
|
|
|
|
result = performance_optimizer.optimize_transcription(
|
|
audio, domain, mock_domain_adapter, mock_domain_detector,
|
|
use_batching=True, use_background_loading=False
|
|
)
|
|
|
|
assert result is not None
|
|
assert "[TECHNICAL]" in result
|
|
assert len(performance_optimizer.inference_times) > 0
|
|
assert performance_optimizer.cache_misses == 1
|
|
|
|
def test_optimize_transcription_cache_hit(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
|
|
"""Test transcription optimization with cache hit."""
|
|
audio = "test_audio_data"
|
|
domain = "technical"
|
|
|
|
# First call - should miss cache
|
|
result1 = performance_optimizer.optimize_transcription(
|
|
audio, domain, mock_domain_adapter, mock_domain_detector,
|
|
use_batching=False, use_background_loading=False
|
|
)
|
|
|
|
# Second call with same audio and domain - should hit cache
|
|
result2 = performance_optimizer.optimize_transcription(
|
|
audio, domain, mock_domain_adapter, mock_domain_detector,
|
|
use_batching=False, use_background_loading=False
|
|
)
|
|
|
|
assert result1 == result2
|
|
assert performance_optimizer.cache_hits == 1
|
|
assert performance_optimizer.cache_misses == 1
|
|
|
|
def test_preload_domain_adapters(self, performance_optimizer):
|
|
"""Test preloading domain adapters."""
|
|
domains = ["technical", "medical", "academic"]
|
|
adapter_paths = {
|
|
"technical": "/path/to/technical.pt",
|
|
"medical": "/path/to/medical.pt",
|
|
"academic": "/path/to/academic.pt"
|
|
}
|
|
|
|
performance_optimizer.preload_domain_adapters(domains, adapter_paths)
|
|
|
|
# Verify adapters are being loaded
|
|
assert len(performance_optimizer.background_loader.loading_futures) == 3
|
|
|
|
def test_get_performance_stats(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
|
|
"""Test getting performance statistics."""
|
|
# Perform some operations to generate stats
|
|
audio = "test_audio_data"
|
|
for i in range(3):
|
|
performance_optimizer.optimize_transcription(
|
|
audio, "technical", mock_domain_adapter, mock_domain_detector,
|
|
use_batching=False, use_background_loading=False
|
|
)
|
|
|
|
stats = performance_optimizer.get_performance_stats()
|
|
|
|
assert isinstance(stats, PerformanceStats)
|
|
assert stats.inference_time_ms > 0
|
|
assert stats.memory_usage_mb > 0
|
|
assert 0 <= stats.cache_hit_rate <= 1
|
|
assert stats.throughput_requests_per_second > 0
|
|
|
|
def test_cache_eviction(self, performance_optimizer, mock_domain_adapter, mock_domain_detector):
|
|
"""Test cache eviction when cache is full."""
|
|
# Fill the cache (size 5)
|
|
for i in range(6):
|
|
audio = f"audio_{i}"
|
|
performance_optimizer.optimize_transcription(
|
|
audio, "technical", mock_domain_adapter, mock_domain_detector,
|
|
use_batching=False, use_background_loading=False
|
|
)
|
|
|
|
# Cache should have evicted the oldest entry
|
|
assert len(performance_optimizer._cache) == 5
|
|
|
|
def test_shutdown(self, performance_optimizer):
|
|
"""Test performance optimizer shutdown."""
|
|
performance_optimizer.shutdown()
|
|
|
|
# Verify background loader is shutdown
|
|
assert performance_optimizer.background_loader.executor._shutdown
|
|
|
|
|
|
class TestPerformanceOptimizationIntegration:
|
|
"""Integration tests for performance optimization features."""
|
|
|
|
@pytest.fixture
|
|
def performance_optimizer(self):
|
|
"""Create a performance optimizer for integration testing."""
|
|
return DomainPerformanceOptimizer(
|
|
cache_size=10,
|
|
background_workers=2,
|
|
batch_size=4,
|
|
enable_progressive_loading=True
|
|
)
|
|
|
|
def test_end_to_end_performance_optimization(self, performance_optimizer):
|
|
"""Test end-to-end performance optimization workflow."""
|
|
mock_domain_adapter = Mock(spec=DomainAdapter)
|
|
mock_domain_detector = Mock(spec=DomainDetector)
|
|
|
|
# Preload adapters
|
|
domains = ["technical", "medical"]
|
|
adapter_paths = {
|
|
"technical": "/path/to/technical.pt",
|
|
"medical": "/path/to/medical.pt"
|
|
}
|
|
performance_optimizer.preload_domain_adapters(domains, adapter_paths)
|
|
|
|
# Perform multiple transcriptions
|
|
results = []
|
|
for i in range(5):
|
|
audio = f"audio_{i}"
|
|
domain = "technical" if i % 2 == 0 else "medical"
|
|
|
|
result = performance_optimizer.optimize_transcription(
|
|
audio, domain, mock_domain_adapter, mock_domain_detector,
|
|
use_batching=True, use_background_loading=True
|
|
)
|
|
results.append(result)
|
|
|
|
# Verify results
|
|
assert len(results) == 5
|
|
assert all(result is not None for result in results)
|
|
|
|
# Check performance stats
|
|
stats = performance_optimizer.get_performance_stats()
|
|
assert stats.inference_time_ms > 0
|
|
assert stats.throughput_requests_per_second > 0
|
|
|
|
def test_concurrent_access(self, performance_optimizer):
|
|
"""Test concurrent access to performance optimizer."""
|
|
mock_domain_adapter = Mock(spec=DomainAdapter)
|
|
mock_domain_detector = Mock(spec=DomainDetector)
|
|
|
|
def transcription_worker(worker_id):
|
|
for i in range(3):
|
|
audio = f"audio_worker_{worker_id}_{i}"
|
|
domain = "technical" if i % 2 == 0 else "medical"
|
|
|
|
result = performance_optimizer.optimize_transcription(
|
|
audio, domain, mock_domain_adapter, mock_domain_detector,
|
|
use_batching=True, use_background_loading=False
|
|
)
|
|
assert result is not None
|
|
|
|
# Create multiple threads
|
|
threads = []
|
|
for i in range(3):
|
|
thread = threading.Thread(target=transcription_worker, args=(i,))
|
|
threads.append(thread)
|
|
thread.start()
|
|
|
|
# Wait for all threads to complete
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
# Verify performance stats
|
|
stats = performance_optimizer.get_performance_stats()
|
|
assert stats.throughput_requests_per_second > 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|