trax/tests/test_secure_config.py

317 lines
11 KiB
Python

"""Unit tests for secure configuration management."""
import json
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, mock_open
import pytest
from cryptography.fernet import Fernet
from src.security.secure_config import SecureConfig, validate_path, validate_youtube_url
class TestSecureConfig:
"""Test cases for SecureConfig class."""
def setup_method(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.config_path = Path(self.temp_dir) / "config.json"
self.key_path = Path(self.temp_dir) / "key.bin"
def teardown_method(self):
"""Clean up test fixtures."""
# Clean up temporary files
if self.config_path.exists():
self.config_path.unlink()
if self.key_path.exists():
self.key_path.unlink()
# Clean up any subdirectories that might have been created
if self.temp_dir and os.path.exists(self.temp_dir):
import shutil
shutil.rmtree(self.temp_dir)
def test_init_creates_config_directory(self):
"""Test that SecureConfig creates config directory if it doesn't exist."""
config_dir = Path(self.temp_dir) / "new_config"
config_path = config_dir / "config.json"
# Directory shouldn't exist initially
assert not config_dir.exists()
# Create SecureConfig instance
secure_config = SecureConfig(config_path)
# Directory should be created
assert config_dir.exists()
assert config_dir.is_dir()
def test_init_generates_new_key_if_not_exists(self):
"""Test that SecureConfig generates a new encryption key if it doesn't exist."""
# Key file shouldn't exist initially
assert not self.key_path.exists()
# Create SecureConfig instance
secure_config = SecureConfig(self.config_path)
# Key file should be created
assert self.key_path.exists()
assert self.key_path.is_file()
# Key should be valid Fernet key
with open(self.key_path, "rb") as f:
key = f.read()
assert len(key) == 44 # Fernet key length
Fernet(key) # Should not raise exception
def test_init_loads_existing_key(self):
"""Test that SecureConfig loads existing encryption key."""
# Create a key file first
key = Fernet.generate_key()
with open(self.key_path, "wb") as f:
f.write(key)
# Create SecureConfig instance
secure_config = SecureConfig(self.config_path)
# Should load the existing key
assert secure_config.fernet is not None
# Test that the key works by encrypting/decrypting
test_data = b"test_data"
encrypted = secure_config.fernet.encrypt(test_data)
decrypted = secure_config.fernet.decrypt(encrypted)
assert decrypted == test_data
def test_set_api_key_creates_new_config(self):
"""Test setting API key when no config exists."""
secure_config = SecureConfig(self.config_path)
# Config file shouldn't exist initially
assert not self.config_path.exists()
# Set API key
result = secure_config.set_api_key("test_service", "test_key")
# Should succeed
assert result is True
# Config file should be created
assert self.config_path.exists()
def test_set_api_key_updates_existing_config(self):
"""Test setting API key when config already exists."""
secure_config = SecureConfig(self.config_path)
# Set initial API key
secure_config.set_api_key("service1", "key1")
# Set another API key
result = secure_config.set_api_key("service2", "key2")
# Should succeed
assert result is True
# Both keys should be retrievable
assert secure_config.get_api_key("service1") == "key1"
assert secure_config.get_api_key("service2") == "key2"
def test_get_api_key_returns_none_for_missing_config(self):
"""Test getting API key when config file doesn't exist."""
secure_config = SecureConfig(self.config_path)
# Config file doesn't exist
assert not self.config_path.exists()
# Should return None
result = secure_config.get_api_key("test_service")
assert result is None
def test_get_api_key_returns_none_for_missing_service(self):
"""Test getting API key for service that doesn't exist in config."""
secure_config = SecureConfig(self.config_path)
# Set one API key
secure_config.set_api_key("service1", "key1")
# Try to get non-existent service
result = secure_config.get_api_key("service2")
assert result is None
def test_get_api_key_returns_correct_value(self):
"""Test getting API key returns the correct value."""
secure_config = SecureConfig(self.config_path)
# Set API key
secure_config.set_api_key("test_service", "test_key")
# Get API key
result = secure_config.get_api_key("test_service")
assert result == "test_key"
def test_set_api_key_encrypts_data(self):
"""Test that API keys are stored encrypted."""
secure_config = SecureConfig(self.config_path)
# Set API key
secure_config.set_api_key("test_service", "test_key")
# Read raw config file
with open(self.config_path, "rb") as f:
encrypted_data = f.read()
# Data should be encrypted (not JSON)
with pytest.raises(json.JSONDecodeError):
json.loads(encrypted_data.decode())
def test_set_api_key_handles_encryption_errors(self):
"""Test that set_api_key handles encryption errors gracefully."""
secure_config = SecureConfig(self.config_path)
# Mock Fernet to raise exception
with patch.object(secure_config.fernet, 'encrypt', side_effect=Exception("Encryption failed")):
result = secure_config.set_api_key("test_service", "test_key")
assert result is False
def test_get_api_key_handles_decryption_errors(self):
"""Test that get_api_key handles decryption errors gracefully."""
secure_config = SecureConfig(self.config_path)
# Create corrupted config file
with open(self.config_path, "wb") as f:
f.write(b"corrupted_data")
# Should return None instead of raising exception
result = secure_config.get_api_key("test_service")
assert result is None
def test_config_file_permissions(self):
"""Test that config file has correct permissions."""
secure_config = SecureConfig(self.config_path)
# Set API key to create config file
secure_config.set_api_key("test_service", "test_key")
# Check file permissions (should be owner-only)
stat = self.config_path.stat()
assert oct(stat.st_mode)[-3:] == "600"
def test_key_file_permissions(self):
"""Test that key file has correct permissions."""
secure_config = SecureConfig(self.config_path)
# Check key file permissions (should be owner-only)
stat = self.key_path.stat()
assert oct(stat.st_mode)[-3:] == "600"
class TestPathValidation:
"""Test cases for path validation functions."""
def test_validate_path_allows_safe_paths(self):
"""Test that validate_path allows safe file paths."""
safe_paths = [
"~/Documents/test.txt",
"~/Downloads/video.mp4",
"~/.trax/config.json",
"~/Desktop/file.txt",
"~/Music/song.mp3",
"~/Videos/video.mp4",
]
for path in safe_paths:
assert validate_path(path) is True
def test_validate_path_blocks_directory_traversal(self):
"""Test that validate_path blocks directory traversal attempts."""
malicious_paths = [
"../../../etc/passwd",
"~/Documents/../../../etc/shadow",
"/tmp/../../../root/.ssh/id_rsa",
"~/Downloads/..//..//..//var/log/auth.log",
]
for path in malicious_paths:
assert validate_path(path) is False
def test_validate_path_blocks_system_directories(self):
"""Test that validate_path blocks access to system directories."""
system_paths = [
"/etc/passwd",
"/var/log/auth.log",
"/root/.ssh/id_rsa",
"/tmp/malicious_file",
]
for path in system_paths:
assert validate_path(path) is False
def test_validate_path_handles_relative_paths(self):
"""Test that validate_path handles relative paths correctly."""
# Should convert to absolute and validate
assert validate_path("./test.txt") is True
assert validate_path("../test.txt") is False
def test_validate_path_handles_empty_path(self):
"""Test that validate_path handles empty path."""
assert validate_path("") is False
def test_validate_path_handles_none_path(self):
"""Test that validate_path handles None path."""
assert validate_path(None) is False
class TestURLValidation:
"""Test cases for URL validation functions."""
def test_validate_youtube_url_allows_valid_urls(self):
"""Test that validate_youtube_url allows valid YouTube URLs."""
valid_urls = [
"https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"https://youtu.be/dQw4w9WgXcQ",
"http://www.youtube.com/watch?v=dQw4w9WgXcQ",
"https://youtube.com/watch?v=dQw4w9WgXcQ",
"https://www.youtube.com/embed/dQw4w9WgXcQ",
]
for url in valid_urls:
assert validate_youtube_url(url) is True
def test_validate_youtube_url_blocks_invalid_urls(self):
"""Test that validate_youtube_url blocks invalid URLs."""
invalid_urls = [
"https://www.google.com",
"https://malicious-site.com/fake-youtube",
"ftp://youtube.com/video",
"javascript:alert('xss')",
"data:text/html,<script>alert('xss')</script>",
"file:///etc/passwd",
]
for url in invalid_urls:
assert validate_youtube_url(url) is False
def test_validate_youtube_url_handles_edge_cases(self):
"""Test that validate_youtube_url handles edge cases."""
edge_cases = [
"", # Empty string
None, # None value
"not_a_url", # Plain text
"youtube.com", # Missing protocol
]
for url in edge_cases:
assert validate_youtube_url(url) is False
def test_validate_youtube_url_handles_complex_urls(self):
"""Test that validate_youtube_url handles complex YouTube URLs."""
complex_urls = [
"https://www.youtube.com/watch?v=dQw4w9WgXcQ&t=30s",
"https://youtu.be/dQw4w9WgXcQ?t=30",
"https://www.youtube.com/watch?v=dQw4w9WgXcQ&list=PL123456",
]
for url in complex_urls:
assert validate_youtube_url(url) is True