81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
"""Tests for merging and parallel orchestration (subtask 7.5)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
|
|
|
|
def _orig_segments() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 0.0, "end": 1.0, "text": "hello", "confidence": 0.9},
|
|
{"start": 1.1, "end": 2.2, "text": "wurld", "confidence": 0.4},
|
|
]
|
|
|
|
|
|
def _refined_segments() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 1.1, "end": 2.2, "text": "world", "confidence": 0.95},
|
|
]
|
|
|
|
|
|
def _diarization_segments() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 0.0, "end": 1.5, "speaker": "S1"},
|
|
{"start": 1.5, "end": 3.0, "speaker": "S2"},
|
|
]
|
|
|
|
|
|
def test_merge_transcription_results_replaces_low_conf_segment():
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
merged = pipeline._merge_transcription_results(_orig_segments(), _refined_segments())
|
|
assert len(merged) == 2
|
|
assert merged[1]["text"] == "world"
|
|
|
|
|
|
def test_merge_with_diarization_assigns_speakers():
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
merged = pipeline._merge_with_diarization(_orig_segments(), _diarization_segments())
|
|
assert len(merged) == 2
|
|
assert merged[0]["speaker"] == "S1"
|
|
# Segment overlaps both, maximum overlap should choose S2 for second
|
|
assert merged[1]["speaker"] in {"S1", "S2"}
|
|
|
|
|
|
@patch("src.services.multi_pass_transcription.DiarizationManager")
|
|
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._perform_first_pass")
|
|
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._perform_refinement_pass")
|
|
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._calculate_confidence")
|
|
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._identify_low_confidence_segments")
|
|
def test_parallel_orchestration(
|
|
mock_identify: MagicMock,
|
|
mock_calc: MagicMock,
|
|
mock_refine: MagicMock,
|
|
mock_first: MagicMock,
|
|
mock_diar: MagicMock,
|
|
tmp_path: Path,
|
|
):
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
|
|
audio = tmp_path / "a.wav"
|
|
audio.write_bytes(b"fake")
|
|
|
|
first = _orig_segments()
|
|
mock_first.return_value = first
|
|
mock_calc.return_value = first
|
|
mock_identify.return_value = _refined_segments()
|
|
mock_refine.return_value = _refined_segments()
|
|
|
|
mock_mgr = MagicMock()
|
|
mock_mgr.process_audio.return_value = MagicMock(segments=_diarization_segments())
|
|
mock_diar.return_value = mock_mgr
|
|
|
|
# Execute method under test
|
|
result = pipeline.transcribe_with_parallel_processing(audio, speaker_diarization=True, domain=None)
|
|
|
|
assert "transcript" in result and isinstance(result["transcript"], list)
|
|
assert any("speaker" in s for s in result["transcript"]) # diarization merged
|