trax/tests/test_enhanced_cli.py

426 lines
17 KiB
Python

"""Unit tests for the enhanced CLI interface."""
import pytest
import tempfile
import os
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from click.testing import CliRunner
from src.cli.enhanced_cli import EnhancedCLI, EnhancedTranscribeCommand, EnhancedBatchCommand
from src.services.model_manager import ModelManager
from src.services.transcription_service import TranscriptionConfig
class TestEnhancedCLI:
"""Test the enhanced CLI interface structure."""
@pytest.fixture
def cli(self):
"""Create an enhanced CLI instance for testing."""
return EnhancedCLI()
@pytest.fixture
def runner(self):
"""Create a Click test runner."""
return CliRunner()
def test_cli_initialization(self, cli):
"""Test that CLI initializes correctly with model manager."""
assert cli.model_manager is not None
assert cli.console is not None
def test_help_display(self, runner):
"""Test that help documentation is displayed correctly."""
from src.cli.enhanced_cli import cli
result = runner.invoke(cli, ['--help'])
assert result.exit_code == 0
assert "Enhanced Audio Transcription Tool" in result.output
# Note: The arguments are now under the subcommands, not the main CLI
assert "transcribe" in result.output
assert "batch" in result.output
def test_single_file_transcription_arguments(self, runner):
"""Test single file transcription argument parsing."""
from src.cli.enhanced_cli import cli
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
result = runner.invoke(cli, [
'transcribe',
'--help'
])
# Should not fail on argument parsing
assert result.exit_code == 0
assert "INPUT" in result.output # Positional argument
assert "--output" in result.output
assert "--format" in result.output
assert "--model" in result.output
assert "--device" in result.output
assert "--domain" in result.output
assert "--diarize" in result.output
assert "--speakers" in result.output
def test_batch_processing_arguments(self, runner):
"""Test batch processing argument parsing."""
from src.cli.enhanced_cli import cli
result = runner.invoke(cli, [
'batch',
'--help'
])
# Should not fail on argument parsing
assert result.exit_code == 0
assert "INPUT" in result.output # Positional argument
assert "--output" in result.output
assert "--concurrency" in result.output
assert "--format" in result.output
assert "--model" in result.output
assert "--device" in result.output
assert "--domain" in result.output
assert "--diarize" in result.output
assert "--speakers" in result.output
def test_invalid_arguments(self, runner):
"""Test that invalid arguments are properly rejected."""
from src.cli.enhanced_cli import cli
# Test invalid format
result = runner.invoke(cli, [
'transcribe',
'--input', 'test.wav',
'--format', 'invalid_format'
])
assert result.exit_code == 2 # Usage error
# Test invalid device
result = runner.invoke(cli, [
'transcribe',
'--input', 'test.wav',
'--device', 'invalid_device'
])
assert result.exit_code == 2 # Usage error
# Test invalid domain
result = runner.invoke(cli, [
'transcribe',
'--input', 'test.wav',
'--domain', 'invalid_domain'
])
assert result.exit_code == 2 # Usage error
def test_model_manager_integration(self, cli):
"""Test that CLI properly integrates with ModelManager."""
mock_manager = Mock()
cli.model_manager = mock_manager
# Test that CLI can access model manager methods
cli.model_manager.get_available_models.return_value = ['tiny', 'base', 'small', 'medium', 'large']
models = cli.model_manager.get_available_models()
assert models == ['tiny', 'base', 'small', 'medium', 'large']
mock_manager.get_available_models.assert_called_once()
class TestEnhancedTranscribeCommand:
"""Test the enhanced transcribe command."""
@pytest.fixture
def command(self):
"""Create an enhanced transcribe command instance."""
return EnhancedTranscribeCommand()
@pytest.fixture
def runner(self):
"""Create a Click test runner."""
return CliRunner()
@pytest.mark.asyncio
async def test_single_file_transcription(self, command, runner):
"""Test single file transcription execution."""
with patch('src.cli.enhanced_cli.create_transcription_service') as mock_service_factory, \
patch.object(command, '_get_audio_duration') as mock_duration:
mock_service = AsyncMock()
mock_service_factory.return_value = mock_service
mock_service.initialize = AsyncMock()
mock_duration.return_value = 60.0 # Mock 60 seconds duration
# Mock transcription result
mock_result = Mock()
mock_result.text_content = "Test transcription result"
mock_result.accuracy = 95.5
mock_result.processing_time = 10.5
mock_result.quality_warnings = []
mock_service.transcribe_file.return_value = mock_result
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
# Create a real file for testing
temp_file.write(b'test audio data')
temp_file.flush()
result = await command.execute_transcription(
input_path=temp_file.name,
output_dir='/tmp/output',
format_type='json',
model='base',
device='cpu',
domain=None,
diarize=False,
speakers=None
)
assert result is not None
mock_service.initialize.assert_called_once()
mock_service.transcribe_file.assert_called_once()
def test_progress_callback_integration(self, command):
"""Test that progress callback integrates with Rich progress bars."""
with patch('src.cli.enhanced_cli.Progress') as mock_progress:
mock_task = Mock()
mock_progress.return_value.__enter__.return_value.add_task.return_value = mock_task
# Test progress callback creation and execution
callback = command._create_progress_callback(mock_task, 100.0)
# Verify callback is callable and executes without error
assert callable(callback)
callback(50.0, 100.0) # Should execute without error
def test_export_formats(self, command):
"""Test export functionality for different formats."""
mock_result = Mock()
mock_result.text_content = "Test transcription"
mock_result.segments = [
{"start": 0.0, "end": 2.0, "text": "Hello world"},
{"start": 2.0, "end": 4.0, "text": "How are you"}
]
with tempfile.TemporaryDirectory() as temp_dir:
# Test JSON export
json_path = command._export_result(mock_result, "test.wav", temp_dir, "json")
assert Path(json_path).exists()
assert json_path.endswith('.json')
# Test TXT export
txt_path = command._export_result(mock_result, "test.wav", temp_dir, "txt")
assert Path(txt_path).exists()
assert txt_path.endswith('.txt')
# Test SRT export
srt_path = command._export_result(mock_result, "test.wav", temp_dir, "srt")
assert Path(srt_path).exists()
assert srt_path.endswith('.srt')
# Test VTT export
vtt_path = command._export_result(mock_result, "test.wav", temp_dir, "vtt")
assert Path(vtt_path).exists()
assert vtt_path.endswith('.vtt')
class TestEnhancedBatchCommand:
"""Test the enhanced batch command."""
@pytest.fixture
def command(self):
"""Create an enhanced batch command instance."""
return EnhancedBatchCommand()
@pytest.mark.asyncio
async def test_batch_processing_setup(self, command):
"""Test batch processing setup and file discovery."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create test files
test_files = [
Path(temp_dir) / "small.wav",
Path(temp_dir) / "medium.mp3",
Path(temp_dir) / "large.m4a"
]
for file_path in test_files:
file_path.touch()
# Set different file sizes for intelligent queuing
file_path.write_bytes(b'x' * (test_files.index(file_path) + 1) * 1024)
files = command._discover_files(temp_dir)
assert len(files) == 3
# Test intelligent queuing (smaller files first)
sorted_files = command._sort_files_by_size(files)
assert sorted_files[0].stat().st_size <= sorted_files[1].stat().st_size
assert sorted_files[1].stat().st_size <= sorted_files[2].stat().st_size
@pytest.mark.asyncio
async def test_concurrent_processing(self, command):
"""Test concurrent processing with ThreadPoolExecutor."""
with patch('concurrent.futures.ThreadPoolExecutor') as mock_executor, \
patch('concurrent.futures.as_completed') as mock_as_completed:
# Create mock futures that can be iterated
mock_future1 = Mock()
mock_future2 = Mock()
mock_future3 = Mock()
mock_executor.return_value.__enter__.return_value.submit.side_effect = [mock_future1, mock_future2, mock_future3]
mock_as_completed.return_value = [mock_future1, mock_future2, mock_future3]
test_files = ["file1.wav", "file2.wav", "file3.wav"]
await command._process_concurrently(
files=test_files,
concurrency=2,
transcription_func=Mock(),
progress_callback=Mock()
)
# Verify ThreadPoolExecutor was used with correct max_workers
mock_executor.assert_called_with(max_workers=2)
def test_performance_monitoring(self, command):
"""Test performance monitoring functionality."""
with patch('src.cli.enhanced_cli.psutil') as mock_psutil:
mock_psutil.cpu_percent.return_value = 45.2
mock_psutil.virtual_memory.return_value = Mock(
used=2 * 1024**3, # 2GB used
total=8 * 1024**3, # 8GB total
percent=25.0
)
mock_psutil.sensors_temperatures.return_value = {
'coretemp': [Mock(current=65.0)]
}
stats = command._get_performance_stats()
assert 'cpu_percent' in stats
assert 'memory_used_gb' in stats
assert 'memory_total_gb' in stats
assert 'memory_percent' in stats
assert 'cpu_temperature' in stats
assert stats['cpu_percent'] == 45.2
assert stats['memory_used_gb'] == 2.0
assert stats['memory_total_gb'] == 8.0
assert stats['memory_percent'] == 25.0
assert stats['cpu_temperature'] == 65.0
class TestErrorHandling:
"""Test error handling and user guidance."""
@pytest.fixture
def cli(self):
"""Create an enhanced CLI instance for testing."""
return EnhancedCLI()
def test_file_not_found_error_handling(self, cli):
"""Test handling of FileNotFoundError."""
error = FileNotFoundError("No such file or directory: 'nonexistent.wav'")
guidance = cli._get_error_guidance(type(error).__name__, str(error))
assert "Check that the input file path is correct" in guidance
assert "file exists" in guidance
def test_permission_error_handling(self, cli):
"""Test handling of PermissionError."""
error = PermissionError("Permission denied: 'protected.wav'")
guidance = cli._get_error_guidance(type(error).__name__, str(error))
assert "Check file permissions" in guidance
assert "administrator privileges" in guidance
def test_cuda_error_handling(self, cli):
"""Test handling of CUDA/GPU errors."""
error = RuntimeError("CUDA out of memory")
guidance = cli._get_error_guidance(type(error).__name__, str(error))
assert "GPU-related error" in guidance
assert "--device cpu" in guidance
def test_memory_error_handling(self, cli):
"""Test handling of memory errors."""
error = MemoryError("Not enough memory")
guidance = cli._get_error_guidance(type(error).__name__, str(error))
assert "Memory error" in guidance
assert "--model small" in guidance
assert "reduce concurrency" in guidance
def test_generic_error_handling(self, cli):
"""Test handling of generic errors."""
error = ValueError("Invalid parameter")
guidance = cli._get_error_guidance(type(error).__name__, str(error))
assert "Check input parameters" in guidance
assert "try again" in guidance
class TestIntegration:
"""Integration tests for the enhanced CLI."""
@pytest.mark.asyncio
async def test_full_transcription_workflow(self):
"""Test complete transcription workflow integration."""
cli = EnhancedCLI()
command = EnhancedTranscribeCommand()
with patch('src.cli.enhanced_cli.ModelManager') as mock_manager_class:
mock_manager = Mock()
mock_manager_class.return_value = mock_manager
with patch('src.cli.enhanced_cli.create_transcription_service') as mock_service_factory, \
patch.object(command, '_get_audio_duration') as mock_duration:
mock_service = AsyncMock()
mock_service_factory.return_value = mock_service
mock_service.initialize = AsyncMock()
mock_duration.return_value = 60.0 # Mock 60 seconds duration
# Mock successful transcription
mock_result = Mock()
mock_result.text_content = "Test transcription"
mock_result.accuracy = 95.0
mock_result.processing_time = 5.0
mock_service.transcribe_file.return_value = mock_result
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
# Create a real file for testing
temp_file.write(b'test audio data')
temp_file.flush()
with tempfile.TemporaryDirectory() as output_dir:
result = await command.execute_transcription(
input_path=temp_file.name,
output_dir=output_dir,
format_type='json',
model='base',
device='cpu'
)
assert result is not None
mock_service.transcribe_file.assert_called_once()
def test_cli_command_registration(self):
"""Test that CLI commands are properly registered."""
from src.cli.enhanced_cli import cli
runner = CliRunner()
# Test that commands are available
result = runner.invoke(cli, ['--help'])
assert result.exit_code == 0
# Test transcribe command
result = runner.invoke(cli, ['transcribe', '--help'])
assert result.exit_code == 0
assert "transcribe" in result.output.lower()
# Test batch command
result = runner.invoke(cli, ['batch', '--help'])
assert result.exit_code == 0
assert "batch" in result.output.lower()
if __name__ == "__main__":
pytest.main([__file__])