591 lines
20 KiB
Python
591 lines
20 KiB
Python
"""
|
|
Unit tests for transcription module.
|
|
"""
|
|
|
|
import pytest
|
|
import json
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from src.core.transcription import (
|
|
WhisperModel,
|
|
Word,
|
|
TranscriptionSegment,
|
|
TranscriptionResult,
|
|
WhisperTranscriber
|
|
)
|
|
|
|
|
|
class TestWhisperModel:
|
|
"""Test WhisperModel enum."""
|
|
|
|
def test_model_values(self):
|
|
"""Test model enum values."""
|
|
assert WhisperModel.TINY.value == "tiny"
|
|
assert WhisperModel.BASE.value == "base"
|
|
assert WhisperModel.SMALL.value == "small"
|
|
assert WhisperModel.MEDIUM.value == "medium"
|
|
assert WhisperModel.LARGE.value == "large"
|
|
assert WhisperModel.LARGE_V2.value == "large-v2"
|
|
assert WhisperModel.LARGE_V3.value == "large-v3"
|
|
|
|
def test_parameters_property(self):
|
|
"""Test parameters property."""
|
|
assert WhisperModel.TINY.parameters == "39M"
|
|
assert WhisperModel.BASE.parameters == "74M"
|
|
assert WhisperModel.SMALL.parameters == "244M"
|
|
assert WhisperModel.MEDIUM.parameters == "769M"
|
|
assert WhisperModel.LARGE.parameters == "1550M"
|
|
assert WhisperModel.LARGE_V2.parameters == "1550M"
|
|
assert WhisperModel.LARGE_V3.parameters == "1550M"
|
|
|
|
def test_relative_speed_property(self):
|
|
"""Test relative speed property."""
|
|
assert WhisperModel.TINY.relative_speed == 1
|
|
assert WhisperModel.BASE.relative_speed == 2
|
|
assert WhisperModel.SMALL.relative_speed == 3
|
|
assert WhisperModel.MEDIUM.relative_speed == 5
|
|
assert WhisperModel.LARGE.relative_speed == 8
|
|
assert WhisperModel.LARGE_V2.relative_speed == 8
|
|
assert WhisperModel.LARGE_V3.relative_speed == 8
|
|
|
|
def test_speed_ordering(self):
|
|
"""Test that speed increases with model size."""
|
|
speeds = [model.relative_speed for model in [
|
|
WhisperModel.TINY, WhisperModel.BASE, WhisperModel.SMALL,
|
|
WhisperModel.MEDIUM, WhisperModel.LARGE
|
|
]]
|
|
assert speeds == sorted(speeds)
|
|
|
|
|
|
class TestWord:
|
|
"""Test Word dataclass."""
|
|
|
|
def test_word_creation(self):
|
|
"""Test creating a Word."""
|
|
word = Word(
|
|
text="hello",
|
|
start=1.5,
|
|
end=2.0,
|
|
confidence=0.95
|
|
)
|
|
|
|
assert word.text == "hello"
|
|
assert word.start == 1.5
|
|
assert word.end == 2.0
|
|
assert word.confidence == 0.95
|
|
|
|
def test_word_duration(self):
|
|
"""Test duration calculation."""
|
|
word = Word("test", 2.0, 3.5, 0.9)
|
|
assert word.duration == 1.5
|
|
|
|
def test_word_to_dict(self):
|
|
"""Test converting word to dictionary."""
|
|
word = Word("world", 0.5, 1.2, 0.88)
|
|
data = word.to_dict()
|
|
|
|
assert data['text'] == "world"
|
|
assert data['start'] == 0.5
|
|
assert data['end'] == 1.2
|
|
assert data['confidence'] == 0.88
|
|
assert data['duration'] == 0.7
|
|
|
|
def test_word_default_confidence(self):
|
|
"""Test default confidence value."""
|
|
word = Word("test", 0.0, 1.0)
|
|
assert word.confidence == 1.0
|
|
|
|
|
|
class TestTranscriptionSegment:
|
|
"""Test TranscriptionSegment dataclass."""
|
|
|
|
def test_segment_creation(self):
|
|
"""Test creating a TranscriptionSegment."""
|
|
words = [
|
|
Word("hello", 0.0, 0.5, 0.9),
|
|
Word("world", 0.5, 1.0, 0.95)
|
|
]
|
|
|
|
segment = TranscriptionSegment(
|
|
id=0,
|
|
text="hello world",
|
|
start=0.0,
|
|
end=1.0,
|
|
words=words
|
|
)
|
|
|
|
assert segment.id == 0
|
|
assert segment.text == "hello world"
|
|
assert segment.start == 0.0
|
|
assert segment.end == 1.0
|
|
assert len(segment.words) == 2
|
|
|
|
def test_segment_duration(self):
|
|
"""Test segment duration calculation."""
|
|
segment = TranscriptionSegment(
|
|
id=1,
|
|
text="test segment",
|
|
start=2.5,
|
|
end=5.0
|
|
)
|
|
|
|
assert segment.duration == 2.5
|
|
|
|
def test_segment_to_dict(self):
|
|
"""Test converting segment to dictionary."""
|
|
words = [Word("test", 0.0, 1.0, 0.8)]
|
|
segment = TranscriptionSegment(
|
|
id=2,
|
|
text="test",
|
|
start=0.0,
|
|
end=1.0,
|
|
words=words
|
|
)
|
|
|
|
data = segment.to_dict()
|
|
|
|
assert data['id'] == 2
|
|
assert data['text'] == "test"
|
|
assert data['start'] == 0.0
|
|
assert data['end'] == 1.0
|
|
assert data['duration'] == 1.0
|
|
assert len(data['words']) == 1
|
|
assert data['words'][0]['text'] == "test"
|
|
|
|
def test_segment_empty_words(self):
|
|
"""Test segment with no words."""
|
|
segment = TranscriptionSegment(
|
|
id=0,
|
|
text="empty",
|
|
start=0.0,
|
|
end=1.0
|
|
)
|
|
|
|
assert len(segment.words) == 0
|
|
assert segment.to_dict()['words'] == []
|
|
|
|
|
|
class TestTranscriptionResult:
|
|
"""Test TranscriptionResult dataclass."""
|
|
|
|
def test_result_creation(self):
|
|
"""Test creating a TranscriptionResult."""
|
|
segments = [
|
|
TranscriptionSegment(0, "hello world", 0.0, 2.0),
|
|
TranscriptionSegment(1, "how are you", 2.0, 4.0)
|
|
]
|
|
|
|
result = TranscriptionResult(
|
|
text="hello world how are you",
|
|
segments=segments,
|
|
language="en",
|
|
duration=4.0,
|
|
model_used="base",
|
|
processing_time=1.5
|
|
)
|
|
|
|
assert result.text == "hello world how are you"
|
|
assert len(result.segments) == 2
|
|
assert result.language == "en"
|
|
assert result.duration == 4.0
|
|
assert result.model_used == "base"
|
|
assert result.processing_time == 1.5
|
|
|
|
def test_word_count_property(self):
|
|
"""Test word count calculation."""
|
|
words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)]
|
|
words2 = [Word("how", 1.0, 1.3), Word("are", 1.3, 1.6), Word("you", 1.6, 2.0)]
|
|
|
|
segments = [
|
|
TranscriptionSegment(0, "hello world", 0.0, 1.0, words1),
|
|
TranscriptionSegment(1, "how are you", 1.0, 2.0, words2)
|
|
]
|
|
|
|
result = TranscriptionResult(
|
|
text="hello world how are you",
|
|
segments=segments,
|
|
language="en",
|
|
duration=2.0,
|
|
model_used="base"
|
|
)
|
|
|
|
assert result.word_count == 5
|
|
|
|
def test_words_property(self):
|
|
"""Test getting all words from all segments."""
|
|
words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)]
|
|
words2 = [Word("test", 1.0, 1.5)]
|
|
|
|
segments = [
|
|
TranscriptionSegment(0, "hello world", 0.0, 1.0, words1),
|
|
TranscriptionSegment(1, "test", 1.0, 1.5, words2)
|
|
]
|
|
|
|
result = TranscriptionResult(
|
|
text="hello world test",
|
|
segments=segments,
|
|
language="en",
|
|
duration=1.5,
|
|
model_used="base"
|
|
)
|
|
|
|
all_words = result.words
|
|
assert len(all_words) == 3
|
|
assert all_words[0].text == "hello"
|
|
assert all_words[1].text == "world"
|
|
assert all_words[2].text == "test"
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting result to dictionary."""
|
|
words = [Word("hello", 0.0, 0.5)]
|
|
segments = [TranscriptionSegment(0, "hello", 0.0, 0.5, words)]
|
|
|
|
result = TranscriptionResult(
|
|
text="hello",
|
|
segments=segments,
|
|
language="en",
|
|
duration=0.5,
|
|
model_used="tiny",
|
|
processing_time=0.8
|
|
)
|
|
|
|
data = result.to_dict()
|
|
|
|
assert data['text'] == "hello"
|
|
assert len(data['segments']) == 1
|
|
assert data['language'] == "en"
|
|
assert data['duration'] == 0.5
|
|
assert data['model_used'] == "tiny"
|
|
assert data['processing_time'] == 0.8
|
|
assert data['word_count'] == 1
|
|
|
|
def test_to_json(self):
|
|
"""Test converting result to JSON."""
|
|
result = TranscriptionResult(
|
|
text="test",
|
|
segments=[],
|
|
language="en",
|
|
duration=1.0,
|
|
model_used="base"
|
|
)
|
|
|
|
json_str = result.to_json()
|
|
data = json.loads(json_str)
|
|
|
|
assert data['text'] == "test"
|
|
assert data['language'] == "en"
|
|
assert data['model_used'] == "base"
|
|
|
|
def test_save_to_file(self, temp_dir):
|
|
"""Test saving result to file."""
|
|
result = TranscriptionResult(
|
|
text="save test",
|
|
segments=[],
|
|
language="en",
|
|
duration=2.0,
|
|
model_used="small"
|
|
)
|
|
|
|
file_path = temp_dir / "transcription.json"
|
|
result.save_to_file(file_path)
|
|
|
|
assert file_path.exists()
|
|
|
|
# Load and verify
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
assert data['text'] == "save test"
|
|
assert data['language'] == "en"
|
|
assert data['model_used'] == "small"
|
|
|
|
|
|
class TestWhisperTranscriber:
|
|
"""Test WhisperTranscriber class."""
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_initialization_default(self, mock_torch):
|
|
"""Test transcriber initialization with defaults."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber()
|
|
|
|
assert transcriber.model_size == WhisperModel.BASE
|
|
assert transcriber.device == "cpu"
|
|
assert transcriber.in_memory is True
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_initialization_custom(self, mock_torch):
|
|
"""Test transcriber initialization with custom parameters."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber(
|
|
model_size=WhisperModel.SMALL,
|
|
device="cpu",
|
|
in_memory=False
|
|
)
|
|
|
|
assert transcriber.model_size == WhisperModel.SMALL
|
|
assert transcriber.device == "cpu"
|
|
assert transcriber.in_memory is False
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_device_detection_cuda(self, mock_torch):
|
|
"""Test CUDA device detection."""
|
|
mock_torch.cuda.is_available.return_value = True
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber()
|
|
|
|
assert transcriber.device == "cuda"
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_device_detection_mps(self, mock_torch):
|
|
"""Test MPS (Apple Silicon) device detection."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = True
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber()
|
|
|
|
assert transcriber.device == "mps"
|
|
|
|
@patch('src.core.transcription.torch')
|
|
@patch('src.core.transcription.whisper')
|
|
def test_load_model(self, mock_whisper, mock_torch):
|
|
"""Test loading Whisper model."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
mock_model = Mock()
|
|
mock_whisper.load_model.return_value = mock_model
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
transcriber._load_model()
|
|
|
|
assert transcriber.model == mock_model
|
|
mock_whisper.load_model.assert_called_once_with(
|
|
"base",
|
|
device="cpu",
|
|
download_root=None
|
|
)
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_unload_model(self, mock_torch):
|
|
"""Test unloading model from memory."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
|
|
# Simulate loaded model
|
|
transcriber.model = Mock()
|
|
|
|
transcriber._unload_model()
|
|
|
|
assert transcriber.model is None
|
|
|
|
@patch('src.core.transcription.torch')
|
|
@patch('src.core.transcription.whisper')
|
|
def test_transcribe_file_path(self, mock_whisper, mock_torch):
|
|
"""Test transcribing from file path."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
# Mock Whisper model and results
|
|
mock_model = Mock()
|
|
mock_whisper.load_model.return_value = mock_model
|
|
|
|
mock_result = {
|
|
'text': 'hello world',
|
|
'language': 'en',
|
|
'segments': [
|
|
{
|
|
'id': 0,
|
|
'text': 'hello world',
|
|
'start': 0.0,
|
|
'end': 2.0,
|
|
'words': [
|
|
{'word': 'hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9},
|
|
{'word': 'world', 'start': 1.0, 'end': 2.0, 'probability': 0.95}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
mock_model.transcribe.return_value = mock_result
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
result = transcriber.transcribe("/path/to/audio.mp3")
|
|
|
|
assert isinstance(result, TranscriptionResult)
|
|
assert result.text == 'hello world'
|
|
assert result.language == 'en'
|
|
assert result.model_used == 'base'
|
|
assert len(result.segments) == 1
|
|
assert result.word_count == 2
|
|
assert result.processing_time > 0
|
|
|
|
@patch('src.core.transcription.torch')
|
|
@patch('src.core.transcription.whisper')
|
|
def test_transcribe_array(self, mock_whisper, mock_torch):
|
|
"""Test transcribing from numpy array."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
mock_model = Mock()
|
|
mock_whisper.load_model.return_value = mock_model
|
|
|
|
mock_result = {
|
|
'text': 'test transcription',
|
|
'language': 'en',
|
|
'segments': []
|
|
}
|
|
mock_model.transcribe.return_value = mock_result
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
audio_array = np.random.randn(16000) # 1 second at 16kHz
|
|
|
|
result = transcriber.transcribe(audio_array)
|
|
|
|
assert result.text == 'test transcription'
|
|
assert result.language == 'en'
|
|
|
|
@patch('src.core.transcription.torch')
|
|
@patch('src.core.transcription.whisper')
|
|
def test_transcribe_with_language(self, mock_whisper, mock_torch):
|
|
"""Test transcribing with specified language."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
mock_model = Mock()
|
|
mock_whisper.load_model.return_value = mock_model
|
|
|
|
mock_result = {
|
|
'text': 'bonjour monde',
|
|
'language': 'fr',
|
|
'segments': []
|
|
}
|
|
mock_model.transcribe.return_value = mock_result
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
result = transcriber.transcribe("audio.mp3", language="fr")
|
|
|
|
# Verify language was passed to Whisper
|
|
call_args = mock_model.transcribe.call_args
|
|
assert call_args[1]['language'] == "fr"
|
|
assert result.language == 'fr'
|
|
|
|
@patch('src.core.transcription.torch')
|
|
@patch('src.core.transcription.whisper')
|
|
def test_transcribe_translate_task(self, mock_whisper, mock_torch):
|
|
"""Test transcribing with translate task."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
mock_model = Mock()
|
|
mock_whisper.load_model.return_value = mock_model
|
|
|
|
mock_result = {
|
|
'text': 'hello world', # Translated to English
|
|
'language': 'fr',
|
|
'segments': []
|
|
}
|
|
mock_model.transcribe.return_value = mock_result
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
result = transcriber.transcribe("audio.mp3", task="translate")
|
|
|
|
# Verify task was passed to Whisper
|
|
call_args = mock_model.transcribe.call_args
|
|
assert call_args[1]['task'] == "translate"
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_process_results(self, mock_torch):
|
|
"""Test processing Whisper results."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
with patch.object(WhisperTranscriber, '_load_model'):
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
|
|
raw_result = {
|
|
'text': ' hello world ',
|
|
'language': 'en',
|
|
'segments': [
|
|
{
|
|
'id': 0,
|
|
'text': ' hello world ',
|
|
'start': 0.0,
|
|
'end': 2.0,
|
|
'words': [
|
|
{'word': ' hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9},
|
|
{'word': ' world', 'start': 1.0, 'end': 2.0, 'probability': 0.95}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
|
|
result = transcriber._process_results(raw_result)
|
|
|
|
assert result.text == 'hello world' # Stripped
|
|
assert result.language == 'en'
|
|
assert result.duration == 2.0
|
|
assert len(result.segments) == 1
|
|
|
|
segment = result.segments[0]
|
|
assert segment.text == 'hello world' # Stripped
|
|
assert len(segment.words) == 2
|
|
assert segment.words[0].text == 'hello' # Stripped
|
|
assert segment.words[1].text == 'world' # Stripped
|
|
|
|
@patch('src.core.transcription.torch')
|
|
def test_transcribe_with_chunks(self, mock_torch):
|
|
"""Test chunked transcription."""
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = False
|
|
|
|
transcriber = WhisperTranscriber(in_memory=False)
|
|
|
|
# Mock the regular transcribe method
|
|
def mock_transcribe(audio_data, **kwargs):
|
|
return TranscriptionResult(
|
|
text="chunk text",
|
|
segments=[
|
|
TranscriptionSegment(0, "chunk text", 0.0, 2.0, [
|
|
Word("chunk", 0.0, 1.0, 0.9),
|
|
Word("text", 1.0, 2.0, 0.95)
|
|
])
|
|
],
|
|
language="en",
|
|
duration=2.0,
|
|
model_used="base"
|
|
)
|
|
|
|
transcriber.transcribe = mock_transcribe
|
|
|
|
# Create test audio (2 chunks worth)
|
|
sample_rate = 16000
|
|
audio_data = np.random.randn(70 * sample_rate) # 70 seconds
|
|
|
|
result = transcriber.transcribe_with_chunks(
|
|
audio_data,
|
|
sample_rate,
|
|
chunk_duration=30,
|
|
overlap=2
|
|
)
|
|
|
|
assert isinstance(result, TranscriptionResult)
|
|
assert result.model_used == "base"
|
|
assert len(result.segments) >= 2 # Should have multiple chunks
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v']) |