68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
"""Tests for confidence scoring and low-confidence selection (subtask 7.2)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
|
|
|
|
def _build_segments(values: List[tuple[float, float, str]]) -> List[Dict[str, Any]]:
|
|
# values: (avg_logprob, no_speech_prob, label)
|
|
segments: List[Dict[str, Any]] = []
|
|
t = 0.0
|
|
for lp, ns, label in values:
|
|
segments.append(
|
|
{
|
|
"start": t,
|
|
"end": t + 1.0,
|
|
"text": label,
|
|
"avg_logprob": lp,
|
|
"no_speech_prob": ns,
|
|
}
|
|
)
|
|
t += 1.1
|
|
return segments
|
|
|
|
|
|
def test_calculate_confidence_basic_monotonicity():
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
|
|
segments = _build_segments(
|
|
[
|
|
(-5.0, 1.0, "very_low"), # worst
|
|
(-2.5, 0.5, "medium"), # medium
|
|
(0.0, 0.0, "high"), # best
|
|
]
|
|
)
|
|
|
|
scored = pipeline._calculate_confidence(segments)
|
|
scores = {s["text"]: s["confidence"] for s in scored}
|
|
|
|
assert 0.0 <= scores["very_low"] <= 0.3
|
|
assert 0.3 <= scores["medium"] <= 0.8
|
|
assert 0.9 <= scores["high"] <= 1.0
|
|
assert scores["high"] > scores["medium"] > scores["very_low"]
|
|
|
|
|
|
def test_identify_low_confidence_segments_thresholding():
|
|
pipeline = MultiPassTranscriptionPipeline()
|
|
pipeline.confidence_threshold = 0.75
|
|
|
|
segments = _build_segments(
|
|
[
|
|
(-4.0, 0.8, "low1"),
|
|
(-1.0, 0.2, "mid1"),
|
|
(0.0, 0.0, "high1"),
|
|
]
|
|
)
|
|
|
|
scored = pipeline._calculate_confidence(segments)
|
|
lows = pipeline._identify_low_confidence_segments(scored)
|
|
low_labels = {s["text"] for s in lows}
|
|
|
|
assert "low1" in low_labels
|
|
assert "high1" not in low_labels
|
|
|
|
|