63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
"""Tests for refinement pass (subtask 7.3)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.model_manager import ModelManager
|
|
|
|
|
|
def _segments_for_refinement() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 0.0, "end": 1.0, "text": "low", "confidence": 0.4},
|
|
{"start": 2.0, "end": 3.2, "text": "also low", "confidence": 0.5},
|
|
]
|
|
|
|
|
|
class _MockSeg:
|
|
def __init__(self, start: float, end: float, text: str):
|
|
self.start = start
|
|
self.end = end
|
|
self.text = text
|
|
self.avg_logprob = -0.5
|
|
self.no_speech_prob = 0.05
|
|
|
|
|
|
@patch("src.services.multi_pass_transcription.ModelManager")
|
|
def test_refinement_replaces_low_conf_segments(mock_mm_cls: MagicMock, tmp_path: Path):
|
|
# Mock model for refinement pass
|
|
mock_mm = MagicMock(spec=ModelManager)
|
|
mock_model = MagicMock()
|
|
|
|
def _fake_transcribe(path: str, **kwargs): # returns (iter, info)
|
|
# Return refined text for any segment file
|
|
refined = [_MockSeg(0.0, 1.0, "refined text")]
|
|
return iter(refined), MagicMock(language="en", language_probability=0.99)
|
|
|
|
mock_model.transcribe.side_effect = _fake_transcribe
|
|
mock_mm.load_model.return_value = mock_model
|
|
mock_mm_cls.return_value = mock_mm
|
|
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
segments_for_refinement = _segments_for_refinement()
|
|
|
|
# Create a fake audio file path
|
|
audio_path = tmp_path / "audio.wav"
|
|
audio_path.write_bytes(b"fake")
|
|
|
|
# Patch out ffmpeg invocation within refinement (we'll write temp files)
|
|
with patch("subprocess.run") as _:
|
|
refined = pipeline._perform_refinement_pass(audio_path, segments_for_refinement)
|
|
|
|
# Expect refined segments mapped back to original time windows
|
|
assert len(refined) == len(segments_for_refinement)
|
|
for r, orig in zip(refined, segments_for_refinement):
|
|
assert pytest.approx(r["start"]) == orig["start"]
|
|
assert pytest.approx(r["end"]) == orig["end"]
|
|
assert isinstance(r["text"], str) and len(r["text"]) > 0
|