trax/tests/test_merging_service.py

304 lines
12 KiB
Python

"""Tests for the merging service."""
import pytest
from pathlib import Path
from unittest.mock import Mock
from src.services.merging_service import MergingService, MergedSegment, MergingResult
from src.services.diarization_types import (
DiarizationResult, SpeakerSegment, MergingConfig, MergingError
)
class TestMergingService:
"""Test cases for MergingService."""
@pytest.fixture
def merging_service(self):
"""Create a MergingService instance for testing."""
config = MergingConfig(
min_overlap_ratio=0.5,
min_confidence_threshold=0.3,
min_segment_duration=0.5,
conflict_threshold=0.1
)
return MergingService(config)
@pytest.fixture
def sample_diarization_result(self):
"""Create a sample diarization result."""
segments = [
SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9),
SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.8),
SpeakerSegment(start=4.0, end=6.0, speaker_id="SPEAKER_00", confidence=0.85),
]
return DiarizationResult(
segments=segments,
speaker_count=2,
processing_time=1.5,
confidence_score=0.85,
model_used="pyannote/speaker-diarization-3.0",
audio_duration=6.0
)
@pytest.fixture
def sample_transcription_result(self):
"""Create a sample transcription result."""
return {
"segments": [
{"start": 0.0, "end": 2.0, "text": "Hello, how are you?", "confidence": 0.95},
{"start": 2.0, "end": 4.0, "text": "I'm doing well, thank you.", "confidence": 0.92},
{"start": 4.0, "end": 6.0, "text": "That's great to hear.", "confidence": 0.88},
],
"accuracy_estimate": 0.92,
"word_count": 12
}
def test_initialization(self, merging_service):
"""Test MergingService initialization."""
assert merging_service.config.min_overlap_ratio == 0.5
assert merging_service.config.min_confidence_threshold == 0.3
assert merging_service.config.min_segment_duration == 0.5
assert merging_service.config.conflict_threshold == 0.1
def test_merge_results_success(self, merging_service, sample_diarization_result, sample_transcription_result):
"""Test successful merging of diarization and transcription results."""
result = merging_service.merge_results(sample_diarization_result, sample_transcription_result)
assert isinstance(result, MergingResult)
assert len(result.segments) == 3
assert result.speaker_count == 2
assert result.total_duration == 6.0
assert result.confidence_score > 0.0
# Check first segment
first_segment = result.segments[0]
assert first_segment.speaker_id == "SPEAKER_00"
assert first_segment.text == "Hello, how are you?"
assert first_segment.start == 0.0
assert first_segment.end == 2.0
assert first_segment.confidence > 0.0
def test_merge_results_no_transcription_segments(self, merging_service, sample_diarization_result):
"""Test merging with no transcription segments."""
transcription_result = {"segments": [], "accuracy_estimate": 0.0}
with pytest.raises(MergingError, match="No transcription segments found"):
merging_service.merge_results(sample_diarization_result, transcription_result)
def test_merge_results_empty_text_segments(self, merging_service, sample_diarization_result):
"""Test merging with empty text segments."""
transcription_result = {
"segments": [
{"start": 0.0, "end": 2.0, "text": "", "confidence": 0.95},
{"start": 2.0, "end": 4.0, "text": " ", "confidence": 0.92},
],
"accuracy_estimate": 0.92
}
result = merging_service.merge_results(sample_diarization_result, transcription_result)
assert len(result.segments) == 0 # Empty segments should be filtered out
def test_overlapping_speakers_detection(self, merging_service):
"""Test detection of overlapping speakers."""
diarization_segments = [
SpeakerSegment(start=0.0, end=3.0, speaker_id="SPEAKER_00", confidence=0.9),
SpeakerSegment(start=2.0, end=5.0, speaker_id="SPEAKER_01", confidence=0.8),
]
overlapping = merging_service._find_overlapping_speakers(1.0, 4.0, diarization_segments)
assert len(overlapping) == 2
assert overlapping[0]["speaker_id"] == "SPEAKER_00"
assert overlapping[1]["speaker_id"] == "SPEAKER_01"
assert overlapping[0]["overlap_ratio"] > 0.5
assert overlapping[1]["overlap_ratio"] > 0.5
def test_speaker_assignment_resolution(self, merging_service):
"""Test speaker assignment resolution with conflicts."""
overlapping_speakers = [
{
"speaker_id": "SPEAKER_00",
"confidence": 0.9,
"overlap_ratio": 0.8,
"overlap_duration": 1.6,
"segment_start": 0.0,
"segment_end": 2.0
},
{
"speaker_id": "SPEAKER_01",
"confidence": 0.85,
"overlap_ratio": 0.7,
"overlap_duration": 1.4,
"segment_start": 0.0,
"segment_end": 2.0
}
]
speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment(
overlapping_speakers, 0.0, 2.0
)
assert speaker_id == "SPEAKER_00" # Higher weighted score
assert confidence == 0.9
assert overlap_ratio == 0.8
def test_speaker_assignment_no_overlap(self, merging_service):
"""Test speaker assignment when no speakers overlap."""
speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment(
[], 0.0, 2.0
)
assert speaker_id is None
assert confidence == 0.0
assert overlap_ratio == 0.0
def test_post_processing_short_segments(self, merging_service):
"""Test post-processing of very short segments."""
segments = [
MergedSegment(
start=0.0, end=0.3, text="Hi", speaker_id="SPEAKER_00",
confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8,
overlap_ratio=0.8, overlapping_speakers=[]
),
MergedSegment(
start=0.3, end=2.0, text="Hello there", speaker_id="SPEAKER_00",
confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9,
overlap_ratio=0.9, overlapping_speakers=[]
),
MergedSegment(
start=2.0, end=4.0, text="How are you?", speaker_id="SPEAKER_01",
confidence=0.85, transcription_confidence=0.92, diarization_confidence=0.85,
overlap_ratio=0.85, overlapping_speakers=[]
)
]
processed = merging_service._post_process_segments(segments)
# Short segment should be merged with the next one
assert len(processed) == 2
assert processed[0].text == "Hi Hello there"
assert processed[0].end == 2.0
assert processed[0].start == 0.0
assert processed[1].text == "How are you?"
assert processed[1].start == 2.0
def test_post_processing_low_confidence(self, merging_service):
"""Test post-processing of low-confidence segments."""
segments = [
MergedSegment(
start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00",
confidence=0.2, transcription_confidence=0.9, diarization_confidence=0.2,
overlap_ratio=0.2, overlapping_speakers=[]
)
]
processed = merging_service._post_process_segments(segments)
assert processed[0].speaker_id == "unknown"
assert processed[0].confidence == 0.0
def test_confidence_calculation(self, merging_service):
"""Test overall confidence calculation."""
segments = [
MergedSegment(
start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00",
confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8,
overlap_ratio=0.8, overlapping_speakers=[]
),
MergedSegment(
start=2.0, end=4.0, text="Hi there", speaker_id="SPEAKER_01",
confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9,
overlap_ratio=0.9, overlapping_speakers=[]
)
]
confidence = merging_service._calculate_overall_confidence(segments)
# Should be weighted average based on duration
assert 0.8 < confidence < 0.9
def test_confidence_calculation_empty(self, merging_service):
"""Test confidence calculation with empty segments."""
confidence = merging_service._calculate_overall_confidence([])
assert confidence == 0.0
def test_metadata_generation(self, merging_service, sample_diarization_result, sample_transcription_result):
"""Test metadata generation."""
result = merging_service.merge_results(sample_diarization_result, sample_transcription_result)
metadata = result.metadata
assert metadata["speaker_count_merged"] == 2
assert metadata["speaker_count_original"] == 2
assert metadata["total_words"] > 0
assert metadata["total_segments"] == 3
assert metadata["average_segment_duration"] > 0.0
assert metadata["unknown_speaker_segments"] == 0
assert metadata["diarization_confidence"] == 0.85
assert metadata["transcription_confidence"] == 0.92
assert "merging_config" in metadata
def test_merge_with_unknown_speakers(self, merging_service):
"""Test merging when some segments have unknown speakers."""
diarization_result = DiarizationResult(
segments=[
SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9),
SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.2), # Low confidence
],
speaker_count=2,
processing_time=1.0,
confidence_score=0.55,
model_used="test",
audio_duration=4.0
)
transcription_result = {
"segments": [
{"start": 0.0, "end": 2.0, "text": "Hello", "confidence": 0.95},
{"start": 2.0, "end": 4.0, "text": "Hi there", "confidence": 0.92},
],
"accuracy_estimate": 0.93
}
result = merging_service.merge_results(diarization_result, transcription_result)
# Second segment should have unknown speaker due to low confidence
assert result.segments[1].speaker_id == "unknown"
assert result.metadata["unknown_speaker_segments"] == 1
class TestMergingConfig:
"""Test cases for MergingConfig."""
def test_default_config(self):
"""Test MergingConfig default values."""
config = MergingConfig()
assert config.min_overlap_ratio == 0.5
assert config.min_confidence_threshold == 0.3
assert config.min_segment_duration == 0.5
assert config.conflict_threshold == 0.1
assert config.enable_post_processing is True
assert config.merge_short_segments is True
assert config.unknown_speaker_label == "unknown"
def test_custom_config(self):
"""Test MergingConfig with custom values."""
config = MergingConfig(
min_overlap_ratio=0.7,
min_confidence_threshold=0.5,
min_segment_duration=1.0,
conflict_threshold=0.2,
enable_post_processing=False,
merge_short_segments=False,
unknown_speaker_label="UNKNOWN"
)
assert config.min_overlap_ratio == 0.7
assert config.min_confidence_threshold == 0.5
assert config.min_segment_duration == 1.0
assert config.conflict_threshold == 0.2
assert config.enable_post_processing is False
assert config.merge_short_segments is False
assert config.unknown_speaker_label == "UNKNOWN"