317 lines
11 KiB
Python
317 lines
11 KiB
Python
"""Unit tests for secure configuration management."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import patch, mock_open
|
|
import pytest
|
|
from cryptography.fernet import Fernet
|
|
|
|
from src.security.secure_config import SecureConfig, validate_path, validate_youtube_url
|
|
|
|
|
|
class TestSecureConfig:
|
|
"""Test cases for SecureConfig class."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
self.config_path = Path(self.temp_dir) / "config.json"
|
|
self.key_path = Path(self.temp_dir) / "key.bin"
|
|
|
|
def teardown_method(self):
|
|
"""Clean up test fixtures."""
|
|
# Clean up temporary files
|
|
if self.config_path.exists():
|
|
self.config_path.unlink()
|
|
if self.key_path.exists():
|
|
self.key_path.unlink()
|
|
|
|
# Clean up any subdirectories that might have been created
|
|
if self.temp_dir and os.path.exists(self.temp_dir):
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_init_creates_config_directory(self):
|
|
"""Test that SecureConfig creates config directory if it doesn't exist."""
|
|
config_dir = Path(self.temp_dir) / "new_config"
|
|
config_path = config_dir / "config.json"
|
|
|
|
# Directory shouldn't exist initially
|
|
assert not config_dir.exists()
|
|
|
|
# Create SecureConfig instance
|
|
secure_config = SecureConfig(config_path)
|
|
|
|
# Directory should be created
|
|
assert config_dir.exists()
|
|
assert config_dir.is_dir()
|
|
|
|
def test_init_generates_new_key_if_not_exists(self):
|
|
"""Test that SecureConfig generates a new encryption key if it doesn't exist."""
|
|
# Key file shouldn't exist initially
|
|
assert not self.key_path.exists()
|
|
|
|
# Create SecureConfig instance
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Key file should be created
|
|
assert self.key_path.exists()
|
|
assert self.key_path.is_file()
|
|
|
|
# Key should be valid Fernet key
|
|
with open(self.key_path, "rb") as f:
|
|
key = f.read()
|
|
assert len(key) == 44 # Fernet key length
|
|
Fernet(key) # Should not raise exception
|
|
|
|
def test_init_loads_existing_key(self):
|
|
"""Test that SecureConfig loads existing encryption key."""
|
|
# Create a key file first
|
|
key = Fernet.generate_key()
|
|
with open(self.key_path, "wb") as f:
|
|
f.write(key)
|
|
|
|
# Create SecureConfig instance
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Should load the existing key
|
|
assert secure_config.fernet is not None
|
|
# Test that the key works by encrypting/decrypting
|
|
test_data = b"test_data"
|
|
encrypted = secure_config.fernet.encrypt(test_data)
|
|
decrypted = secure_config.fernet.decrypt(encrypted)
|
|
assert decrypted == test_data
|
|
|
|
def test_set_api_key_creates_new_config(self):
|
|
"""Test setting API key when no config exists."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Config file shouldn't exist initially
|
|
assert not self.config_path.exists()
|
|
|
|
# Set API key
|
|
result = secure_config.set_api_key("test_service", "test_key")
|
|
|
|
# Should succeed
|
|
assert result is True
|
|
|
|
# Config file should be created
|
|
assert self.config_path.exists()
|
|
|
|
def test_set_api_key_updates_existing_config(self):
|
|
"""Test setting API key when config already exists."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Set initial API key
|
|
secure_config.set_api_key("service1", "key1")
|
|
|
|
# Set another API key
|
|
result = secure_config.set_api_key("service2", "key2")
|
|
|
|
# Should succeed
|
|
assert result is True
|
|
|
|
# Both keys should be retrievable
|
|
assert secure_config.get_api_key("service1") == "key1"
|
|
assert secure_config.get_api_key("service2") == "key2"
|
|
|
|
def test_get_api_key_returns_none_for_missing_config(self):
|
|
"""Test getting API key when config file doesn't exist."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Config file doesn't exist
|
|
assert not self.config_path.exists()
|
|
|
|
# Should return None
|
|
result = secure_config.get_api_key("test_service")
|
|
assert result is None
|
|
|
|
def test_get_api_key_returns_none_for_missing_service(self):
|
|
"""Test getting API key for service that doesn't exist in config."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Set one API key
|
|
secure_config.set_api_key("service1", "key1")
|
|
|
|
# Try to get non-existent service
|
|
result = secure_config.get_api_key("service2")
|
|
assert result is None
|
|
|
|
def test_get_api_key_returns_correct_value(self):
|
|
"""Test getting API key returns the correct value."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Set API key
|
|
secure_config.set_api_key("test_service", "test_key")
|
|
|
|
# Get API key
|
|
result = secure_config.get_api_key("test_service")
|
|
assert result == "test_key"
|
|
|
|
def test_set_api_key_encrypts_data(self):
|
|
"""Test that API keys are stored encrypted."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Set API key
|
|
secure_config.set_api_key("test_service", "test_key")
|
|
|
|
# Read raw config file
|
|
with open(self.config_path, "rb") as f:
|
|
encrypted_data = f.read()
|
|
|
|
# Data should be encrypted (not JSON)
|
|
with pytest.raises(json.JSONDecodeError):
|
|
json.loads(encrypted_data.decode())
|
|
|
|
def test_set_api_key_handles_encryption_errors(self):
|
|
"""Test that set_api_key handles encryption errors gracefully."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Mock Fernet to raise exception
|
|
with patch.object(secure_config.fernet, 'encrypt', side_effect=Exception("Encryption failed")):
|
|
result = secure_config.set_api_key("test_service", "test_key")
|
|
assert result is False
|
|
|
|
def test_get_api_key_handles_decryption_errors(self):
|
|
"""Test that get_api_key handles decryption errors gracefully."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Create corrupted config file
|
|
with open(self.config_path, "wb") as f:
|
|
f.write(b"corrupted_data")
|
|
|
|
# Should return None instead of raising exception
|
|
result = secure_config.get_api_key("test_service")
|
|
assert result is None
|
|
|
|
def test_config_file_permissions(self):
|
|
"""Test that config file has correct permissions."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Set API key to create config file
|
|
secure_config.set_api_key("test_service", "test_key")
|
|
|
|
# Check file permissions (should be owner-only)
|
|
stat = self.config_path.stat()
|
|
assert oct(stat.st_mode)[-3:] == "600"
|
|
|
|
def test_key_file_permissions(self):
|
|
"""Test that key file has correct permissions."""
|
|
secure_config = SecureConfig(self.config_path)
|
|
|
|
# Check key file permissions (should be owner-only)
|
|
stat = self.key_path.stat()
|
|
assert oct(stat.st_mode)[-3:] == "600"
|
|
|
|
|
|
class TestPathValidation:
|
|
"""Test cases for path validation functions."""
|
|
|
|
def test_validate_path_allows_safe_paths(self):
|
|
"""Test that validate_path allows safe file paths."""
|
|
safe_paths = [
|
|
"~/Documents/test.txt",
|
|
"~/Downloads/video.mp4",
|
|
"~/.trax/config.json",
|
|
"~/Desktop/file.txt",
|
|
"~/Music/song.mp3",
|
|
"~/Videos/video.mp4",
|
|
]
|
|
|
|
for path in safe_paths:
|
|
assert validate_path(path) is True
|
|
|
|
def test_validate_path_blocks_directory_traversal(self):
|
|
"""Test that validate_path blocks directory traversal attempts."""
|
|
malicious_paths = [
|
|
"../../../etc/passwd",
|
|
"~/Documents/../../../etc/shadow",
|
|
"/tmp/../../../root/.ssh/id_rsa",
|
|
"~/Downloads/..//..//..//var/log/auth.log",
|
|
]
|
|
|
|
for path in malicious_paths:
|
|
assert validate_path(path) is False
|
|
|
|
def test_validate_path_blocks_system_directories(self):
|
|
"""Test that validate_path blocks access to system directories."""
|
|
system_paths = [
|
|
"/etc/passwd",
|
|
"/var/log/auth.log",
|
|
"/root/.ssh/id_rsa",
|
|
"/tmp/malicious_file",
|
|
]
|
|
|
|
for path in system_paths:
|
|
assert validate_path(path) is False
|
|
|
|
def test_validate_path_handles_relative_paths(self):
|
|
"""Test that validate_path handles relative paths correctly."""
|
|
# Should convert to absolute and validate
|
|
assert validate_path("./test.txt") is True
|
|
assert validate_path("../test.txt") is False
|
|
|
|
def test_validate_path_handles_empty_path(self):
|
|
"""Test that validate_path handles empty path."""
|
|
assert validate_path("") is False
|
|
|
|
def test_validate_path_handles_none_path(self):
|
|
"""Test that validate_path handles None path."""
|
|
assert validate_path(None) is False
|
|
|
|
|
|
class TestURLValidation:
|
|
"""Test cases for URL validation functions."""
|
|
|
|
def test_validate_youtube_url_allows_valid_urls(self):
|
|
"""Test that validate_youtube_url allows valid YouTube URLs."""
|
|
valid_urls = [
|
|
"https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
"https://youtu.be/dQw4w9WgXcQ",
|
|
"http://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
"https://youtube.com/watch?v=dQw4w9WgXcQ",
|
|
"https://www.youtube.com/embed/dQw4w9WgXcQ",
|
|
]
|
|
|
|
for url in valid_urls:
|
|
assert validate_youtube_url(url) is True
|
|
|
|
def test_validate_youtube_url_blocks_invalid_urls(self):
|
|
"""Test that validate_youtube_url blocks invalid URLs."""
|
|
invalid_urls = [
|
|
"https://www.google.com",
|
|
"https://malicious-site.com/fake-youtube",
|
|
"ftp://youtube.com/video",
|
|
"javascript:alert('xss')",
|
|
"data:text/html,<script>alert('xss')</script>",
|
|
"file:///etc/passwd",
|
|
]
|
|
|
|
for url in invalid_urls:
|
|
assert validate_youtube_url(url) is False
|
|
|
|
def test_validate_youtube_url_handles_edge_cases(self):
|
|
"""Test that validate_youtube_url handles edge cases."""
|
|
edge_cases = [
|
|
"", # Empty string
|
|
None, # None value
|
|
"not_a_url", # Plain text
|
|
"youtube.com", # Missing protocol
|
|
]
|
|
|
|
for url in edge_cases:
|
|
assert validate_youtube_url(url) is False
|
|
|
|
def test_validate_youtube_url_handles_complex_urls(self):
|
|
"""Test that validate_youtube_url handles complex YouTube URLs."""
|
|
complex_urls = [
|
|
"https://www.youtube.com/watch?v=dQw4w9WgXcQ&t=30s",
|
|
"https://youtu.be/dQw4w9WgXcQ?t=30",
|
|
"https://www.youtube.com/watch?v=dQw4w9WgXcQ&list=PL123456",
|
|
]
|
|
|
|
for url in complex_urls:
|
|
assert validate_youtube_url(url) is True
|