348 lines
14 KiB
Python
348 lines
14 KiB
Python
"""Unit tests for the error handling system.
|
|
|
|
This module tests the error classification, creation, and handling functionality.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import Mock, patch
|
|
|
|
from src.errors import (
|
|
# Base error classes
|
|
TraxError, NetworkError, APIError, FileSystemError, ValidationError,
|
|
ProcessingError, ConfigurationError, ResourceError,
|
|
ConnectionError, TimeoutError, DNSResolutionError,
|
|
AuthenticationError, RateLimitError, QuotaExceededError,
|
|
ServiceUnavailableError, InvalidResponseError,
|
|
FileNotFoundError, PermissionError, DiskSpaceError, CorruptedFileError,
|
|
InvalidInputError, MissingRequiredFieldError, FormatError,
|
|
TranscriptionError, EnhancementError, MediaProcessingError, AudioConversionError,
|
|
MissingConfigError, InvalidConfigError, EnvironmentError,
|
|
MemoryError, CPUError,
|
|
|
|
# Error codes
|
|
ErrorCode, ErrorCategory, ErrorSeverity,
|
|
NETWORK_CONNECTION_FAILED, NETWORK_TIMEOUT, DNS_RESOLUTION_FAILED,
|
|
API_AUTHENTICATION_FAILED, API_RATE_LIMIT_EXCEEDED, API_QUOTA_EXCEEDED,
|
|
API_SERVICE_UNAVAILABLE, API_INVALID_RESPONSE,
|
|
FILE_NOT_FOUND, FILE_PERMISSION_DENIED, DISK_SPACE_INSUFFICIENT, FILE_CORRUPTED,
|
|
INVALID_INPUT, MISSING_REQUIRED_FIELD, INVALID_FORMAT,
|
|
TRANSCRIPTION_FAILED, ENHANCEMENT_FAILED, MEDIA_PROCESSING_FAILED, AUDIO_CONVERSION_FAILED,
|
|
MISSING_CONFIGURATION, INVALID_CONFIGURATION, ENVIRONMENT_ERROR,
|
|
MEMORY_INSUFFICIENT, CPU_OVERLOADED,
|
|
|
|
# Error utilities
|
|
create_network_error, create_api_error, create_filesystem_error, create_validation_error,
|
|
create_error_from_code, classify_error, extract_error_context, is_retryable_error,
|
|
get_error_severity, get_error_category, wrap_error, get_actionable_message,
|
|
error_handler, async_error_handler
|
|
)
|
|
|
|
|
|
class TestTraxError:
|
|
"""Test the base TraxError class."""
|
|
|
|
def test_trax_error_creation(self):
|
|
"""Test basic TraxError creation."""
|
|
error = TraxError("Test error message")
|
|
assert error.message == "Test error message"
|
|
assert error.error_code is None
|
|
assert error.context == {}
|
|
assert error.original_error is None
|
|
assert isinstance(error.timestamp, datetime)
|
|
assert error.timestamp.tzinfo == timezone.utc
|
|
|
|
def test_trax_error_with_code(self):
|
|
"""Test TraxError creation with error code."""
|
|
error = TraxError("Test error", NETWORK_CONNECTION_FAILED)
|
|
assert error.error_code == NETWORK_CONNECTION_FAILED
|
|
assert error.is_retryable == NETWORK_CONNECTION_FAILED.retryable
|
|
assert error.severity == NETWORK_CONNECTION_FAILED.severity
|
|
assert error.category == NETWORK_CONNECTION_FAILED.category
|
|
|
|
def test_trax_error_with_context(self):
|
|
"""Test TraxError creation with context."""
|
|
context = {"file": "test.mp3", "size": 1024}
|
|
error = TraxError("Test error", context=context)
|
|
assert error.context == context
|
|
|
|
def test_trax_error_with_original_error(self):
|
|
"""Test TraxError creation with original error."""
|
|
original = ValueError("Original error")
|
|
error = TraxError("Test error", original_error=original)
|
|
assert error.original_error == original
|
|
|
|
def test_trax_error_to_dict(self):
|
|
"""Test TraxError serialization to dictionary."""
|
|
original = ValueError("Original error")
|
|
context = {"file": "test.mp3"}
|
|
error = TraxError("Test error", NETWORK_CONNECTION_FAILED, context, original)
|
|
|
|
error_dict = error.to_dict()
|
|
assert error_dict["error_type"] == "TraxError"
|
|
assert error_dict["message"] == "Test error"
|
|
assert error_dict["error_code"] == str(NETWORK_CONNECTION_FAILED)
|
|
assert error_dict["category"] == NETWORK_CONNECTION_FAILED.category.value
|
|
assert error_dict["severity"] == NETWORK_CONNECTION_FAILED.severity.value
|
|
assert error_dict["retryable"] == NETWORK_CONNECTION_FAILED.retryable
|
|
assert error_dict["context"] == context
|
|
assert "timestamp" in error_dict
|
|
assert "traceback" in error_dict
|
|
assert error_dict["original_error"] == "Original error"
|
|
|
|
def test_trax_error_string_representation(self):
|
|
"""Test TraxError string representation."""
|
|
error = TraxError("Test error", NETWORK_CONNECTION_FAILED)
|
|
assert str(error) == f"{NETWORK_CONNECTION_FAILED.code}: Test error"
|
|
|
|
error_no_code = TraxError("Test error")
|
|
assert str(error_no_code) == "Test error"
|
|
|
|
|
|
class TestNetworkErrors:
|
|
"""Test network-related error classes."""
|
|
|
|
def test_network_error_inheritance(self):
|
|
"""Test that network errors inherit from NetworkError."""
|
|
assert issubclass(ConnectionError, NetworkError)
|
|
assert issubclass(TimeoutError, NetworkError)
|
|
assert issubclass(DNSResolutionError, NetworkError)
|
|
|
|
def test_connection_error(self):
|
|
"""Test ConnectionError creation."""
|
|
error = ConnectionError("Connection failed")
|
|
assert isinstance(error, NetworkError)
|
|
assert error.message == "Connection failed"
|
|
|
|
def test_timeout_error(self):
|
|
"""Test TimeoutError creation."""
|
|
error = TimeoutError("Request timed out")
|
|
assert isinstance(error, NetworkError)
|
|
assert error.message == "Request timed out"
|
|
|
|
def test_dns_resolution_error(self):
|
|
"""Test DNSResolutionError creation."""
|
|
error = DNSResolutionError("DNS resolution failed")
|
|
assert isinstance(error, NetworkError)
|
|
assert error.message == "DNS resolution failed"
|
|
|
|
|
|
class TestAPIErrors:
|
|
"""Test API-related error classes."""
|
|
|
|
def test_api_error_inheritance(self):
|
|
"""Test that API errors inherit from APIError."""
|
|
assert issubclass(AuthenticationError, APIError)
|
|
assert issubclass(RateLimitError, APIError)
|
|
assert issubclass(QuotaExceededError, APIError)
|
|
assert issubclass(ServiceUnavailableError, APIError)
|
|
assert issubclass(InvalidResponseError, APIError)
|
|
|
|
def test_authentication_error(self):
|
|
"""Test AuthenticationError creation."""
|
|
error = AuthenticationError("Invalid API key")
|
|
assert isinstance(error, APIError)
|
|
assert error.message == "Invalid API key"
|
|
|
|
def test_rate_limit_error(self):
|
|
"""Test RateLimitError creation."""
|
|
error = RateLimitError("Rate limit exceeded")
|
|
assert isinstance(error, APIError)
|
|
assert error.message == "Rate limit exceeded"
|
|
|
|
def test_quota_exceeded_error(self):
|
|
"""Test QuotaExceededError creation."""
|
|
error = QuotaExceededError("API quota exceeded")
|
|
assert isinstance(error, APIError)
|
|
assert error.message == "API quota exceeded"
|
|
|
|
|
|
class TestFileSystemErrors:
|
|
"""Test file system-related error classes."""
|
|
|
|
def test_filesystem_error_inheritance(self):
|
|
"""Test that file system errors inherit from FileSystemError."""
|
|
assert issubclass(FileNotFoundError, FileSystemError)
|
|
assert issubclass(PermissionError, FileSystemError)
|
|
assert issubclass(DiskSpaceError, FileSystemError)
|
|
assert issubclass(CorruptedFileError, FileSystemError)
|
|
|
|
def test_file_not_found_error(self):
|
|
"""Test FileNotFoundError creation."""
|
|
error = FileNotFoundError("File not found: test.mp3")
|
|
assert isinstance(error, FileSystemError)
|
|
assert error.message == "File not found: test.mp3"
|
|
|
|
def test_permission_error(self):
|
|
"""Test PermissionError creation."""
|
|
error = PermissionError("Permission denied: test.mp3")
|
|
assert isinstance(error, FileSystemError)
|
|
assert error.message == "Permission denied: test.mp3"
|
|
|
|
|
|
class TestValidationErrors:
|
|
"""Test validation-related error classes."""
|
|
|
|
def test_validation_error_inheritance(self):
|
|
"""Test that validation errors inherit from ValidationError."""
|
|
assert issubclass(InvalidInputError, ValidationError)
|
|
assert issubclass(MissingRequiredFieldError, ValidationError)
|
|
assert issubclass(FormatError, ValidationError)
|
|
|
|
def test_invalid_input_error(self):
|
|
"""Test InvalidInputError creation."""
|
|
error = InvalidInputError("Invalid input format")
|
|
assert isinstance(error, ValidationError)
|
|
assert error.message == "Invalid input format"
|
|
|
|
def test_missing_required_field_error(self):
|
|
"""Test MissingRequiredFieldError creation."""
|
|
error = MissingRequiredFieldError("Missing required field: api_key")
|
|
assert isinstance(error, ValidationError)
|
|
assert error.message == "Missing required field: api_key"
|
|
|
|
|
|
class TestProcessingErrors:
|
|
"""Test processing-related error classes."""
|
|
|
|
def test_processing_error_inheritance(self):
|
|
"""Test that processing errors inherit from ProcessingError."""
|
|
assert issubclass(TranscriptionError, ProcessingError)
|
|
assert issubclass(EnhancementError, ProcessingError)
|
|
assert issubclass(MediaProcessingError, ProcessingError)
|
|
assert issubclass(AudioConversionError, ProcessingError)
|
|
|
|
def test_transcription_error(self):
|
|
"""Test TranscriptionError creation."""
|
|
error = TranscriptionError("Transcription failed")
|
|
assert isinstance(error, ProcessingError)
|
|
assert error.message == "Transcription failed"
|
|
|
|
def test_enhancement_error(self):
|
|
"""Test EnhancementError creation."""
|
|
error = EnhancementError("Enhancement failed")
|
|
assert isinstance(error, ProcessingError)
|
|
assert error.message == "Enhancement failed"
|
|
|
|
|
|
class TestErrorCreationUtilities:
|
|
"""Test error creation utility functions."""
|
|
|
|
def test_create_network_error(self):
|
|
"""Test create_network_error utility."""
|
|
original = ConnectionError("Original connection error")
|
|
error = create_network_error("Network failed", NETWORK_CONNECTION_FAILED,
|
|
{"url": "https://api.example.com"}, original)
|
|
|
|
assert isinstance(error, NetworkError)
|
|
assert error.message == "Network failed"
|
|
assert error.error_code == NETWORK_CONNECTION_FAILED
|
|
assert error.context["url"] == "https://api.example.com"
|
|
assert error.original_error == original
|
|
|
|
def test_create_api_error(self):
|
|
"""Test create_api_error utility."""
|
|
error = create_api_error("API failed", API_AUTHENTICATION_FAILED,
|
|
{"endpoint": "/transcribe"})
|
|
|
|
assert isinstance(error, APIError)
|
|
assert error.message == "API failed"
|
|
assert error.error_code == API_AUTHENTICATION_FAILED
|
|
assert error.context["endpoint"] == "/transcribe"
|
|
|
|
def test_create_filesystem_error(self):
|
|
"""Test create_filesystem_error utility."""
|
|
error = create_filesystem_error("File operation failed", FILE_NOT_FOUND,
|
|
{"path": "/tmp/test.mp3"})
|
|
|
|
assert isinstance(error, FileSystemError)
|
|
assert error.message == "File operation failed"
|
|
assert error.error_code == FILE_NOT_FOUND
|
|
assert error.context["path"] == "/tmp/test.mp3"
|
|
|
|
def test_create_validation_error(self):
|
|
"""Test create_validation_error utility."""
|
|
error = create_validation_error("Validation failed", INVALID_INPUT,
|
|
{"field": "api_key"})
|
|
|
|
assert isinstance(error, ValidationError)
|
|
assert error.message == "Validation failed"
|
|
assert error.error_code == INVALID_INPUT
|
|
assert error.context["field"] == "api_key"
|
|
|
|
|
|
class TestErrorClassification:
|
|
"""Test error classification functionality."""
|
|
|
|
def test_classify_error_network(self):
|
|
"""Test error classification for network errors."""
|
|
error = ConnectionError("Connection failed")
|
|
category = classify_error(error)
|
|
assert category == ErrorCategory.NETWORK
|
|
|
|
def test_classify_error_api(self):
|
|
"""Test error classification for API errors."""
|
|
error = AuthenticationError("Invalid API key")
|
|
category = classify_error(error)
|
|
assert category == ErrorCategory.API
|
|
|
|
def test_classify_error_filesystem(self):
|
|
"""Test error classification for file system errors."""
|
|
error = FileNotFoundError("File not found")
|
|
category = classify_error(error)
|
|
assert category == ErrorCategory.FILESYSTEM
|
|
|
|
def test_classify_error_validation(self):
|
|
"""Test error classification for validation errors."""
|
|
error = InvalidInputError("Invalid input")
|
|
category = classify_error(error)
|
|
assert category == ErrorCategory.VALIDATION
|
|
|
|
def test_is_retryable_error(self):
|
|
"""Test retryable error detection."""
|
|
retryable_error = ConnectionError("Connection failed", NETWORK_CONNECTION_FAILED)
|
|
non_retryable_error = ValidationError("Invalid input", INVALID_INPUT)
|
|
|
|
assert is_retryable_error(retryable_error) == NETWORK_CONNECTION_FAILED.retryable
|
|
assert is_retryable_error(non_retryable_error) == INVALID_INPUT.retryable
|
|
|
|
|
|
class TestErrorHandlingDecorators:
|
|
"""Test error handling decorators."""
|
|
|
|
def test_error_handler_sync(self):
|
|
"""Test synchronous error handler decorator."""
|
|
@error_handler
|
|
def test_function():
|
|
raise ValueError("Test error")
|
|
|
|
with pytest.raises(TraxError) as exc_info:
|
|
test_function()
|
|
|
|
assert "Test error" in str(exc_info.value)
|
|
assert isinstance(exc_info.value.original_error, ValueError)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_error_handler(self):
|
|
"""Test asynchronous error handler decorator."""
|
|
@async_error_handler
|
|
async def test_async_function():
|
|
raise ValueError("Test async error")
|
|
|
|
with pytest.raises(TraxError) as exc_info:
|
|
await test_async_function()
|
|
|
|
assert "Test async error" in str(exc_info.value)
|
|
assert isinstance(exc_info.value.original_error, ValueError)
|
|
|
|
def test_error_handler_with_context(self):
|
|
"""Test error handler with context."""
|
|
@error_handler(context={"operation": "test"})
|
|
def test_function():
|
|
raise ValueError("Test error")
|
|
|
|
with pytest.raises(TraxError) as exc_info:
|
|
test_function()
|
|
|
|
assert exc_info.value.context["operation"] == "test"
|