clean-tracks/tests/unit/test_word_detector.py

542 lines
18 KiB
Python

"""
Unit tests for WordDetector and related classes.
"""
import pytest
import json
from pathlib import Path
from unittest.mock import Mock, patch, mock_open
from src.core.word_detector import (
Severity,
DetectedWord,
WordList,
WordDetector
)
class TestSeverity:
"""Test Severity enum."""
def test_severity_values(self):
"""Test severity values are correct."""
assert Severity.LOW.value == 1
assert Severity.MEDIUM.value == 2
assert Severity.HIGH.value == 3
assert Severity.EXTREME.value == 4
def test_from_string(self):
"""Test creating severity from string."""
assert Severity.from_string('low') == Severity.LOW
assert Severity.from_string('LOW') == Severity.LOW
assert Severity.from_string('medium') == Severity.MEDIUM
assert Severity.from_string('high') == Severity.HIGH
assert Severity.from_string('extreme') == Severity.EXTREME
# Unknown values should default to MEDIUM
assert Severity.from_string('unknown') == Severity.MEDIUM
assert Severity.from_string('') == Severity.MEDIUM
def test_severity_ordering(self):
"""Test severity levels can be compared."""
assert Severity.LOW.value < Severity.MEDIUM.value
assert Severity.MEDIUM.value < Severity.HIGH.value
assert Severity.HIGH.value < Severity.EXTREME.value
class TestDetectedWord:
"""Test DetectedWord dataclass."""
def test_basic_creation(self):
"""Test creating a DetectedWord."""
word = DetectedWord(
word="badword",
original="BadWord",
start=5.0,
end=6.0,
severity=Severity.HIGH,
confidence=0.95
)
assert word.word == "badword"
assert word.original == "BadWord"
assert word.start == 5.0
assert word.end == 6.0
assert word.severity == Severity.HIGH
assert word.confidence == 0.95
assert word.context == ""
def test_duration_property(self):
"""Test duration calculation."""
word = DetectedWord(
word="test",
original="test",
start=2.5,
end=4.0,
severity=Severity.LOW,
confidence=1.0
)
assert word.duration == 1.5
def test_to_dict(self):
"""Test converting to dictionary."""
word = DetectedWord(
word="test",
original="TEST",
start=1.0,
end=2.5,
severity=Severity.MEDIUM,
confidence=0.85,
context="this is a [test] word"
)
data = word.to_dict()
assert data['word'] == "test"
assert data['original'] == "TEST"
assert data['start'] == 1.0
assert data['end'] == 2.5
assert data['duration'] == 1.5
assert data['severity'] == "MEDIUM"
assert data['confidence'] == 0.85
assert data['context'] == "this is a [test] word"
class TestWordList:
"""Test WordList class."""
def test_initialization(self):
"""Test WordList initialization."""
word_list = WordList()
# Should have some default words loaded
assert len(word_list) > 0
assert isinstance(word_list.words, dict)
assert isinstance(word_list.patterns, dict)
assert isinstance(word_list.variations, dict)
def test_add_word(self):
"""Test adding words to the list."""
word_list = WordList()
initial_count = len(word_list)
# Add word with string severity
word_list.add_word("testword", "high")
assert "testword" in word_list.words
assert word_list.words["testword"] == Severity.HIGH
# Add word with Severity enum
word_list.add_word("another", Severity.LOW)
assert "another" in word_list.words
assert word_list.words["another"] == Severity.LOW
assert len(word_list) == initial_count + 2
def test_add_word_variations(self):
"""Test that adding a word creates variations."""
word_list = WordList()
word_list.add_word("test", Severity.MEDIUM)
# Should create plural variation
assert "tests" in word_list.variations
assert word_list.variations["tests"] == "test"
def test_remove_word(self):
"""Test removing words from the list."""
word_list = WordList()
word_list.add_word("removeme", Severity.LOW)
# Verify word was added
assert "removeme" in word_list.words
# Remove the word
removed = word_list.remove_word("removeme")
assert removed is True
assert "removeme" not in word_list.words
# Try removing non-existent word
removed = word_list.remove_word("nonexistent")
assert removed is False
def test_contains(self):
"""Test checking if word is in list."""
word_list = WordList()
word_list.add_word("contained", Severity.MEDIUM)
assert "contained" in word_list
assert "CONTAINED" in word_list # Case insensitive
assert " contained " in word_list # Strips whitespace
assert "notcontained" not in word_list
def test_load_from_json_file(self, temp_dir):
"""Test loading word list from JSON file."""
# Create test JSON file
test_data = {
"word1": "LOW",
"word2": "HIGH",
"word3": "EXTREME"
}
json_file = temp_dir / "test_words.json"
with open(json_file, 'w') as f:
json.dump(test_data, f)
word_list = WordList()
initial_count = len(word_list)
word_list.load_from_file(json_file)
assert "word1" in word_list.words
assert word_list.words["word1"] == Severity.LOW
assert "word2" in word_list.words
assert word_list.words["word2"] == Severity.HIGH
assert "word3" in word_list.words
assert word_list.words["word3"] == Severity.EXTREME
assert len(word_list) == initial_count + 3
def test_load_from_csv_file(self, temp_dir):
"""Test loading word list from CSV file."""
# Create test CSV file
csv_content = """word,severity
testword1,low
testword2,medium
testword3,high"""
csv_file = temp_dir / "test_words.csv"
csv_file.write_text(csv_content)
word_list = WordList()
initial_count = len(word_list)
word_list.load_from_file(csv_file)
assert "testword1" in word_list.words
assert word_list.words["testword1"] == Severity.LOW
assert "testword2" in word_list.words
assert word_list.words["testword2"] == Severity.MEDIUM
assert "testword3" in word_list.words
assert word_list.words["testword3"] == Severity.HIGH
assert len(word_list) == initial_count + 3
def test_load_from_text_file(self, temp_dir):
"""Test loading word list from plain text file."""
# Create test text file
text_content = """word1
word2
# This is a comment
word3
"""
text_file = temp_dir / "test_words.txt"
text_file.write_text(text_content)
word_list = WordList()
initial_count = len(word_list)
word_list.load_from_file(text_file)
assert "word1" in word_list.words
assert "word2" in word_list.words
assert "word3" in word_list.words
# Comment should be ignored
assert "# This is a comment" not in word_list.words
assert len(word_list) == initial_count + 3
def test_load_nonexistent_file(self):
"""Test loading from non-existent file."""
word_list = WordList()
with pytest.raises(FileNotFoundError):
word_list.load_from_file("nonexistent.json")
def test_save_to_json_file(self, temp_dir):
"""Test saving word list to JSON file."""
word_list = WordList()
word_list.add_word("save1", Severity.LOW)
word_list.add_word("save2", Severity.HIGH)
json_file = temp_dir / "saved_words.json"
word_list.save_to_file(json_file)
assert json_file.exists()
# Load and verify
with open(json_file, 'r') as f:
data = json.load(f)
assert "save1" in data
assert "save2" in data
assert data["save1"] == "LOW"
assert data["save2"] == "HIGH"
def test_save_to_csv_file(self, temp_dir):
"""Test saving word list to CSV file."""
word_list = WordList()
word_list.add_word("csv1", Severity.MEDIUM)
word_list.add_word("csv2", Severity.EXTREME)
csv_file = temp_dir / "saved_words.csv"
word_list.save_to_file(csv_file)
assert csv_file.exists()
# Verify content
content = csv_file.read_text()
assert "csv1,medium" in content
assert "csv2,extreme" in content
assert "word,severity" in content # Header
class TestWordDetector:
"""Test WordDetector class."""
def test_initialization_default(self):
"""Test detector initialization with defaults."""
detector = WordDetector()
assert detector.word_list is not None
assert detector.min_confidence == 0.7
assert detector.check_variations is True
assert detector.context_window == 5
def test_initialization_custom(self):
"""Test detector initialization with custom parameters."""
word_list = WordList()
detector = WordDetector(
word_list=word_list,
min_confidence=0.8,
check_variations=False,
context_window=3
)
assert detector.word_list == word_list
assert detector.min_confidence == 0.8
assert detector.check_variations is False
assert detector.context_window == 3
def test_detect_direct_match(self):
"""Test detecting direct word matches."""
word_list = WordList()
word_list.add_word("badword", Severity.HIGH)
detector = WordDetector(word_list=word_list)
# Mock transcription result
mock_word = Mock()
mock_word.text = "badword"
mock_word.start = 5.0
mock_word.end = 6.0
mock_transcription = Mock()
mock_transcription.words = [mock_word]
detected = detector.detect(mock_transcription, include_context=False)
assert len(detected) == 1
assert detected[0].word == "badword"
assert detected[0].original == "badword"
assert detected[0].start == 5.0
assert detected[0].end == 6.0
assert detected[0].severity == Severity.HIGH
assert detected[0].confidence == 1.0
def test_detect_case_insensitive(self):
"""Test case-insensitive detection."""
word_list = WordList()
word_list.add_word("badword", Severity.MEDIUM)
detector = WordDetector(word_list=word_list)
# Mock transcription with uppercase word
mock_word = Mock()
mock_word.text = "BADWORD"
mock_word.start = 2.0
mock_word.end = 3.0
mock_transcription = Mock()
mock_transcription.words = [mock_word]
detected = detector.detect(mock_transcription, include_context=False)
assert len(detected) == 1
assert detected[0].word == "badword" # Normalized
assert detected[0].original == "BADWORD" # Original preserved
def test_detect_with_context(self):
"""Test detection with context extraction."""
word_list = WordList()
word_list.add_word("explicit", Severity.MEDIUM)
detector = WordDetector(word_list=word_list, context_window=2)
# Mock transcription with multiple words
words = []
word_texts = ["this", "is", "explicit", "content", "here"]
for i, text in enumerate(word_texts):
word = Mock()
word.text = text
word.start = float(i)
word.end = float(i + 1)
words.append(word)
mock_transcription = Mock()
mock_transcription.words = words
detected = detector.detect(mock_transcription, include_context=True)
assert len(detected) == 1
assert detected[0].word == "explicit"
assert detected[0].context == "this is [explicit] content here"
def test_detect_variations(self):
"""Test detection of word variations."""
word_list = WordList()
word_list.add_word("test", Severity.LOW)
# This should create "tests" variation
detector = WordDetector(word_list=word_list, check_variations=True)
# Mock transcription with variation
mock_word = Mock()
mock_word.text = "tests"
mock_word.start = 1.0
mock_word.end = 2.0
mock_transcription = Mock()
mock_transcription.words = [mock_word]
detected = detector.detect(mock_transcription, include_context=False)
assert len(detected) == 1
assert detected[0].word == "test" # Base word
assert detected[0].original == "tests" # Original variation
assert detected[0].confidence == 0.95 # Variation confidence
def test_detect_no_variations(self):
"""Test detection with variations disabled."""
word_list = WordList()
word_list.add_word("test", Severity.LOW)
detector = WordDetector(word_list=word_list, check_variations=False)
# Mock transcription with variation that shouldn't match
mock_word = Mock()
mock_word.text = "tests"
mock_word.start = 1.0
mock_word.end = 2.0
mock_transcription = Mock()
mock_transcription.words = [mock_word]
detected = detector.detect(mock_transcription, include_context=False)
assert len(detected) == 0
def test_check_variations_known(self):
"""Test checking known variations."""
word_list = WordList()
word_list.add_word("base", Severity.MEDIUM)
word_list.variations["bases"] = "base" # Manually add variation
detector = WordDetector(word_list=word_list)
match, confidence = detector._check_variations("bases")
assert match == "bases"
assert confidence == 0.95
def test_check_variations_fuzzy(self):
"""Test fuzzy matching for variations."""
word_list = WordList()
word_list.add_word("hello", Severity.LOW)
detector = WordDetector(word_list=word_list, min_confidence=0.8)
# Test similar word
match, confidence = detector._check_variations("helo") # Missing 'l'
if match: # Fuzzy matching might or might not match depending on similarity
assert confidence >= 0.8
def test_get_context_boundary(self):
"""Test context extraction at boundaries."""
detector = WordDetector(context_window=2)
# Create mock words
word_texts = ["a", "b", "target", "d", "e"]
words = []
for text in word_texts:
word = Mock()
word.text = text
words.append(word)
# Test target at beginning
context = detector._get_context(words, 0)
assert context == "[a] b target"
# Test target at end
context = detector._get_context(words, 4)
assert context == "target d [e]"
# Test target in middle
context = detector._get_context(words, 2)
assert context == "a b [target] d e"
def test_filter_by_severity(self):
"""Test filtering detected words by severity."""
detector = WordDetector()
# Create detected words with different severities
detected_words = [
DetectedWord("low", "low", 1.0, 2.0, Severity.LOW, 1.0),
DetectedWord("med", "med", 3.0, 4.0, Severity.MEDIUM, 1.0),
DetectedWord("high", "high", 5.0, 6.0, Severity.HIGH, 1.0),
DetectedWord("extreme", "extreme", 7.0, 8.0, Severity.EXTREME, 1.0)
]
# Filter by MEDIUM and above
filtered = detector.filter_by_severity(detected_words, Severity.MEDIUM)
assert len(filtered) == 3 # MEDIUM, HIGH, EXTREME
severities = [w.severity for w in filtered]
assert Severity.LOW not in severities
assert Severity.MEDIUM in severities
assert Severity.HIGH in severities
assert Severity.EXTREME in severities
def test_get_statistics_empty(self):
"""Test statistics for empty detection results."""
detector = WordDetector()
stats = detector.get_statistics([])
assert stats['total_count'] == 0
assert stats['unique_words'] == 0
assert stats['by_severity'] == {}
assert stats['most_common'] == []
def test_get_statistics_with_words(self):
"""Test statistics for detection results."""
detector = WordDetector()
detected_words = [
DetectedWord("word1", "word1", 1.0, 2.0, Severity.HIGH, 0.9),
DetectedWord("word1", "word1", 3.0, 4.0, Severity.HIGH, 0.8),
DetectedWord("word2", "word2", 5.0, 6.0, Severity.MEDIUM, 0.95),
DetectedWord("word3", "word3", 7.0, 8.0, Severity.LOW, 1.0)
]
stats = detector.get_statistics(detected_words)
assert stats['total_count'] == 4
assert stats['unique_words'] == 3
assert stats['by_severity']['HIGH'] == 2
assert stats['by_severity']['MEDIUM'] == 1
assert stats['by_severity']['LOW'] == 1
assert stats['most_common'][0] == ('word1', 2) # Most frequent
assert stats['average_confidence'] == (0.9 + 0.8 + 0.95 + 1.0) / 4
if __name__ == '__main__':
pytest.main([__file__, '-v'])