youtube-summarizer/venv311/lib/python3.11/site-packages/mcp/server/transport_security.py

128 lines
4.5 KiB
Python

"""DNS rebinding protection for MCP server transports."""
import logging
from pydantic import BaseModel, Field
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
class TransportSecuritySettings(BaseModel):
"""Settings for MCP transport security features.
These settings help protect against DNS rebinding attacks by validating
incoming request headers.
"""
enable_dns_rebinding_protection: bool = Field(
default=True,
description="Enable DNS rebinding protection (recommended for production)",
)
allowed_hosts: list[str] = Field(
default=[],
description="List of allowed Host header values. Only applies when "
+ "enable_dns_rebinding_protection is True.",
)
allowed_origins: list[str] = Field(
default=[],
description="List of allowed Origin header values. Only applies when "
+ "enable_dns_rebinding_protection is True.",
)
class TransportSecurityMiddleware:
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default
# for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
return False
# Check exact match first
if host in self.settings.allowed_hosts:
return True
# Check wildcard port patterns
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
if host.startswith(base_host + ":"):
return True
logger.warning(f"Invalid Host header: {host}")
return False
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
return True
# Check exact match first
if origin in self.settings.allowed_origins:
return True
# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
if origin.startswith(base_origin + ":"):
return True
logger.warning(f"Invalid Origin header: {origin}")
return False
def _validate_content_type(self, content_type: str | None) -> bool:
"""Validate the Content-Type header for POST requests."""
if not content_type:
logger.warning("Missing Content-Type header in POST request")
return False
# Content-Type must start with application/json
if not content_type.lower().startswith("application/json"):
logger.warning(f"Invalid Content-Type header: {content_type}")
return False
return True
async def validate_request(self, request: Request, is_post: bool = False) -> Response | None:
"""Validate request headers for DNS rebinding protection.
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
# Skip remaining validation if DNS rebinding protection is disabled
if not self.settings.enable_dns_rebinding_protection:
return None
# Validate Host header
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=421)
# Validate Origin header
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=400)
return None