426 lines
17 KiB
Python
426 lines
17 KiB
Python
"""Unit tests for the enhanced CLI interface."""
|
|
|
|
import pytest
|
|
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from click.testing import CliRunner
|
|
|
|
from src.cli.enhanced_cli import EnhancedCLI, EnhancedTranscribeCommand, EnhancedBatchCommand
|
|
from src.services.model_manager import ModelManager
|
|
from src.services.transcription_service import TranscriptionConfig
|
|
|
|
|
|
class TestEnhancedCLI:
|
|
"""Test the enhanced CLI interface structure."""
|
|
|
|
@pytest.fixture
|
|
def cli(self):
|
|
"""Create an enhanced CLI instance for testing."""
|
|
return EnhancedCLI()
|
|
|
|
@pytest.fixture
|
|
def runner(self):
|
|
"""Create a Click test runner."""
|
|
return CliRunner()
|
|
|
|
def test_cli_initialization(self, cli):
|
|
"""Test that CLI initializes correctly with model manager."""
|
|
assert cli.model_manager is not None
|
|
assert cli.console is not None
|
|
|
|
def test_help_display(self, runner):
|
|
"""Test that help documentation is displayed correctly."""
|
|
from src.cli.enhanced_cli import cli
|
|
|
|
result = runner.invoke(cli, ['--help'])
|
|
assert result.exit_code == 0
|
|
assert "Enhanced Audio Transcription Tool" in result.output
|
|
# Note: The arguments are now under the subcommands, not the main CLI
|
|
assert "transcribe" in result.output
|
|
assert "batch" in result.output
|
|
|
|
def test_single_file_transcription_arguments(self, runner):
|
|
"""Test single file transcription argument parsing."""
|
|
from src.cli.enhanced_cli import cli
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
|
|
result = runner.invoke(cli, [
|
|
'transcribe',
|
|
'--help'
|
|
])
|
|
|
|
# Should not fail on argument parsing
|
|
assert result.exit_code == 0
|
|
assert "INPUT" in result.output # Positional argument
|
|
assert "--output" in result.output
|
|
assert "--format" in result.output
|
|
assert "--model" in result.output
|
|
assert "--device" in result.output
|
|
assert "--domain" in result.output
|
|
assert "--diarize" in result.output
|
|
assert "--speakers" in result.output
|
|
|
|
def test_batch_processing_arguments(self, runner):
|
|
"""Test batch processing argument parsing."""
|
|
from src.cli.enhanced_cli import cli
|
|
|
|
result = runner.invoke(cli, [
|
|
'batch',
|
|
'--help'
|
|
])
|
|
|
|
# Should not fail on argument parsing
|
|
assert result.exit_code == 0
|
|
assert "INPUT" in result.output # Positional argument
|
|
assert "--output" in result.output
|
|
assert "--concurrency" in result.output
|
|
assert "--format" in result.output
|
|
assert "--model" in result.output
|
|
assert "--device" in result.output
|
|
assert "--domain" in result.output
|
|
assert "--diarize" in result.output
|
|
assert "--speakers" in result.output
|
|
|
|
def test_invalid_arguments(self, runner):
|
|
"""Test that invalid arguments are properly rejected."""
|
|
from src.cli.enhanced_cli import cli
|
|
|
|
# Test invalid format
|
|
result = runner.invoke(cli, [
|
|
'transcribe',
|
|
'--input', 'test.wav',
|
|
'--format', 'invalid_format'
|
|
])
|
|
assert result.exit_code == 2 # Usage error
|
|
|
|
# Test invalid device
|
|
result = runner.invoke(cli, [
|
|
'transcribe',
|
|
'--input', 'test.wav',
|
|
'--device', 'invalid_device'
|
|
])
|
|
assert result.exit_code == 2 # Usage error
|
|
|
|
# Test invalid domain
|
|
result = runner.invoke(cli, [
|
|
'transcribe',
|
|
'--input', 'test.wav',
|
|
'--domain', 'invalid_domain'
|
|
])
|
|
assert result.exit_code == 2 # Usage error
|
|
|
|
def test_model_manager_integration(self, cli):
|
|
"""Test that CLI properly integrates with ModelManager."""
|
|
mock_manager = Mock()
|
|
cli.model_manager = mock_manager
|
|
|
|
# Test that CLI can access model manager methods
|
|
cli.model_manager.get_available_models.return_value = ['tiny', 'base', 'small', 'medium', 'large']
|
|
|
|
models = cli.model_manager.get_available_models()
|
|
assert models == ['tiny', 'base', 'small', 'medium', 'large']
|
|
mock_manager.get_available_models.assert_called_once()
|
|
|
|
|
|
class TestEnhancedTranscribeCommand:
|
|
"""Test the enhanced transcribe command."""
|
|
|
|
@pytest.fixture
|
|
def command(self):
|
|
"""Create an enhanced transcribe command instance."""
|
|
return EnhancedTranscribeCommand()
|
|
|
|
@pytest.fixture
|
|
def runner(self):
|
|
"""Create a Click test runner."""
|
|
return CliRunner()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_file_transcription(self, command, runner):
|
|
"""Test single file transcription execution."""
|
|
with patch('src.cli.enhanced_cli.create_transcription_service') as mock_service_factory, \
|
|
patch.object(command, '_get_audio_duration') as mock_duration:
|
|
|
|
mock_service = AsyncMock()
|
|
mock_service_factory.return_value = mock_service
|
|
mock_service.initialize = AsyncMock()
|
|
mock_duration.return_value = 60.0 # Mock 60 seconds duration
|
|
|
|
# Mock transcription result
|
|
mock_result = Mock()
|
|
mock_result.text_content = "Test transcription result"
|
|
mock_result.accuracy = 95.5
|
|
mock_result.processing_time = 10.5
|
|
mock_result.quality_warnings = []
|
|
mock_service.transcribe_file.return_value = mock_result
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
|
|
# Create a real file for testing
|
|
temp_file.write(b'test audio data')
|
|
temp_file.flush()
|
|
|
|
result = await command.execute_transcription(
|
|
input_path=temp_file.name,
|
|
output_dir='/tmp/output',
|
|
format_type='json',
|
|
model='base',
|
|
device='cpu',
|
|
domain=None,
|
|
diarize=False,
|
|
speakers=None
|
|
)
|
|
|
|
assert result is not None
|
|
mock_service.initialize.assert_called_once()
|
|
mock_service.transcribe_file.assert_called_once()
|
|
|
|
def test_progress_callback_integration(self, command):
|
|
"""Test that progress callback integrates with Rich progress bars."""
|
|
with patch('src.cli.enhanced_cli.Progress') as mock_progress:
|
|
mock_task = Mock()
|
|
mock_progress.return_value.__enter__.return_value.add_task.return_value = mock_task
|
|
|
|
# Test progress callback creation and execution
|
|
callback = command._create_progress_callback(mock_task, 100.0)
|
|
|
|
# Verify callback is callable and executes without error
|
|
assert callable(callback)
|
|
callback(50.0, 100.0) # Should execute without error
|
|
|
|
def test_export_formats(self, command):
|
|
"""Test export functionality for different formats."""
|
|
mock_result = Mock()
|
|
mock_result.text_content = "Test transcription"
|
|
mock_result.segments = [
|
|
{"start": 0.0, "end": 2.0, "text": "Hello world"},
|
|
{"start": 2.0, "end": 4.0, "text": "How are you"}
|
|
]
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Test JSON export
|
|
json_path = command._export_result(mock_result, "test.wav", temp_dir, "json")
|
|
assert Path(json_path).exists()
|
|
assert json_path.endswith('.json')
|
|
|
|
# Test TXT export
|
|
txt_path = command._export_result(mock_result, "test.wav", temp_dir, "txt")
|
|
assert Path(txt_path).exists()
|
|
assert txt_path.endswith('.txt')
|
|
|
|
# Test SRT export
|
|
srt_path = command._export_result(mock_result, "test.wav", temp_dir, "srt")
|
|
assert Path(srt_path).exists()
|
|
assert srt_path.endswith('.srt')
|
|
|
|
# Test VTT export
|
|
vtt_path = command._export_result(mock_result, "test.wav", temp_dir, "vtt")
|
|
assert Path(vtt_path).exists()
|
|
assert vtt_path.endswith('.vtt')
|
|
|
|
|
|
class TestEnhancedBatchCommand:
|
|
"""Test the enhanced batch command."""
|
|
|
|
@pytest.fixture
|
|
def command(self):
|
|
"""Create an enhanced batch command instance."""
|
|
return EnhancedBatchCommand()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_processing_setup(self, command):
|
|
"""Test batch processing setup and file discovery."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create test files
|
|
test_files = [
|
|
Path(temp_dir) / "small.wav",
|
|
Path(temp_dir) / "medium.mp3",
|
|
Path(temp_dir) / "large.m4a"
|
|
]
|
|
|
|
for file_path in test_files:
|
|
file_path.touch()
|
|
# Set different file sizes for intelligent queuing
|
|
file_path.write_bytes(b'x' * (test_files.index(file_path) + 1) * 1024)
|
|
|
|
files = command._discover_files(temp_dir)
|
|
assert len(files) == 3
|
|
|
|
# Test intelligent queuing (smaller files first)
|
|
sorted_files = command._sort_files_by_size(files)
|
|
assert sorted_files[0].stat().st_size <= sorted_files[1].stat().st_size
|
|
assert sorted_files[1].stat().st_size <= sorted_files[2].stat().st_size
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_processing(self, command):
|
|
"""Test concurrent processing with ThreadPoolExecutor."""
|
|
with patch('concurrent.futures.ThreadPoolExecutor') as mock_executor, \
|
|
patch('concurrent.futures.as_completed') as mock_as_completed:
|
|
|
|
# Create mock futures that can be iterated
|
|
mock_future1 = Mock()
|
|
mock_future2 = Mock()
|
|
mock_future3 = Mock()
|
|
|
|
mock_executor.return_value.__enter__.return_value.submit.side_effect = [mock_future1, mock_future2, mock_future3]
|
|
mock_as_completed.return_value = [mock_future1, mock_future2, mock_future3]
|
|
|
|
test_files = ["file1.wav", "file2.wav", "file3.wav"]
|
|
|
|
await command._process_concurrently(
|
|
files=test_files,
|
|
concurrency=2,
|
|
transcription_func=Mock(),
|
|
progress_callback=Mock()
|
|
)
|
|
|
|
# Verify ThreadPoolExecutor was used with correct max_workers
|
|
mock_executor.assert_called_with(max_workers=2)
|
|
|
|
def test_performance_monitoring(self, command):
|
|
"""Test performance monitoring functionality."""
|
|
with patch('src.cli.enhanced_cli.psutil') as mock_psutil:
|
|
mock_psutil.cpu_percent.return_value = 45.2
|
|
mock_psutil.virtual_memory.return_value = Mock(
|
|
used=2 * 1024**3, # 2GB used
|
|
total=8 * 1024**3, # 8GB total
|
|
percent=25.0
|
|
)
|
|
mock_psutil.sensors_temperatures.return_value = {
|
|
'coretemp': [Mock(current=65.0)]
|
|
}
|
|
|
|
stats = command._get_performance_stats()
|
|
|
|
assert 'cpu_percent' in stats
|
|
assert 'memory_used_gb' in stats
|
|
assert 'memory_total_gb' in stats
|
|
assert 'memory_percent' in stats
|
|
assert 'cpu_temperature' in stats
|
|
|
|
assert stats['cpu_percent'] == 45.2
|
|
assert stats['memory_used_gb'] == 2.0
|
|
assert stats['memory_total_gb'] == 8.0
|
|
assert stats['memory_percent'] == 25.0
|
|
assert stats['cpu_temperature'] == 65.0
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling and user guidance."""
|
|
|
|
@pytest.fixture
|
|
def cli(self):
|
|
"""Create an enhanced CLI instance for testing."""
|
|
return EnhancedCLI()
|
|
|
|
def test_file_not_found_error_handling(self, cli):
|
|
"""Test handling of FileNotFoundError."""
|
|
error = FileNotFoundError("No such file or directory: 'nonexistent.wav'")
|
|
guidance = cli._get_error_guidance(type(error).__name__, str(error))
|
|
|
|
assert "Check that the input file path is correct" in guidance
|
|
assert "file exists" in guidance
|
|
|
|
def test_permission_error_handling(self, cli):
|
|
"""Test handling of PermissionError."""
|
|
error = PermissionError("Permission denied: 'protected.wav'")
|
|
guidance = cli._get_error_guidance(type(error).__name__, str(error))
|
|
|
|
assert "Check file permissions" in guidance
|
|
assert "administrator privileges" in guidance
|
|
|
|
def test_cuda_error_handling(self, cli):
|
|
"""Test handling of CUDA/GPU errors."""
|
|
error = RuntimeError("CUDA out of memory")
|
|
guidance = cli._get_error_guidance(type(error).__name__, str(error))
|
|
|
|
assert "GPU-related error" in guidance
|
|
assert "--device cpu" in guidance
|
|
|
|
def test_memory_error_handling(self, cli):
|
|
"""Test handling of memory errors."""
|
|
error = MemoryError("Not enough memory")
|
|
guidance = cli._get_error_guidance(type(error).__name__, str(error))
|
|
|
|
assert "Memory error" in guidance
|
|
assert "--model small" in guidance
|
|
assert "reduce concurrency" in guidance
|
|
|
|
def test_generic_error_handling(self, cli):
|
|
"""Test handling of generic errors."""
|
|
error = ValueError("Invalid parameter")
|
|
guidance = cli._get_error_guidance(type(error).__name__, str(error))
|
|
|
|
assert "Check input parameters" in guidance
|
|
assert "try again" in guidance
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for the enhanced CLI."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_transcription_workflow(self):
|
|
"""Test complete transcription workflow integration."""
|
|
cli = EnhancedCLI()
|
|
command = EnhancedTranscribeCommand()
|
|
|
|
with patch('src.cli.enhanced_cli.ModelManager') as mock_manager_class:
|
|
mock_manager = Mock()
|
|
mock_manager_class.return_value = mock_manager
|
|
|
|
with patch('src.cli.enhanced_cli.create_transcription_service') as mock_service_factory, \
|
|
patch.object(command, '_get_audio_duration') as mock_duration:
|
|
|
|
mock_service = AsyncMock()
|
|
mock_service_factory.return_value = mock_service
|
|
mock_service.initialize = AsyncMock()
|
|
mock_duration.return_value = 60.0 # Mock 60 seconds duration
|
|
|
|
# Mock successful transcription
|
|
mock_result = Mock()
|
|
mock_result.text_content = "Test transcription"
|
|
mock_result.accuracy = 95.0
|
|
mock_result.processing_time = 5.0
|
|
mock_service.transcribe_file.return_value = mock_result
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav') as temp_file:
|
|
# Create a real file for testing
|
|
temp_file.write(b'test audio data')
|
|
temp_file.flush()
|
|
|
|
with tempfile.TemporaryDirectory() as output_dir:
|
|
result = await command.execute_transcription(
|
|
input_path=temp_file.name,
|
|
output_dir=output_dir,
|
|
format_type='json',
|
|
model='base',
|
|
device='cpu'
|
|
)
|
|
|
|
assert result is not None
|
|
mock_service.transcribe_file.assert_called_once()
|
|
|
|
def test_cli_command_registration(self):
|
|
"""Test that CLI commands are properly registered."""
|
|
from src.cli.enhanced_cli import cli
|
|
runner = CliRunner()
|
|
|
|
# Test that commands are available
|
|
result = runner.invoke(cli, ['--help'])
|
|
assert result.exit_code == 0
|
|
|
|
# Test transcribe command
|
|
result = runner.invoke(cli, ['transcribe', '--help'])
|
|
assert result.exit_code == 0
|
|
assert "transcribe" in result.output.lower()
|
|
|
|
# Test batch command
|
|
result = runner.invoke(cli, ['batch', '--help'])
|
|
assert result.exit_code == 0
|
|
assert "batch" in result.output.lower()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|