trax/tests/test_multi_pass_merge_paral...

81 lines
2.9 KiB
Python

"""Tests for merging and parallel orchestration (subtask 7.5)."""
from __future__ import annotations
from pathlib import Path
from typing import List, Dict, Any
from unittest.mock import MagicMock, patch
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
def _orig_segments() -> List[Dict[str, Any]]:
return [
{"start": 0.0, "end": 1.0, "text": "hello", "confidence": 0.9},
{"start": 1.1, "end": 2.2, "text": "wurld", "confidence": 0.4},
]
def _refined_segments() -> List[Dict[str, Any]]:
return [
{"start": 1.1, "end": 2.2, "text": "world", "confidence": 0.95},
]
def _diarization_segments() -> List[Dict[str, Any]]:
return [
{"start": 0.0, "end": 1.5, "speaker": "S1"},
{"start": 1.5, "end": 3.0, "speaker": "S2"},
]
def test_merge_transcription_results_replaces_low_conf_segment():
pipeline = MultiPassTranscriptionPipeline()
merged = pipeline._merge_transcription_results(_orig_segments(), _refined_segments())
assert len(merged) == 2
assert merged[1]["text"] == "world"
def test_merge_with_diarization_assigns_speakers():
pipeline = MultiPassTranscriptionPipeline()
merged = pipeline._merge_with_diarization(_orig_segments(), _diarization_segments())
assert len(merged) == 2
assert merged[0]["speaker"] == "S1"
# Segment overlaps both, maximum overlap should choose S2 for second
assert merged[1]["speaker"] in {"S1", "S2"}
@patch("src.services.multi_pass_transcription.DiarizationManager")
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._perform_first_pass")
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._perform_refinement_pass")
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._calculate_confidence")
@patch("src.services.multi_pass_transcription.MultiPassTranscriptionPipeline._identify_low_confidence_segments")
def test_parallel_orchestration(
mock_identify: MagicMock,
mock_calc: MagicMock,
mock_refine: MagicMock,
mock_first: MagicMock,
mock_diar: MagicMock,
tmp_path: Path,
):
pipeline = MultiPassTranscriptionPipeline()
audio = tmp_path / "a.wav"
audio.write_bytes(b"fake")
first = _orig_segments()
mock_first.return_value = first
mock_calc.return_value = first
mock_identify.return_value = _refined_segments()
mock_refine.return_value = _refined_segments()
mock_mgr = MagicMock()
mock_mgr.process_audio.return_value = MagicMock(segments=_diarization_segments())
mock_diar.return_value = mock_mgr
# Execute method under test
result = pipeline.transcribe_with_parallel_processing(audio, speaker_diarization=True, domain=None)
assert "transcript" in result and isinstance(result["transcript"], list)
assert any("speaker" in s for s in result["transcript"]) # diarization merged