youtube-summarizer/venv311/lib/python3.11/site-packages/fastmcp/server/middleware/rate_limiting.py

232 lines
7.7 KiB
Python

"""Rate limiting middleware for protecting FastMCP servers from abuse."""
import asyncio
import time
from collections import defaultdict, deque
from collections.abc import Callable
from typing import Any
from mcp import McpError
from mcp.types import ErrorData
from .middleware import CallNext, Middleware, MiddlewareContext
class RateLimitError(McpError):
"""Error raised when rate limit is exceeded."""
def __init__(self, message: str = "Rate limit exceeded"):
super().__init__(ErrorData(code=-32000, message=message))
class TokenBucketRateLimiter:
"""Token bucket implementation for rate limiting."""
def __init__(self, capacity: int, refill_rate: float):
"""Initialize token bucket.
Args:
capacity: Maximum number of tokens in the bucket
refill_rate: Tokens added per second
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens from the bucket.
Args:
tokens: Number of tokens to consume
Returns:
True if tokens were available and consumed, False otherwise
"""
async with self._lock:
now = time.time()
elapsed = now - self.last_refill
# Add tokens based on elapsed time
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
self.last_refill = now
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
class SlidingWindowRateLimiter:
"""Sliding window rate limiter implementation."""
def __init__(self, max_requests: int, window_seconds: int):
"""Initialize sliding window rate limiter.
Args:
max_requests: Maximum requests allowed in the time window
window_seconds: Time window in seconds
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = deque()
self._lock = asyncio.Lock()
async def is_allowed(self) -> bool:
"""Check if a request is allowed."""
async with self._lock:
now = time.time()
cutoff = now - self.window_seconds
# Remove old requests outside the window
while self.requests and self.requests[0] < cutoff:
self.requests.popleft()
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
class RateLimitingMiddleware(Middleware):
"""Middleware that implements rate limiting to prevent server abuse.
Uses a token bucket algorithm by default, allowing for burst traffic
while maintaining a sustainable long-term rate.
Example:
```python
from fastmcp.server.middleware.rate_limiting import RateLimitingMiddleware
# Allow 10 requests per second with bursts up to 20
rate_limiter = RateLimitingMiddleware(
max_requests_per_second=10,
burst_capacity=20
)
mcp = FastMCP("MyServer")
mcp.add_middleware(rate_limiter)
```
"""
def __init__(
self,
max_requests_per_second: float = 10.0,
burst_capacity: int | None = None,
get_client_id: Callable[[MiddlewareContext], str] | None = None,
global_limit: bool = False,
):
"""Initialize rate limiting middleware.
Args:
max_requests_per_second: Sustained requests per second allowed
burst_capacity: Maximum burst capacity. If None, defaults to 2x max_requests_per_second
get_client_id: Function to extract client ID from context. If None, uses global limiting
global_limit: If True, apply limit globally; if False, per-client
"""
self.max_requests_per_second = max_requests_per_second
self.burst_capacity = burst_capacity or int(max_requests_per_second * 2)
self.get_client_id = get_client_id
self.global_limit = global_limit
# Storage for rate limiters per client
self.limiters: dict[str, TokenBucketRateLimiter] = defaultdict(
lambda: TokenBucketRateLimiter(
self.burst_capacity, self.max_requests_per_second
)
)
# Global rate limiter
if self.global_limit:
self.global_limiter = TokenBucketRateLimiter(
self.burst_capacity, self.max_requests_per_second
)
def _get_client_identifier(self, context: MiddlewareContext) -> str:
"""Get client identifier for rate limiting."""
if self.get_client_id:
return self.get_client_id(context)
return "global"
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
"""Apply rate limiting to requests."""
if self.global_limit:
# Global rate limiting
allowed = await self.global_limiter.consume()
if not allowed:
raise RateLimitError("Global rate limit exceeded")
else:
# Per-client rate limiting
client_id = self._get_client_identifier(context)
limiter = self.limiters[client_id]
allowed = await limiter.consume()
if not allowed:
raise RateLimitError(f"Rate limit exceeded for client: {client_id}")
return await call_next(context)
class SlidingWindowRateLimitingMiddleware(Middleware):
"""Middleware that implements sliding window rate limiting.
Uses a sliding window approach which provides more precise rate limiting
but uses more memory to track individual request timestamps.
Example:
```python
from fastmcp.server.middleware.rate_limiting import SlidingWindowRateLimitingMiddleware
# Allow 100 requests per minute
rate_limiter = SlidingWindowRateLimitingMiddleware(
max_requests=100,
window_minutes=1
)
mcp = FastMCP("MyServer")
mcp.add_middleware(rate_limiter)
```
"""
def __init__(
self,
max_requests: int,
window_minutes: int = 1,
get_client_id: Callable[[MiddlewareContext], str] | None = None,
):
"""Initialize sliding window rate limiting middleware.
Args:
max_requests: Maximum requests allowed in the time window
window_minutes: Time window in minutes
get_client_id: Function to extract client ID from context
"""
self.max_requests = max_requests
self.window_seconds = window_minutes * 60
self.get_client_id = get_client_id
# Storage for rate limiters per client
self.limiters: dict[str, SlidingWindowRateLimiter] = defaultdict(
lambda: SlidingWindowRateLimiter(self.max_requests, self.window_seconds)
)
def _get_client_identifier(self, context: MiddlewareContext) -> str:
"""Get client identifier for rate limiting."""
if self.get_client_id:
return self.get_client_id(context)
return "global"
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
"""Apply sliding window rate limiting to requests."""
client_id = self._get_client_identifier(context)
limiter = self.limiters[client_id]
allowed = await limiter.is_allowed()
if not allowed:
raise RateLimitError(
f"Rate limit exceeded: {self.max_requests} requests per "
f"{self.window_seconds // 60} minutes for client: {client_id}"
)
return await call_next(context)