538 lines
18 KiB
Python
538 lines
18 KiB
Python
"""TokenVerifier implementations for FastMCP."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, cast
|
|
|
|
import httpx
|
|
from authlib.jose import JsonWebKey, JsonWebToken
|
|
from authlib.jose.errors import JoseError
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from pydantic import AnyHttpUrl, SecretStr
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
from typing_extensions import TypedDict
|
|
|
|
from fastmcp.server.auth import AccessToken, TokenVerifier
|
|
from fastmcp.server.auth.registry import register_provider
|
|
from fastmcp.utilities.logging import get_logger
|
|
from fastmcp.utilities.types import NotSet, NotSetT
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class JWKData(TypedDict, total=False):
|
|
"""JSON Web Key data structure."""
|
|
|
|
kty: str # Key type (e.g., "RSA") - required
|
|
kid: str # Key ID (optional but recommended)
|
|
use: str # Usage (e.g., "sig")
|
|
alg: str # Algorithm (e.g., "RS256")
|
|
n: str # Modulus (for RSA keys)
|
|
e: str # Exponent (for RSA keys)
|
|
x5c: list[str] # X.509 certificate chain (for JWKs)
|
|
x5t: str # X.509 certificate thumbprint (for JWKs)
|
|
|
|
|
|
class JWKSData(TypedDict):
|
|
"""JSON Web Key Set data structure."""
|
|
|
|
keys: list[JWKData]
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, repr=False)
|
|
class RSAKeyPair:
|
|
"""RSA key pair for JWT testing."""
|
|
|
|
private_key: SecretStr
|
|
public_key: str
|
|
|
|
@classmethod
|
|
def generate(cls) -> RSAKeyPair:
|
|
"""
|
|
Generate an RSA key pair for testing.
|
|
|
|
Returns:
|
|
RSAKeyPair: Generated key pair
|
|
"""
|
|
# Generate private key
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=2048,
|
|
)
|
|
|
|
# Serialize private key to PEM format
|
|
private_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
).decode("utf-8")
|
|
|
|
# Serialize public key to PEM format
|
|
public_pem = (
|
|
private_key.public_key()
|
|
.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
.decode("utf-8")
|
|
)
|
|
|
|
return cls(
|
|
private_key=SecretStr(private_pem),
|
|
public_key=public_pem,
|
|
)
|
|
|
|
def create_token(
|
|
self,
|
|
subject: str = "fastmcp-user",
|
|
issuer: str = "https://fastmcp.example.com",
|
|
audience: str | list[str] | None = None,
|
|
scopes: list[str] | None = None,
|
|
expires_in_seconds: int = 3600,
|
|
additional_claims: dict[str, Any] | None = None,
|
|
kid: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Generate a test JWT token for testing purposes.
|
|
|
|
Args:
|
|
subject: Subject claim (usually user ID)
|
|
issuer: Issuer claim
|
|
audience: Audience claim - can be a string or list of strings (optional)
|
|
scopes: List of scopes to include
|
|
expires_in_seconds: Token expiration time in seconds
|
|
additional_claims: Any additional claims to include
|
|
kid: Key ID to include in header
|
|
"""
|
|
# Create header
|
|
header = {"alg": "RS256"}
|
|
if kid:
|
|
header["kid"] = kid
|
|
|
|
# Create payload
|
|
payload = {
|
|
"sub": subject,
|
|
"iss": issuer,
|
|
"iat": int(time.time()),
|
|
"exp": int(time.time()) + expires_in_seconds,
|
|
}
|
|
|
|
if audience:
|
|
payload["aud"] = audience
|
|
|
|
if scopes:
|
|
payload["scope"] = " ".join(scopes)
|
|
|
|
if additional_claims:
|
|
payload.update(additional_claims)
|
|
|
|
# Create JWT
|
|
jwt_lib = JsonWebToken(["RS256"])
|
|
token_bytes = jwt_lib.encode(
|
|
header, payload, self.private_key.get_secret_value()
|
|
)
|
|
|
|
return token_bytes.decode("utf-8")
|
|
|
|
|
|
class JWTVerifierSettings(BaseSettings):
|
|
"""Settings for JWT token verification."""
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_prefix="FASTMCP_SERVER_AUTH_JWT_",
|
|
env_file=".env",
|
|
extra="ignore",
|
|
)
|
|
|
|
public_key: str | None = None
|
|
jwks_uri: str | None = None
|
|
issuer: str | None = None
|
|
algorithm: str | None = None
|
|
audience: str | list[str] | None = None
|
|
required_scopes: list[str] | None = None
|
|
resource_server_url: AnyHttpUrl | str | None = None
|
|
|
|
|
|
@register_provider("JWT")
|
|
class JWTVerifier(TokenVerifier):
|
|
"""
|
|
JWT token verifier using public key or JWKS.
|
|
|
|
This verifier validates JWT tokens signed by an external issuer. It's ideal for
|
|
scenarios where you have a centralized identity provider (like Auth0, Okta, or
|
|
your own OAuth server) that issues JWTs, and your FastMCP server acts as a
|
|
resource server validating those tokens.
|
|
|
|
Use this when:
|
|
- You have JWT tokens issued by an external service
|
|
- You want asymmetric key verification (public/private key pairs)
|
|
- You need JWKS support for automatic key rotation
|
|
- Your tokens contain standard OAuth scopes and claims
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
public_key: str | None | NotSetT = NotSet,
|
|
jwks_uri: str | None | NotSetT = NotSet,
|
|
issuer: str | None | NotSetT = NotSet,
|
|
audience: str | list[str] | None | NotSetT = NotSet,
|
|
algorithm: str | None | NotSetT = NotSet,
|
|
required_scopes: list[str] | None | NotSetT = NotSet,
|
|
resource_server_url: AnyHttpUrl | str | None | NotSetT = NotSet,
|
|
):
|
|
"""
|
|
Initialize the JWT token verifier.
|
|
|
|
Args:
|
|
public_key: PEM-encoded public key for verification
|
|
jwks_uri: URI to fetch JSON Web Key Set
|
|
issuer: Expected issuer claim
|
|
audience: Expected audience claim(s)
|
|
algorithm: JWT signing algorithm (default: RS256)
|
|
required_scopes: Required scopes for all tokens
|
|
resource_server_url: Resource server URL for TokenVerifier protocol
|
|
"""
|
|
settings = JWTVerifierSettings.model_validate(
|
|
{
|
|
k: v
|
|
for k, v in {
|
|
"public_key": public_key,
|
|
"jwks_uri": jwks_uri,
|
|
"issuer": issuer,
|
|
"audience": audience,
|
|
"algorithm": algorithm,
|
|
"required_scopes": required_scopes,
|
|
"resource_server_url": resource_server_url,
|
|
}.items()
|
|
if v is not NotSet
|
|
}
|
|
)
|
|
|
|
if not settings.public_key and not settings.jwks_uri:
|
|
raise ValueError("Either public_key or jwks_uri must be provided")
|
|
|
|
if settings.public_key and settings.jwks_uri:
|
|
raise ValueError("Provide either public_key or jwks_uri, not both")
|
|
|
|
algorithm = settings.algorithm or "RS256"
|
|
if algorithm not in {
|
|
"HS256",
|
|
"HS384",
|
|
"HS512",
|
|
"RS256",
|
|
"RS384",
|
|
"RS512",
|
|
"ES256",
|
|
"ES384",
|
|
"ES512",
|
|
"PS256",
|
|
"PS384",
|
|
"PS512",
|
|
}:
|
|
raise ValueError(f"Unsupported algorithm: {algorithm}.")
|
|
|
|
# Initialize parent TokenVerifier
|
|
super().__init__(
|
|
resource_server_url=settings.resource_server_url,
|
|
required_scopes=settings.required_scopes,
|
|
)
|
|
|
|
self.algorithm = algorithm
|
|
self.issuer = settings.issuer
|
|
self.audience = settings.audience
|
|
self.public_key = settings.public_key
|
|
self.jwks_uri = settings.jwks_uri
|
|
self.jwt = JsonWebToken([self.algorithm])
|
|
self.logger = get_logger(__name__)
|
|
|
|
# Simple JWKS cache
|
|
self._jwks_cache: dict[str, str] = {}
|
|
self._jwks_cache_time: float = 0
|
|
self._cache_ttl = 3600 # 1 hour
|
|
|
|
async def _get_verification_key(self, token: str) -> str:
|
|
"""Get the verification key for the token."""
|
|
if self.public_key:
|
|
return self.public_key
|
|
|
|
# Extract kid from token header for JWKS lookup
|
|
try:
|
|
import base64
|
|
import json
|
|
|
|
header_b64 = token.split(".")[0]
|
|
header_b64 += "=" * (4 - len(header_b64) % 4) # Add padding
|
|
header = json.loads(base64.urlsafe_b64decode(header_b64))
|
|
kid = header.get("kid")
|
|
|
|
return await self._get_jwks_key(kid)
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to extract key ID from token: {e}")
|
|
|
|
async def _get_jwks_key(self, kid: str | None) -> str:
|
|
"""Fetch key from JWKS with simple caching."""
|
|
if not self.jwks_uri:
|
|
raise ValueError("JWKS URI not configured")
|
|
|
|
current_time = time.time()
|
|
|
|
# Check cache first
|
|
if current_time - self._jwks_cache_time < self._cache_ttl:
|
|
if kid and kid in self._jwks_cache:
|
|
return self._jwks_cache[kid]
|
|
elif not kid and len(self._jwks_cache) == 1:
|
|
# If no kid but only one key cached, use it
|
|
return next(iter(self._jwks_cache.values()))
|
|
|
|
# Fetch JWKS
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(self.jwks_uri)
|
|
response.raise_for_status()
|
|
jwks_data = response.json()
|
|
|
|
# Cache all keys
|
|
self._jwks_cache = {}
|
|
for key_data in jwks_data.get("keys", []):
|
|
key_kid = key_data.get("kid")
|
|
jwk = JsonWebKey.import_key(key_data)
|
|
public_key = jwk.get_public_key() # type: ignore
|
|
|
|
if key_kid:
|
|
self._jwks_cache[key_kid] = public_key
|
|
else:
|
|
# Key without kid - use a default identifier
|
|
self._jwks_cache["_default"] = public_key
|
|
|
|
self._jwks_cache_time = current_time
|
|
|
|
# Select the appropriate key
|
|
if kid:
|
|
if kid not in self._jwks_cache:
|
|
self.logger.debug(
|
|
"JWKS key lookup failed: key ID '%s' not found", kid
|
|
)
|
|
raise ValueError(f"Key ID '{kid}' not found in JWKS")
|
|
return self._jwks_cache[kid]
|
|
else:
|
|
# No kid in token - only allow if there's exactly one key
|
|
if len(self._jwks_cache) == 1:
|
|
return next(iter(self._jwks_cache.values()))
|
|
elif len(self._jwks_cache) > 1:
|
|
raise ValueError(
|
|
"Multiple keys in JWKS but no key ID (kid) in token"
|
|
)
|
|
else:
|
|
raise ValueError("No keys found in JWKS")
|
|
|
|
except httpx.HTTPError as e:
|
|
raise ValueError(f"Failed to fetch JWKS: {e}")
|
|
except Exception as e:
|
|
self.logger.debug(f"JWKS fetch failed: {e}")
|
|
raise ValueError(f"Failed to fetch JWKS: {e}")
|
|
|
|
def _extract_scopes(self, claims: dict[str, Any]) -> list[str]:
|
|
"""
|
|
Extract scopes from JWT claims. Supports both 'scope' and 'scp'
|
|
claims.
|
|
|
|
Checks the `scope` claim first (standard OAuth2 claim), then the `scp`
|
|
claim (used by some Identity Providers).
|
|
"""
|
|
for claim in ["scope", "scp"]:
|
|
if claim in claims:
|
|
if isinstance(claims[claim], str):
|
|
return claims[claim].split()
|
|
elif isinstance(claims[claim], list):
|
|
return claims[claim]
|
|
|
|
return []
|
|
|
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
"""
|
|
Validates the provided JWT bearer token.
|
|
|
|
Args:
|
|
token: The JWT token string to validate
|
|
|
|
Returns:
|
|
AccessToken object if valid, None if invalid or expired
|
|
"""
|
|
try:
|
|
# Get verification key (static or from JWKS)
|
|
verification_key = await self._get_verification_key(token)
|
|
|
|
# Decode and verify the JWT token
|
|
claims = self.jwt.decode(token, verification_key)
|
|
|
|
# Extract client ID early for logging
|
|
client_id = claims.get("client_id") or claims.get("sub") or "unknown"
|
|
|
|
# Validate expiration
|
|
exp = claims.get("exp")
|
|
if exp and exp < time.time():
|
|
self.logger.debug(
|
|
"Token validation failed: expired token for client %s", client_id
|
|
)
|
|
self.logger.info("Bearer token rejected for client %s", client_id)
|
|
return None
|
|
|
|
# Validate issuer - note we use issuer instead of issuer_url here because
|
|
# issuer is optional, allowing users to make this check optional
|
|
if self.issuer:
|
|
if claims.get("iss") != self.issuer:
|
|
self.logger.debug(
|
|
"Token validation failed: issuer mismatch for client %s",
|
|
client_id,
|
|
)
|
|
self.logger.info("Bearer token rejected for client %s", client_id)
|
|
return None
|
|
|
|
# Validate audience if configured
|
|
if self.audience:
|
|
aud = claims.get("aud")
|
|
|
|
# Handle different combinations of audience types
|
|
audience_valid = False
|
|
if isinstance(self.audience, list):
|
|
# self.audience is a list - check if any expected audience is present
|
|
if isinstance(aud, list):
|
|
# Both are lists - check for intersection
|
|
audience_valid = any(
|
|
expected in aud for expected in self.audience
|
|
)
|
|
else:
|
|
# aud is a string - check if it's in our expected list
|
|
audience_valid = aud in cast(list, self.audience)
|
|
else:
|
|
# self.audience is a string - use original logic
|
|
if isinstance(aud, list):
|
|
audience_valid = self.audience in aud
|
|
else:
|
|
audience_valid = aud == self.audience
|
|
|
|
if not audience_valid:
|
|
self.logger.debug(
|
|
"Token validation failed: audience mismatch for client %s",
|
|
client_id,
|
|
)
|
|
self.logger.info("Bearer token rejected for client %s", client_id)
|
|
return None
|
|
|
|
# Extract scopes
|
|
scopes = self._extract_scopes(claims)
|
|
|
|
# Check required scopes
|
|
if self.required_scopes:
|
|
token_scopes = set(scopes)
|
|
required_scopes = set(self.required_scopes)
|
|
if not required_scopes.issubset(token_scopes):
|
|
self.logger.debug(
|
|
"Token missing required scopes. Has: %s, Required: %s",
|
|
token_scopes,
|
|
required_scopes,
|
|
)
|
|
self.logger.info("Bearer token rejected for client %s", client_id)
|
|
return None
|
|
|
|
return AccessToken(
|
|
token=token,
|
|
client_id=str(client_id),
|
|
scopes=scopes,
|
|
expires_at=int(exp) if exp else None,
|
|
claims=claims,
|
|
)
|
|
|
|
except JoseError:
|
|
self.logger.debug("Token validation failed: JWT signature/format invalid")
|
|
return None
|
|
except Exception as e:
|
|
self.logger.debug("Token validation failed: %s", str(e))
|
|
return None
|
|
|
|
async def verify_token(self, token: str) -> AccessToken | None:
|
|
"""
|
|
Verify a bearer token and return access info if valid.
|
|
|
|
This method implements the TokenVerifier protocol by delegating
|
|
to our existing load_access_token method.
|
|
|
|
Args:
|
|
token: The JWT token string to validate
|
|
|
|
Returns:
|
|
AccessToken object if valid, None if invalid or expired
|
|
"""
|
|
return await self.load_access_token(token)
|
|
|
|
|
|
class StaticTokenVerifier(TokenVerifier):
|
|
"""
|
|
Simple static token verifier for testing and development.
|
|
|
|
This verifier validates tokens against a predefined dictionary of valid token
|
|
strings and their associated claims. When a token string matches a key in the
|
|
dictionary, the verifier returns the corresponding claims as if the token was
|
|
validated by a real authorization server.
|
|
|
|
Use this when:
|
|
- You're developing or testing locally without a real OAuth server
|
|
- You need predictable tokens for automated testing
|
|
- You want to simulate different users/scopes without complex setup
|
|
- You're prototyping and need simple API key-style authentication
|
|
|
|
WARNING: Never use this in production - tokens are stored in plain text!
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tokens: dict[str, dict[str, Any]],
|
|
required_scopes: list[str] | None = None,
|
|
):
|
|
"""
|
|
Initialize the static token verifier.
|
|
|
|
Args:
|
|
tokens: Dict mapping token strings to token metadata
|
|
Each token should have: client_id, scopes, expires_at (optional)
|
|
required_scopes: Required scopes for all tokens
|
|
"""
|
|
super().__init__(required_scopes=required_scopes)
|
|
self.tokens = tokens
|
|
|
|
async def verify_token(self, token: str) -> AccessToken | None:
|
|
"""Verify token against static token dictionary."""
|
|
token_data = self.tokens.get(token)
|
|
if not token_data:
|
|
return None
|
|
|
|
# Check expiration if present
|
|
expires_at = token_data.get("expires_at")
|
|
if expires_at is not None and expires_at < time.time():
|
|
return None
|
|
|
|
scopes = token_data.get("scopes", [])
|
|
|
|
# Check required scopes
|
|
if self.required_scopes:
|
|
token_scopes = set(scopes)
|
|
required_scopes = set(self.required_scopes)
|
|
if not required_scopes.issubset(token_scopes):
|
|
logger.debug(
|
|
f"Token missing required scopes. Has: {token_scopes}, Required: {required_scopes}"
|
|
)
|
|
return None
|
|
|
|
return AccessToken(
|
|
token=token,
|
|
client_id=token_data["client_id"],
|
|
scopes=scopes,
|
|
expires_at=expires_at,
|
|
claims=token_data,
|
|
)
|