import inspect from collections.abc import Awaitable, Callable from typing import TypeAlias import mcp.types from mcp import ClientSession, CreateMessageResult from mcp.client.session import SamplingFnT from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import CreateMessageRequestParams as SamplingParams from mcp.types import SamplingMessage __all__ = ["SamplingMessage", "SamplingParams", "SamplingHandler"] SamplingHandler: TypeAlias = Callable[ [ list[SamplingMessage], SamplingParams, RequestContext[ClientSession, LifespanContextT], ], str | CreateMessageResult | Awaitable[str | CreateMessageResult], ] def create_sampling_callback(sampling_handler: SamplingHandler) -> SamplingFnT: async def _sampling_handler( context: RequestContext[ClientSession, LifespanContextT], params: SamplingParams, ) -> CreateMessageResult | mcp.types.ErrorData: try: result = sampling_handler(params.messages, params, context) if inspect.isawaitable(result): result = await result if isinstance(result, str): result = CreateMessageResult( role="assistant", model="fastmcp-client", content=mcp.types.TextContent(type="text", text=result), ) return result except Exception as e: return mcp.types.ErrorData( code=mcp.types.INTERNAL_ERROR, message=str(e), ) return _sampling_handler