"""Tests for diarization services.""" import pytest from pathlib import Path from unittest.mock import Mock, patch, MagicMock from src.services.diarization_types import ( DiarizationConfig, SpeakerSegment, DiarizationResult, SpeakerProfile, ProfileMatch, ProcessingResult ) from src.services.diarization_service import DiarizationManager from src.services.speaker_profile_manager import SpeakerProfileManager from src.services.parallel_processor import ParallelProcessor @pytest.fixture def sample_audio_path(): """Provide a sample audio file path for testing.""" return Path("tests/sample_5s.wav") @pytest.fixture def diarization_manager(): """Provide a DiarizationManager instance for testing.""" config = DiarizationConfig( model_path="pyannote/speaker-diarization-3.0", device="cpu", memory_optimization=False ) return DiarizationManager(config) @pytest.fixture def speaker_profile_manager(): """Provide a SpeakerProfileManager instance for testing.""" return SpeakerProfileManager(storage_dir=Path("tests/temp_profiles")) @pytest.fixture def parallel_processor(): """Provide a ParallelProcessor instance for testing.""" from src.services.diarization_types import ParallelProcessingConfig config = ParallelProcessingConfig(max_workers=2) return ParallelProcessor(config) class TestDiarizationManager: """Test cases for DiarizationManager.""" def test_initialization(self, diarization_manager): """Test DiarizationManager initialization.""" assert diarization_manager.config.model_path == "pyannote/speaker-diarization-3.0" assert diarization_manager._device in ["cpu", "cuda"] assert not diarization_manager._initialized @patch('src.services.diarization_utils.determine_device') def test_device_determination(self, mock_determine_device, diarization_manager): """Test device determination logic.""" mock_determine_device.return_value = "cpu" device = diarization_manager._device assert device == "cpu" @patch('pyannote.audio.Pipeline.from_pretrained') def test_pipeline_loading(self, mock_pipeline, diarization_manager): """Test pipeline loading with error handling.""" mock_pipeline.return_value = Mock() pipeline = diarization_manager._load_pipeline() assert pipeline is not None assert diarization_manager._initialized @patch('pyannote.audio.Pipeline.from_pretrained') def test_pipeline_loading_error(self, mock_pipeline, diarization_manager): """Test pipeline loading error handling.""" mock_pipeline.side_effect = Exception("Model loading failed") with pytest.raises(Exception): diarization_manager._load_pipeline() @patch.object(DiarizationManager, '_load_pipeline') def test_process_audio_success(self, mock_load_pipeline, diarization_manager, sample_audio_path): """Test successful audio processing.""" # Mock pipeline and annotation mock_pipeline = Mock() mock_annotation = Mock() mock_annotation.itertracks.return_value = [ (Mock(start=0.0, end=2.0, duration=2.0), 1, "SPEAKER_00"), (Mock(start=2.0, end=4.0, duration=2.0), 2, "SPEAKER_01") ] mock_pipeline.return_value = mock_annotation mock_load_pipeline.return_value = mock_pipeline result = diarization_manager.process_audio(sample_audio_path) assert isinstance(result, DiarizationResult) assert result.speaker_count == 2 assert len(result.segments) == 2 assert result.processing_time > 0 def test_process_audio_file_not_found(self, diarization_manager): """Test audio processing with non-existent file.""" with pytest.raises(FileNotFoundError): diarization_manager.process_audio(Path("nonexistent.wav")) def test_estimate_speaker_count(self, diarization_manager, sample_audio_path): """Test speaker count estimation.""" with patch.object(diarization_manager, 'process_audio') as mock_process: mock_result = Mock() mock_result.speaker_count = 3 mock_process.return_value = mock_result count = diarization_manager.estimate_speaker_count(sample_audio_path) assert count == 3 def test_get_speaker_segments(self, diarization_manager, sample_audio_path): """Test getting segments for specific speaker.""" with patch.object(diarization_manager, 'process_audio') as mock_process: mock_result = Mock() mock_result.segments = [ SpeakerSegment(0.0, 2.0, "SPEAKER_00", 0.8), SpeakerSegment(2.0, 4.0, "SPEAKER_01", 0.9) ] mock_process.return_value = mock_result segments = diarization_manager.get_speaker_segments(sample_audio_path, "SPEAKER_00") assert len(segments) == 1 assert segments[0].speaker_id == "SPEAKER_00" def test_cleanup(self, diarization_manager): """Test resource cleanup.""" diarization_manager._pipeline = Mock() diarization_manager._initialized = True diarization_manager.cleanup() assert diarization_manager._pipeline is None assert not diarization_manager._initialized class TestSpeakerProfileManager: """Test cases for SpeakerProfileManager.""" def test_initialization(self, speaker_profile_manager): """Test SpeakerProfileManager initialization.""" assert speaker_profile_manager.storage_dir.exists() assert len(speaker_profile_manager.profiles) == 0 assert speaker_profile_manager.similarity_threshold == 0.7 def test_add_speaker_success(self, speaker_profile_manager): """Test adding a speaker profile.""" import numpy as np speaker_id = "test_speaker" embedding = np.random.rand(512) profile = speaker_profile_manager.add_speaker(speaker_id, embedding, name="Test Speaker") assert profile.speaker_id == speaker_id assert profile.name == "Test Speaker" assert speaker_id in speaker_profile_manager.profiles assert speaker_id in speaker_profile_manager.embeddings_cache def test_add_speaker_validation_error(self, speaker_profile_manager): """Test adding speaker with invalid data.""" import numpy as np # Empty speaker ID with pytest.raises(Exception): speaker_profile_manager.add_speaker("", np.random.rand(512)) # Empty embedding with pytest.raises(Exception): speaker_profile_manager.add_speaker("test", np.array([])) def test_get_speaker(self, speaker_profile_manager): """Test getting a speaker profile.""" import numpy as np speaker_id = "test_speaker" embedding = np.random.rand(512) speaker_profile_manager.add_speaker(speaker_id, embedding) profile = speaker_profile_manager.get_speaker(speaker_id) assert profile is not None assert profile.speaker_id == speaker_id def test_find_similar_speakers(self, speaker_profile_manager): """Test finding similar speakers.""" import numpy as np # Add test profiles embedding1 = np.random.rand(512) embedding2 = np.random.rand(512) speaker_profile_manager.add_speaker("speaker1", embedding1) speaker_profile_manager.add_speaker("speaker2", embedding2) # Find similar speakers matches = speaker_profile_manager.find_similar_speakers(embedding1, threshold=0.5) assert len(matches) >= 1 def test_update_speaker(self, speaker_profile_manager): """Test updating a speaker profile.""" import numpy as np speaker_id = "test_speaker" embedding = np.random.rand(512) speaker_profile_manager.add_speaker(speaker_id, embedding) new_embedding = np.random.rand(512) updated_profile = speaker_profile_manager.update_speaker( speaker_id, new_embedding, name="Updated Name" ) assert updated_profile.name == "Updated Name" assert np.array_equal(updated_profile.embedding, new_embedding) def test_remove_speaker(self, speaker_profile_manager): """Test removing a speaker profile.""" import numpy as np speaker_id = "test_speaker" embedding = np.random.rand(512) speaker_profile_manager.add_speaker(speaker_id, embedding) # Remove speaker success = speaker_profile_manager.remove_speaker(speaker_id) assert success assert speaker_id not in speaker_profile_manager.profiles def test_get_profile_stats(self, speaker_profile_manager): """Test getting profile statistics.""" stats = speaker_profile_manager.get_profile_stats() assert "total_profiles" in stats assert "profiles_with_embeddings" in stats def test_cleanup(self, speaker_profile_manager): """Test cleanup method.""" speaker_profile_manager.cleanup() # Should not raise any exceptions class TestParallelProcessor: """Test cases for ParallelProcessor.""" def test_initialization(self, parallel_processor): """Test ParallelProcessor initialization.""" assert parallel_processor.config.max_workers == 2 assert parallel_processor.executor is not None assert len(parallel_processor.stats) > 0 @patch.object(ParallelProcessor, '_initialize_services') def test_process_file_success(self, mock_init_services, parallel_processor, sample_audio_path): """Test successful file processing.""" # Mock services parallel_processor.diarization_manager = Mock() parallel_processor.transcription_service = Mock() # Mock results mock_diarization_result = Mock() mock_transcription_result = Mock() parallel_processor.diarization_manager.process_audio.return_value = mock_diarization_result parallel_processor.transcription_service.transcribe_file.return_value = mock_transcription_result result = parallel_processor.process_file(sample_audio_path) assert isinstance(result, ProcessingResult) assert result.success assert result.task_id is not None def test_process_file_not_found(self, parallel_processor): """Test processing non-existent file.""" with pytest.raises(Exception): parallel_processor.process_file(Path("nonexistent.wav")) def test_process_batch(self, parallel_processor): """Test batch processing.""" audio_paths = [Path("tests/sample_5s.wav"), Path("tests/sample_30s.mp3")] with patch.object(parallel_processor, 'process_file') as mock_process: mock_process.return_value = ProcessingResult(task_id="test", success=True) results = parallel_processor.process_batch(audio_paths) assert len(results) == 2 def test_get_processing_stats(self, parallel_processor): """Test getting processing statistics.""" stats = parallel_processor.get_processing_stats() assert "total_files_processed" in stats assert "success_rate" in stats def test_estimate_speedup(self, parallel_processor): """Test speedup estimation.""" speedup = parallel_processor.estimate_speedup(10.0, 5.0) assert speedup == 2.0 def test_cleanup(self, parallel_processor): """Test cleanup method.""" parallel_processor.cleanup() # Should not raise any exceptions class TestIntegration: """Integration tests for the diarization pipeline.""" def test_full_pipeline_integration(self, sample_audio_path): """Test full pipeline integration.""" # This would require actual audio files and models # For now, we'll just test that the components can be instantiated together diarization_manager = DiarizationManager() profile_manager = SpeakerProfileManager() parallel_processor = ParallelProcessor() assert diarization_manager is not None assert profile_manager is not None assert parallel_processor is not None def test_memory_optimization(self): """Test memory optimization features.""" config = DiarizationConfig(memory_optimization=True) manager = DiarizationManager(config) assert manager.config.memory_optimization is True if __name__ == "__main__": pytest.main([__file__, "-v"])