128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
"""DNS rebinding protection for MCP server transports."""
|
|
|
|
import logging
|
|
|
|
from pydantic import BaseModel, Field
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TransportSecuritySettings(BaseModel):
|
|
"""Settings for MCP transport security features.
|
|
|
|
These settings help protect against DNS rebinding attacks by validating
|
|
incoming request headers.
|
|
"""
|
|
|
|
enable_dns_rebinding_protection: bool = Field(
|
|
default=True,
|
|
description="Enable DNS rebinding protection (recommended for production)",
|
|
)
|
|
|
|
allowed_hosts: list[str] = Field(
|
|
default=[],
|
|
description="List of allowed Host header values. Only applies when "
|
|
+ "enable_dns_rebinding_protection is True.",
|
|
)
|
|
|
|
allowed_origins: list[str] = Field(
|
|
default=[],
|
|
description="List of allowed Origin header values. Only applies when "
|
|
+ "enable_dns_rebinding_protection is True.",
|
|
)
|
|
|
|
|
|
class TransportSecurityMiddleware:
|
|
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
|
|
|
|
def __init__(self, settings: TransportSecuritySettings | None = None):
|
|
# If not specified, disable DNS rebinding protection by default
|
|
# for backwards compatibility
|
|
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
|
|
|
|
def _validate_host(self, host: str | None) -> bool:
|
|
"""Validate the Host header against allowed values."""
|
|
if not host:
|
|
logger.warning("Missing Host header in request")
|
|
return False
|
|
|
|
# Check exact match first
|
|
if host in self.settings.allowed_hosts:
|
|
return True
|
|
|
|
# Check wildcard port patterns
|
|
for allowed in self.settings.allowed_hosts:
|
|
if allowed.endswith(":*"):
|
|
# Extract base host from pattern
|
|
base_host = allowed[:-2]
|
|
# Check if the actual host starts with base host and has a port
|
|
if host.startswith(base_host + ":"):
|
|
return True
|
|
|
|
logger.warning(f"Invalid Host header: {host}")
|
|
return False
|
|
|
|
def _validate_origin(self, origin: str | None) -> bool:
|
|
"""Validate the Origin header against allowed values."""
|
|
# Origin can be absent for same-origin requests
|
|
if not origin:
|
|
return True
|
|
|
|
# Check exact match first
|
|
if origin in self.settings.allowed_origins:
|
|
return True
|
|
|
|
# Check wildcard port patterns
|
|
for allowed in self.settings.allowed_origins:
|
|
if allowed.endswith(":*"):
|
|
# Extract base origin from pattern
|
|
base_origin = allowed[:-2]
|
|
# Check if the actual origin starts with base origin and has a port
|
|
if origin.startswith(base_origin + ":"):
|
|
return True
|
|
|
|
logger.warning(f"Invalid Origin header: {origin}")
|
|
return False
|
|
|
|
def _validate_content_type(self, content_type: str | None) -> bool:
|
|
"""Validate the Content-Type header for POST requests."""
|
|
if not content_type:
|
|
logger.warning("Missing Content-Type header in POST request")
|
|
return False
|
|
|
|
# Content-Type must start with application/json
|
|
if not content_type.lower().startswith("application/json"):
|
|
logger.warning(f"Invalid Content-Type header: {content_type}")
|
|
return False
|
|
|
|
return True
|
|
|
|
async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
|
|
"""Validate request headers for DNS rebinding protection.
|
|
|
|
Returns None if validation passes, or an error Response if validation fails.
|
|
"""
|
|
# Always validate Content-Type for POST requests
|
|
if is_post:
|
|
content_type = request.headers.get("content-type")
|
|
if not self._validate_content_type(content_type):
|
|
return Response("Invalid Content-Type header", status_code=400)
|
|
|
|
# Skip remaining validation if DNS rebinding protection is disabled
|
|
if not self.settings.enable_dns_rebinding_protection:
|
|
return None
|
|
|
|
# Validate Host header
|
|
host = request.headers.get("host")
|
|
if not self._validate_host(host):
|
|
return Response("Invalid Host header", status_code=421)
|
|
|
|
# Validate Origin header
|
|
origin = request.headers.get("origin")
|
|
if not self._validate_origin(origin):
|
|
return Response("Invalid Origin header", status_code=400)
|
|
|
|
return None
|