trax/tests/test_multi_pass_confidence.py

68 lines
1.8 KiB
Python

"""Tests for confidence scoring and low-confidence selection (subtask 7.2)."""
from __future__ import annotations
from typing import List, Dict, Any
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
def _build_segments(values: List[tuple[float, float, str]]) -> List[Dict[str, Any]]:
# values: (avg_logprob, no_speech_prob, label)
segments: List[Dict[str, Any]] = []
t = 0.0
for lp, ns, label in values:
segments.append(
{
"start": t,
"end": t + 1.0,
"text": label,
"avg_logprob": lp,
"no_speech_prob": ns,
}
)
t += 1.1
return segments
def test_calculate_confidence_basic_monotonicity():
pipeline = MultiPassTranscriptionPipeline()
segments = _build_segments(
[
(-5.0, 1.0, "very_low"), # worst
(-2.5, 0.5, "medium"), # medium
(0.0, 0.0, "high"), # best
]
)
scored = pipeline._calculate_confidence(segments)
scores = {s["text"]: s["confidence"] for s in scored}
assert 0.0 <= scores["very_low"] <= 0.3
assert 0.3 <= scores["medium"] <= 0.8
assert 0.9 <= scores["high"] <= 1.0
assert scores["high"] > scores["medium"] > scores["very_low"]
def test_identify_low_confidence_segments_thresholding():
pipeline = MultiPassTranscriptionPipeline()
pipeline.confidence_threshold = 0.75
segments = _build_segments(
[
(-4.0, 0.8, "low1"),
(-1.0, 0.2, "mid1"),
(0.0, 0.0, "high1"),
]
)
scored = pipeline._calculate_confidence(segments)
lows = pipeline._identify_low_confidence_segments(scored)
low_labels = {s["text"] for s in lows}
assert "low1" in low_labels
assert "high1" not in low_labels