329 lines
13 KiB
Python
329 lines
13 KiB
Python
"""Tests for diarization services."""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from src.services.diarization_types import (
|
|
DiarizationConfig, SpeakerSegment, DiarizationResult,
|
|
SpeakerProfile, ProfileMatch, ProcessingResult
|
|
)
|
|
from src.services.diarization_service import DiarizationManager
|
|
from src.services.speaker_profile_manager import SpeakerProfileManager
|
|
from src.services.parallel_processor import ParallelProcessor
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_audio_path():
|
|
"""Provide a sample audio file path for testing."""
|
|
return Path("tests/sample_5s.wav")
|
|
|
|
|
|
@pytest.fixture
|
|
def diarization_manager():
|
|
"""Provide a DiarizationManager instance for testing."""
|
|
config = DiarizationConfig(
|
|
model_path="pyannote/speaker-diarization-3.0",
|
|
device="cpu",
|
|
memory_optimization=False
|
|
)
|
|
return DiarizationManager(config)
|
|
|
|
|
|
@pytest.fixture
|
|
def speaker_profile_manager():
|
|
"""Provide a SpeakerProfileManager instance for testing."""
|
|
return SpeakerProfileManager(storage_dir=Path("tests/temp_profiles"))
|
|
|
|
|
|
@pytest.fixture
|
|
def parallel_processor():
|
|
"""Provide a ParallelProcessor instance for testing."""
|
|
from src.services.diarization_types import ParallelProcessingConfig
|
|
config = ParallelProcessingConfig(max_workers=2)
|
|
return ParallelProcessor(config)
|
|
|
|
|
|
class TestDiarizationManager:
|
|
"""Test cases for DiarizationManager."""
|
|
|
|
def test_initialization(self, diarization_manager):
|
|
"""Test DiarizationManager initialization."""
|
|
assert diarization_manager.config.model_path == "pyannote/speaker-diarization-3.0"
|
|
assert diarization_manager._device in ["cpu", "cuda"]
|
|
assert not diarization_manager._initialized
|
|
|
|
@patch('src.services.diarization_utils.determine_device')
|
|
def test_device_determination(self, mock_determine_device, diarization_manager):
|
|
"""Test device determination logic."""
|
|
mock_determine_device.return_value = "cpu"
|
|
device = diarization_manager._device
|
|
assert device == "cpu"
|
|
|
|
@patch('pyannote.audio.Pipeline.from_pretrained')
|
|
def test_pipeline_loading(self, mock_pipeline, diarization_manager):
|
|
"""Test pipeline loading with error handling."""
|
|
mock_pipeline.return_value = Mock()
|
|
|
|
pipeline = diarization_manager._load_pipeline()
|
|
assert pipeline is not None
|
|
assert diarization_manager._initialized
|
|
|
|
@patch('pyannote.audio.Pipeline.from_pretrained')
|
|
def test_pipeline_loading_error(self, mock_pipeline, diarization_manager):
|
|
"""Test pipeline loading error handling."""
|
|
mock_pipeline.side_effect = Exception("Model loading failed")
|
|
|
|
with pytest.raises(Exception):
|
|
diarization_manager._load_pipeline()
|
|
|
|
@patch.object(DiarizationManager, '_load_pipeline')
|
|
def test_process_audio_success(self, mock_load_pipeline, diarization_manager, sample_audio_path):
|
|
"""Test successful audio processing."""
|
|
# Mock pipeline and annotation
|
|
mock_pipeline = Mock()
|
|
mock_annotation = Mock()
|
|
mock_annotation.itertracks.return_value = [
|
|
(Mock(start=0.0, end=2.0, duration=2.0), 1, "SPEAKER_00"),
|
|
(Mock(start=2.0, end=4.0, duration=2.0), 2, "SPEAKER_01")
|
|
]
|
|
mock_pipeline.return_value = mock_annotation
|
|
mock_load_pipeline.return_value = mock_pipeline
|
|
|
|
result = diarization_manager.process_audio(sample_audio_path)
|
|
|
|
assert isinstance(result, DiarizationResult)
|
|
assert result.speaker_count == 2
|
|
assert len(result.segments) == 2
|
|
assert result.processing_time > 0
|
|
|
|
def test_process_audio_file_not_found(self, diarization_manager):
|
|
"""Test audio processing with non-existent file."""
|
|
with pytest.raises(FileNotFoundError):
|
|
diarization_manager.process_audio(Path("nonexistent.wav"))
|
|
|
|
def test_estimate_speaker_count(self, diarization_manager, sample_audio_path):
|
|
"""Test speaker count estimation."""
|
|
with patch.object(diarization_manager, 'process_audio') as mock_process:
|
|
mock_result = Mock()
|
|
mock_result.speaker_count = 3
|
|
mock_process.return_value = mock_result
|
|
|
|
count = diarization_manager.estimate_speaker_count(sample_audio_path)
|
|
assert count == 3
|
|
|
|
def test_get_speaker_segments(self, diarization_manager, sample_audio_path):
|
|
"""Test getting segments for specific speaker."""
|
|
with patch.object(diarization_manager, 'process_audio') as mock_process:
|
|
mock_result = Mock()
|
|
mock_result.segments = [
|
|
SpeakerSegment(0.0, 2.0, "SPEAKER_00", 0.8),
|
|
SpeakerSegment(2.0, 4.0, "SPEAKER_01", 0.9)
|
|
]
|
|
mock_process.return_value = mock_result
|
|
|
|
segments = diarization_manager.get_speaker_segments(sample_audio_path, "SPEAKER_00")
|
|
assert len(segments) == 1
|
|
assert segments[0].speaker_id == "SPEAKER_00"
|
|
|
|
def test_cleanup(self, diarization_manager):
|
|
"""Test resource cleanup."""
|
|
diarization_manager._pipeline = Mock()
|
|
diarization_manager._initialized = True
|
|
|
|
diarization_manager.cleanup()
|
|
|
|
assert diarization_manager._pipeline is None
|
|
assert not diarization_manager._initialized
|
|
|
|
|
|
class TestSpeakerProfileManager:
|
|
"""Test cases for SpeakerProfileManager."""
|
|
|
|
def test_initialization(self, speaker_profile_manager):
|
|
"""Test SpeakerProfileManager initialization."""
|
|
assert speaker_profile_manager.storage_dir.exists()
|
|
assert len(speaker_profile_manager.profiles) == 0
|
|
assert speaker_profile_manager.similarity_threshold == 0.7
|
|
|
|
def test_add_speaker_success(self, speaker_profile_manager):
|
|
"""Test adding a speaker profile."""
|
|
import numpy as np
|
|
|
|
speaker_id = "test_speaker"
|
|
embedding = np.random.rand(512)
|
|
|
|
profile = speaker_profile_manager.add_speaker(speaker_id, embedding, name="Test Speaker")
|
|
|
|
assert profile.speaker_id == speaker_id
|
|
assert profile.name == "Test Speaker"
|
|
assert speaker_id in speaker_profile_manager.profiles
|
|
assert speaker_id in speaker_profile_manager.embeddings_cache
|
|
|
|
def test_add_speaker_validation_error(self, speaker_profile_manager):
|
|
"""Test adding speaker with invalid data."""
|
|
import numpy as np
|
|
|
|
# Empty speaker ID
|
|
with pytest.raises(Exception):
|
|
speaker_profile_manager.add_speaker("", np.random.rand(512))
|
|
|
|
# Empty embedding
|
|
with pytest.raises(Exception):
|
|
speaker_profile_manager.add_speaker("test", np.array([]))
|
|
|
|
def test_get_speaker(self, speaker_profile_manager):
|
|
"""Test getting a speaker profile."""
|
|
import numpy as np
|
|
|
|
speaker_id = "test_speaker"
|
|
embedding = np.random.rand(512)
|
|
speaker_profile_manager.add_speaker(speaker_id, embedding)
|
|
|
|
profile = speaker_profile_manager.get_speaker(speaker_id)
|
|
assert profile is not None
|
|
assert profile.speaker_id == speaker_id
|
|
|
|
def test_find_similar_speakers(self, speaker_profile_manager):
|
|
"""Test finding similar speakers."""
|
|
import numpy as np
|
|
|
|
# Add test profiles
|
|
embedding1 = np.random.rand(512)
|
|
embedding2 = np.random.rand(512)
|
|
|
|
speaker_profile_manager.add_speaker("speaker1", embedding1)
|
|
speaker_profile_manager.add_speaker("speaker2", embedding2)
|
|
|
|
# Find similar speakers
|
|
matches = speaker_profile_manager.find_similar_speakers(embedding1, threshold=0.5)
|
|
assert len(matches) >= 1
|
|
|
|
def test_update_speaker(self, speaker_profile_manager):
|
|
"""Test updating a speaker profile."""
|
|
import numpy as np
|
|
|
|
speaker_id = "test_speaker"
|
|
embedding = np.random.rand(512)
|
|
speaker_profile_manager.add_speaker(speaker_id, embedding)
|
|
|
|
new_embedding = np.random.rand(512)
|
|
updated_profile = speaker_profile_manager.update_speaker(
|
|
speaker_id, new_embedding, name="Updated Name"
|
|
)
|
|
|
|
assert updated_profile.name == "Updated Name"
|
|
assert np.array_equal(updated_profile.embedding, new_embedding)
|
|
|
|
def test_remove_speaker(self, speaker_profile_manager):
|
|
"""Test removing a speaker profile."""
|
|
import numpy as np
|
|
|
|
speaker_id = "test_speaker"
|
|
embedding = np.random.rand(512)
|
|
speaker_profile_manager.add_speaker(speaker_id, embedding)
|
|
|
|
# Remove speaker
|
|
success = speaker_profile_manager.remove_speaker(speaker_id)
|
|
assert success
|
|
assert speaker_id not in speaker_profile_manager.profiles
|
|
|
|
def test_get_profile_stats(self, speaker_profile_manager):
|
|
"""Test getting profile statistics."""
|
|
stats = speaker_profile_manager.get_profile_stats()
|
|
assert "total_profiles" in stats
|
|
assert "profiles_with_embeddings" in stats
|
|
|
|
def test_cleanup(self, speaker_profile_manager):
|
|
"""Test cleanup method."""
|
|
speaker_profile_manager.cleanup()
|
|
# Should not raise any exceptions
|
|
|
|
|
|
class TestParallelProcessor:
|
|
"""Test cases for ParallelProcessor."""
|
|
|
|
def test_initialization(self, parallel_processor):
|
|
"""Test ParallelProcessor initialization."""
|
|
assert parallel_processor.config.max_workers == 2
|
|
assert parallel_processor.executor is not None
|
|
assert len(parallel_processor.stats) > 0
|
|
|
|
@patch.object(ParallelProcessor, '_initialize_services')
|
|
def test_process_file_success(self, mock_init_services, parallel_processor, sample_audio_path):
|
|
"""Test successful file processing."""
|
|
# Mock services
|
|
parallel_processor.diarization_manager = Mock()
|
|
parallel_processor.transcription_service = Mock()
|
|
|
|
# Mock results
|
|
mock_diarization_result = Mock()
|
|
mock_transcription_result = Mock()
|
|
|
|
parallel_processor.diarization_manager.process_audio.return_value = mock_diarization_result
|
|
parallel_processor.transcription_service.transcribe_file.return_value = mock_transcription_result
|
|
|
|
result = parallel_processor.process_file(sample_audio_path)
|
|
|
|
assert isinstance(result, ProcessingResult)
|
|
assert result.success
|
|
assert result.task_id is not None
|
|
|
|
def test_process_file_not_found(self, parallel_processor):
|
|
"""Test processing non-existent file."""
|
|
with pytest.raises(Exception):
|
|
parallel_processor.process_file(Path("nonexistent.wav"))
|
|
|
|
def test_process_batch(self, parallel_processor):
|
|
"""Test batch processing."""
|
|
audio_paths = [Path("tests/sample_5s.wav"), Path("tests/sample_30s.mp3")]
|
|
|
|
with patch.object(parallel_processor, 'process_file') as mock_process:
|
|
mock_process.return_value = ProcessingResult(task_id="test", success=True)
|
|
|
|
results = parallel_processor.process_batch(audio_paths)
|
|
assert len(results) == 2
|
|
|
|
def test_get_processing_stats(self, parallel_processor):
|
|
"""Test getting processing statistics."""
|
|
stats = parallel_processor.get_processing_stats()
|
|
assert "total_files_processed" in stats
|
|
assert "success_rate" in stats
|
|
|
|
def test_estimate_speedup(self, parallel_processor):
|
|
"""Test speedup estimation."""
|
|
speedup = parallel_processor.estimate_speedup(10.0, 5.0)
|
|
assert speedup == 2.0
|
|
|
|
def test_cleanup(self, parallel_processor):
|
|
"""Test cleanup method."""
|
|
parallel_processor.cleanup()
|
|
# Should not raise any exceptions
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for the diarization pipeline."""
|
|
|
|
def test_full_pipeline_integration(self, sample_audio_path):
|
|
"""Test full pipeline integration."""
|
|
# This would require actual audio files and models
|
|
# For now, we'll just test that the components can be instantiated together
|
|
diarization_manager = DiarizationManager()
|
|
profile_manager = SpeakerProfileManager()
|
|
parallel_processor = ParallelProcessor()
|
|
|
|
assert diarization_manager is not None
|
|
assert profile_manager is not None
|
|
assert parallel_processor is not None
|
|
|
|
def test_memory_optimization(self):
|
|
"""Test memory optimization features."""
|
|
config = DiarizationConfig(memory_optimization=True)
|
|
manager = DiarizationManager(config)
|
|
|
|
assert manager.config.memory_optimization is True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|