"""Unit tests for the error handling system. This module tests the error classification, creation, and handling functionality. """ import pytest from datetime import datetime, timezone from unittest.mock import Mock, patch from src.errors import ( # Base error classes TraxError, NetworkError, APIError, FileSystemError, ValidationError, ProcessingError, ConfigurationError, ResourceError, ConnectionError, TimeoutError, DNSResolutionError, AuthenticationError, RateLimitError, QuotaExceededError, ServiceUnavailableError, InvalidResponseError, FileNotFoundError, PermissionError, DiskSpaceError, CorruptedFileError, InvalidInputError, MissingRequiredFieldError, FormatError, TranscriptionError, EnhancementError, MediaProcessingError, AudioConversionError, MissingConfigError, InvalidConfigError, EnvironmentError, MemoryError, CPUError, # Error codes ErrorCode, ErrorCategory, ErrorSeverity, NETWORK_CONNECTION_FAILED, NETWORK_TIMEOUT, DNS_RESOLUTION_FAILED, API_AUTHENTICATION_FAILED, API_RATE_LIMIT_EXCEEDED, API_QUOTA_EXCEEDED, API_SERVICE_UNAVAILABLE, API_INVALID_RESPONSE, FILE_NOT_FOUND, FILE_PERMISSION_DENIED, DISK_SPACE_INSUFFICIENT, FILE_CORRUPTED, INVALID_INPUT, MISSING_REQUIRED_FIELD, INVALID_FORMAT, TRANSCRIPTION_FAILED, ENHANCEMENT_FAILED, MEDIA_PROCESSING_FAILED, AUDIO_CONVERSION_FAILED, MISSING_CONFIGURATION, INVALID_CONFIGURATION, ENVIRONMENT_ERROR, MEMORY_INSUFFICIENT, CPU_OVERLOADED, # Error utilities create_network_error, create_api_error, create_filesystem_error, create_validation_error, create_error_from_code, classify_error, extract_error_context, is_retryable_error, get_error_severity, get_error_category, wrap_error, get_actionable_message, error_handler, async_error_handler ) class TestTraxError: """Test the base TraxError class.""" def test_trax_error_creation(self): """Test basic TraxError creation.""" error = TraxError("Test error message") assert error.message == "Test error message" assert error.error_code is None assert error.context == {} assert error.original_error is None assert isinstance(error.timestamp, datetime) assert error.timestamp.tzinfo == timezone.utc def test_trax_error_with_code(self): """Test TraxError creation with error code.""" error = TraxError("Test error", NETWORK_CONNECTION_FAILED) assert error.error_code == NETWORK_CONNECTION_FAILED assert error.is_retryable == NETWORK_CONNECTION_FAILED.retryable assert error.severity == NETWORK_CONNECTION_FAILED.severity assert error.category == NETWORK_CONNECTION_FAILED.category def test_trax_error_with_context(self): """Test TraxError creation with context.""" context = {"file": "test.mp3", "size": 1024} error = TraxError("Test error", context=context) assert error.context == context def test_trax_error_with_original_error(self): """Test TraxError creation with original error.""" original = ValueError("Original error") error = TraxError("Test error", original_error=original) assert error.original_error == original def test_trax_error_to_dict(self): """Test TraxError serialization to dictionary.""" original = ValueError("Original error") context = {"file": "test.mp3"} error = TraxError("Test error", NETWORK_CONNECTION_FAILED, context, original) error_dict = error.to_dict() assert error_dict["error_type"] == "TraxError" assert error_dict["message"] == "Test error" assert error_dict["error_code"] == str(NETWORK_CONNECTION_FAILED) assert error_dict["category"] == NETWORK_CONNECTION_FAILED.category.value assert error_dict["severity"] == NETWORK_CONNECTION_FAILED.severity.value assert error_dict["retryable"] == NETWORK_CONNECTION_FAILED.retryable assert error_dict["context"] == context assert "timestamp" in error_dict assert "traceback" in error_dict assert error_dict["original_error"] == "Original error" def test_trax_error_string_representation(self): """Test TraxError string representation.""" error = TraxError("Test error", NETWORK_CONNECTION_FAILED) assert str(error) == f"{NETWORK_CONNECTION_FAILED.code}: Test error" error_no_code = TraxError("Test error") assert str(error_no_code) == "Test error" class TestNetworkErrors: """Test network-related error classes.""" def test_network_error_inheritance(self): """Test that network errors inherit from NetworkError.""" assert issubclass(ConnectionError, NetworkError) assert issubclass(TimeoutError, NetworkError) assert issubclass(DNSResolutionError, NetworkError) def test_connection_error(self): """Test ConnectionError creation.""" error = ConnectionError("Connection failed") assert isinstance(error, NetworkError) assert error.message == "Connection failed" def test_timeout_error(self): """Test TimeoutError creation.""" error = TimeoutError("Request timed out") assert isinstance(error, NetworkError) assert error.message == "Request timed out" def test_dns_resolution_error(self): """Test DNSResolutionError creation.""" error = DNSResolutionError("DNS resolution failed") assert isinstance(error, NetworkError) assert error.message == "DNS resolution failed" class TestAPIErrors: """Test API-related error classes.""" def test_api_error_inheritance(self): """Test that API errors inherit from APIError.""" assert issubclass(AuthenticationError, APIError) assert issubclass(RateLimitError, APIError) assert issubclass(QuotaExceededError, APIError) assert issubclass(ServiceUnavailableError, APIError) assert issubclass(InvalidResponseError, APIError) def test_authentication_error(self): """Test AuthenticationError creation.""" error = AuthenticationError("Invalid API key") assert isinstance(error, APIError) assert error.message == "Invalid API key" def test_rate_limit_error(self): """Test RateLimitError creation.""" error = RateLimitError("Rate limit exceeded") assert isinstance(error, APIError) assert error.message == "Rate limit exceeded" def test_quota_exceeded_error(self): """Test QuotaExceededError creation.""" error = QuotaExceededError("API quota exceeded") assert isinstance(error, APIError) assert error.message == "API quota exceeded" class TestFileSystemErrors: """Test file system-related error classes.""" def test_filesystem_error_inheritance(self): """Test that file system errors inherit from FileSystemError.""" assert issubclass(FileNotFoundError, FileSystemError) assert issubclass(PermissionError, FileSystemError) assert issubclass(DiskSpaceError, FileSystemError) assert issubclass(CorruptedFileError, FileSystemError) def test_file_not_found_error(self): """Test FileNotFoundError creation.""" error = FileNotFoundError("File not found: test.mp3") assert isinstance(error, FileSystemError) assert error.message == "File not found: test.mp3" def test_permission_error(self): """Test PermissionError creation.""" error = PermissionError("Permission denied: test.mp3") assert isinstance(error, FileSystemError) assert error.message == "Permission denied: test.mp3" class TestValidationErrors: """Test validation-related error classes.""" def test_validation_error_inheritance(self): """Test that validation errors inherit from ValidationError.""" assert issubclass(InvalidInputError, ValidationError) assert issubclass(MissingRequiredFieldError, ValidationError) assert issubclass(FormatError, ValidationError) def test_invalid_input_error(self): """Test InvalidInputError creation.""" error = InvalidInputError("Invalid input format") assert isinstance(error, ValidationError) assert error.message == "Invalid input format" def test_missing_required_field_error(self): """Test MissingRequiredFieldError creation.""" error = MissingRequiredFieldError("Missing required field: api_key") assert isinstance(error, ValidationError) assert error.message == "Missing required field: api_key" class TestProcessingErrors: """Test processing-related error classes.""" def test_processing_error_inheritance(self): """Test that processing errors inherit from ProcessingError.""" assert issubclass(TranscriptionError, ProcessingError) assert issubclass(EnhancementError, ProcessingError) assert issubclass(MediaProcessingError, ProcessingError) assert issubclass(AudioConversionError, ProcessingError) def test_transcription_error(self): """Test TranscriptionError creation.""" error = TranscriptionError("Transcription failed") assert isinstance(error, ProcessingError) assert error.message == "Transcription failed" def test_enhancement_error(self): """Test EnhancementError creation.""" error = EnhancementError("Enhancement failed") assert isinstance(error, ProcessingError) assert error.message == "Enhancement failed" class TestErrorCreationUtilities: """Test error creation utility functions.""" def test_create_network_error(self): """Test create_network_error utility.""" original = ConnectionError("Original connection error") error = create_network_error("Network failed", NETWORK_CONNECTION_FAILED, {"url": "https://api.example.com"}, original) assert isinstance(error, NetworkError) assert error.message == "Network failed" assert error.error_code == NETWORK_CONNECTION_FAILED assert error.context["url"] == "https://api.example.com" assert error.original_error == original def test_create_api_error(self): """Test create_api_error utility.""" error = create_api_error("API failed", API_AUTHENTICATION_FAILED, {"endpoint": "/transcribe"}) assert isinstance(error, APIError) assert error.message == "API failed" assert error.error_code == API_AUTHENTICATION_FAILED assert error.context["endpoint"] == "/transcribe" def test_create_filesystem_error(self): """Test create_filesystem_error utility.""" error = create_filesystem_error("File operation failed", FILE_NOT_FOUND, {"path": "/tmp/test.mp3"}) assert isinstance(error, FileSystemError) assert error.message == "File operation failed" assert error.error_code == FILE_NOT_FOUND assert error.context["path"] == "/tmp/test.mp3" def test_create_validation_error(self): """Test create_validation_error utility.""" error = create_validation_error("Validation failed", INVALID_INPUT, {"field": "api_key"}) assert isinstance(error, ValidationError) assert error.message == "Validation failed" assert error.error_code == INVALID_INPUT assert error.context["field"] == "api_key" class TestErrorClassification: """Test error classification functionality.""" def test_classify_error_network(self): """Test error classification for network errors.""" error = ConnectionError("Connection failed") category = classify_error(error) assert category == ErrorCategory.NETWORK def test_classify_error_api(self): """Test error classification for API errors.""" error = AuthenticationError("Invalid API key") category = classify_error(error) assert category == ErrorCategory.API def test_classify_error_filesystem(self): """Test error classification for file system errors.""" error = FileNotFoundError("File not found") category = classify_error(error) assert category == ErrorCategory.FILESYSTEM def test_classify_error_validation(self): """Test error classification for validation errors.""" error = InvalidInputError("Invalid input") category = classify_error(error) assert category == ErrorCategory.VALIDATION def test_is_retryable_error(self): """Test retryable error detection.""" retryable_error = ConnectionError("Connection failed", NETWORK_CONNECTION_FAILED) non_retryable_error = ValidationError("Invalid input", INVALID_INPUT) assert is_retryable_error(retryable_error) == NETWORK_CONNECTION_FAILED.retryable assert is_retryable_error(non_retryable_error) == INVALID_INPUT.retryable class TestErrorHandlingDecorators: """Test error handling decorators.""" def test_error_handler_sync(self): """Test synchronous error handler decorator.""" @error_handler def test_function(): raise ValueError("Test error") with pytest.raises(TraxError) as exc_info: test_function() assert "Test error" in str(exc_info.value) assert isinstance(exc_info.value.original_error, ValueError) @pytest.mark.asyncio async def test_async_error_handler(self): """Test asynchronous error handler decorator.""" @async_error_handler async def test_async_function(): raise ValueError("Test async error") with pytest.raises(TraxError) as exc_info: await test_async_function() assert "Test async error" in str(exc_info.value) assert isinstance(exc_info.value.original_error, ValueError) def test_error_handler_with_context(self): """Test error handler with context.""" @error_handler(context={"operation": "test"}) def test_function(): raise ValueError("Test error") with pytest.raises(TraxError) as exc_info: test_function() assert exc_info.value.context["operation"] == "test"