304 lines
12 KiB
Python
304 lines
12 KiB
Python
"""Tests for the merging service."""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock
|
|
|
|
from src.services.merging_service import MergingService, MergedSegment, MergingResult
|
|
from src.services.diarization_types import (
|
|
DiarizationResult, SpeakerSegment, MergingConfig, MergingError
|
|
)
|
|
|
|
|
|
class TestMergingService:
|
|
"""Test cases for MergingService."""
|
|
|
|
@pytest.fixture
|
|
def merging_service(self):
|
|
"""Create a MergingService instance for testing."""
|
|
config = MergingConfig(
|
|
min_overlap_ratio=0.5,
|
|
min_confidence_threshold=0.3,
|
|
min_segment_duration=0.5,
|
|
conflict_threshold=0.1
|
|
)
|
|
return MergingService(config)
|
|
|
|
@pytest.fixture
|
|
def sample_diarization_result(self):
|
|
"""Create a sample diarization result."""
|
|
segments = [
|
|
SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9),
|
|
SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.8),
|
|
SpeakerSegment(start=4.0, end=6.0, speaker_id="SPEAKER_00", confidence=0.85),
|
|
]
|
|
return DiarizationResult(
|
|
segments=segments,
|
|
speaker_count=2,
|
|
processing_time=1.5,
|
|
confidence_score=0.85,
|
|
model_used="pyannote/speaker-diarization-3.0",
|
|
audio_duration=6.0
|
|
)
|
|
|
|
@pytest.fixture
|
|
def sample_transcription_result(self):
|
|
"""Create a sample transcription result."""
|
|
return {
|
|
"segments": [
|
|
{"start": 0.0, "end": 2.0, "text": "Hello, how are you?", "confidence": 0.95},
|
|
{"start": 2.0, "end": 4.0, "text": "I'm doing well, thank you.", "confidence": 0.92},
|
|
{"start": 4.0, "end": 6.0, "text": "That's great to hear.", "confidence": 0.88},
|
|
],
|
|
"accuracy_estimate": 0.92,
|
|
"word_count": 12
|
|
}
|
|
|
|
def test_initialization(self, merging_service):
|
|
"""Test MergingService initialization."""
|
|
assert merging_service.config.min_overlap_ratio == 0.5
|
|
assert merging_service.config.min_confidence_threshold == 0.3
|
|
assert merging_service.config.min_segment_duration == 0.5
|
|
assert merging_service.config.conflict_threshold == 0.1
|
|
|
|
def test_merge_results_success(self, merging_service, sample_diarization_result, sample_transcription_result):
|
|
"""Test successful merging of diarization and transcription results."""
|
|
result = merging_service.merge_results(sample_diarization_result, sample_transcription_result)
|
|
|
|
assert isinstance(result, MergingResult)
|
|
assert len(result.segments) == 3
|
|
assert result.speaker_count == 2
|
|
assert result.total_duration == 6.0
|
|
assert result.confidence_score > 0.0
|
|
|
|
# Check first segment
|
|
first_segment = result.segments[0]
|
|
assert first_segment.speaker_id == "SPEAKER_00"
|
|
assert first_segment.text == "Hello, how are you?"
|
|
assert first_segment.start == 0.0
|
|
assert first_segment.end == 2.0
|
|
assert first_segment.confidence > 0.0
|
|
|
|
def test_merge_results_no_transcription_segments(self, merging_service, sample_diarization_result):
|
|
"""Test merging with no transcription segments."""
|
|
transcription_result = {"segments": [], "accuracy_estimate": 0.0}
|
|
|
|
with pytest.raises(MergingError, match="No transcription segments found"):
|
|
merging_service.merge_results(sample_diarization_result, transcription_result)
|
|
|
|
def test_merge_results_empty_text_segments(self, merging_service, sample_diarization_result):
|
|
"""Test merging with empty text segments."""
|
|
transcription_result = {
|
|
"segments": [
|
|
{"start": 0.0, "end": 2.0, "text": "", "confidence": 0.95},
|
|
{"start": 2.0, "end": 4.0, "text": " ", "confidence": 0.92},
|
|
],
|
|
"accuracy_estimate": 0.92
|
|
}
|
|
|
|
result = merging_service.merge_results(sample_diarization_result, transcription_result)
|
|
assert len(result.segments) == 0 # Empty segments should be filtered out
|
|
|
|
def test_overlapping_speakers_detection(self, merging_service):
|
|
"""Test detection of overlapping speakers."""
|
|
diarization_segments = [
|
|
SpeakerSegment(start=0.0, end=3.0, speaker_id="SPEAKER_00", confidence=0.9),
|
|
SpeakerSegment(start=2.0, end=5.0, speaker_id="SPEAKER_01", confidence=0.8),
|
|
]
|
|
|
|
overlapping = merging_service._find_overlapping_speakers(1.0, 4.0, diarization_segments)
|
|
|
|
assert len(overlapping) == 2
|
|
assert overlapping[0]["speaker_id"] == "SPEAKER_00"
|
|
assert overlapping[1]["speaker_id"] == "SPEAKER_01"
|
|
assert overlapping[0]["overlap_ratio"] > 0.5
|
|
assert overlapping[1]["overlap_ratio"] > 0.5
|
|
|
|
def test_speaker_assignment_resolution(self, merging_service):
|
|
"""Test speaker assignment resolution with conflicts."""
|
|
overlapping_speakers = [
|
|
{
|
|
"speaker_id": "SPEAKER_00",
|
|
"confidence": 0.9,
|
|
"overlap_ratio": 0.8,
|
|
"overlap_duration": 1.6,
|
|
"segment_start": 0.0,
|
|
"segment_end": 2.0
|
|
},
|
|
{
|
|
"speaker_id": "SPEAKER_01",
|
|
"confidence": 0.85,
|
|
"overlap_ratio": 0.7,
|
|
"overlap_duration": 1.4,
|
|
"segment_start": 0.0,
|
|
"segment_end": 2.0
|
|
}
|
|
]
|
|
|
|
speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment(
|
|
overlapping_speakers, 0.0, 2.0
|
|
)
|
|
|
|
assert speaker_id == "SPEAKER_00" # Higher weighted score
|
|
assert confidence == 0.9
|
|
assert overlap_ratio == 0.8
|
|
|
|
def test_speaker_assignment_no_overlap(self, merging_service):
|
|
"""Test speaker assignment when no speakers overlap."""
|
|
speaker_id, confidence, overlap_ratio = merging_service._resolve_speaker_assignment(
|
|
[], 0.0, 2.0
|
|
)
|
|
|
|
assert speaker_id is None
|
|
assert confidence == 0.0
|
|
assert overlap_ratio == 0.0
|
|
|
|
def test_post_processing_short_segments(self, merging_service):
|
|
"""Test post-processing of very short segments."""
|
|
segments = [
|
|
MergedSegment(
|
|
start=0.0, end=0.3, text="Hi", speaker_id="SPEAKER_00",
|
|
confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8,
|
|
overlap_ratio=0.8, overlapping_speakers=[]
|
|
),
|
|
MergedSegment(
|
|
start=0.3, end=2.0, text="Hello there", speaker_id="SPEAKER_00",
|
|
confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9,
|
|
overlap_ratio=0.9, overlapping_speakers=[]
|
|
),
|
|
MergedSegment(
|
|
start=2.0, end=4.0, text="How are you?", speaker_id="SPEAKER_01",
|
|
confidence=0.85, transcription_confidence=0.92, diarization_confidence=0.85,
|
|
overlap_ratio=0.85, overlapping_speakers=[]
|
|
)
|
|
]
|
|
|
|
processed = merging_service._post_process_segments(segments)
|
|
|
|
# Short segment should be merged with the next one
|
|
assert len(processed) == 2
|
|
assert processed[0].text == "Hi Hello there"
|
|
assert processed[0].end == 2.0
|
|
assert processed[0].start == 0.0
|
|
assert processed[1].text == "How are you?"
|
|
assert processed[1].start == 2.0
|
|
|
|
def test_post_processing_low_confidence(self, merging_service):
|
|
"""Test post-processing of low-confidence segments."""
|
|
segments = [
|
|
MergedSegment(
|
|
start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00",
|
|
confidence=0.2, transcription_confidence=0.9, diarization_confidence=0.2,
|
|
overlap_ratio=0.2, overlapping_speakers=[]
|
|
)
|
|
]
|
|
|
|
processed = merging_service._post_process_segments(segments)
|
|
|
|
assert processed[0].speaker_id == "unknown"
|
|
assert processed[0].confidence == 0.0
|
|
|
|
def test_confidence_calculation(self, merging_service):
|
|
"""Test overall confidence calculation."""
|
|
segments = [
|
|
MergedSegment(
|
|
start=0.0, end=2.0, text="Hello", speaker_id="SPEAKER_00",
|
|
confidence=0.8, transcription_confidence=0.9, diarization_confidence=0.8,
|
|
overlap_ratio=0.8, overlapping_speakers=[]
|
|
),
|
|
MergedSegment(
|
|
start=2.0, end=4.0, text="Hi there", speaker_id="SPEAKER_01",
|
|
confidence=0.9, transcription_confidence=0.95, diarization_confidence=0.9,
|
|
overlap_ratio=0.9, overlapping_speakers=[]
|
|
)
|
|
]
|
|
|
|
confidence = merging_service._calculate_overall_confidence(segments)
|
|
|
|
# Should be weighted average based on duration
|
|
assert 0.8 < confidence < 0.9
|
|
|
|
def test_confidence_calculation_empty(self, merging_service):
|
|
"""Test confidence calculation with empty segments."""
|
|
confidence = merging_service._calculate_overall_confidence([])
|
|
assert confidence == 0.0
|
|
|
|
def test_metadata_generation(self, merging_service, sample_diarization_result, sample_transcription_result):
|
|
"""Test metadata generation."""
|
|
result = merging_service.merge_results(sample_diarization_result, sample_transcription_result)
|
|
|
|
metadata = result.metadata
|
|
assert metadata["speaker_count_merged"] == 2
|
|
assert metadata["speaker_count_original"] == 2
|
|
assert metadata["total_words"] > 0
|
|
assert metadata["total_segments"] == 3
|
|
assert metadata["average_segment_duration"] > 0.0
|
|
assert metadata["unknown_speaker_segments"] == 0
|
|
assert metadata["diarization_confidence"] == 0.85
|
|
assert metadata["transcription_confidence"] == 0.92
|
|
assert "merging_config" in metadata
|
|
|
|
def test_merge_with_unknown_speakers(self, merging_service):
|
|
"""Test merging when some segments have unknown speakers."""
|
|
diarization_result = DiarizationResult(
|
|
segments=[
|
|
SpeakerSegment(start=0.0, end=2.0, speaker_id="SPEAKER_00", confidence=0.9),
|
|
SpeakerSegment(start=2.0, end=4.0, speaker_id="SPEAKER_01", confidence=0.2), # Low confidence
|
|
],
|
|
speaker_count=2,
|
|
processing_time=1.0,
|
|
confidence_score=0.55,
|
|
model_used="test",
|
|
audio_duration=4.0
|
|
)
|
|
|
|
transcription_result = {
|
|
"segments": [
|
|
{"start": 0.0, "end": 2.0, "text": "Hello", "confidence": 0.95},
|
|
{"start": 2.0, "end": 4.0, "text": "Hi there", "confidence": 0.92},
|
|
],
|
|
"accuracy_estimate": 0.93
|
|
}
|
|
|
|
result = merging_service.merge_results(diarization_result, transcription_result)
|
|
|
|
# Second segment should have unknown speaker due to low confidence
|
|
assert result.segments[1].speaker_id == "unknown"
|
|
assert result.metadata["unknown_speaker_segments"] == 1
|
|
|
|
|
|
class TestMergingConfig:
|
|
"""Test cases for MergingConfig."""
|
|
|
|
def test_default_config(self):
|
|
"""Test MergingConfig default values."""
|
|
config = MergingConfig()
|
|
|
|
assert config.min_overlap_ratio == 0.5
|
|
assert config.min_confidence_threshold == 0.3
|
|
assert config.min_segment_duration == 0.5
|
|
assert config.conflict_threshold == 0.1
|
|
assert config.enable_post_processing is True
|
|
assert config.merge_short_segments is True
|
|
assert config.unknown_speaker_label == "unknown"
|
|
|
|
def test_custom_config(self):
|
|
"""Test MergingConfig with custom values."""
|
|
config = MergingConfig(
|
|
min_overlap_ratio=0.7,
|
|
min_confidence_threshold=0.5,
|
|
min_segment_duration=1.0,
|
|
conflict_threshold=0.2,
|
|
enable_post_processing=False,
|
|
merge_short_segments=False,
|
|
unknown_speaker_label="UNKNOWN"
|
|
)
|
|
|
|
assert config.min_overlap_ratio == 0.7
|
|
assert config.min_confidence_threshold == 0.5
|
|
assert config.min_segment_duration == 1.0
|
|
assert config.conflict_threshold == 0.2
|
|
assert config.enable_post_processing is False
|
|
assert config.merge_short_segments is False
|
|
assert config.unknown_speaker_label == "UNKNOWN"
|