377 lines
14 KiB
Python
377 lines
14 KiB
Python
"""Base classes for FastMCP prompts."""
|
|
|
|
from __future__ import annotations as _annotations
|
|
|
|
import inspect
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Callable, Sequence
|
|
from typing import Any
|
|
|
|
import pydantic_core
|
|
from mcp.types import ContentBlock, PromptMessage, Role, TextContent
|
|
from mcp.types import Prompt as MCPPrompt
|
|
from mcp.types import PromptArgument as MCPPromptArgument
|
|
from pydantic import Field, TypeAdapter
|
|
|
|
from fastmcp.exceptions import PromptError
|
|
from fastmcp.server.dependencies import get_context
|
|
from fastmcp.utilities.components import FastMCPComponent
|
|
from fastmcp.utilities.json_schema import compress_schema
|
|
from fastmcp.utilities.logging import get_logger
|
|
from fastmcp.utilities.types import (
|
|
FastMCPBaseModel,
|
|
find_kwarg_by_type,
|
|
get_cached_typeadapter,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def Message(
|
|
content: str | ContentBlock, role: Role | None = None, **kwargs: Any
|
|
) -> PromptMessage:
|
|
"""A user-friendly constructor for PromptMessage."""
|
|
if isinstance(content, str):
|
|
content = TextContent(type="text", text=content)
|
|
if role is None:
|
|
role = "user"
|
|
return PromptMessage(content=content, role=role, **kwargs)
|
|
|
|
|
|
message_validator = TypeAdapter[PromptMessage](PromptMessage)
|
|
|
|
SyncPromptResult = (
|
|
str
|
|
| PromptMessage
|
|
| dict[str, Any]
|
|
| Sequence[str | PromptMessage | dict[str, Any]]
|
|
)
|
|
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
|
|
|
|
|
class PromptArgument(FastMCPBaseModel):
|
|
"""An argument that can be passed to a prompt."""
|
|
|
|
name: str = Field(description="Name of the argument")
|
|
description: str | None = Field(
|
|
default=None, description="Description of what the argument does"
|
|
)
|
|
required: bool = Field(
|
|
default=False, description="Whether the argument is required"
|
|
)
|
|
|
|
|
|
class Prompt(FastMCPComponent, ABC):
|
|
"""A prompt template that can be rendered with parameters."""
|
|
|
|
arguments: list[PromptArgument] | None = Field(
|
|
default=None, description="Arguments that can be passed to the prompt"
|
|
)
|
|
|
|
def enable(self) -> None:
|
|
super().enable()
|
|
try:
|
|
context = get_context()
|
|
context._queue_prompt_list_changed() # type: ignore[private-use]
|
|
except RuntimeError:
|
|
pass # No context available
|
|
|
|
def disable(self) -> None:
|
|
super().disable()
|
|
try:
|
|
context = get_context()
|
|
context._queue_prompt_list_changed() # type: ignore[private-use]
|
|
except RuntimeError:
|
|
pass # No context available
|
|
|
|
def to_mcp_prompt(
|
|
self,
|
|
*,
|
|
include_fastmcp_meta: bool | None = None,
|
|
**overrides: Any,
|
|
) -> MCPPrompt:
|
|
"""Convert the prompt to an MCP prompt."""
|
|
arguments = [
|
|
MCPPromptArgument(
|
|
name=arg.name,
|
|
description=arg.description,
|
|
required=arg.required,
|
|
)
|
|
for arg in self.arguments or []
|
|
]
|
|
kwargs = {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"arguments": arguments,
|
|
"title": self.title,
|
|
"_meta": self.get_meta(include_fastmcp_meta=include_fastmcp_meta),
|
|
}
|
|
return MCPPrompt(**kwargs | overrides)
|
|
|
|
@staticmethod
|
|
def from_function(
|
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
|
name: str | None = None,
|
|
title: str | None = None,
|
|
description: str | None = None,
|
|
tags: set[str] | None = None,
|
|
enabled: bool | None = None,
|
|
meta: dict[str, Any] | None = None,
|
|
) -> FunctionPrompt:
|
|
"""Create a Prompt from a function.
|
|
|
|
The function can return:
|
|
- A string (converted to a message)
|
|
- A Message object
|
|
- A dict (converted to a message)
|
|
- A sequence of any of the above
|
|
"""
|
|
return FunctionPrompt.from_function(
|
|
fn=fn,
|
|
name=name,
|
|
title=title,
|
|
description=description,
|
|
tags=tags,
|
|
enabled=enabled,
|
|
meta=meta,
|
|
)
|
|
|
|
@abstractmethod
|
|
async def render(
|
|
self,
|
|
arguments: dict[str, Any] | None = None,
|
|
) -> list[PromptMessage]:
|
|
"""Render the prompt with arguments."""
|
|
raise NotImplementedError("Prompt.render() must be implemented by subclasses")
|
|
|
|
|
|
class FunctionPrompt(Prompt):
|
|
"""A prompt that is a function."""
|
|
|
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]]
|
|
|
|
@classmethod
|
|
def from_function(
|
|
cls,
|
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
|
name: str | None = None,
|
|
title: str | None = None,
|
|
description: str | None = None,
|
|
tags: set[str] | None = None,
|
|
enabled: bool | None = None,
|
|
meta: dict[str, Any] | None = None,
|
|
) -> FunctionPrompt:
|
|
"""Create a Prompt from a function.
|
|
|
|
The function can return:
|
|
- A string (converted to a message)
|
|
- A Message object
|
|
- A dict (converted to a message)
|
|
- A sequence of any of the above
|
|
"""
|
|
from fastmcp.server.context import Context
|
|
|
|
func_name = name or getattr(fn, "__name__", None) or fn.__class__.__name__
|
|
|
|
if func_name == "<lambda>":
|
|
raise ValueError("You must provide a name for lambda functions")
|
|
# Reject functions with *args or **kwargs
|
|
sig = inspect.signature(fn)
|
|
for param in sig.parameters.values():
|
|
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
raise ValueError("Functions with *args are not supported as prompts")
|
|
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
raise ValueError("Functions with **kwargs are not supported as prompts")
|
|
|
|
description = description or inspect.getdoc(fn)
|
|
|
|
# if the fn is a callable class, we need to get the __call__ method from here out
|
|
if not inspect.isroutine(fn):
|
|
fn = fn.__call__
|
|
# if the fn is a staticmethod, we need to work with the underlying function
|
|
if isinstance(fn, staticmethod):
|
|
fn = fn.__func__
|
|
|
|
type_adapter = get_cached_typeadapter(fn)
|
|
parameters = type_adapter.json_schema()
|
|
|
|
# Auto-detect context parameter if not provided
|
|
|
|
context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context)
|
|
if context_kwarg:
|
|
prune_params = [context_kwarg]
|
|
else:
|
|
prune_params = None
|
|
|
|
parameters = compress_schema(parameters, prune_params=prune_params)
|
|
|
|
# Convert parameters to PromptArguments
|
|
arguments: list[PromptArgument] = []
|
|
if "properties" in parameters:
|
|
for param_name, param in parameters["properties"].items():
|
|
arg_description = param.get("description")
|
|
|
|
# For non-string parameters, append JSON schema info to help users
|
|
# understand the expected format when passing as strings (MCP requirement)
|
|
if param_name in sig.parameters:
|
|
sig_param = sig.parameters[param_name]
|
|
if (
|
|
sig_param.annotation != inspect.Parameter.empty
|
|
and sig_param.annotation is not str
|
|
and param_name != context_kwarg
|
|
):
|
|
# Get the JSON schema for this specific parameter type
|
|
try:
|
|
param_adapter = get_cached_typeadapter(sig_param.annotation)
|
|
param_schema = param_adapter.json_schema()
|
|
|
|
# Create compact schema representation
|
|
schema_str = json.dumps(param_schema, separators=(",", ":"))
|
|
|
|
# Append schema info to description
|
|
schema_note = f"Provide as a JSON string matching the following schema: {schema_str}"
|
|
if arg_description:
|
|
arg_description = f"{arg_description}\n\n{schema_note}"
|
|
else:
|
|
arg_description = schema_note
|
|
except Exception:
|
|
# If schema generation fails, skip enhancement
|
|
pass
|
|
|
|
arguments.append(
|
|
PromptArgument(
|
|
name=param_name,
|
|
description=arg_description,
|
|
required=param_name in parameters.get("required", []),
|
|
)
|
|
)
|
|
|
|
return cls(
|
|
name=func_name,
|
|
title=title,
|
|
description=description,
|
|
arguments=arguments,
|
|
tags=tags or set(),
|
|
enabled=enabled if enabled is not None else True,
|
|
fn=fn,
|
|
meta=meta,
|
|
)
|
|
|
|
def _convert_string_arguments(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
"""Convert string arguments to expected types based on function signature."""
|
|
from fastmcp.server.context import Context
|
|
|
|
sig = inspect.signature(self.fn)
|
|
converted_kwargs = {}
|
|
|
|
# Find context parameter name if any
|
|
context_param_name = find_kwarg_by_type(self.fn, kwarg_type=Context)
|
|
|
|
for param_name, param_value in kwargs.items():
|
|
if param_name in sig.parameters:
|
|
param = sig.parameters[param_name]
|
|
|
|
# Skip Context parameters - they're handled separately
|
|
if param_name == context_param_name:
|
|
converted_kwargs[param_name] = param_value
|
|
continue
|
|
|
|
# If parameter has no annotation or annotation is str, pass as-is
|
|
if (
|
|
param.annotation == inspect.Parameter.empty
|
|
or param.annotation is str
|
|
):
|
|
converted_kwargs[param_name] = param_value
|
|
# If argument is not a string, pass as-is (already properly typed)
|
|
elif not isinstance(param_value, str):
|
|
converted_kwargs[param_name] = param_value
|
|
else:
|
|
# Try to convert string argument using type adapter
|
|
try:
|
|
adapter = get_cached_typeadapter(param.annotation)
|
|
# Try JSON parsing first for complex types
|
|
try:
|
|
converted_kwargs[param_name] = adapter.validate_json(
|
|
param_value
|
|
)
|
|
except (ValueError, TypeError, pydantic_core.ValidationError):
|
|
# Fallback to direct validation
|
|
converted_kwargs[param_name] = adapter.validate_python(
|
|
param_value
|
|
)
|
|
except (ValueError, TypeError, pydantic_core.ValidationError) as e:
|
|
# If conversion fails, provide informative error
|
|
raise PromptError(
|
|
f"Could not convert argument '{param_name}' with value '{param_value}' "
|
|
f"to expected type {param.annotation}. Error: {e}"
|
|
)
|
|
else:
|
|
# Parameter not in function signature, pass as-is
|
|
converted_kwargs[param_name] = param_value
|
|
|
|
return converted_kwargs
|
|
|
|
async def render(
|
|
self,
|
|
arguments: dict[str, Any] | None = None,
|
|
) -> list[PromptMessage]:
|
|
"""Render the prompt with arguments."""
|
|
from fastmcp.server.context import Context
|
|
|
|
# Validate required arguments
|
|
if self.arguments:
|
|
required = {arg.name for arg in self.arguments if arg.required}
|
|
provided = set(arguments or {})
|
|
missing = required - provided
|
|
if missing:
|
|
raise ValueError(f"Missing required arguments: {missing}")
|
|
|
|
try:
|
|
# Prepare arguments with context
|
|
kwargs = arguments.copy() if arguments else {}
|
|
context_kwarg = find_kwarg_by_type(self.fn, kwarg_type=Context)
|
|
if context_kwarg and context_kwarg not in kwargs:
|
|
kwargs[context_kwarg] = get_context()
|
|
|
|
# Convert string arguments to expected types when needed
|
|
kwargs = self._convert_string_arguments(kwargs)
|
|
|
|
# Call function and check if result is a coroutine
|
|
result = self.fn(**kwargs)
|
|
if inspect.isawaitable(result):
|
|
result = await result
|
|
|
|
# Validate messages
|
|
if not isinstance(result, list | tuple):
|
|
result = [result]
|
|
|
|
# Convert result to messages
|
|
messages: list[PromptMessage] = []
|
|
for msg in result:
|
|
try:
|
|
if isinstance(msg, PromptMessage):
|
|
messages.append(msg)
|
|
elif isinstance(msg, str):
|
|
messages.append(
|
|
PromptMessage(
|
|
role="user",
|
|
content=TextContent(type="text", text=msg),
|
|
)
|
|
)
|
|
else:
|
|
content = pydantic_core.to_json(msg, fallback=str).decode()
|
|
messages.append(
|
|
PromptMessage(
|
|
role="user",
|
|
content=TextContent(type="text", text=content),
|
|
)
|
|
)
|
|
except Exception:
|
|
raise PromptError("Could not convert prompt result to message.")
|
|
|
|
return messages
|
|
except Exception:
|
|
logger.exception(f"Error rendering prompt {self.name}")
|
|
raise PromptError(f"Error rendering prompt {self.name}.")
|