trax/tests/test_quality_assessment.py

432 lines
16 KiB
Python

"""
Unit tests for quality assessment system.
Tests accuracy estimation, quality warnings, confidence scoring,
and transcript comparison functionality.
"""
import pytest
from unittest.mock import MagicMock, patch
from typing import Dict, List, Any
from src.services.quality_assessment import (
QualityAssessor,
QualityMetrics,
QualityWarning,
WarningSeverity
)
from src.services.confidence_scorer import (
ConfidenceScorer,
SegmentConfidence,
ConfidenceLevel
)
from src.services.transcript_comparer import (
TranscriptComparer,
SegmentChange,
ComparisonResult,
ChangeType
)
class TestQualityAssessor:
"""Test quality assessment functionality."""
@pytest.fixture
def quality_assessor(self):
"""Create a quality assessor instance."""
return QualityAssessor()
@pytest.fixture
def sample_transcript(self):
"""Create a sample transcript for testing."""
return {
"segments": [
{
"start": 0.0,
"end": 5.0,
"text": "Hello world, this is a test transcript.",
"confidence": 0.95
},
{
"start": 5.0,
"end": 10.0,
"text": "It contains some technical terms like React.js and API.",
"confidence": 0.88
},
{
"start": 10.0,
"end": 15.0,
"text": "Um, you know, there are some filler words here.",
"confidence": 0.75
}
]
}
def test_quality_assessor_initialization(self, quality_assessor):
"""Test quality assessor initialization."""
assert quality_assessor.filler_patterns is not None
assert quality_assessor.tech_term_patterns is not None
assert len(quality_assessor.filler_patterns) > 0
assert len(quality_assessor.tech_term_patterns) > 0
def test_estimate_accuracy_with_confidence(self, quality_assessor, sample_transcript):
"""Test accuracy estimation with confidence scores."""
accuracy = quality_assessor.estimate_accuracy(sample_transcript)
assert 0.5 <= accuracy <= 0.99
assert accuracy > 0.8 # Should be high due to good confidence scores
def test_estimate_accuracy_without_confidence(self, quality_assessor):
"""Test accuracy estimation without confidence scores."""
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world"},
{"start": 5.0, "end": 10.0, "text": "Technical terms like React.js"}
]
}
accuracy = quality_assessor.estimate_accuracy(transcript)
assert 0.5 <= accuracy <= 0.99
# Should use default confidence of 0.85
def test_estimate_accuracy_empty_transcript(self, quality_assessor):
"""Test accuracy estimation with empty transcript."""
transcript = {"segments": []}
accuracy = quality_assessor.estimate_accuracy(transcript)
assert accuracy == 0.0
def test_generate_quality_warnings_low_accuracy(self, quality_assessor):
"""Test quality warnings for low accuracy."""
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world", "confidence": 0.6}
]
}
warnings = quality_assessor.generate_quality_warnings(transcript, 0.75)
assert "Low overall accuracy detected" in warnings
def test_generate_quality_warnings_inaudible_sections(self, quality_assessor):
"""Test quality warnings for inaudible sections."""
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world"},
{"start": 5.0, "end": 10.0, "text": "(inaudible) some text"},
{"start": 10.0, "end": 15.0, "text": "More text (unintelligible)"}
]
}
warnings = quality_assessor.generate_quality_warnings(transcript, 0.85)
assert "Inaudible or unintelligible sections detected" in warnings
def test_generate_quality_warnings_short_segments(self, quality_assessor):
"""Test quality warnings for short segments."""
transcript = {
"segments": [
{"start": 0.0, "end": 1.0, "text": "Hi"},
{"start": 1.0, "end": 2.0, "text": "There"},
{"start": 2.0, "end": 3.0, "text": "You"},
{"start": 3.0, "end": 10.0, "text": "This is a longer segment with more words"}
]
}
warnings = quality_assessor.generate_quality_warnings(transcript, 0.85)
assert "High number of very short segments detected" in warnings
def test_generate_quality_warnings_repeated_words(self, quality_assessor):
"""Test quality warnings for repeated words."""
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello hello world"},
{"start": 5.0, "end": 10.0, "text": "Test test test"}
]
}
warnings = quality_assessor.generate_quality_warnings(transcript, 0.85)
assert any("Repeated words detected" in warning for warning in warnings)
def test_generate_quality_warnings_long_pauses(self, quality_assessor):
"""Test quality warnings for long pauses."""
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "First segment"},
{"start": 8.0, "end": 13.0, "text": "Second segment after long pause"}
]
}
warnings = quality_assessor.generate_quality_warnings(transcript, 0.85)
assert any("Long pause detected" in warning for warning in warnings)
class TestConfidenceScorer:
"""Test confidence scoring functionality."""
@pytest.fixture
def confidence_scorer(self):
"""Create a confidence scorer instance."""
return ConfidenceScorer()
def test_confidence_scorer_initialization(self, confidence_scorer):
"""Test confidence scorer initialization."""
assert confidence_scorer.base_confidence == 0.85
assert confidence_scorer.min_confidence == 0.5
assert confidence_scorer.max_confidence == 0.99
def test_calculate_segment_confidence(self, confidence_scorer):
"""Test individual segment confidence calculation."""
segment = {
"text": "Hello world with technical terms like React.js",
"start": 0.0,
"end": 5.0
}
confidence = confidence_scorer.calculate_segment_confidence(segment)
assert 0.5 <= confidence <= 0.99
def test_calculate_overall_confidence(self, confidence_scorer):
"""Test overall confidence calculation."""
segments = [
{"text": "First segment", "confidence": 0.9},
{"text": "Second segment", "confidence": 0.8},
{"text": "Third segment", "confidence": 0.95}
]
overall_confidence = confidence_scorer.calculate_overall_confidence(segments)
assert 0.5 <= overall_confidence <= 0.95
# Should be weighted average of segment confidences
def test_identify_low_confidence_segments(self, confidence_scorer):
"""Test identification of low confidence segments."""
segments = [
{"text": "High confidence", "confidence": 0.95},
{"text": "Low confidence", "confidence": 0.6},
{"text": "Medium confidence", "confidence": 0.8},
{"text": "Very low confidence", "confidence": 0.4}
]
low_confidence = confidence_scorer.identify_low_confidence_segments(segments, threshold=0.7)
assert len(low_confidence) >= 2 # At least 2 segments should be below threshold
# Check that we have low confidence segments, but don't assume specific order
low_confidence_texts = [seg["text"] for seg in low_confidence]
assert "Low confidence" in low_confidence_texts
assert "Very low confidence" in low_confidence_texts
class TestTranscriptComparer:
"""Test transcript comparison functionality."""
@pytest.fixture
def transcript_comparer(self):
"""Create a transcript comparer instance."""
return TranscriptComparer()
@pytest.fixture
def original_transcript(self):
"""Create original transcript for comparison."""
return {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world"},
{"start": 5.0, "end": 10.0, "text": "This is a test"}
]
}
@pytest.fixture
def enhanced_transcript(self):
"""Create enhanced transcript for comparison."""
return {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world!"},
{"start": 5.0, "end": 10.0, "text": "This is a test transcript."}
]
}
def test_transcript_comparer_initialization(self, transcript_comparer):
"""Test transcript comparer initialization."""
assert transcript_comparer.high_similarity_threshold == 0.9
assert transcript_comparer.medium_similarity_threshold == 0.7
assert transcript_comparer.low_similarity_threshold == 0.5
def test_compare_transcripts_content_preserved(self, transcript_comparer, original_transcript, enhanced_transcript):
"""Test transcript comparison with content preserved."""
comparison = transcript_comparer.compare_transcripts(original_transcript["segments"], enhanced_transcript["segments"])
assert comparison.total_segments > 0
assert comparison.overall_improvement_score >= 0.0
assert "summary_statistics" in comparison.to_dict()
assert "quality_metrics" in comparison.to_dict()
def test_compare_transcripts_content_not_preserved(self, transcript_comparer):
"""Test transcript comparison with content not preserved."""
original = [{"start": 0.0, "end": 5.0, "text": "Short text"}]
enhanced = [{"start": 0.0, "end": 5.0, "text": "This is a much longer enhanced version of the original short text"}]
comparison = transcript_comparer.compare_transcripts(original, enhanced)
assert comparison.total_segments > 0
assert comparison.segments_with_changes > 0
assert len(comparison.segment_changes) > 0
def test_calculate_similarity_score(self, transcript_comparer, original_transcript, enhanced_transcript):
"""Test similarity score calculation."""
original_text = original_transcript["segments"][0]["text"]
enhanced_text = enhanced_transcript["segments"][0]["text"]
similarity = transcript_comparer.calculate_similarity(original_text, enhanced_text)
assert 0.0 <= similarity <= 1.0
assert similarity >= 0.5 # Should be reasonably high for similar content
class TestQualityMetrics:
"""Test quality metrics functionality."""
def test_quality_metrics_creation(self):
"""Test quality metrics creation."""
metrics = QualityMetrics(
overall_accuracy=0.85,
segment_count=10,
average_confidence=0.88,
filler_word_count=5,
tech_term_count=15,
warnings=["Low confidence in segment 3"]
)
assert metrics.overall_accuracy == 0.85
assert metrics.segment_count == 10
assert metrics.average_confidence == 0.88
assert metrics.filler_word_count == 5
assert metrics.tech_term_count == 15
assert len(metrics.warnings) == 1
def test_quality_metrics_to_dict(self):
"""Test quality metrics to dictionary conversion."""
metrics = QualityMetrics(
overall_accuracy=0.85,
segment_count=10,
average_confidence=0.88,
filler_word_count=5,
tech_term_count=15,
warnings=["Test warning"]
)
metrics_dict = metrics.to_dict()
assert "overall_accuracy" in metrics_dict
assert "segment_count" in metrics_dict
assert "average_confidence" in metrics_dict
assert "filler_word_count" in metrics_dict
assert "tech_term_count" in metrics_dict
assert "warnings" in metrics_dict
class TestQualityWarning:
"""Test quality warning functionality."""
def test_quality_warning_creation(self):
"""Test quality warning creation."""
warning = QualityWarning(
warning_type="low_confidence",
message="Low confidence detected in segment 3",
severity="medium",
segment_index=2
)
assert warning.warning_type == "low_confidence"
assert warning.message == "Low confidence detected in segment 3"
assert warning.severity == "medium"
assert warning.segment_index == 2
def test_quality_warning_to_dict(self):
"""Test quality warning to dictionary conversion."""
warning = QualityWarning(
warning_type="inaudible_section",
message="Inaudible section detected",
severity="high",
segment_index=5
)
warning_dict = warning.to_dict()
assert "warning_type" in warning_dict
assert "message" in warning_dict
assert "severity" in warning_dict
assert "segment_index" in warning_dict
class TestQualityAssessmentIntegration:
"""Integration tests for quality assessment system."""
@pytest.fixture
def quality_system(self):
"""Create a complete quality assessment system."""
from src.services.quality_assessment import QualityAssessor
from src.services.confidence_scorer import ConfidenceScorer
from src.services.transcript_comparer import TranscriptComparer
return {
"assessor": QualityAssessor(),
"scorer": ConfidenceScorer(),
"comparer": TranscriptComparer()
}
def test_end_to_end_quality_assessment(self, quality_system):
"""Test end-to-end quality assessment workflow."""
system = quality_system
# Create test transcript
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "Hello world with React.js", "confidence": 0.9},
{"start": 5.0, "end": 10.0, "text": "Um, you know, some filler words", "confidence": 0.7},
{"start": 10.0, "end": 15.0, "text": "Technical API documentation", "confidence": 0.95}
]
}
# Assess quality
accuracy = system["assessor"].estimate_accuracy(transcript)
warnings = system["assessor"].generate_quality_warnings(transcript, accuracy)
# Score confidence
overall_confidence = system["scorer"].calculate_overall_confidence(transcript["segments"])
low_confidence_segments = system["scorer"].identify_low_confidence_segments(transcript["segments"])
# Verify results
assert 0.5 <= accuracy <= 0.99
assert 0.5 <= overall_confidence <= 0.99
assert len(warnings) >= 0 # May or may not have warnings
assert len(low_confidence_segments) >= 0 # May or may not have low confidence segments
def test_quality_assessment_with_tech_content(self, quality_system):
"""Test quality assessment with technical content."""
system = quality_system
# Create tech-heavy transcript
transcript = {
"segments": [
{"start": 0.0, "end": 5.0, "text": "React.js component with useState hook", "confidence": 0.9},
{"start": 5.0, "end": 10.0, "text": "API endpoint /api/users", "confidence": 0.95},
{"start": 10.0, "end": 15.0, "text": "Database query with SQLAlchemy", "confidence": 0.88}
]
}
accuracy = system["assessor"].estimate_accuracy(transcript)
# Technical content should have higher accuracy
assert accuracy > 0.85
if __name__ == "__main__":
pytest.main([__file__, "-v"])