trax/tests/test_multi_pass_refinement.py

63 lines
2.2 KiB
Python

"""Tests for refinement pass (subtask 7.3)."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import pytest
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.model_manager import ModelManager
def _segments_for_refinement() -> List[Dict[str, Any]]:
return [
{"start": 0.0, "end": 1.0, "text": "low", "confidence": 0.4},
{"start": 2.0, "end": 3.2, "text": "also low", "confidence": 0.5},
]
class _MockSeg:
def __init__(self, start: float, end: float, text: str):
self.start = start
self.end = end
self.text = text
self.avg_logprob = -0.5
self.no_speech_prob = 0.05
@patch("src.services.multi_pass_transcription.ModelManager")
def test_refinement_replaces_low_conf_segments(mock_mm_cls: MagicMock, tmp_path: Path):
# Mock model for refinement pass
mock_mm = MagicMock(spec=ModelManager)
mock_model = MagicMock()
def _fake_transcribe(path: str, **kwargs): # returns (iter, info)
# Return refined text for any segment file
refined = [_MockSeg(0.0, 1.0, "refined text")]
return iter(refined), MagicMock(language="en", language_probability=0.99)
mock_model.transcribe.side_effect = _fake_transcribe
mock_mm.load_model.return_value = mock_model
mock_mm_cls.return_value = mock_mm
pipeline = MultiPassTranscriptionPipeline()
segments_for_refinement = _segments_for_refinement()
# Create a fake audio file path
audio_path = tmp_path / "audio.wav"
audio_path.write_bytes(b"fake")
# Patch out ffmpeg invocation within refinement (we'll write temp files)
with patch("subprocess.run") as _:
refined = pipeline._perform_refinement_pass(audio_path, segments_for_refinement)
# Expect refined segments mapped back to original time windows
assert len(refined) == len(segments_for_refinement)
for r, orig in zip(refined, segments_for_refinement):
assert pytest.approx(r["start"]) == orig["start"]
assert pytest.approx(r["end"]) == orig["end"]
assert isinstance(r["text"], str) and len(r["text"]) > 0