120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
from __future__ import annotations as _annotations
|
|
|
|
import functools
|
|
import inspect
|
|
from collections.abc import Callable
|
|
from functools import cached_property
|
|
from typing import TYPE_CHECKING, Any, get_origin
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from mcp.server.fastmcp.exceptions import ToolError
|
|
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
|
from mcp.types import ToolAnnotations
|
|
|
|
if TYPE_CHECKING:
|
|
from mcp.server.fastmcp.server import Context
|
|
from mcp.server.session import ServerSessionT
|
|
from mcp.shared.context import LifespanContextT, RequestT
|
|
|
|
|
|
class Tool(BaseModel):
|
|
"""Internal tool registration info."""
|
|
|
|
fn: Callable[..., Any] = Field(exclude=True)
|
|
name: str = Field(description="Name of the tool")
|
|
title: str | None = Field(None, description="Human-readable title of the tool")
|
|
description: str = Field(description="Description of what the tool does")
|
|
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
|
fn_metadata: FuncMetadata = Field(
|
|
description="Metadata about the function including a pydantic model for tool arguments"
|
|
)
|
|
is_async: bool = Field(description="Whether the tool is async")
|
|
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
|
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
|
|
|
|
@cached_property
|
|
def output_schema(self) -> dict[str, Any] | None:
|
|
return self.fn_metadata.output_schema
|
|
|
|
@classmethod
|
|
def from_function(
|
|
cls,
|
|
fn: Callable[..., Any],
|
|
name: str | None = None,
|
|
title: str | None = None,
|
|
description: str | None = None,
|
|
context_kwarg: str | None = None,
|
|
annotations: ToolAnnotations | None = None,
|
|
structured_output: bool | None = None,
|
|
) -> Tool:
|
|
"""Create a Tool from a function."""
|
|
from mcp.server.fastmcp.server import Context
|
|
|
|
func_name = name or fn.__name__
|
|
|
|
if func_name == "<lambda>":
|
|
raise ValueError("You must provide a name for lambda functions")
|
|
|
|
func_doc = description or fn.__doc__ or ""
|
|
is_async = _is_async_callable(fn)
|
|
|
|
if context_kwarg is None:
|
|
sig = inspect.signature(fn)
|
|
for param_name, param in sig.parameters.items():
|
|
if get_origin(param.annotation) is not None:
|
|
continue
|
|
if issubclass(param.annotation, Context):
|
|
context_kwarg = param_name
|
|
break
|
|
|
|
func_arg_metadata = func_metadata(
|
|
fn,
|
|
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
|
structured_output=structured_output,
|
|
)
|
|
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
|
|
|
|
return cls(
|
|
fn=fn,
|
|
name=func_name,
|
|
title=title,
|
|
description=func_doc,
|
|
parameters=parameters,
|
|
fn_metadata=func_arg_metadata,
|
|
is_async=is_async,
|
|
context_kwarg=context_kwarg,
|
|
annotations=annotations,
|
|
)
|
|
|
|
async def run(
|
|
self,
|
|
arguments: dict[str, Any],
|
|
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
|
convert_result: bool = False,
|
|
) -> Any:
|
|
"""Run the tool with arguments."""
|
|
try:
|
|
result = await self.fn_metadata.call_fn_with_arg_validation(
|
|
self.fn,
|
|
self.is_async,
|
|
arguments,
|
|
{self.context_kwarg: context} if self.context_kwarg is not None else None,
|
|
)
|
|
|
|
if convert_result:
|
|
result = self.fn_metadata.convert_result(result)
|
|
|
|
return result
|
|
except Exception as e:
|
|
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
|
|
|
|
|
def _is_async_callable(obj: Any) -> bool:
|
|
while isinstance(obj, functools.partial):
|
|
obj = obj.func
|
|
|
|
return inspect.iscoroutinefunction(obj) or (
|
|
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
|
|
)
|