"""Elicitation utilities for MCP servers.""" from __future__ import annotations import types from typing import Generic, Literal, TypeVar, Union, get_args, get_origin from pydantic import BaseModel from pydantic.fields import FieldInfo from mcp.server.session import ServerSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]): """Result when user accepts the elicitation.""" action: Literal["accept"] = "accept" data: ElicitSchemaModelT class DeclinedElicitation(BaseModel): """Result when user declines the elicitation.""" action: Literal["decline"] = "decline" class CancelledElicitation(BaseModel): """Result when user cancels the elicitation.""" action: Literal["cancel"] = "cancel" ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation # Primitive types allowed in elicitation schemas _ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) def _validate_elicitation_schema(schema: type[BaseModel]) -> None: """Validate that a Pydantic model only contains primitive field types.""" for field_name, field_info in schema.model_fields.items(): if not _is_primitive_field(field_info): raise TypeError( f"Elicitation schema field '{field_name}' must be a primitive type " f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " f"Complex types like lists, dicts, or nested models are not allowed." ) def _is_primitive_field(field_info: FieldInfo) -> bool: """Check if a field is a primitive type allowed in elicitation schemas.""" annotation = field_info.annotation # Handle None type if annotation is types.NoneType: return True # Handle basic primitive types if annotation in _ELICITATION_PRIMITIVE_TYPES: return True # Handle Union types origin = get_origin(annotation) if origin is Union or origin is types.UnionType: args = get_args(annotation) # All args must be primitive types or None return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) return False async def elicit_with_validation( session: ServerSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user with schema validation. This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the user and collect a response according to the provided schema. Or in case a client is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. """ # Validate that schema only contains primitive types and fail loudly if not _validate_elicitation_schema(schema) json_schema = schema.model_json_schema() result = await session.elicit( message=message, requestedSchema=json_schema, related_request_id=related_request_id, ) if result.action == "accept" and result.content: # Validate and parse the content using the schema validated_data = schema.model_validate(result.content) return AcceptedElicitation(data=validated_data) elif result.action == "decline": return DeclinedElicitation() elif result.action == "cancel": return CancelledElicitation() else: # This should never happen, but handle it just in case raise ValueError(f"Unexpected elicitation action: {result.action}")