import base64 import hashlib import time from dataclasses import dataclass from typing import Annotated, Any, Literal from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode from mcp.shared.auth import OAuthToken class AuthorizationCodeRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 code_verifier: str = Field(..., description="PKCE code verifier") # RFC 8707 resource indicator resource: str | None = Field(None, description="Resource indicator for the token") class RefreshTokenRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: str | None = Field(None, description="Optional scope parameter") client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # RFC 8707 resource indicator resource: str | None = Field(None, description="Resource indicator for the token") class TokenRequest( RootModel[ Annotated[ AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type"), ] ] ): root: Annotated[ AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type"), ] class TokenErrorResponse(BaseModel): """ See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ error: TokenErrorCode error_description: str | None = None error_uri: AnyHttpUrl | None = None class TokenSuccessResponse(RootModel[OAuthToken]): # this is just a wrapper over OAuthToken; the only reason we do this # is to have some separation between the HTTP response type, and the # type returned by the provider root: OAuthToken @dataclass class TokenHandler: provider: OAuthAuthorizationServerProvider[Any, Any, Any] client_authenticator: ClientAuthenticator def response(self, obj: TokenSuccessResponse | TokenErrorResponse): status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 return PydanticJSONResponse( content=obj, status_code=status_code, headers={ "Cache-Control": "no-store", "Pragma": "no-cache", }, ) async def handle(self, request: Request): try: form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root except ValidationError as validation_error: return self.response( TokenErrorResponse( error="invalid_request", error_description=stringify_pydantic_error(validation_error), ) ) try: client_info = await self.client_authenticator.authenticate( client_id=token_request.client_id, client_secret=token_request.client_secret, ) except AuthenticationError as e: return self.response( TokenErrorResponse( error="unauthorized_client", error_description=e.message, ) ) if token_request.grant_type not in client_info.grant_types: return self.response( TokenErrorResponse( error="unsupported_grant_type", error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"), ) ) tokens: OAuthToken match token_request: case AuthorizationCodeRequest(): auth_code = await self.provider.load_authorization_code(client_info, token_request.code) if auth_code is None or auth_code.client_id != token_request.client_id: # if code belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", error_description="authorization code does not exist", ) ) # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 if auth_code.expires_at < time.time(): return self.response( TokenErrorResponse( error="invalid_grant", error_description="authorization code has expired", ) ) # verify redirect_uri doesn't change between /authorize and /tokens # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if auth_code.redirect_uri_provided_explicitly: authorize_request_redirect_uri = auth_code.redirect_uri else: authorize_request_redirect_uri = None # Convert both sides to strings for comparison to handle AnyUrl vs string issues token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None auth_redirect_str = ( str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None ) if token_redirect_str != auth_redirect_str: return self.response( TokenErrorResponse( error="invalid_request", error_description=("redirect_uri did not match the one used when creating auth code"), ) ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 return self.response( TokenErrorResponse( error="invalid_grant", error_description="incorrect code_verifier", ) ) try: # Exchange authorization code for tokens tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: return self.response( TokenErrorResponse( error=e.error, error_description=e.error_description, ) ) case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) if refresh_token is None or refresh_token.client_id != token_request.client_id: # if token belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", error_description="refresh token does not exist", ) ) if refresh_token.expires_at and refresh_token.expires_at < time.time(): # if the refresh token has expired, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", error_description="refresh token has expired", ) ) # Parse scopes if provided scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes for scope in scopes: if scope not in refresh_token.scopes: return self.response( TokenErrorResponse( error="invalid_scope", error_description=(f"cannot request scope `{scope}` not provided by refresh token"), ) ) try: # Exchange refresh token for new tokens tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: return self.response( TokenErrorResponse( error=e.error, error_description=e.error_description, ) ) return self.response(TokenSuccessResponse(root=tokens))