112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
"""Elicitation utilities for MCP servers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import types
|
|
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import FieldInfo
|
|
|
|
from mcp.server.session import ServerSession
|
|
from mcp.types import RequestId
|
|
|
|
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
|
|
|
|
|
|
class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
|
|
"""Result when user accepts the elicitation."""
|
|
|
|
action: Literal["accept"] = "accept"
|
|
data: ElicitSchemaModelT
|
|
|
|
|
|
class DeclinedElicitation(BaseModel):
|
|
"""Result when user declines the elicitation."""
|
|
|
|
action: Literal["decline"] = "decline"
|
|
|
|
|
|
class CancelledElicitation(BaseModel):
|
|
"""Result when user cancels the elicitation."""
|
|
|
|
action: Literal["cancel"] = "cancel"
|
|
|
|
|
|
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
|
|
|
|
|
|
# Primitive types allowed in elicitation schemas
|
|
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
|
|
|
|
|
|
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
|
|
"""Validate that a Pydantic model only contains primitive field types."""
|
|
for field_name, field_info in schema.model_fields.items():
|
|
if not _is_primitive_field(field_info):
|
|
raise TypeError(
|
|
f"Elicitation schema field '{field_name}' must be a primitive type "
|
|
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
|
|
f"Complex types like lists, dicts, or nested models are not allowed."
|
|
)
|
|
|
|
|
|
def _is_primitive_field(field_info: FieldInfo) -> bool:
|
|
"""Check if a field is a primitive type allowed in elicitation schemas."""
|
|
annotation = field_info.annotation
|
|
|
|
# Handle None type
|
|
if annotation is types.NoneType:
|
|
return True
|
|
|
|
# Handle basic primitive types
|
|
if annotation in _ELICITATION_PRIMITIVE_TYPES:
|
|
return True
|
|
|
|
# Handle Union types
|
|
origin = get_origin(annotation)
|
|
if origin is Union or origin is types.UnionType:
|
|
args = get_args(annotation)
|
|
# All args must be primitive types or None
|
|
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
|
|
|
|
return False
|
|
|
|
|
|
async def elicit_with_validation(
|
|
session: ServerSession,
|
|
message: str,
|
|
schema: type[ElicitSchemaModelT],
|
|
related_request_id: RequestId | None = None,
|
|
) -> ElicitationResult[ElicitSchemaModelT]:
|
|
"""Elicit information from the client/user with schema validation.
|
|
|
|
This method can be used to interactively ask for additional information from the
|
|
client within a tool's execution. The client might display the message to the
|
|
user and collect a response according to the provided schema. Or in case a
|
|
client is an agent, it might decide how to handle the elicitation -- either by asking
|
|
the user or automatically generating a response.
|
|
"""
|
|
# Validate that schema only contains primitive types and fail loudly if not
|
|
_validate_elicitation_schema(schema)
|
|
|
|
json_schema = schema.model_json_schema()
|
|
|
|
result = await session.elicit(
|
|
message=message,
|
|
requestedSchema=json_schema,
|
|
related_request_id=related_request_id,
|
|
)
|
|
|
|
if result.action == "accept" and result.content:
|
|
# Validate and parse the content using the schema
|
|
validated_data = schema.model_validate(result.content)
|
|
return AcceptedElicitation(data=validated_data)
|
|
elif result.action == "decline":
|
|
return DeclinedElicitation()
|
|
elif result.action == "cancel":
|
|
return CancelledElicitation()
|
|
else:
|
|
# This should never happen, but handle it just in case
|
|
raise ValueError(f"Unexpected elicitation action: {result.action}")
|