clean-tracks/tests/unit/test_transcription.py

591 lines
20 KiB
Python

"""
Unit tests for transcription module.
"""
import pytest
import json
import numpy as np
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from src.core.transcription import (
WhisperModel,
Word,
TranscriptionSegment,
TranscriptionResult,
WhisperTranscriber
)
class TestWhisperModel:
"""Test WhisperModel enum."""
def test_model_values(self):
"""Test model enum values."""
assert WhisperModel.TINY.value == "tiny"
assert WhisperModel.BASE.value == "base"
assert WhisperModel.SMALL.value == "small"
assert WhisperModel.MEDIUM.value == "medium"
assert WhisperModel.LARGE.value == "large"
assert WhisperModel.LARGE_V2.value == "large-v2"
assert WhisperModel.LARGE_V3.value == "large-v3"
def test_parameters_property(self):
"""Test parameters property."""
assert WhisperModel.TINY.parameters == "39M"
assert WhisperModel.BASE.parameters == "74M"
assert WhisperModel.SMALL.parameters == "244M"
assert WhisperModel.MEDIUM.parameters == "769M"
assert WhisperModel.LARGE.parameters == "1550M"
assert WhisperModel.LARGE_V2.parameters == "1550M"
assert WhisperModel.LARGE_V3.parameters == "1550M"
def test_relative_speed_property(self):
"""Test relative speed property."""
assert WhisperModel.TINY.relative_speed == 1
assert WhisperModel.BASE.relative_speed == 2
assert WhisperModel.SMALL.relative_speed == 3
assert WhisperModel.MEDIUM.relative_speed == 5
assert WhisperModel.LARGE.relative_speed == 8
assert WhisperModel.LARGE_V2.relative_speed == 8
assert WhisperModel.LARGE_V3.relative_speed == 8
def test_speed_ordering(self):
"""Test that speed increases with model size."""
speeds = [model.relative_speed for model in [
WhisperModel.TINY, WhisperModel.BASE, WhisperModel.SMALL,
WhisperModel.MEDIUM, WhisperModel.LARGE
]]
assert speeds == sorted(speeds)
class TestWord:
"""Test Word dataclass."""
def test_word_creation(self):
"""Test creating a Word."""
word = Word(
text="hello",
start=1.5,
end=2.0,
confidence=0.95
)
assert word.text == "hello"
assert word.start == 1.5
assert word.end == 2.0
assert word.confidence == 0.95
def test_word_duration(self):
"""Test duration calculation."""
word = Word("test", 2.0, 3.5, 0.9)
assert word.duration == 1.5
def test_word_to_dict(self):
"""Test converting word to dictionary."""
word = Word("world", 0.5, 1.2, 0.88)
data = word.to_dict()
assert data['text'] == "world"
assert data['start'] == 0.5
assert data['end'] == 1.2
assert data['confidence'] == 0.88
assert data['duration'] == 0.7
def test_word_default_confidence(self):
"""Test default confidence value."""
word = Word("test", 0.0, 1.0)
assert word.confidence == 1.0
class TestTranscriptionSegment:
"""Test TranscriptionSegment dataclass."""
def test_segment_creation(self):
"""Test creating a TranscriptionSegment."""
words = [
Word("hello", 0.0, 0.5, 0.9),
Word("world", 0.5, 1.0, 0.95)
]
segment = TranscriptionSegment(
id=0,
text="hello world",
start=0.0,
end=1.0,
words=words
)
assert segment.id == 0
assert segment.text == "hello world"
assert segment.start == 0.0
assert segment.end == 1.0
assert len(segment.words) == 2
def test_segment_duration(self):
"""Test segment duration calculation."""
segment = TranscriptionSegment(
id=1,
text="test segment",
start=2.5,
end=5.0
)
assert segment.duration == 2.5
def test_segment_to_dict(self):
"""Test converting segment to dictionary."""
words = [Word("test", 0.0, 1.0, 0.8)]
segment = TranscriptionSegment(
id=2,
text="test",
start=0.0,
end=1.0,
words=words
)
data = segment.to_dict()
assert data['id'] == 2
assert data['text'] == "test"
assert data['start'] == 0.0
assert data['end'] == 1.0
assert data['duration'] == 1.0
assert len(data['words']) == 1
assert data['words'][0]['text'] == "test"
def test_segment_empty_words(self):
"""Test segment with no words."""
segment = TranscriptionSegment(
id=0,
text="empty",
start=0.0,
end=1.0
)
assert len(segment.words) == 0
assert segment.to_dict()['words'] == []
class TestTranscriptionResult:
"""Test TranscriptionResult dataclass."""
def test_result_creation(self):
"""Test creating a TranscriptionResult."""
segments = [
TranscriptionSegment(0, "hello world", 0.0, 2.0),
TranscriptionSegment(1, "how are you", 2.0, 4.0)
]
result = TranscriptionResult(
text="hello world how are you",
segments=segments,
language="en",
duration=4.0,
model_used="base",
processing_time=1.5
)
assert result.text == "hello world how are you"
assert len(result.segments) == 2
assert result.language == "en"
assert result.duration == 4.0
assert result.model_used == "base"
assert result.processing_time == 1.5
def test_word_count_property(self):
"""Test word count calculation."""
words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)]
words2 = [Word("how", 1.0, 1.3), Word("are", 1.3, 1.6), Word("you", 1.6, 2.0)]
segments = [
TranscriptionSegment(0, "hello world", 0.0, 1.0, words1),
TranscriptionSegment(1, "how are you", 1.0, 2.0, words2)
]
result = TranscriptionResult(
text="hello world how are you",
segments=segments,
language="en",
duration=2.0,
model_used="base"
)
assert result.word_count == 5
def test_words_property(self):
"""Test getting all words from all segments."""
words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)]
words2 = [Word("test", 1.0, 1.5)]
segments = [
TranscriptionSegment(0, "hello world", 0.0, 1.0, words1),
TranscriptionSegment(1, "test", 1.0, 1.5, words2)
]
result = TranscriptionResult(
text="hello world test",
segments=segments,
language="en",
duration=1.5,
model_used="base"
)
all_words = result.words
assert len(all_words) == 3
assert all_words[0].text == "hello"
assert all_words[1].text == "world"
assert all_words[2].text == "test"
def test_to_dict(self):
"""Test converting result to dictionary."""
words = [Word("hello", 0.0, 0.5)]
segments = [TranscriptionSegment(0, "hello", 0.0, 0.5, words)]
result = TranscriptionResult(
text="hello",
segments=segments,
language="en",
duration=0.5,
model_used="tiny",
processing_time=0.8
)
data = result.to_dict()
assert data['text'] == "hello"
assert len(data['segments']) == 1
assert data['language'] == "en"
assert data['duration'] == 0.5
assert data['model_used'] == "tiny"
assert data['processing_time'] == 0.8
assert data['word_count'] == 1
def test_to_json(self):
"""Test converting result to JSON."""
result = TranscriptionResult(
text="test",
segments=[],
language="en",
duration=1.0,
model_used="base"
)
json_str = result.to_json()
data = json.loads(json_str)
assert data['text'] == "test"
assert data['language'] == "en"
assert data['model_used'] == "base"
def test_save_to_file(self, temp_dir):
"""Test saving result to file."""
result = TranscriptionResult(
text="save test",
segments=[],
language="en",
duration=2.0,
model_used="small"
)
file_path = temp_dir / "transcription.json"
result.save_to_file(file_path)
assert file_path.exists()
# Load and verify
with open(file_path, 'r') as f:
data = json.load(f)
assert data['text'] == "save test"
assert data['language'] == "en"
assert data['model_used'] == "small"
class TestWhisperTranscriber:
"""Test WhisperTranscriber class."""
@patch('src.core.transcription.torch')
def test_initialization_default(self, mock_torch):
"""Test transcriber initialization with defaults."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber()
assert transcriber.model_size == WhisperModel.BASE
assert transcriber.device == "cpu"
assert transcriber.in_memory is True
@patch('src.core.transcription.torch')
def test_initialization_custom(self, mock_torch):
"""Test transcriber initialization with custom parameters."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber(
model_size=WhisperModel.SMALL,
device="cpu",
in_memory=False
)
assert transcriber.model_size == WhisperModel.SMALL
assert transcriber.device == "cpu"
assert transcriber.in_memory is False
@patch('src.core.transcription.torch')
def test_device_detection_cuda(self, mock_torch):
"""Test CUDA device detection."""
mock_torch.cuda.is_available.return_value = True
mock_torch.backends.mps.is_available.return_value = False
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber()
assert transcriber.device == "cuda"
@patch('src.core.transcription.torch')
def test_device_detection_mps(self, mock_torch):
"""Test MPS (Apple Silicon) device detection."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = True
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber()
assert transcriber.device == "mps"
@patch('src.core.transcription.torch')
@patch('src.core.transcription.whisper')
def test_load_model(self, mock_whisper, mock_torch):
"""Test loading Whisper model."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
mock_model = Mock()
mock_whisper.load_model.return_value = mock_model
transcriber = WhisperTranscriber(in_memory=False)
transcriber._load_model()
assert transcriber.model == mock_model
mock_whisper.load_model.assert_called_once_with(
"base",
device="cpu",
download_root=None
)
@patch('src.core.transcription.torch')
def test_unload_model(self, mock_torch):
"""Test unloading model from memory."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber(in_memory=False)
# Simulate loaded model
transcriber.model = Mock()
transcriber._unload_model()
assert transcriber.model is None
@patch('src.core.transcription.torch')
@patch('src.core.transcription.whisper')
def test_transcribe_file_path(self, mock_whisper, mock_torch):
"""Test transcribing from file path."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
# Mock Whisper model and results
mock_model = Mock()
mock_whisper.load_model.return_value = mock_model
mock_result = {
'text': 'hello world',
'language': 'en',
'segments': [
{
'id': 0,
'text': 'hello world',
'start': 0.0,
'end': 2.0,
'words': [
{'word': 'hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9},
{'word': 'world', 'start': 1.0, 'end': 2.0, 'probability': 0.95}
]
}
]
}
mock_model.transcribe.return_value = mock_result
transcriber = WhisperTranscriber(in_memory=False)
result = transcriber.transcribe("/path/to/audio.mp3")
assert isinstance(result, TranscriptionResult)
assert result.text == 'hello world'
assert result.language == 'en'
assert result.model_used == 'base'
assert len(result.segments) == 1
assert result.word_count == 2
assert result.processing_time > 0
@patch('src.core.transcription.torch')
@patch('src.core.transcription.whisper')
def test_transcribe_array(self, mock_whisper, mock_torch):
"""Test transcribing from numpy array."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
mock_model = Mock()
mock_whisper.load_model.return_value = mock_model
mock_result = {
'text': 'test transcription',
'language': 'en',
'segments': []
}
mock_model.transcribe.return_value = mock_result
transcriber = WhisperTranscriber(in_memory=False)
audio_array = np.random.randn(16000) # 1 second at 16kHz
result = transcriber.transcribe(audio_array)
assert result.text == 'test transcription'
assert result.language == 'en'
@patch('src.core.transcription.torch')
@patch('src.core.transcription.whisper')
def test_transcribe_with_language(self, mock_whisper, mock_torch):
"""Test transcribing with specified language."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
mock_model = Mock()
mock_whisper.load_model.return_value = mock_model
mock_result = {
'text': 'bonjour monde',
'language': 'fr',
'segments': []
}
mock_model.transcribe.return_value = mock_result
transcriber = WhisperTranscriber(in_memory=False)
result = transcriber.transcribe("audio.mp3", language="fr")
# Verify language was passed to Whisper
call_args = mock_model.transcribe.call_args
assert call_args[1]['language'] == "fr"
assert result.language == 'fr'
@patch('src.core.transcription.torch')
@patch('src.core.transcription.whisper')
def test_transcribe_translate_task(self, mock_whisper, mock_torch):
"""Test transcribing with translate task."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
mock_model = Mock()
mock_whisper.load_model.return_value = mock_model
mock_result = {
'text': 'hello world', # Translated to English
'language': 'fr',
'segments': []
}
mock_model.transcribe.return_value = mock_result
transcriber = WhisperTranscriber(in_memory=False)
result = transcriber.transcribe("audio.mp3", task="translate")
# Verify task was passed to Whisper
call_args = mock_model.transcribe.call_args
assert call_args[1]['task'] == "translate"
@patch('src.core.transcription.torch')
def test_process_results(self, mock_torch):
"""Test processing Whisper results."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
with patch.object(WhisperTranscriber, '_load_model'):
transcriber = WhisperTranscriber(in_memory=False)
raw_result = {
'text': ' hello world ',
'language': 'en',
'segments': [
{
'id': 0,
'text': ' hello world ',
'start': 0.0,
'end': 2.0,
'words': [
{'word': ' hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9},
{'word': ' world', 'start': 1.0, 'end': 2.0, 'probability': 0.95}
]
}
]
}
result = transcriber._process_results(raw_result)
assert result.text == 'hello world' # Stripped
assert result.language == 'en'
assert result.duration == 2.0
assert len(result.segments) == 1
segment = result.segments[0]
assert segment.text == 'hello world' # Stripped
assert len(segment.words) == 2
assert segment.words[0].text == 'hello' # Stripped
assert segment.words[1].text == 'world' # Stripped
@patch('src.core.transcription.torch')
def test_transcribe_with_chunks(self, mock_torch):
"""Test chunked transcription."""
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
transcriber = WhisperTranscriber(in_memory=False)
# Mock the regular transcribe method
def mock_transcribe(audio_data, **kwargs):
return TranscriptionResult(
text="chunk text",
segments=[
TranscriptionSegment(0, "chunk text", 0.0, 2.0, [
Word("chunk", 0.0, 1.0, 0.9),
Word("text", 1.0, 2.0, 0.95)
])
],
language="en",
duration=2.0,
model_used="base"
)
transcriber.transcribe = mock_transcribe
# Create test audio (2 chunks worth)
sample_rate = 16000
audio_data = np.random.randn(70 * sample_rate) # 70 seconds
result = transcriber.transcribe_with_chunks(
audio_data,
sample_rate,
chunk_duration=30,
overlap=2
)
assert isinstance(result, TranscriptionResult)
assert result.model_used == "base"
assert len(result.segments) >= 2 # Should have multiple chunks
if __name__ == '__main__':
pytest.main([__file__, '-v'])