"""OpenAPI component implementations: Tool, Resource, and ResourceTemplate classes.""" import json import re from collections.abc import Callable from typing import TYPE_CHECKING, Any import httpx from mcp.types import ToolAnnotations from pydantic.networks import AnyUrl # Import from our new utilities from fastmcp.experimental.utilities.openapi import HTTPRoute from fastmcp.experimental.utilities.openapi.director import RequestDirector from fastmcp.resources import Resource, ResourceTemplate from fastmcp.server.dependencies import get_http_headers from fastmcp.tools.tool import Tool, ToolResult from fastmcp.utilities.logging import get_logger if TYPE_CHECKING: from fastmcp.server import Context logger = get_logger(__name__) class OpenAPITool(Tool): """Tool implementation for OpenAPI endpoints.""" def __init__( self, client: httpx.AsyncClient, route: HTTPRoute, director: RequestDirector, name: str, description: str, parameters: dict[str, Any], output_schema: dict[str, Any] | None = None, tags: set[str] | None = None, timeout: float | None = None, annotations: ToolAnnotations | None = None, serializer: Callable[[Any], str] | None = None, ): super().__init__( name=name, description=description, parameters=parameters, output_schema=output_schema, tags=tags or set(), annotations=annotations, serializer=serializer, ) self._client = client self._route = route self._director = director self._timeout = timeout def __repr__(self) -> str: """Custom representation to prevent recursion errors when printing.""" return f"OpenAPITool(name={self.name!r}, method={self._route.method}, path={self._route.path})" async def run(self, arguments: dict[str, Any]) -> ToolResult: """Execute the HTTP request using RequestDirector for simplified parameter handling.""" try: # Get base URL from client base_url = ( str(self._client.base_url) if hasattr(self._client, "base_url") and self._client.base_url else "http://localhost" ) # Get Headers from client cli_headers = ( self._client.headers if hasattr(self._client, "headers") and self._client.headers else {} ) # Build the request using RequestDirector request = self._director.build(self._route, arguments, base_url) # First add server headers (lowest precedence) if cli_headers: # Merge with existing headers, _client headers as base if request.headers: # Start with request headers, then add client headers for key, value in cli_headers.items(): if key not in request.headers: request.headers[key] = value else: # Create new headers from cli_headers for key, value in cli_headers.items(): request.headers[key] = value # Then add MCP client transport headers (highest precedence) mcp_headers = get_http_headers() if mcp_headers: # Merge with existing headers, MCP headers take precedence over all if request.headers: request.headers.update(mcp_headers) else: # Create new headers from mcp_headers for key, value in mcp_headers.items(): request.headers[key] = value # print logger logger.debug(f"run - sending request; headers: {request.headers}") # Execute the request # Note: httpx.AsyncClient.send() doesn't accept timeout parameter # The timeout should be configured on the client itself response = await self._client.send(request) # Raise for 4xx/5xx responses response.raise_for_status() # Try to parse as JSON first try: result = response.json() # Handle structured content based on output schema, if any structured_output = None if self.output_schema is not None: if self.output_schema.get("x-fastmcp-wrap-result"): # Schema says wrap - always wrap in result key structured_output = {"result": result} else: structured_output = result # If no output schema, use fallback logic for backward compatibility elif not isinstance(result, dict): structured_output = {"result": result} else: structured_output = result return ToolResult(structured_content=structured_output) except json.JSONDecodeError: return ToolResult(content=response.text) except httpx.HTTPStatusError as e: # Handle HTTP errors (4xx, 5xx) error_message = ( f"HTTP error {e.response.status_code}: {e.response.reason_phrase}" ) try: error_data = e.response.json() error_message += f" - {error_data}" except (json.JSONDecodeError, ValueError): if e.response.text: error_message += f" - {e.response.text}" raise ValueError(error_message) except httpx.RequestError as e: # Handle request errors (connection, timeout, etc.) raise ValueError(f"Request error: {str(e)}") class OpenAPIResource(Resource): """Resource implementation for OpenAPI endpoints.""" def __init__( self, client: httpx.AsyncClient, route: HTTPRoute, director: RequestDirector, uri: str, name: str, description: str, mime_type: str = "application/json", tags: set[str] = set(), timeout: float | None = None, ): super().__init__( uri=AnyUrl(uri), # Convert string to AnyUrl name=name, description=description, mime_type=mime_type, tags=tags, ) self._client = client self._route = route self._director = director self._timeout = timeout def __repr__(self) -> str: """Custom representation to prevent recursion errors when printing.""" return f"OpenAPIResource(name={self.name!r}, uri={self.uri!r}, path={self._route.path})" async def read(self) -> str | bytes: """Fetch the resource data by making an HTTP request.""" try: # Extract path parameters from the URI if present path = self._route.path resource_uri = str(self.uri) # If this is a templated resource, extract path parameters from the URI if "{" in path and "}" in path: # Extract the resource ID from the URI (the last part after the last slash) parts = resource_uri.split("/") if len(parts) > 1: # Find all path parameters in the route path path_params = {} # Find the path parameter names from the route path param_matches = re.findall(r"\{([^}]+)\}", path) if param_matches: # Reverse sorting from creation order (traversal is backwards) param_matches.sort(reverse=True) # Number of sent parameters is number of parts -1 (assuming first part is resource identifier) expected_param_count = len(parts) - 1 # Map parameters from the end of the URI to the parameters in the path # Last parameter in URI (parts[-1]) maps to last parameter in path, and so on for i, param_name in enumerate(param_matches): # Ensure we don't use resource identifier as parameter if i < expected_param_count: # Get values from the end of parts param_value = parts[-1 - i] path_params[param_name] = param_value # Replace path parameters with their values for param_name, param_value in path_params.items(): path = path.replace(f"{{{param_name}}}", str(param_value)) # Filter any query parameters - get query parameters and filter out None/empty values query_params = {} for param in self._route.parameters: if param.location == "query" and hasattr(self, f"_{param.name}"): value = getattr(self, f"_{param.name}") if value is not None and value != "": query_params[param.name] = value # Prepare headers with correct precedence: server < client transport headers = {} # Start with server headers (lowest precedence) cli_headers = ( self._client.headers if hasattr(self._client, "headers") and self._client.headers else {} ) headers.update(cli_headers) # Add MCP client transport headers (highest precedence) mcp_headers = get_http_headers() headers.update(mcp_headers) response = await self._client.request( method=self._route.method, url=path, params=query_params, headers=headers, timeout=self._timeout, ) # Raise for 4xx/5xx responses response.raise_for_status() # Determine content type and return appropriate format content_type = response.headers.get("content-type", "").lower() if "application/json" in content_type: result = response.json() return json.dumps(result) elif any(ct in content_type for ct in ["text/", "application/xml"]): return response.text else: return response.content except httpx.HTTPStatusError as e: # Handle HTTP errors (4xx, 5xx) error_message = ( f"HTTP error {e.response.status_code}: {e.response.reason_phrase}" ) try: error_data = e.response.json() error_message += f" - {error_data}" except (json.JSONDecodeError, ValueError): if e.response.text: error_message += f" - {e.response.text}" raise ValueError(error_message) except httpx.RequestError as e: # Handle request errors (connection, timeout, etc.) raise ValueError(f"Request error: {str(e)}") class OpenAPIResourceTemplate(ResourceTemplate): """Resource template implementation for OpenAPI endpoints.""" def __init__( self, client: httpx.AsyncClient, route: HTTPRoute, director: RequestDirector, uri_template: str, name: str, description: str, parameters: dict[str, Any], tags: set[str] = set(), timeout: float | None = None, ): super().__init__( uri_template=uri_template, name=name, description=description, parameters=parameters, tags=tags, ) self._client = client self._route = route self._director = director self._timeout = timeout def __repr__(self) -> str: """Custom representation to prevent recursion errors when printing.""" return f"OpenAPIResourceTemplate(name={self.name!r}, uri_template={self.uri_template!r}, path={self._route.path})" async def create_resource( self, uri: str, params: dict[str, Any], context: "Context | None" = None, ) -> Resource: """Create a resource with the given parameters.""" # Generate a URI for this resource instance uri_parts = [] for key, value in params.items(): uri_parts.append(f"{key}={value}") # Create and return a resource return OpenAPIResource( client=self._client, route=self._route, director=self._director, uri=uri, name=f"{self.name}-{'-'.join(uri_parts)}", description=self.description or f"Resource for {self._route.path}", mime_type="application/json", tags=set(self._route.tags or []), timeout=self._timeout, ) # Export public symbols __all__ = [ "OpenAPITool", "OpenAPIResource", "OpenAPIResourceTemplate", ]