trax/tests/test_diarization_service.py

329 lines
13 KiB
Python

"""Tests for diarization services."""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from src.services.diarization_types import (
DiarizationConfig, SpeakerSegment, DiarizationResult,
SpeakerProfile, ProfileMatch, ProcessingResult
)
from src.services.diarization_service import DiarizationManager
from src.services.speaker_profile_manager import SpeakerProfileManager
from src.services.parallel_processor import ParallelProcessor
@pytest.fixture
def sample_audio_path():
"""Provide a sample audio file path for testing."""
return Path("tests/sample_5s.wav")
@pytest.fixture
def diarization_manager():
"""Provide a DiarizationManager instance for testing."""
config = DiarizationConfig(
model_path="pyannote/speaker-diarization-3.0",
device="cpu",
memory_optimization=False
)
return DiarizationManager(config)
@pytest.fixture
def speaker_profile_manager():
"""Provide a SpeakerProfileManager instance for testing."""
return SpeakerProfileManager(storage_dir=Path("tests/temp_profiles"))
@pytest.fixture
def parallel_processor():
"""Provide a ParallelProcessor instance for testing."""
from src.services.diarization_types import ParallelProcessingConfig
config = ParallelProcessingConfig(max_workers=2)
return ParallelProcessor(config)
class TestDiarizationManager:
"""Test cases for DiarizationManager."""
def test_initialization(self, diarization_manager):
"""Test DiarizationManager initialization."""
assert diarization_manager.config.model_path == "pyannote/speaker-diarization-3.0"
assert diarization_manager._device in ["cpu", "cuda"]
assert not diarization_manager._initialized
@patch('src.services.diarization_utils.determine_device')
def test_device_determination(self, mock_determine_device, diarization_manager):
"""Test device determination logic."""
mock_determine_device.return_value = "cpu"
device = diarization_manager._device
assert device == "cpu"
@patch('pyannote.audio.Pipeline.from_pretrained')
def test_pipeline_loading(self, mock_pipeline, diarization_manager):
"""Test pipeline loading with error handling."""
mock_pipeline.return_value = Mock()
pipeline = diarization_manager._load_pipeline()
assert pipeline is not None
assert diarization_manager._initialized
@patch('pyannote.audio.Pipeline.from_pretrained')
def test_pipeline_loading_error(self, mock_pipeline, diarization_manager):
"""Test pipeline loading error handling."""
mock_pipeline.side_effect = Exception("Model loading failed")
with pytest.raises(Exception):
diarization_manager._load_pipeline()
@patch.object(DiarizationManager, '_load_pipeline')
def test_process_audio_success(self, mock_load_pipeline, diarization_manager, sample_audio_path):
"""Test successful audio processing."""
# Mock pipeline and annotation
mock_pipeline = Mock()
mock_annotation = Mock()
mock_annotation.itertracks.return_value = [
(Mock(start=0.0, end=2.0, duration=2.0), 1, "SPEAKER_00"),
(Mock(start=2.0, end=4.0, duration=2.0), 2, "SPEAKER_01")
]
mock_pipeline.return_value = mock_annotation
mock_load_pipeline.return_value = mock_pipeline
result = diarization_manager.process_audio(sample_audio_path)
assert isinstance(result, DiarizationResult)
assert result.speaker_count == 2
assert len(result.segments) == 2
assert result.processing_time > 0
def test_process_audio_file_not_found(self, diarization_manager):
"""Test audio processing with non-existent file."""
with pytest.raises(FileNotFoundError):
diarization_manager.process_audio(Path("nonexistent.wav"))
def test_estimate_speaker_count(self, diarization_manager, sample_audio_path):
"""Test speaker count estimation."""
with patch.object(diarization_manager, 'process_audio') as mock_process:
mock_result = Mock()
mock_result.speaker_count = 3
mock_process.return_value = mock_result
count = diarization_manager.estimate_speaker_count(sample_audio_path)
assert count == 3
def test_get_speaker_segments(self, diarization_manager, sample_audio_path):
"""Test getting segments for specific speaker."""
with patch.object(diarization_manager, 'process_audio') as mock_process:
mock_result = Mock()
mock_result.segments = [
SpeakerSegment(0.0, 2.0, "SPEAKER_00", 0.8),
SpeakerSegment(2.0, 4.0, "SPEAKER_01", 0.9)
]
mock_process.return_value = mock_result
segments = diarization_manager.get_speaker_segments(sample_audio_path, "SPEAKER_00")
assert len(segments) == 1
assert segments[0].speaker_id == "SPEAKER_00"
def test_cleanup(self, diarization_manager):
"""Test resource cleanup."""
diarization_manager._pipeline = Mock()
diarization_manager._initialized = True
diarization_manager.cleanup()
assert diarization_manager._pipeline is None
assert not diarization_manager._initialized
class TestSpeakerProfileManager:
"""Test cases for SpeakerProfileManager."""
def test_initialization(self, speaker_profile_manager):
"""Test SpeakerProfileManager initialization."""
assert speaker_profile_manager.storage_dir.exists()
assert len(speaker_profile_manager.profiles) == 0
assert speaker_profile_manager.similarity_threshold == 0.7
def test_add_speaker_success(self, speaker_profile_manager):
"""Test adding a speaker profile."""
import numpy as np
speaker_id = "test_speaker"
embedding = np.random.rand(512)
profile = speaker_profile_manager.add_speaker(speaker_id, embedding, name="Test Speaker")
assert profile.speaker_id == speaker_id
assert profile.name == "Test Speaker"
assert speaker_id in speaker_profile_manager.profiles
assert speaker_id in speaker_profile_manager.embeddings_cache
def test_add_speaker_validation_error(self, speaker_profile_manager):
"""Test adding speaker with invalid data."""
import numpy as np
# Empty speaker ID
with pytest.raises(Exception):
speaker_profile_manager.add_speaker("", np.random.rand(512))
# Empty embedding
with pytest.raises(Exception):
speaker_profile_manager.add_speaker("test", np.array([]))
def test_get_speaker(self, speaker_profile_manager):
"""Test getting a speaker profile."""
import numpy as np
speaker_id = "test_speaker"
embedding = np.random.rand(512)
speaker_profile_manager.add_speaker(speaker_id, embedding)
profile = speaker_profile_manager.get_speaker(speaker_id)
assert profile is not None
assert profile.speaker_id == speaker_id
def test_find_similar_speakers(self, speaker_profile_manager):
"""Test finding similar speakers."""
import numpy as np
# Add test profiles
embedding1 = np.random.rand(512)
embedding2 = np.random.rand(512)
speaker_profile_manager.add_speaker("speaker1", embedding1)
speaker_profile_manager.add_speaker("speaker2", embedding2)
# Find similar speakers
matches = speaker_profile_manager.find_similar_speakers(embedding1, threshold=0.5)
assert len(matches) >= 1
def test_update_speaker(self, speaker_profile_manager):
"""Test updating a speaker profile."""
import numpy as np
speaker_id = "test_speaker"
embedding = np.random.rand(512)
speaker_profile_manager.add_speaker(speaker_id, embedding)
new_embedding = np.random.rand(512)
updated_profile = speaker_profile_manager.update_speaker(
speaker_id, new_embedding, name="Updated Name"
)
assert updated_profile.name == "Updated Name"
assert np.array_equal(updated_profile.embedding, new_embedding)
def test_remove_speaker(self, speaker_profile_manager):
"""Test removing a speaker profile."""
import numpy as np
speaker_id = "test_speaker"
embedding = np.random.rand(512)
speaker_profile_manager.add_speaker(speaker_id, embedding)
# Remove speaker
success = speaker_profile_manager.remove_speaker(speaker_id)
assert success
assert speaker_id not in speaker_profile_manager.profiles
def test_get_profile_stats(self, speaker_profile_manager):
"""Test getting profile statistics."""
stats = speaker_profile_manager.get_profile_stats()
assert "total_profiles" in stats
assert "profiles_with_embeddings" in stats
def test_cleanup(self, speaker_profile_manager):
"""Test cleanup method."""
speaker_profile_manager.cleanup()
# Should not raise any exceptions
class TestParallelProcessor:
"""Test cases for ParallelProcessor."""
def test_initialization(self, parallel_processor):
"""Test ParallelProcessor initialization."""
assert parallel_processor.config.max_workers == 2
assert parallel_processor.executor is not None
assert len(parallel_processor.stats) > 0
@patch.object(ParallelProcessor, '_initialize_services')
def test_process_file_success(self, mock_init_services, parallel_processor, sample_audio_path):
"""Test successful file processing."""
# Mock services
parallel_processor.diarization_manager = Mock()
parallel_processor.transcription_service = Mock()
# Mock results
mock_diarization_result = Mock()
mock_transcription_result = Mock()
parallel_processor.diarization_manager.process_audio.return_value = mock_diarization_result
parallel_processor.transcription_service.transcribe_file.return_value = mock_transcription_result
result = parallel_processor.process_file(sample_audio_path)
assert isinstance(result, ProcessingResult)
assert result.success
assert result.task_id is not None
def test_process_file_not_found(self, parallel_processor):
"""Test processing non-existent file."""
with pytest.raises(Exception):
parallel_processor.process_file(Path("nonexistent.wav"))
def test_process_batch(self, parallel_processor):
"""Test batch processing."""
audio_paths = [Path("tests/sample_5s.wav"), Path("tests/sample_30s.mp3")]
with patch.object(parallel_processor, 'process_file') as mock_process:
mock_process.return_value = ProcessingResult(task_id="test", success=True)
results = parallel_processor.process_batch(audio_paths)
assert len(results) == 2
def test_get_processing_stats(self, parallel_processor):
"""Test getting processing statistics."""
stats = parallel_processor.get_processing_stats()
assert "total_files_processed" in stats
assert "success_rate" in stats
def test_estimate_speedup(self, parallel_processor):
"""Test speedup estimation."""
speedup = parallel_processor.estimate_speedup(10.0, 5.0)
assert speedup == 2.0
def test_cleanup(self, parallel_processor):
"""Test cleanup method."""
parallel_processor.cleanup()
# Should not raise any exceptions
class TestIntegration:
"""Integration tests for the diarization pipeline."""
def test_full_pipeline_integration(self, sample_audio_path):
"""Test full pipeline integration."""
# This would require actual audio files and models
# For now, we'll just test that the components can be instantiated together
diarization_manager = DiarizationManager()
profile_manager = SpeakerProfileManager()
parallel_processor = ParallelProcessor()
assert diarization_manager is not None
assert profile_manager is not None
assert parallel_processor is not None
def test_memory_optimization(self):
"""Test memory optimization features."""
config = DiarizationConfig(memory_optimization=True)
manager = DiarizationManager(config)
assert manager.config.memory_optimization is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])