""" Unit tests for transcription module. """ import pytest import json import numpy as np from pathlib import Path from unittest.mock import Mock, patch, MagicMock from src.core.transcription import ( WhisperModel, Word, TranscriptionSegment, TranscriptionResult, WhisperTranscriber ) class TestWhisperModel: """Test WhisperModel enum.""" def test_model_values(self): """Test model enum values.""" assert WhisperModel.TINY.value == "tiny" assert WhisperModel.BASE.value == "base" assert WhisperModel.SMALL.value == "small" assert WhisperModel.MEDIUM.value == "medium" assert WhisperModel.LARGE.value == "large" assert WhisperModel.LARGE_V2.value == "large-v2" assert WhisperModel.LARGE_V3.value == "large-v3" def test_parameters_property(self): """Test parameters property.""" assert WhisperModel.TINY.parameters == "39M" assert WhisperModel.BASE.parameters == "74M" assert WhisperModel.SMALL.parameters == "244M" assert WhisperModel.MEDIUM.parameters == "769M" assert WhisperModel.LARGE.parameters == "1550M" assert WhisperModel.LARGE_V2.parameters == "1550M" assert WhisperModel.LARGE_V3.parameters == "1550M" def test_relative_speed_property(self): """Test relative speed property.""" assert WhisperModel.TINY.relative_speed == 1 assert WhisperModel.BASE.relative_speed == 2 assert WhisperModel.SMALL.relative_speed == 3 assert WhisperModel.MEDIUM.relative_speed == 5 assert WhisperModel.LARGE.relative_speed == 8 assert WhisperModel.LARGE_V2.relative_speed == 8 assert WhisperModel.LARGE_V3.relative_speed == 8 def test_speed_ordering(self): """Test that speed increases with model size.""" speeds = [model.relative_speed for model in [ WhisperModel.TINY, WhisperModel.BASE, WhisperModel.SMALL, WhisperModel.MEDIUM, WhisperModel.LARGE ]] assert speeds == sorted(speeds) class TestWord: """Test Word dataclass.""" def test_word_creation(self): """Test creating a Word.""" word = Word( text="hello", start=1.5, end=2.0, confidence=0.95 ) assert word.text == "hello" assert word.start == 1.5 assert word.end == 2.0 assert word.confidence == 0.95 def test_word_duration(self): """Test duration calculation.""" word = Word("test", 2.0, 3.5, 0.9) assert word.duration == 1.5 def test_word_to_dict(self): """Test converting word to dictionary.""" word = Word("world", 0.5, 1.2, 0.88) data = word.to_dict() assert data['text'] == "world" assert data['start'] == 0.5 assert data['end'] == 1.2 assert data['confidence'] == 0.88 assert data['duration'] == 0.7 def test_word_default_confidence(self): """Test default confidence value.""" word = Word("test", 0.0, 1.0) assert word.confidence == 1.0 class TestTranscriptionSegment: """Test TranscriptionSegment dataclass.""" def test_segment_creation(self): """Test creating a TranscriptionSegment.""" words = [ Word("hello", 0.0, 0.5, 0.9), Word("world", 0.5, 1.0, 0.95) ] segment = TranscriptionSegment( id=0, text="hello world", start=0.0, end=1.0, words=words ) assert segment.id == 0 assert segment.text == "hello world" assert segment.start == 0.0 assert segment.end == 1.0 assert len(segment.words) == 2 def test_segment_duration(self): """Test segment duration calculation.""" segment = TranscriptionSegment( id=1, text="test segment", start=2.5, end=5.0 ) assert segment.duration == 2.5 def test_segment_to_dict(self): """Test converting segment to dictionary.""" words = [Word("test", 0.0, 1.0, 0.8)] segment = TranscriptionSegment( id=2, text="test", start=0.0, end=1.0, words=words ) data = segment.to_dict() assert data['id'] == 2 assert data['text'] == "test" assert data['start'] == 0.0 assert data['end'] == 1.0 assert data['duration'] == 1.0 assert len(data['words']) == 1 assert data['words'][0]['text'] == "test" def test_segment_empty_words(self): """Test segment with no words.""" segment = TranscriptionSegment( id=0, text="empty", start=0.0, end=1.0 ) assert len(segment.words) == 0 assert segment.to_dict()['words'] == [] class TestTranscriptionResult: """Test TranscriptionResult dataclass.""" def test_result_creation(self): """Test creating a TranscriptionResult.""" segments = [ TranscriptionSegment(0, "hello world", 0.0, 2.0), TranscriptionSegment(1, "how are you", 2.0, 4.0) ] result = TranscriptionResult( text="hello world how are you", segments=segments, language="en", duration=4.0, model_used="base", processing_time=1.5 ) assert result.text == "hello world how are you" assert len(result.segments) == 2 assert result.language == "en" assert result.duration == 4.0 assert result.model_used == "base" assert result.processing_time == 1.5 def test_word_count_property(self): """Test word count calculation.""" words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)] words2 = [Word("how", 1.0, 1.3), Word("are", 1.3, 1.6), Word("you", 1.6, 2.0)] segments = [ TranscriptionSegment(0, "hello world", 0.0, 1.0, words1), TranscriptionSegment(1, "how are you", 1.0, 2.0, words2) ] result = TranscriptionResult( text="hello world how are you", segments=segments, language="en", duration=2.0, model_used="base" ) assert result.word_count == 5 def test_words_property(self): """Test getting all words from all segments.""" words1 = [Word("hello", 0.0, 0.5), Word("world", 0.5, 1.0)] words2 = [Word("test", 1.0, 1.5)] segments = [ TranscriptionSegment(0, "hello world", 0.0, 1.0, words1), TranscriptionSegment(1, "test", 1.0, 1.5, words2) ] result = TranscriptionResult( text="hello world test", segments=segments, language="en", duration=1.5, model_used="base" ) all_words = result.words assert len(all_words) == 3 assert all_words[0].text == "hello" assert all_words[1].text == "world" assert all_words[2].text == "test" def test_to_dict(self): """Test converting result to dictionary.""" words = [Word("hello", 0.0, 0.5)] segments = [TranscriptionSegment(0, "hello", 0.0, 0.5, words)] result = TranscriptionResult( text="hello", segments=segments, language="en", duration=0.5, model_used="tiny", processing_time=0.8 ) data = result.to_dict() assert data['text'] == "hello" assert len(data['segments']) == 1 assert data['language'] == "en" assert data['duration'] == 0.5 assert data['model_used'] == "tiny" assert data['processing_time'] == 0.8 assert data['word_count'] == 1 def test_to_json(self): """Test converting result to JSON.""" result = TranscriptionResult( text="test", segments=[], language="en", duration=1.0, model_used="base" ) json_str = result.to_json() data = json.loads(json_str) assert data['text'] == "test" assert data['language'] == "en" assert data['model_used'] == "base" def test_save_to_file(self, temp_dir): """Test saving result to file.""" result = TranscriptionResult( text="save test", segments=[], language="en", duration=2.0, model_used="small" ) file_path = temp_dir / "transcription.json" result.save_to_file(file_path) assert file_path.exists() # Load and verify with open(file_path, 'r') as f: data = json.load(f) assert data['text'] == "save test" assert data['language'] == "en" assert data['model_used'] == "small" class TestWhisperTranscriber: """Test WhisperTranscriber class.""" @patch('src.core.transcription.torch') def test_initialization_default(self, mock_torch): """Test transcriber initialization with defaults.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber() assert transcriber.model_size == WhisperModel.BASE assert transcriber.device == "cpu" assert transcriber.in_memory is True @patch('src.core.transcription.torch') def test_initialization_custom(self, mock_torch): """Test transcriber initialization with custom parameters.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber( model_size=WhisperModel.SMALL, device="cpu", in_memory=False ) assert transcriber.model_size == WhisperModel.SMALL assert transcriber.device == "cpu" assert transcriber.in_memory is False @patch('src.core.transcription.torch') def test_device_detection_cuda(self, mock_torch): """Test CUDA device detection.""" mock_torch.cuda.is_available.return_value = True mock_torch.backends.mps.is_available.return_value = False with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber() assert transcriber.device == "cuda" @patch('src.core.transcription.torch') def test_device_detection_mps(self, mock_torch): """Test MPS (Apple Silicon) device detection.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = True with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber() assert transcriber.device == "mps" @patch('src.core.transcription.torch') @patch('src.core.transcription.whisper') def test_load_model(self, mock_whisper, mock_torch): """Test loading Whisper model.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False mock_model = Mock() mock_whisper.load_model.return_value = mock_model transcriber = WhisperTranscriber(in_memory=False) transcriber._load_model() assert transcriber.model == mock_model mock_whisper.load_model.assert_called_once_with( "base", device="cpu", download_root=None ) @patch('src.core.transcription.torch') def test_unload_model(self, mock_torch): """Test unloading model from memory.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber(in_memory=False) # Simulate loaded model transcriber.model = Mock() transcriber._unload_model() assert transcriber.model is None @patch('src.core.transcription.torch') @patch('src.core.transcription.whisper') def test_transcribe_file_path(self, mock_whisper, mock_torch): """Test transcribing from file path.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False # Mock Whisper model and results mock_model = Mock() mock_whisper.load_model.return_value = mock_model mock_result = { 'text': 'hello world', 'language': 'en', 'segments': [ { 'id': 0, 'text': 'hello world', 'start': 0.0, 'end': 2.0, 'words': [ {'word': 'hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9}, {'word': 'world', 'start': 1.0, 'end': 2.0, 'probability': 0.95} ] } ] } mock_model.transcribe.return_value = mock_result transcriber = WhisperTranscriber(in_memory=False) result = transcriber.transcribe("/path/to/audio.mp3") assert isinstance(result, TranscriptionResult) assert result.text == 'hello world' assert result.language == 'en' assert result.model_used == 'base' assert len(result.segments) == 1 assert result.word_count == 2 assert result.processing_time > 0 @patch('src.core.transcription.torch') @patch('src.core.transcription.whisper') def test_transcribe_array(self, mock_whisper, mock_torch): """Test transcribing from numpy array.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False mock_model = Mock() mock_whisper.load_model.return_value = mock_model mock_result = { 'text': 'test transcription', 'language': 'en', 'segments': [] } mock_model.transcribe.return_value = mock_result transcriber = WhisperTranscriber(in_memory=False) audio_array = np.random.randn(16000) # 1 second at 16kHz result = transcriber.transcribe(audio_array) assert result.text == 'test transcription' assert result.language == 'en' @patch('src.core.transcription.torch') @patch('src.core.transcription.whisper') def test_transcribe_with_language(self, mock_whisper, mock_torch): """Test transcribing with specified language.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False mock_model = Mock() mock_whisper.load_model.return_value = mock_model mock_result = { 'text': 'bonjour monde', 'language': 'fr', 'segments': [] } mock_model.transcribe.return_value = mock_result transcriber = WhisperTranscriber(in_memory=False) result = transcriber.transcribe("audio.mp3", language="fr") # Verify language was passed to Whisper call_args = mock_model.transcribe.call_args assert call_args[1]['language'] == "fr" assert result.language == 'fr' @patch('src.core.transcription.torch') @patch('src.core.transcription.whisper') def test_transcribe_translate_task(self, mock_whisper, mock_torch): """Test transcribing with translate task.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False mock_model = Mock() mock_whisper.load_model.return_value = mock_model mock_result = { 'text': 'hello world', # Translated to English 'language': 'fr', 'segments': [] } mock_model.transcribe.return_value = mock_result transcriber = WhisperTranscriber(in_memory=False) result = transcriber.transcribe("audio.mp3", task="translate") # Verify task was passed to Whisper call_args = mock_model.transcribe.call_args assert call_args[1]['task'] == "translate" @patch('src.core.transcription.torch') def test_process_results(self, mock_torch): """Test processing Whisper results.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False with patch.object(WhisperTranscriber, '_load_model'): transcriber = WhisperTranscriber(in_memory=False) raw_result = { 'text': ' hello world ', 'language': 'en', 'segments': [ { 'id': 0, 'text': ' hello world ', 'start': 0.0, 'end': 2.0, 'words': [ {'word': ' hello', 'start': 0.0, 'end': 1.0, 'probability': 0.9}, {'word': ' world', 'start': 1.0, 'end': 2.0, 'probability': 0.95} ] } ] } result = transcriber._process_results(raw_result) assert result.text == 'hello world' # Stripped assert result.language == 'en' assert result.duration == 2.0 assert len(result.segments) == 1 segment = result.segments[0] assert segment.text == 'hello world' # Stripped assert len(segment.words) == 2 assert segment.words[0].text == 'hello' # Stripped assert segment.words[1].text == 'world' # Stripped @patch('src.core.transcription.torch') def test_transcribe_with_chunks(self, mock_torch): """Test chunked transcription.""" mock_torch.cuda.is_available.return_value = False mock_torch.backends.mps.is_available.return_value = False transcriber = WhisperTranscriber(in_memory=False) # Mock the regular transcribe method def mock_transcribe(audio_data, **kwargs): return TranscriptionResult( text="chunk text", segments=[ TranscriptionSegment(0, "chunk text", 0.0, 2.0, [ Word("chunk", 0.0, 1.0, 0.9), Word("text", 1.0, 2.0, 0.95) ]) ], language="en", duration=2.0, model_used="base" ) transcriber.transcribe = mock_transcribe # Create test audio (2 chunks worth) sample_rate = 16000 audio_data = np.random.randn(70 * sample_rate) # 70 seconds result = transcriber.transcribe_with_chunks( audio_data, sample_rate, chunk_duration=30, overlap=2 ) assert isinstance(result, TranscriptionResult) assert result.model_used == "base" assert len(result.segments) >= 2 # Should have multiple chunks if __name__ == '__main__': pytest.main([__file__, '-v'])