239 lines
9.8 KiB
Python
239 lines
9.8 KiB
Python
import base64
|
|
import hashlib
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
|
from starlette.requests import Request
|
|
|
|
from mcp.server.auth.errors import stringify_pydantic_error
|
|
from mcp.server.auth.json_response import PydanticJSONResponse
|
|
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
|
from mcp.shared.auth import OAuthToken
|
|
|
|
|
|
class AuthorizationCodeRequest(BaseModel):
|
|
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
|
grant_type: Literal["authorization_code"]
|
|
code: str = Field(..., description="The authorization code")
|
|
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
|
client_id: str
|
|
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
|
client_secret: str | None = None
|
|
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
|
|
code_verifier: str = Field(..., description="PKCE code verifier")
|
|
# RFC 8707 resource indicator
|
|
resource: str | None = Field(None, description="Resource indicator for the token")
|
|
|
|
|
|
class RefreshTokenRequest(BaseModel):
|
|
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
|
|
grant_type: Literal["refresh_token"]
|
|
refresh_token: str = Field(..., description="The refresh token")
|
|
scope: str | None = Field(None, description="Optional scope parameter")
|
|
client_id: str
|
|
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
|
client_secret: str | None = None
|
|
# RFC 8707 resource indicator
|
|
resource: str | None = Field(None, description="Resource indicator for the token")
|
|
|
|
|
|
class TokenRequest(
|
|
RootModel[
|
|
Annotated[
|
|
AuthorizationCodeRequest | RefreshTokenRequest,
|
|
Field(discriminator="grant_type"),
|
|
]
|
|
]
|
|
):
|
|
root: Annotated[
|
|
AuthorizationCodeRequest | RefreshTokenRequest,
|
|
Field(discriminator="grant_type"),
|
|
]
|
|
|
|
|
|
class TokenErrorResponse(BaseModel):
|
|
"""
|
|
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
|
"""
|
|
|
|
error: TokenErrorCode
|
|
error_description: str | None = None
|
|
error_uri: AnyHttpUrl | None = None
|
|
|
|
|
|
class TokenSuccessResponse(RootModel[OAuthToken]):
|
|
# this is just a wrapper over OAuthToken; the only reason we do this
|
|
# is to have some separation between the HTTP response type, and the
|
|
# type returned by the provider
|
|
root: OAuthToken
|
|
|
|
|
|
@dataclass
|
|
class TokenHandler:
|
|
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
client_authenticator: ClientAuthenticator
|
|
|
|
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
|
status_code = 200
|
|
if isinstance(obj, TokenErrorResponse):
|
|
status_code = 400
|
|
|
|
return PydanticJSONResponse(
|
|
content=obj,
|
|
status_code=status_code,
|
|
headers={
|
|
"Cache-Control": "no-store",
|
|
"Pragma": "no-cache",
|
|
},
|
|
)
|
|
|
|
async def handle(self, request: Request):
|
|
try:
|
|
form_data = await request.form()
|
|
token_request = TokenRequest.model_validate(dict(form_data)).root
|
|
except ValidationError as validation_error:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_request",
|
|
error_description=stringify_pydantic_error(validation_error),
|
|
)
|
|
)
|
|
|
|
try:
|
|
client_info = await self.client_authenticator.authenticate(
|
|
client_id=token_request.client_id,
|
|
client_secret=token_request.client_secret,
|
|
)
|
|
except AuthenticationError as e:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="unauthorized_client",
|
|
error_description=e.message,
|
|
)
|
|
)
|
|
|
|
if token_request.grant_type not in client_info.grant_types:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="unsupported_grant_type",
|
|
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
|
|
)
|
|
)
|
|
|
|
tokens: OAuthToken
|
|
|
|
match token_request:
|
|
case AuthorizationCodeRequest():
|
|
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
|
if auth_code is None or auth_code.client_id != token_request.client_id:
|
|
# if code belongs to different client, pretend it doesn't exist
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_grant",
|
|
error_description="authorization code does not exist",
|
|
)
|
|
)
|
|
|
|
# make auth codes expire after a deadline
|
|
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
|
|
if auth_code.expires_at < time.time():
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_grant",
|
|
error_description="authorization code has expired",
|
|
)
|
|
)
|
|
|
|
# verify redirect_uri doesn't change between /authorize and /tokens
|
|
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
|
|
if auth_code.redirect_uri_provided_explicitly:
|
|
authorize_request_redirect_uri = auth_code.redirect_uri
|
|
else:
|
|
authorize_request_redirect_uri = None
|
|
|
|
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
|
|
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
|
|
auth_redirect_str = (
|
|
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
|
|
)
|
|
|
|
if token_redirect_str != auth_redirect_str:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_request",
|
|
error_description=("redirect_uri did not match the one used when creating auth code"),
|
|
)
|
|
)
|
|
|
|
# Verify PKCE code verifier
|
|
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
|
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
|
|
|
if hashed_code_verifier != auth_code.code_challenge:
|
|
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_grant",
|
|
error_description="incorrect code_verifier",
|
|
)
|
|
)
|
|
|
|
try:
|
|
# Exchange authorization code for tokens
|
|
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
|
except TokenError as e:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error=e.error,
|
|
error_description=e.error_description,
|
|
)
|
|
)
|
|
|
|
case RefreshTokenRequest():
|
|
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
|
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
|
# if token belongs to different client, pretend it doesn't exist
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_grant",
|
|
error_description="refresh token does not exist",
|
|
)
|
|
)
|
|
|
|
if refresh_token.expires_at and refresh_token.expires_at < time.time():
|
|
# if the refresh token has expired, pretend it doesn't exist
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_grant",
|
|
error_description="refresh token has expired",
|
|
)
|
|
)
|
|
|
|
# Parse scopes if provided
|
|
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
|
|
|
for scope in scopes:
|
|
if scope not in refresh_token.scopes:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error="invalid_scope",
|
|
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
|
|
)
|
|
)
|
|
|
|
try:
|
|
# Exchange refresh token for new tokens
|
|
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
|
except TokenError as e:
|
|
return self.response(
|
|
TokenErrorResponse(
|
|
error=e.error,
|
|
error_description=e.error_description,
|
|
)
|
|
)
|
|
|
|
return self.response(TokenSuccessResponse(root=tokens))
|