193 lines
6.0 KiB
Python
193 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
import multiprocessing
|
|
import socket
|
|
import time
|
|
from collections.abc import Callable, Generator
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
import httpx
|
|
import uvicorn
|
|
|
|
from fastmcp import settings
|
|
from fastmcp.client.auth.oauth import OAuth
|
|
from fastmcp.utilities.http import find_available_port
|
|
|
|
if TYPE_CHECKING:
|
|
from fastmcp.server.server import FastMCP
|
|
|
|
|
|
@contextmanager
|
|
def temporary_settings(**kwargs: Any):
|
|
"""
|
|
Temporarily override FastMCP setting values.
|
|
|
|
Args:
|
|
**kwargs: The settings to override, including nested settings.
|
|
|
|
Example:
|
|
Temporarily override a setting:
|
|
```python
|
|
import fastmcp
|
|
from fastmcp.utilities.tests import temporary_settings
|
|
|
|
with temporary_settings(log_level='DEBUG'):
|
|
assert fastmcp.settings.log_level == 'DEBUG'
|
|
assert fastmcp.settings.log_level == 'INFO'
|
|
```
|
|
"""
|
|
old_settings = copy.deepcopy(settings)
|
|
|
|
try:
|
|
# apply the new settings
|
|
for attr, value in kwargs.items():
|
|
settings.set_setting(attr, value)
|
|
yield
|
|
|
|
finally:
|
|
# restore the old settings
|
|
for attr in kwargs:
|
|
settings.set_setting(attr, old_settings.get_setting(attr))
|
|
|
|
|
|
def _run_server(mcp_server: FastMCP, transport: Literal["sse"], port: int) -> None:
|
|
# Some Starlette apps are not pickleable, so we need to create them here based on the indicated transport
|
|
if transport == "sse":
|
|
app = mcp_server.http_app(transport="sse")
|
|
else:
|
|
raise ValueError(f"Invalid transport: {transport}")
|
|
uvicorn_server = uvicorn.Server(
|
|
config=uvicorn.Config(
|
|
app=app,
|
|
host="127.0.0.1",
|
|
port=port,
|
|
log_level="error",
|
|
)
|
|
)
|
|
uvicorn_server.run()
|
|
|
|
|
|
@contextmanager
|
|
def run_server_in_process(
|
|
server_fn: Callable[..., None],
|
|
*args,
|
|
provide_host_and_port: bool = True,
|
|
**kwargs,
|
|
) -> Generator[str, None, None]:
|
|
"""
|
|
Context manager that runs a FastMCP server in a separate process and
|
|
returns the server URL. When the context manager is exited, the server process is killed.
|
|
|
|
Args:
|
|
server_fn: The function that runs a FastMCP server. FastMCP servers are
|
|
not pickleable, so we need a function that creates and runs one.
|
|
*args: Arguments to pass to the server function.
|
|
provide_host_and_port: Whether to provide the host and port to the server function as kwargs.
|
|
**kwargs: Keyword arguments to pass to the server function.
|
|
|
|
Returns:
|
|
The server URL.
|
|
"""
|
|
host = "127.0.0.1"
|
|
port = find_available_port()
|
|
|
|
if provide_host_and_port:
|
|
kwargs |= {"host": host, "port": port}
|
|
|
|
proc = multiprocessing.Process(
|
|
target=server_fn, args=args, kwargs=kwargs, daemon=True
|
|
)
|
|
proc.start()
|
|
|
|
# Wait for server to be running
|
|
max_attempts = 10
|
|
attempt = 0
|
|
while attempt < max_attempts and proc.is_alive():
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect((host, port))
|
|
break
|
|
except ConnectionRefusedError:
|
|
if attempt < 3:
|
|
time.sleep(0.01)
|
|
else:
|
|
time.sleep(0.1)
|
|
attempt += 1
|
|
else:
|
|
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
|
|
|
|
yield f"http://{host}:{port}"
|
|
|
|
proc.terminate()
|
|
proc.join(timeout=5)
|
|
if proc.is_alive():
|
|
# If it's still alive, then force kill it
|
|
proc.kill()
|
|
proc.join(timeout=2)
|
|
if proc.is_alive():
|
|
raise RuntimeError("Server process failed to terminate even after kill")
|
|
|
|
|
|
@contextmanager
|
|
def caplog_for_fastmcp(caplog):
|
|
"""Context manager to capture logs from FastMCP loggers even when propagation is disabled."""
|
|
caplog.clear()
|
|
logger = logging.getLogger("FastMCP")
|
|
logger.addHandler(caplog.handler)
|
|
try:
|
|
yield
|
|
finally:
|
|
logger.removeHandler(caplog.handler)
|
|
|
|
|
|
class HeadlessOAuth(OAuth):
|
|
"""
|
|
OAuth provider that bypasses browser interaction for testing.
|
|
|
|
This simulates the complete OAuth flow programmatically by making HTTP requests
|
|
instead of opening a browser and running a callback server. Useful for automated testing.
|
|
"""
|
|
|
|
def __init__(self, mcp_url: str, **kwargs):
|
|
"""Initialize HeadlessOAuth with stored response tracking."""
|
|
self._stored_response = None
|
|
super().__init__(mcp_url, **kwargs)
|
|
|
|
async def redirect_handler(self, authorization_url: str) -> None:
|
|
"""Make HTTP request to authorization URL and store response for callback handler."""
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(authorization_url, follow_redirects=False)
|
|
self._stored_response = response
|
|
|
|
async def callback_handler(self) -> tuple[str, str | None]:
|
|
"""Parse stored response and return (auth_code, state)."""
|
|
if not self._stored_response:
|
|
raise RuntimeError(
|
|
"No authorization response stored. redirect_handler must be called first."
|
|
)
|
|
|
|
response = self._stored_response
|
|
|
|
# Extract auth code from redirect location
|
|
if response.status_code == 302:
|
|
redirect_url = response.headers["location"]
|
|
parsed = urlparse(redirect_url)
|
|
query_params = parse_qs(parsed.query)
|
|
|
|
if "error" in query_params:
|
|
error = query_params["error"][0]
|
|
error_desc = query_params.get("error_description", ["Unknown error"])[0]
|
|
raise RuntimeError(
|
|
f"OAuth authorization failed: {error} - {error_desc}"
|
|
)
|
|
|
|
auth_code = query_params["code"][0]
|
|
state = query_params.get("state", [None])[0]
|
|
return auth_code, state
|
|
else:
|
|
raise RuntimeError(f"Authorization failed: {response.status_code}")
|