trax/tests/test_diarization_config_man...

263 lines
12 KiB
Python

"""Tests for the diarization configuration manager."""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
import numpy as np
from src.services.diarization_config_manager import (
DiarizationConfigManager, SystemResources, OptimizationRecommendations
)
from src.services.diarization_types import DiarizationConfig
class TestDiarizationConfigManager:
"""Test cases for DiarizationConfigManager."""
@pytest.fixture
def config_manager(self):
"""Create a DiarizationConfigManager instance for testing."""
return DiarizationConfigManager()
@pytest.fixture
def mock_system_resources(self):
"""Create mock system resources for testing."""
return SystemResources(
total_memory_gb=16.0,
available_memory_gb=12.0,
cpu_count=8,
gpu_available=True,
gpu_memory_gb=8.0,
gpu_name="NVIDIA RTX 3080"
)
def test_initialization(self, config_manager):
"""Test configuration manager initialization."""
assert config_manager.base_config is not None
assert config_manager.system_resources is not None
assert config_manager.memory_optimizer is not None
# Check that system resources are analyzed
assert config_manager.system_resources.total_memory_gb > 0
assert config_manager.system_resources.cpu_count > 0
@patch('src.services.diarization_config_manager.psutil.virtual_memory')
@patch('src.services.diarization_config_manager.psutil.cpu_count')
@patch('src.services.diarization_config_manager.torch.cuda.is_available')
def test_analyze_system_resources(self, mock_cuda_available, mock_cpu_count, mock_virtual_memory):
"""Test system resource analysis."""
# Mock system resources
mock_memory = Mock()
mock_memory.total = 16 * 1024**3 # 16GB
mock_memory.available = 12 * 1024**3 # 12GB
mock_virtual_memory.return_value = mock_memory
mock_cpu_count.return_value = 8
mock_cuda_available.return_value = True
# Mock GPU properties
with patch('src.services.diarization_config_manager.torch.cuda.get_device_properties') as mock_gpu_props:
mock_gpu_props.return_value.total_memory = 8 * 1024**3 # 8GB
with patch('src.services.diarization_config_manager.torch.cuda.get_device_name') as mock_gpu_name:
mock_gpu_name.return_value = "NVIDIA RTX 3080"
config_manager = DiarizationConfigManager()
# Verify system resources
resources = config_manager.system_resources
assert resources.total_memory_gb == 16.0
assert resources.available_memory_gb == 12.0
assert resources.cpu_count == 8
assert resources.gpu_available is True
assert resources.gpu_memory_gb == 8.0
assert resources.gpu_name == "NVIDIA RTX 3080"
def test_get_optimization_recommendations_high_memory(self, config_manager):
"""Test optimization recommendations for high memory systems."""
# Mock high memory system
config_manager.system_resources.available_memory_gb = 16.0
recommendations = config_manager.get_optimization_recommendations()
assert recommendations.recommended_batch_size == 4
assert recommendations.recommended_chunk_duration == 900 # 15 minutes
assert recommendations.enable_quantization is False
assert recommendations.enable_offloading is False
assert recommendations.enable_chunking is False
assert recommendations.target_sample_rate == 16000
assert "quantization" not in recommendations.memory_optimizations
assert "model_offloading" not in recommendations.memory_optimizations
def test_get_optimization_recommendations_low_memory(self, config_manager):
"""Test optimization recommendations for low memory systems."""
# Mock low memory system
config_manager.system_resources.available_memory_gb = 4.0
recommendations = config_manager.get_optimization_recommendations()
assert recommendations.recommended_batch_size == 1
assert recommendations.recommended_chunk_duration == 300 # 5 minutes
assert recommendations.enable_quantization is True
assert recommendations.enable_offloading is True
assert recommendations.enable_chunking is True
assert recommendations.target_sample_rate == 8000
assert "quantization" in recommendations.memory_optimizations
assert "model_offloading" in recommendations.memory_optimizations
assert "audio_chunking" in recommendations.memory_optimizations
assert "downsampling" in recommendations.memory_optimizations
def test_create_optimized_config(self, config_manager):
"""Test creation of optimized configuration."""
# Mock high memory system
config_manager.system_resources.available_memory_gb = 12.0
config_manager.system_resources.gpu_available = True
config_manager.system_resources.gpu_memory_gb = 6.0
config = config_manager.create_optimized_config(audio_duration_seconds=1800) # 30 minutes
assert config.batch_size == 2
assert config.enable_quantization is False
assert config.enable_model_offloading is False
assert config.enable_chunking is True
assert config.target_sample_rate == 16000
assert config.chunk_duration_seconds == 900 # Should be 900 for 12GB available memory (15 minutes)
assert config.device == "cuda"
assert config.max_memory_gb <= 12.0 * 0.8 # 80% of available memory
def test_create_optimized_config_short_audio(self, config_manager):
"""Test optimized configuration for short audio files."""
config_manager.system_resources.available_memory_gb = 8.0
config = config_manager.create_optimized_config(audio_duration_seconds=300) # 5 minutes
assert config.enable_chunking is False # No chunking needed for short audio
@patch('librosa.load')
@patch('librosa.feature.spectral_centroid')
@patch('librosa.feature.spectral_rolloff')
def test_estimate_speaker_count(self, mock_rolloff, mock_centroid, mock_load, config_manager):
"""Test speaker count estimation."""
# Mock audio analysis
mock_load.return_value = (np.random.random(16000), 16000) # 1 second of audio
# Mock spectral features
mock_centroid.return_value = np.array([[0.5, 0.6, 0.4]])
mock_rolloff.return_value = np.array([[0.7, 0.8, 0.6]])
audio_path = Path("test_audio.wav")
config = DiarizationConfig(enable_speaker_estimation=True)
estimated_speakers = config_manager.estimate_speaker_count(audio_path, config)
assert estimated_speakers is not None
assert 1 <= estimated_speakers <= 4
def test_estimate_speaker_count_disabled(self, config_manager):
"""Test speaker count estimation when disabled."""
audio_path = Path("test_audio.wav")
config = DiarizationConfig(
enable_speaker_estimation=False,
num_speakers=3
)
estimated_speakers = config_manager.estimate_speaker_count(audio_path, config)
assert estimated_speakers == 3 # Should return configured value
def test_validate_config_valid(self, config_manager):
"""Test configuration validation with valid config."""
config = DiarizationConfig(
max_memory_gb=4.0,
batch_size=2,
chunk_duration_seconds=600,
device="cpu"
)
is_valid, warnings = config_manager.validate_config(config)
assert is_valid is True
assert len(warnings) == 0
def test_validate_config_invalid_memory(self, config_manager):
"""Test configuration validation with invalid memory requirements."""
config = DiarizationConfig(
max_memory_gb=20.0, # More than available
batch_size=2
)
is_valid, warnings = config_manager.validate_config(config)
assert is_valid is False
assert len(warnings) > 0
assert any("memory" in warning.lower() for warning in warnings)
def test_validate_config_large_batch_size(self, config_manager):
"""Test configuration validation with large batch size."""
config = DiarizationConfig(
max_memory_gb=4.0,
batch_size=8 # Large batch size
)
is_valid, warnings = config_manager.validate_config(config)
assert is_valid is True # Should still be valid but with warning
assert len(warnings) > 0
assert any("batch size" in warning.lower() for warning in warnings)
def test_get_memory_usage_estimate(self, config_manager):
"""Test memory usage estimation."""
config = DiarizationConfig(
target_sample_rate=16000,
enable_quantization=True,
enable_chunking=True
)
audio_duration_seconds = 3600 # 1 hour
estimate = config_manager.get_memory_usage_estimate(config, audio_duration_seconds)
assert "model_memory_gb" in estimate
assert "audio_memory_gb" in estimate
assert "processing_overhead_gb" in estimate
assert "total_memory_gb" in estimate
assert "available_memory_gb" in estimate
# Check that quantization reduces model memory
assert estimate["model_memory_gb"] == 1.0 # 50% of 2.0GB
# Check that audio memory is calculated correctly
expected_audio_memory = (16000 * 3600 * 4) / (1024**3) # ~0.21GB
assert abs(estimate["audio_memory_gb"] - expected_audio_memory) < 0.1
def test_create_merging_config_high_quality(self, config_manager):
"""Test merging configuration creation for high quality diarization."""
diarization_config = DiarizationConfig(quality_threshold=0.9)
merging_config = config_manager.create_merging_config(diarization_config)
assert merging_config.min_overlap_ratio == 0.6
assert merging_config.min_confidence_threshold == 0.5
assert merging_config.min_segment_duration == diarization_config.min_duration
def test_create_merging_config_low_quality(self, config_manager):
"""Test merging configuration creation for low quality diarization."""
diarization_config = DiarizationConfig(quality_threshold=0.5)
merging_config = config_manager.create_merging_config(diarization_config)
assert merging_config.min_overlap_ratio == 0.4
assert merging_config.min_confidence_threshold == 0.3
assert merging_config.min_segment_duration == diarization_config.min_duration
def test_create_merging_config_medium_quality(self, config_manager):
"""Test merging configuration creation for medium quality diarization."""
diarization_config = DiarizationConfig(quality_threshold=0.7)
merging_config = config_manager.create_merging_config(diarization_config)
assert merging_config.min_overlap_ratio == 0.5
assert merging_config.min_confidence_threshold == 0.4
assert merging_config.min_segment_duration == diarization_config.min_duration