"""Tests for the merging service.""" import pytest from pathlib import Path from unittest.mock import Mock from src.services.merging_service import MergingService, MergedSegment, MergingResult from src.services.diarization_types import ( DiarizationResult, SpeakerSegment, MergingConfig, MergingError ) class TestMergingService: """Test cases for MergingService.""" @pytest.fixture def merging_service(self): """Create a MergingService instance for testing.""" config = MergingConfig( min_overlap_ratio=0.5, min_confidence_threshold=0.3, min_segment_duration=0.5, conflict_threshold=0.1 ) return MergingService(config) @pytest.fixture def sample_diarization_result(self): """Create a sample diarization result.""" segments = [ SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9), SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.8), SpeakerSegment(start=4.0, end=6.0, speaker_id="SPEAKER_00", confidence=0.85), ] return DiarizationResult( segments=segments, speaker_count=2, processing_time=1.5, confidence_score=0.85, model_used="pyannote/speaker-diarization-3.0", audio_duration=6.0 ) @pytest.fixture def sample_transcription_result(self): """Create a sample transcription result.""" return { "segments": [ {"start": 0.0, "end": 2.0, "text": "Hello, how are you?", "confidence": 0.95}, {"start": 2.0, "end": 4.0, "text": "I'm doing well, thank you.", "confidence": 0.92}, {"start": 4.0, "end": 6.0, "text": "That's great to hear.", "confidence": 0.88}, ], "accuracy_estimate": 0.92, "word_count": 12 } def test_initialization(self, merging_service): """Test MergingService initialization.""" assert merging_service.config.min_overlap_ratio == 0.5 assert merging_service.config.min_confidence_threshold == 0.3 assert merging_service.config.min_segment_duration == 0.5 assert merging_service.config.conflict_threshold == 0.1 def test_merge_results_success(self, merging_service, sample_diarization_result, sample_transcription_result): """Test successful merging of diarization and transcription results.""" result = merging_service.merge_results(sample_diarization_result, sample_transcription_result) assert isinstance(result, MergingResult) assert len(result.segments) == 3 assert result.speaker_count == 2 assert result.total_duration == 6.0 assert result.confidence_score > 0.0 # Check first segment first_segment = result.segments[0] assert first_segment.speaker_id == "SPEAKER_00" assert first_segment.text == "Hello, how are you?" assert first_segment.start == 0.0 assert first_segment.end == 2.0 assert first_segment.confidence > 0.0 def test_merge_results_no_transcription_segments(self, merging_service, sample_diarization_result): """Test merging with no transcription segments.""" transcription_result = {"segments": [], "accuracy_estimate": 0.0} with pytest.raises(MergingError, match="No transcription segments found"): merging_service.merge_results(sample_diarization_result, transcription_result) def test_merge_results_empty_text_segments(self, merging_service, sample_diarization_result): """Test merging with empty text segments.""" transcription_result = { "segments": [ {"start": 0.0, "end": 2.0, "text": "", "confidence": 0.95}, {"start": 2.0, "end": 4.0, "text": " ", "confidence": 0.92}, ], "accuracy_estimate": 0.92 } result = merging_service.merge_results(sample_diarization_result, transcription_result) assert len(result.segments) == 0 # Empty segments should be filtered out def test_overlapping_speakers_detection(self, merging_service): """Test detection of overlapping speakers.""" diarization_segments = [ SpeakerSegment(start=0.0, end=3.0, speaker_id="SPEAKER_00", confidence=0.9), SpeakerSegment(start=2.0, end=5.0, speaker_id="SPEAKER_01", confidence=0.8), ] overlapping = merging_service._find_overlapping_speakers(1.0, 4.0, diarization_segments) assert len(overlapping) == 2 assert overlapping[0]["speaker_id"] == "SPEAKER_00" assert overlapping[1]["speaker_id"] == "SPEAKER_01" assert overlapping[0]["overlap_ratio"] > 0.5 assert overlapping[1]["overlap_ratio"] > 0.5 def test_speaker_assignment_resolution(self, merging_service): """Test speaker assignment resolution with conflicts.""" overlapping_speakers = [ { "speaker_id": "SPEAKER_00", "confidence": 0.9, "overlap_ratio": 0.8, "overlap_duration": 1.6, "segment_start": 0.0, "segment_end": 2.0 }, { "speaker_id": "SPEAKER_01", "confidence": 0.85, "overlap_ratio": 0.7, "overlap_duration": 1.4, "segment_start": 0.0, "segment_end": 2.0 } ] speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment( overlapping_speakers, 0.0, 2.0 ) assert speaker_id == "SPEAKER_00" # Higher weighted score assert confidence == 0.9 assert overlap_ratio == 0.8 def test_speaker_assignment_no_overlap(self, merging_service): """Test speaker assignment when no speakers overlap.""" speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment( [], 0.0, 2.0 ) assert speaker_id is None assert confidence == 0.0 assert overlap_ratio == 0.0 def test_post_processing_short_segments(self, merging_service): """Test post-processing of very short segments.""" segments = [ MergedSegment( start=0.0, end=0.3, text="Hi", speaker_id="SPEAKER_00", confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8, overlap_ratio=0.8, overlapping_speakers=[] ), MergedSegment( start=0.3, end=2.0, text="Hello there", speaker_id="SPEAKER_00", confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9, overlap_ratio=0.9, overlapping_speakers=[] ), MergedSegment( start=2.0, end=4.0, text="How are you?", speaker_id="SPEAKER_01", confidence=0.85, transcription_confidence=0.92, diarization_confidence=0.85, overlap_ratio=0.85, overlapping_speakers=[] ) ] processed = merging_service._post_process_segments(segments) # Short segment should be merged with the next one assert len(processed) == 2 assert processed[0].text == "Hi Hello there" assert processed[0].end == 2.0 assert processed[0].start == 0.0 assert processed[1].text == "How are you?" assert processed[1].start == 2.0 def test_post_processing_low_confidence(self, merging_service): """Test post-processing of low-confidence segments.""" segments = [ MergedSegment( start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00", confidence=0.2, transcription_confidence=0.9, diarization_confidence=0.2, overlap_ratio=0.2, overlapping_speakers=[] ) ] processed = merging_service._post_process_segments(segments) assert processed[0].speaker_id == "unknown" assert processed[0].confidence == 0.0 def test_confidence_calculation(self, merging_service): """Test overall confidence calculation.""" segments = [ MergedSegment( start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00", confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8, overlap_ratio=0.8, overlapping_speakers=[] ), MergedSegment( start=2.0, end=4.0, text="Hi there", speaker_id="SPEAKER_01", confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9, overlap_ratio=0.9, overlapping_speakers=[] ) ] confidence = merging_service._calculate_overall_confidence(segments) # Should be weighted average based on duration assert 0.8 < confidence < 0.9 def test_confidence_calculation_empty(self, merging_service): """Test confidence calculation with empty segments.""" confidence = merging_service._calculate_overall_confidence([]) assert confidence == 0.0 def test_metadata_generation(self, merging_service, sample_diarization_result, sample_transcription_result): """Test metadata generation.""" result = merging_service.merge_results(sample_diarization_result, sample_transcription_result) metadata = result.metadata assert metadata["speaker_count_merged"] == 2 assert metadata["speaker_count_original"] == 2 assert metadata["total_words"] > 0 assert metadata["total_segments"] == 3 assert metadata["average_segment_duration"] > 0.0 assert metadata["unknown_speaker_segments"] == 0 assert metadata["diarization_confidence"] == 0.85 assert metadata["transcription_confidence"] == 0.92 assert "merging_config" in metadata def test_merge_with_unknown_speakers(self, merging_service): """Test merging when some segments have unknown speakers.""" diarization_result = DiarizationResult( segments=[ SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9), SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.2), # Low confidence ], speaker_count=2, processing_time=1.0, confidence_score=0.55, model_used="test", audio_duration=4.0 ) transcription_result = { "segments": [ {"start": 0.0, "end": 2.0, "text": "Hello", "confidence": 0.95}, {"start": 2.0, "end": 4.0, "text": "Hi there", "confidence": 0.92}, ], "accuracy_estimate": 0.93 } result = merging_service.merge_results(diarization_result, transcription_result) # Second segment should have unknown speaker due to low confidence assert result.segments[1].speaker_id == "unknown" assert result.metadata["unknown_speaker_segments"] == 1 class TestMergingConfig: """Test cases for MergingConfig.""" def test_default_config(self): """Test MergingConfig default values.""" config = MergingConfig() assert config.min_overlap_ratio == 0.5 assert config.min_confidence_threshold == 0.3 assert config.min_segment_duration == 0.5 assert config.conflict_threshold == 0.1 assert config.enable_post_processing is True assert config.merge_short_segments is True assert config.unknown_speaker_label == "unknown" def test_custom_config(self): """Test MergingConfig with custom values.""" config = MergingConfig( min_overlap_ratio=0.7, min_confidence_threshold=0.5, min_segment_duration=1.0, conflict_threshold=0.2, enable_post_processing=False, merge_short_segments=False, unknown_speaker_label="UNKNOWN" ) assert config.min_overlap_ratio == 0.7 assert config.min_confidence_threshold == 0.5 assert config.min_segment_duration == 1.0 assert config.conflict_threshold == 0.2 assert config.enable_post_processing is False assert config.merge_short_segments is False assert config.unknown_speaker_label == "UNKNOWN"