263 lines
12 KiB
Python
263 lines
12 KiB
Python
"""Tests for the diarization configuration manager."""
|
|
|
|
import pytest
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import numpy as np
|
|
|
|
from src.services.diarization_config_manager import (
|
|
DiarizationConfigManager, SystemResources, OptimizationRecommendations
|
|
)
|
|
from src.services.diarization_types import DiarizationConfig
|
|
|
|
|
|
class TestDiarizationConfigManager:
|
|
"""Test cases for DiarizationConfigManager."""
|
|
|
|
@pytest.fixture
|
|
def config_manager(self):
|
|
"""Create a DiarizationConfigManager instance for testing."""
|
|
return DiarizationConfigManager()
|
|
|
|
@pytest.fixture
|
|
def mock_system_resources(self):
|
|
"""Create mock system resources for testing."""
|
|
return SystemResources(
|
|
total_memory_gb=16.0,
|
|
available_memory_gb=12.0,
|
|
cpu_count=8,
|
|
gpu_available=True,
|
|
gpu_memory_gb=8.0,
|
|
gpu_name="NVIDIA RTX 3080"
|
|
)
|
|
|
|
def test_initialization(self, config_manager):
|
|
"""Test configuration manager initialization."""
|
|
assert config_manager.base_config is not None
|
|
assert config_manager.system_resources is not None
|
|
assert config_manager.memory_optimizer is not None
|
|
|
|
# Check that system resources are analyzed
|
|
assert config_manager.system_resources.total_memory_gb > 0
|
|
assert config_manager.system_resources.cpu_count > 0
|
|
|
|
@patch('src.services.diarization_config_manager.psutil.virtual_memory')
|
|
@patch('src.services.diarization_config_manager.psutil.cpu_count')
|
|
@patch('src.services.diarization_config_manager.torch.cuda.is_available')
|
|
def test_analyze_system_resources(self, mock_cuda_available, mock_cpu_count, mock_virtual_memory):
|
|
"""Test system resource analysis."""
|
|
# Mock system resources
|
|
mock_memory = Mock()
|
|
mock_memory.total = 16 * 1024**3 # 16GB
|
|
mock_memory.available = 12 * 1024**3 # 12GB
|
|
mock_virtual_memory.return_value = mock_memory
|
|
|
|
mock_cpu_count.return_value = 8
|
|
mock_cuda_available.return_value = True
|
|
|
|
# Mock GPU properties
|
|
with patch('src.services.diarization_config_manager.torch.cuda.get_device_properties') as mock_gpu_props:
|
|
mock_gpu_props.return_value.total_memory = 8 * 1024**3 # 8GB
|
|
|
|
with patch('src.services.diarization_config_manager.torch.cuda.get_device_name') as mock_gpu_name:
|
|
mock_gpu_name.return_value = "NVIDIA RTX 3080"
|
|
|
|
config_manager = DiarizationConfigManager()
|
|
|
|
# Verify system resources
|
|
resources = config_manager.system_resources
|
|
assert resources.total_memory_gb == 16.0
|
|
assert resources.available_memory_gb == 12.0
|
|
assert resources.cpu_count == 8
|
|
assert resources.gpu_available is True
|
|
assert resources.gpu_memory_gb == 8.0
|
|
assert resources.gpu_name == "NVIDIA RTX 3080"
|
|
|
|
def test_get_optimization_recommendations_high_memory(self, config_manager):
|
|
"""Test optimization recommendations for high memory systems."""
|
|
# Mock high memory system
|
|
config_manager.system_resources.available_memory_gb = 16.0
|
|
|
|
recommendations = config_manager.get_optimization_recommendations()
|
|
|
|
assert recommendations.recommended_batch_size == 4
|
|
assert recommendations.recommended_chunk_duration == 900 # 15 minutes
|
|
assert recommendations.enable_quantization is False
|
|
assert recommendations.enable_offloading is False
|
|
assert recommendations.enable_chunking is False
|
|
assert recommendations.target_sample_rate == 16000
|
|
assert "quantization" not in recommendations.memory_optimizations
|
|
assert "model_offloading" not in recommendations.memory_optimizations
|
|
|
|
def test_get_optimization_recommendations_low_memory(self, config_manager):
|
|
"""Test optimization recommendations for low memory systems."""
|
|
# Mock low memory system
|
|
config_manager.system_resources.available_memory_gb = 4.0
|
|
|
|
recommendations = config_manager.get_optimization_recommendations()
|
|
|
|
assert recommendations.recommended_batch_size == 1
|
|
assert recommendations.recommended_chunk_duration == 300 # 5 minutes
|
|
assert recommendations.enable_quantization is True
|
|
assert recommendations.enable_offloading is True
|
|
assert recommendations.enable_chunking is True
|
|
assert recommendations.target_sample_rate == 8000
|
|
assert "quantization" in recommendations.memory_optimizations
|
|
assert "model_offloading" in recommendations.memory_optimizations
|
|
assert "audio_chunking" in recommendations.memory_optimizations
|
|
assert "downsampling" in recommendations.memory_optimizations
|
|
|
|
def test_create_optimized_config(self, config_manager):
|
|
"""Test creation of optimized configuration."""
|
|
# Mock high memory system
|
|
config_manager.system_resources.available_memory_gb = 12.0
|
|
config_manager.system_resources.gpu_available = True
|
|
config_manager.system_resources.gpu_memory_gb = 6.0
|
|
|
|
config = config_manager.create_optimized_config(audio_duration_seconds=1800) # 30 minutes
|
|
|
|
assert config.batch_size == 2
|
|
assert config.enable_quantization is False
|
|
assert config.enable_model_offloading is False
|
|
assert config.enable_chunking is True
|
|
assert config.target_sample_rate == 16000
|
|
assert config.chunk_duration_seconds == 900 # Should be 900 for 12GB available memory (15 minutes)
|
|
assert config.device == "cuda"
|
|
assert config.max_memory_gb <= 12.0 * 0.8 # 80% of available memory
|
|
|
|
def test_create_optimized_config_short_audio(self, config_manager):
|
|
"""Test optimized configuration for short audio files."""
|
|
config_manager.system_resources.available_memory_gb = 8.0
|
|
|
|
config = config_manager.create_optimized_config(audio_duration_seconds=300) # 5 minutes
|
|
|
|
assert config.enable_chunking is False # No chunking needed for short audio
|
|
|
|
@patch('librosa.load')
|
|
@patch('librosa.feature.spectral_centroid')
|
|
@patch('librosa.feature.spectral_rolloff')
|
|
def test_estimate_speaker_count(self, mock_rolloff, mock_centroid, mock_load, config_manager):
|
|
"""Test speaker count estimation."""
|
|
# Mock audio analysis
|
|
mock_load.return_value = (np.random.random(16000), 16000) # 1 second of audio
|
|
|
|
# Mock spectral features
|
|
mock_centroid.return_value = np.array([[0.5, 0.6, 0.4]])
|
|
mock_rolloff.return_value = np.array([[0.7, 0.8, 0.6]])
|
|
|
|
audio_path = Path("test_audio.wav")
|
|
config = DiarizationConfig(enable_speaker_estimation=True)
|
|
|
|
estimated_speakers = config_manager.estimate_speaker_count(audio_path, config)
|
|
|
|
assert estimated_speakers is not None
|
|
assert 1 <= estimated_speakers <= 4
|
|
|
|
def test_estimate_speaker_count_disabled(self, config_manager):
|
|
"""Test speaker count estimation when disabled."""
|
|
audio_path = Path("test_audio.wav")
|
|
config = DiarizationConfig(
|
|
enable_speaker_estimation=False,
|
|
num_speakers=3
|
|
)
|
|
|
|
estimated_speakers = config_manager.estimate_speaker_count(audio_path, config)
|
|
|
|
assert estimated_speakers == 3 # Should return configured value
|
|
|
|
def test_validate_config_valid(self, config_manager):
|
|
"""Test configuration validation with valid config."""
|
|
config = DiarizationConfig(
|
|
max_memory_gb=4.0,
|
|
batch_size=2,
|
|
chunk_duration_seconds=600,
|
|
device="cpu"
|
|
)
|
|
|
|
is_valid, warnings = config_manager.validate_config(config)
|
|
|
|
assert is_valid is True
|
|
assert len(warnings) == 0
|
|
|
|
def test_validate_config_invalid_memory(self, config_manager):
|
|
"""Test configuration validation with invalid memory requirements."""
|
|
config = DiarizationConfig(
|
|
max_memory_gb=20.0, # More than available
|
|
batch_size=2
|
|
)
|
|
|
|
is_valid, warnings = config_manager.validate_config(config)
|
|
|
|
assert is_valid is False
|
|
assert len(warnings) > 0
|
|
assert any("memory" in warning.lower() for warning in warnings)
|
|
|
|
def test_validate_config_large_batch_size(self, config_manager):
|
|
"""Test configuration validation with large batch size."""
|
|
config = DiarizationConfig(
|
|
max_memory_gb=4.0,
|
|
batch_size=8 # Large batch size
|
|
)
|
|
|
|
is_valid, warnings = config_manager.validate_config(config)
|
|
|
|
assert is_valid is True # Should still be valid but with warning
|
|
assert len(warnings) > 0
|
|
assert any("batch size" in warning.lower() for warning in warnings)
|
|
|
|
def test_get_memory_usage_estimate(self, config_manager):
|
|
"""Test memory usage estimation."""
|
|
config = DiarizationConfig(
|
|
target_sample_rate=16000,
|
|
enable_quantization=True,
|
|
enable_chunking=True
|
|
)
|
|
|
|
audio_duration_seconds = 3600 # 1 hour
|
|
|
|
estimate = config_manager.get_memory_usage_estimate(config, audio_duration_seconds)
|
|
|
|
assert "model_memory_gb" in estimate
|
|
assert "audio_memory_gb" in estimate
|
|
assert "processing_overhead_gb" in estimate
|
|
assert "total_memory_gb" in estimate
|
|
assert "available_memory_gb" in estimate
|
|
|
|
# Check that quantization reduces model memory
|
|
assert estimate["model_memory_gb"] == 1.0 # 50% of 2.0GB
|
|
|
|
# Check that audio memory is calculated correctly
|
|
expected_audio_memory = (16000 * 3600 * 4) / (1024**3) # ~0.21GB
|
|
assert abs(estimate["audio_memory_gb"] - expected_audio_memory) < 0.1
|
|
|
|
def test_create_merging_config_high_quality(self, config_manager):
|
|
"""Test merging configuration creation for high quality diarization."""
|
|
diarization_config = DiarizationConfig(quality_threshold=0.9)
|
|
|
|
merging_config = config_manager.create_merging_config(diarization_config)
|
|
|
|
assert merging_config.min_overlap_ratio == 0.6
|
|
assert merging_config.min_confidence_threshold == 0.5
|
|
assert merging_config.min_segment_duration == diarization_config.min_duration
|
|
|
|
def test_create_merging_config_low_quality(self, config_manager):
|
|
"""Test merging configuration creation for low quality diarization."""
|
|
diarization_config = DiarizationConfig(quality_threshold=0.5)
|
|
|
|
merging_config = config_manager.create_merging_config(diarization_config)
|
|
|
|
assert merging_config.min_overlap_ratio == 0.4
|
|
assert merging_config.min_confidence_threshold == 0.3
|
|
assert merging_config.min_segment_duration == diarization_config.min_duration
|
|
|
|
def test_create_merging_config_medium_quality(self, config_manager):
|
|
"""Test merging configuration creation for medium quality diarization."""
|
|
diarization_config = DiarizationConfig(quality_threshold=0.7)
|
|
|
|
merging_config = config_manager.create_merging_config(diarization_config)
|
|
|
|
assert merging_config.min_overlap_ratio == 0.5
|
|
assert merging_config.min_confidence_threshold == 0.4
|
|
assert merging_config.min_segment_duration == diarization_config.min_duration
|