"""Unit tests for secure configuration management.""" import json import os import tempfile from pathlib import Path from unittest.mock import patch, mock_open import pytest from cryptography.fernet import Fernet from src.security.secure_config import SecureConfig, validate_path, validate_youtube_url class TestSecureConfig: """Test cases for SecureConfig class.""" def setup_method(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.config_path = Path(self.temp_dir) / "config.json" self.key_path = Path(self.temp_dir) / "key.bin" def teardown_method(self): """Clean up test fixtures.""" # Clean up temporary files if self.config_path.exists(): self.config_path.unlink() if self.key_path.exists(): self.key_path.unlink() # Clean up any subdirectories that might have been created if self.temp_dir and os.path.exists(self.temp_dir): import shutil shutil.rmtree(self.temp_dir) def test_init_creates_config_directory(self): """Test that SecureConfig creates config directory if it doesn't exist.""" config_dir = Path(self.temp_dir) / "new_config" config_path = config_dir / "config.json" # Directory shouldn't exist initially assert not config_dir.exists() # Create SecureConfig instance secure_config = SecureConfig(config_path) # Directory should be created assert config_dir.exists() assert config_dir.is_dir() def test_init_generates_new_key_if_not_exists(self): """Test that SecureConfig generates a new encryption key if it doesn't exist.""" # Key file shouldn't exist initially assert not self.key_path.exists() # Create SecureConfig instance secure_config = SecureConfig(self.config_path) # Key file should be created assert self.key_path.exists() assert self.key_path.is_file() # Key should be valid Fernet key with open(self.key_path, "rb") as f: key = f.read() assert len(key) == 44 # Fernet key length Fernet(key) # Should not raise exception def test_init_loads_existing_key(self): """Test that SecureConfig loads existing encryption key.""" # Create a key file first key = Fernet.generate_key() with open(self.key_path, "wb") as f: f.write(key) # Create SecureConfig instance secure_config = SecureConfig(self.config_path) # Should load the existing key assert secure_config.fernet is not None # Test that the key works by encrypting/decrypting test_data = b"test_data" encrypted = secure_config.fernet.encrypt(test_data) decrypted = secure_config.fernet.decrypt(encrypted) assert decrypted == test_data def test_set_api_key_creates_new_config(self): """Test setting API key when no config exists.""" secure_config = SecureConfig(self.config_path) # Config file shouldn't exist initially assert not self.config_path.exists() # Set API key result = secure_config.set_api_key("test_service", "test_key") # Should succeed assert result is True # Config file should be created assert self.config_path.exists() def test_set_api_key_updates_existing_config(self): """Test setting API key when config already exists.""" secure_config = SecureConfig(self.config_path) # Set initial API key secure_config.set_api_key("service1", "key1") # Set another API key result = secure_config.set_api_key("service2", "key2") # Should succeed assert result is True # Both keys should be retrievable assert secure_config.get_api_key("service1") == "key1" assert secure_config.get_api_key("service2") == "key2" def test_get_api_key_returns_none_for_missing_config(self): """Test getting API key when config file doesn't exist.""" secure_config = SecureConfig(self.config_path) # Config file doesn't exist assert not self.config_path.exists() # Should return None result = secure_config.get_api_key("test_service") assert result is None def test_get_api_key_returns_none_for_missing_service(self): """Test getting API key for service that doesn't exist in config.""" secure_config = SecureConfig(self.config_path) # Set one API key secure_config.set_api_key("service1", "key1") # Try to get non-existent service result = secure_config.get_api_key("service2") assert result is None def test_get_api_key_returns_correct_value(self): """Test getting API key returns the correct value.""" secure_config = SecureConfig(self.config_path) # Set API key secure_config.set_api_key("test_service", "test_key") # Get API key result = secure_config.get_api_key("test_service") assert result == "test_key" def test_set_api_key_encrypts_data(self): """Test that API keys are stored encrypted.""" secure_config = SecureConfig(self.config_path) # Set API key secure_config.set_api_key("test_service", "test_key") # Read raw config file with open(self.config_path, "rb") as f: encrypted_data = f.read() # Data should be encrypted (not JSON) with pytest.raises(json.JSONDecodeError): json.loads(encrypted_data.decode()) def test_set_api_key_handles_encryption_errors(self): """Test that set_api_key handles encryption errors gracefully.""" secure_config = SecureConfig(self.config_path) # Mock Fernet to raise exception with patch.object(secure_config.fernet, 'encrypt', side_effect=Exception("Encryption failed")): result = secure_config.set_api_key("test_service", "test_key") assert result is False def test_get_api_key_handles_decryption_errors(self): """Test that get_api_key handles decryption errors gracefully.""" secure_config = SecureConfig(self.config_path) # Create corrupted config file with open(self.config_path, "wb") as f: f.write(b"corrupted_data") # Should return None instead of raising exception result = secure_config.get_api_key("test_service") assert result is None def test_config_file_permissions(self): """Test that config file has correct permissions.""" secure_config = SecureConfig(self.config_path) # Set API key to create config file secure_config.set_api_key("test_service", "test_key") # Check file permissions (should be owner-only) stat = self.config_path.stat() assert oct(stat.st_mode)[-3:] == "600" def test_key_file_permissions(self): """Test that key file has correct permissions.""" secure_config = SecureConfig(self.config_path) # Check key file permissions (should be owner-only) stat = self.key_path.stat() assert oct(stat.st_mode)[-3:] == "600" class TestPathValidation: """Test cases for path validation functions.""" def test_validate_path_allows_safe_paths(self): """Test that validate_path allows safe file paths.""" safe_paths = [ "~/Documents/test.txt", "~/Downloads/video.mp4", "~/.trax/config.json", "~/Desktop/file.txt", "~/Music/song.mp3", "~/Videos/video.mp4", ] for path in safe_paths: assert validate_path(path) is True def test_validate_path_blocks_directory_traversal(self): """Test that validate_path blocks directory traversal attempts.""" malicious_paths = [ "../../../etc/passwd", "~/Documents/../../../etc/shadow", "/tmp/../../../root/.ssh/id_rsa", "~/Downloads/..//..//..//var/log/auth.log", ] for path in malicious_paths: assert validate_path(path) is False def test_validate_path_blocks_system_directories(self): """Test that validate_path blocks access to system directories.""" system_paths = [ "/etc/passwd", "/var/log/auth.log", "/root/.ssh/id_rsa", "/tmp/malicious_file", ] for path in system_paths: assert validate_path(path) is False def test_validate_path_handles_relative_paths(self): """Test that validate_path handles relative paths correctly.""" # Should convert to absolute and validate assert validate_path("./test.txt") is True assert validate_path("../test.txt") is False def test_validate_path_handles_empty_path(self): """Test that validate_path handles empty path.""" assert validate_path("") is False def test_validate_path_handles_none_path(self): """Test that validate_path handles None path.""" assert validate_path(None) is False class TestURLValidation: """Test cases for URL validation functions.""" def test_validate_youtube_url_allows_valid_urls(self): """Test that validate_youtube_url allows valid YouTube URLs.""" valid_urls = [ "https://www.youtube.com/watch?v=dQw4w9WgXcQ", "https://youtu.be/dQw4w9WgXcQ", "http://www.youtube.com/watch?v=dQw4w9WgXcQ", "https://youtube.com/watch?v=dQw4w9WgXcQ", "https://www.youtube.com/embed/dQw4w9WgXcQ", ] for url in valid_urls: assert validate_youtube_url(url) is True def test_validate_youtube_url_blocks_invalid_urls(self): """Test that validate_youtube_url blocks invalid URLs.""" invalid_urls = [ "https://www.google.com", "https://malicious-site.com/fake-youtube", "ftp://youtube.com/video", "javascript:alert('xss')", "data:text/html,", "file:///etc/passwd", ] for url in invalid_urls: assert validate_youtube_url(url) is False def test_validate_youtube_url_handles_edge_cases(self): """Test that validate_youtube_url handles edge cases.""" edge_cases = [ "", # Empty string None, # None value "not_a_url", # Plain text "youtube.com", # Missing protocol ] for url in edge_cases: assert validate_youtube_url(url) is False def test_validate_youtube_url_handles_complex_urls(self): """Test that validate_youtube_url handles complex YouTube URLs.""" complex_urls = [ "https://www.youtube.com/watch?v=dQw4w9WgXcQ&t=30s", "https://youtu.be/dQw4w9WgXcQ?t=30", "https://www.youtube.com/watch?v=dQw4w9WgXcQ&list=PL123456", ] for url in complex_urls: assert validate_youtube_url(url) is True