902 lines
37 KiB
Python
902 lines
37 KiB
Python
"""
|
|
StreamableHTTP Server Transport Module
|
|
|
|
This module implements an HTTP transport layer with Streamable HTTP.
|
|
|
|
The transport handles bidirectional communication using HTTP requests and
|
|
responses, with streaming support for long-running operations.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from http import HTTPStatus
|
|
|
|
import anyio
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from pydantic import ValidationError
|
|
from sse_starlette import EventSourceResponse
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.types import Receive, Scope, Send
|
|
|
|
from mcp.server.transport_security import (
|
|
TransportSecurityMiddleware,
|
|
TransportSecuritySettings,
|
|
)
|
|
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
|
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
|
from mcp.types import (
|
|
DEFAULT_NEGOTIATED_VERSION,
|
|
INTERNAL_ERROR,
|
|
INVALID_PARAMS,
|
|
INVALID_REQUEST,
|
|
PARSE_ERROR,
|
|
ErrorData,
|
|
JSONRPCError,
|
|
JSONRPCMessage,
|
|
JSONRPCRequest,
|
|
JSONRPCResponse,
|
|
RequestId,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Header names
|
|
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
|
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
|
|
LAST_EVENT_ID_HEADER = "last-event-id"
|
|
|
|
# Content types
|
|
CONTENT_TYPE_JSON = "application/json"
|
|
CONTENT_TYPE_SSE = "text/event-stream"
|
|
|
|
# Special key for the standalone GET stream
|
|
GET_STREAM_KEY = "_GET_stream"
|
|
|
|
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
|
|
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
|
|
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
|
|
|
|
# Type aliases
|
|
StreamId = str
|
|
EventId = str
|
|
|
|
|
|
@dataclass
|
|
class EventMessage:
|
|
"""
|
|
A JSONRPCMessage with an optional event ID for stream resumability.
|
|
"""
|
|
|
|
message: JSONRPCMessage
|
|
event_id: str | None = None
|
|
|
|
|
|
EventCallback = Callable[[EventMessage], Awaitable[None]]
|
|
|
|
|
|
class EventStore(ABC):
|
|
"""
|
|
Interface for resumability support via event storage.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
|
|
"""
|
|
Stores an event for later retrieval.
|
|
|
|
Args:
|
|
stream_id: ID of the stream the event belongs to
|
|
message: The JSON-RPC message to store
|
|
|
|
Returns:
|
|
The generated event ID for the stored event
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def replay_events_after(
|
|
self,
|
|
last_event_id: EventId,
|
|
send_callback: EventCallback,
|
|
) -> StreamId | None:
|
|
"""
|
|
Replays events that occurred after the specified event ID.
|
|
|
|
Args:
|
|
last_event_id: The ID of the last event the client received
|
|
send_callback: A callback function to send events to the client
|
|
|
|
Returns:
|
|
The stream ID of the replayed events
|
|
"""
|
|
pass
|
|
|
|
|
|
class StreamableHTTPServerTransport:
|
|
"""
|
|
HTTP server transport with event streaming support for MCP.
|
|
|
|
Handles JSON-RPC messages in HTTP POST requests with SSE streaming.
|
|
Supports optional JSON responses and session management.
|
|
"""
|
|
|
|
# Server notification streams for POST requests as well as standalone SSE stream
|
|
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None
|
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
|
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
|
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
|
_security: TransportSecurityMiddleware
|
|
|
|
def __init__(
|
|
self,
|
|
mcp_session_id: str | None,
|
|
is_json_response_enabled: bool = False,
|
|
event_store: EventStore | None = None,
|
|
security_settings: TransportSecuritySettings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize a new StreamableHTTP server transport.
|
|
|
|
Args:
|
|
mcp_session_id: Optional session identifier for this connection.
|
|
Must contain only visible ASCII characters (0x21-0x7E).
|
|
is_json_response_enabled: If True, return JSON responses for requests
|
|
instead of SSE streams. Default is False.
|
|
event_store: Event store for resumability support. If provided,
|
|
resumability will be enabled, allowing clients to
|
|
reconnect and resume messages.
|
|
security_settings: Optional security settings for DNS rebinding protection.
|
|
|
|
Raises:
|
|
ValueError: If the session ID contains invalid characters.
|
|
"""
|
|
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id):
|
|
raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)")
|
|
|
|
self.mcp_session_id = mcp_session_id
|
|
self.is_json_response_enabled = is_json_response_enabled
|
|
self._event_store = event_store
|
|
self._security = TransportSecurityMiddleware(security_settings)
|
|
self._request_streams: dict[
|
|
RequestId,
|
|
tuple[
|
|
MemoryObjectSendStream[EventMessage],
|
|
MemoryObjectReceiveStream[EventMessage],
|
|
],
|
|
] = {}
|
|
self._terminated = False
|
|
|
|
@property
|
|
def is_terminated(self) -> bool:
|
|
"""Check if this transport has been explicitly terminated."""
|
|
return self._terminated
|
|
|
|
def _create_error_response(
|
|
self,
|
|
error_message: str,
|
|
status_code: HTTPStatus,
|
|
error_code: int = INVALID_REQUEST,
|
|
headers: dict[str, str] | None = None,
|
|
) -> Response:
|
|
"""Create an error response with a simple string message."""
|
|
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
|
if headers:
|
|
response_headers.update(headers)
|
|
|
|
if self.mcp_session_id:
|
|
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
# Return a properly formatted JSON error response
|
|
error_response = JSONRPCError(
|
|
jsonrpc="2.0",
|
|
id="server-error", # We don't have a request ID for general errors
|
|
error=ErrorData(
|
|
code=error_code,
|
|
message=error_message,
|
|
),
|
|
)
|
|
|
|
return Response(
|
|
error_response.model_dump_json(by_alias=True, exclude_none=True),
|
|
status_code=status_code,
|
|
headers=response_headers,
|
|
)
|
|
|
|
def _create_json_response(
|
|
self,
|
|
response_message: JSONRPCMessage | None,
|
|
status_code: HTTPStatus = HTTPStatus.OK,
|
|
headers: dict[str, str] | None = None,
|
|
) -> Response:
|
|
"""Create a JSON response from a JSONRPCMessage"""
|
|
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
|
|
if headers:
|
|
response_headers.update(headers)
|
|
|
|
if self.mcp_session_id:
|
|
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
return Response(
|
|
response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None,
|
|
status_code=status_code,
|
|
headers=response_headers,
|
|
)
|
|
|
|
def _get_session_id(self, request: Request) -> str | None:
|
|
"""Extract the session ID from request headers."""
|
|
return request.headers.get(MCP_SESSION_ID_HEADER)
|
|
|
|
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
|
|
"""Create event data dictionary from an EventMessage."""
|
|
event_data = {
|
|
"event": "message",
|
|
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
|
}
|
|
|
|
# If an event ID was provided, include it
|
|
if event_message.event_id:
|
|
event_data["id"] = event_message.event_id
|
|
|
|
return event_data
|
|
|
|
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
|
|
"""Clean up memory streams for a given request ID."""
|
|
if request_id in self._request_streams:
|
|
try:
|
|
# Close the request stream
|
|
await self._request_streams[request_id][0].aclose()
|
|
await self._request_streams[request_id][1].aclose()
|
|
except Exception:
|
|
# During cleanup, we catch all exceptions since streams might be in various states
|
|
logger.debug("Error closing memory streams - may already be closed")
|
|
finally:
|
|
# Remove the request stream from the mapping
|
|
self._request_streams.pop(request_id, None)
|
|
|
|
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
"""Application entry point that handles all HTTP requests"""
|
|
request = Request(scope, receive)
|
|
|
|
# Validate request headers for DNS rebinding protection
|
|
is_post = request.method == "POST"
|
|
error_response = await self._security.validate_request(request, is_post=is_post)
|
|
if error_response:
|
|
await error_response(scope, receive, send)
|
|
return
|
|
|
|
if self._terminated:
|
|
# If the session has been terminated, return 404 Not Found
|
|
response = self._create_error_response(
|
|
"Not Found: Session has been terminated",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
if request.method == "POST":
|
|
await self._handle_post_request(scope, request, receive, send)
|
|
elif request.method == "GET":
|
|
await self._handle_get_request(request, send)
|
|
elif request.method == "DELETE":
|
|
await self._handle_delete_request(request, send)
|
|
else:
|
|
await self._handle_unsupported_request(request, send)
|
|
|
|
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
|
|
"""Check if the request accepts the required media types."""
|
|
accept_header = request.headers.get("accept", "")
|
|
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
|
|
|
has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
|
|
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
|
|
|
|
return has_json, has_sse
|
|
|
|
def _check_content_type(self, request: Request) -> bool:
|
|
"""Check if the request has the correct Content-Type."""
|
|
content_type = request.headers.get("content-type", "")
|
|
content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")]
|
|
|
|
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
|
|
|
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
|
|
"""Handle POST requests containing JSON-RPC messages."""
|
|
writer = self._read_stream_writer
|
|
if writer is None:
|
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
|
try:
|
|
# Check Accept headers
|
|
has_json, has_sse = self._check_accept_headers(request)
|
|
if not (has_json and has_sse):
|
|
response = self._create_error_response(
|
|
("Not Acceptable: Client must accept both application/json and text/event-stream"),
|
|
HTTPStatus.NOT_ACCEPTABLE,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Validate Content-Type
|
|
if not self._check_content_type(request):
|
|
response = self._create_error_response(
|
|
"Unsupported Media Type: Content-Type must be application/json",
|
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Parse the body - only read it once
|
|
body = await request.body()
|
|
|
|
try:
|
|
raw_message = json.loads(body)
|
|
except json.JSONDecodeError as e:
|
|
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
try:
|
|
message = JSONRPCMessage.model_validate(raw_message)
|
|
except ValidationError as e:
|
|
response = self._create_error_response(
|
|
f"Validation error: {str(e)}",
|
|
HTTPStatus.BAD_REQUEST,
|
|
INVALID_PARAMS,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Check if this is an initialization request
|
|
is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
|
|
|
if is_initialization_request:
|
|
# Check if the server already has an established session
|
|
if self.mcp_session_id:
|
|
# Check if request has a session ID
|
|
request_session_id = self._get_session_id(request)
|
|
|
|
# If request has a session ID but doesn't match, return 404
|
|
if request_session_id and request_session_id != self.mcp_session_id:
|
|
response = self._create_error_response(
|
|
"Not Found: Invalid or expired session ID",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
elif not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
# For notifications and responses only, return 202 Accepted
|
|
if not isinstance(message.root, JSONRPCRequest):
|
|
# Create response object and send it
|
|
response = self._create_json_response(
|
|
None,
|
|
HTTPStatus.ACCEPTED,
|
|
)
|
|
await response(scope, receive, send)
|
|
|
|
# Process the message after sending the response
|
|
metadata = ServerMessageMetadata(request_context=request)
|
|
session_message = SessionMessage(message, metadata=metadata)
|
|
await writer.send(session_message)
|
|
|
|
return
|
|
|
|
# Extract the request ID outside the try block for proper scope
|
|
request_id = str(message.root.id)
|
|
# Register this stream for the request ID
|
|
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
|
|
request_stream_reader = self._request_streams[request_id][1]
|
|
|
|
if self.is_json_response_enabled:
|
|
# Process the message
|
|
metadata = ServerMessageMetadata(request_context=request)
|
|
session_message = SessionMessage(message, metadata=metadata)
|
|
await writer.send(session_message)
|
|
try:
|
|
# Process messages from the request-specific stream
|
|
# We need to collect all messages until we get a response
|
|
response_message = None
|
|
|
|
# Use similar approach to SSE writer for consistency
|
|
async for event_message in request_stream_reader:
|
|
# If it's a response, this is what we're waiting for
|
|
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
|
|
response_message = event_message.message
|
|
break
|
|
# For notifications and request, keep waiting
|
|
else:
|
|
logger.debug(f"received: {event_message.message.root.method}")
|
|
|
|
# At this point we should have a response
|
|
if response_message:
|
|
# Create JSON response
|
|
response = self._create_json_response(response_message)
|
|
await response(scope, receive, send)
|
|
else:
|
|
# This shouldn't happen in normal operation
|
|
logger.error("No response message received before stream closed")
|
|
response = self._create_error_response(
|
|
"Error processing request: No response received",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
except Exception:
|
|
logger.exception("Error processing JSON response")
|
|
response = self._create_error_response(
|
|
"Error processing request",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
finally:
|
|
await self._clean_up_memory_streams(request_id)
|
|
else:
|
|
# Create SSE stream
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
|
|
|
async def sse_writer():
|
|
# Get the request ID from the incoming request message
|
|
try:
|
|
async with sse_stream_writer, request_stream_reader:
|
|
# Process messages from the request-specific stream
|
|
async for event_message in request_stream_reader:
|
|
# Build the event data
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
|
|
# If response, remove from pending streams and close
|
|
if isinstance(
|
|
event_message.message.root,
|
|
JSONRPCResponse | JSONRPCError,
|
|
):
|
|
break
|
|
except Exception:
|
|
logger.exception("Error in SSE writer")
|
|
finally:
|
|
logger.debug("Closing SSE writer")
|
|
await self._clean_up_memory_streams(request_id)
|
|
|
|
# Create and start EventSourceResponse
|
|
# SSE stream mode (original behavior)
|
|
# Set up headers
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
|
|
}
|
|
response = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=sse_writer,
|
|
headers=headers,
|
|
)
|
|
|
|
# Start the SSE response (this will send headers immediately)
|
|
try:
|
|
# First send the response to establish the SSE connection
|
|
async with anyio.create_task_group() as tg:
|
|
tg.start_soon(response, scope, receive, send)
|
|
# Then send the message to be processed by the server
|
|
metadata = ServerMessageMetadata(request_context=request)
|
|
session_message = SessionMessage(message, metadata=metadata)
|
|
await writer.send(session_message)
|
|
except Exception:
|
|
logger.exception("SSE response error")
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
await self._clean_up_memory_streams(request_id)
|
|
|
|
except Exception as err:
|
|
logger.exception("Error handling POST request")
|
|
response = self._create_error_response(
|
|
f"Error handling POST request: {err}",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await response(scope, receive, send)
|
|
if writer:
|
|
await writer.send(Exception(err))
|
|
return
|
|
|
|
async def _handle_get_request(self, request: Request, send: Send) -> None:
|
|
"""
|
|
Handle GET request to establish SSE.
|
|
|
|
This allows the server to communicate to the client without the client
|
|
first sending data via HTTP POST. The server can send JSON-RPC requests
|
|
and notifications on this stream.
|
|
"""
|
|
writer = self._read_stream_writer
|
|
if writer is None:
|
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
|
|
|
# Validate Accept header - must include text/event-stream
|
|
_, has_sse = self._check_accept_headers(request)
|
|
|
|
if not has_sse:
|
|
response = self._create_error_response(
|
|
"Not Acceptable: Client must accept text/event-stream",
|
|
HTTPStatus.NOT_ACCEPTABLE,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return
|
|
|
|
if not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
# Handle resumability: check for Last-Event-ID header
|
|
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
|
|
await self._replay_events(last_event_id, request, send)
|
|
return
|
|
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
}
|
|
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
# Check if we already have an active GET stream
|
|
if GET_STREAM_KEY in self._request_streams:
|
|
response = self._create_error_response(
|
|
"Conflict: Only one SSE stream is allowed per session",
|
|
HTTPStatus.CONFLICT,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return
|
|
|
|
# Create SSE stream
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
|
|
|
async def standalone_sse_writer():
|
|
try:
|
|
# Create a standalone message stream for server-initiated messages
|
|
|
|
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
|
|
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
|
|
|
async with sse_stream_writer, standalone_stream_reader:
|
|
# Process messages from the standalone stream
|
|
async for event_message in standalone_stream_reader:
|
|
# For the standalone stream, we handle:
|
|
# - JSONRPCNotification (server sends notifications to client)
|
|
# - JSONRPCRequest (server sends requests to client)
|
|
# We should NOT receive JSONRPCResponse
|
|
|
|
# Send the message via SSE
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
except Exception:
|
|
logger.exception("Error in standalone SSE writer")
|
|
finally:
|
|
logger.debug("Closing standalone SSE writer")
|
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
|
|
|
# Create and start EventSourceResponse
|
|
response = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=standalone_sse_writer,
|
|
headers=headers,
|
|
)
|
|
|
|
try:
|
|
# This will send headers immediately and establish the SSE connection
|
|
await response(request.scope, request.receive, send)
|
|
except Exception:
|
|
logger.exception("Error in standalone SSE response")
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
await self._clean_up_memory_streams(GET_STREAM_KEY)
|
|
|
|
async def _handle_delete_request(self, request: Request, send: Send) -> None:
|
|
"""Handle DELETE requests for explicit session termination."""
|
|
# Validate session ID
|
|
if not self.mcp_session_id:
|
|
# If no session ID set, return Method Not Allowed
|
|
response = self._create_error_response(
|
|
"Method Not Allowed: Session termination not supported",
|
|
HTTPStatus.METHOD_NOT_ALLOWED,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return
|
|
|
|
if not await self._validate_request_headers(request, send):
|
|
return
|
|
|
|
await self.terminate()
|
|
|
|
response = self._create_json_response(
|
|
None,
|
|
HTTPStatus.OK,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
|
|
async def terminate(self) -> None:
|
|
"""Terminate the current session, closing all streams.
|
|
|
|
Once terminated, all requests with this session ID will receive 404 Not Found.
|
|
"""
|
|
|
|
self._terminated = True
|
|
logger.info(f"Terminating session: {self.mcp_session_id}")
|
|
|
|
# We need a copy of the keys to avoid modification during iteration
|
|
request_stream_keys = list(self._request_streams.keys())
|
|
|
|
# Close all request streams asynchronously
|
|
for key in request_stream_keys:
|
|
await self._clean_up_memory_streams(key)
|
|
|
|
# Clear the request streams dictionary immediately
|
|
self._request_streams.clear()
|
|
try:
|
|
if self._read_stream_writer is not None:
|
|
await self._read_stream_writer.aclose()
|
|
if self._read_stream is not None:
|
|
await self._read_stream.aclose()
|
|
if self._write_stream_reader is not None:
|
|
await self._write_stream_reader.aclose()
|
|
if self._write_stream is not None:
|
|
await self._write_stream.aclose()
|
|
except Exception as e:
|
|
# During cleanup, we catch all exceptions since streams might be in various states
|
|
logger.debug(f"Error closing streams: {e}")
|
|
|
|
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
|
|
"""Handle unsupported HTTP methods."""
|
|
headers = {
|
|
"Content-Type": CONTENT_TYPE_JSON,
|
|
"Allow": "GET, POST, DELETE",
|
|
}
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
response = self._create_error_response(
|
|
"Method Not Allowed",
|
|
HTTPStatus.METHOD_NOT_ALLOWED,
|
|
headers=headers,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
|
|
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
|
|
if not await self._validate_session(request, send):
|
|
return False
|
|
if not await self._validate_protocol_version(request, send):
|
|
return False
|
|
return True
|
|
|
|
async def _validate_session(self, request: Request, send: Send) -> bool:
|
|
"""Validate the session ID in the request."""
|
|
if not self.mcp_session_id:
|
|
# If we're not using session IDs, return True
|
|
return True
|
|
|
|
# Get the session ID from the request headers
|
|
request_session_id = self._get_session_id(request)
|
|
|
|
# If no session ID provided but required, return error
|
|
if not request_session_id:
|
|
response = self._create_error_response(
|
|
"Bad Request: Missing session ID",
|
|
HTTPStatus.BAD_REQUEST,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return False
|
|
|
|
# If session ID doesn't match, return error
|
|
if request_session_id != self.mcp_session_id:
|
|
response = self._create_error_response(
|
|
"Not Found: Invalid or expired session ID",
|
|
HTTPStatus.NOT_FOUND,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
|
|
"""Validate the protocol version header in the request."""
|
|
# Get the protocol version from the request headers
|
|
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
|
|
|
|
# If no protocol version provided, assume default version
|
|
if protocol_version is None:
|
|
protocol_version = DEFAULT_NEGOTIATED_VERSION
|
|
|
|
# Check if the protocol version is supported
|
|
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
|
|
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
|
|
response = self._create_error_response(
|
|
f"Bad Request: Unsupported protocol version: {protocol_version}. "
|
|
+ f"Supported versions: {supported_versions}",
|
|
HTTPStatus.BAD_REQUEST,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
|
"""
|
|
Replays events that would have been sent after the specified event ID.
|
|
Only used when resumability is enabled.
|
|
"""
|
|
event_store = self._event_store
|
|
if not event_store:
|
|
return
|
|
|
|
try:
|
|
headers = {
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": CONTENT_TYPE_SSE,
|
|
}
|
|
|
|
if self.mcp_session_id:
|
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
|
|
|
# Create SSE stream for replay
|
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
|
|
|
async def replay_sender():
|
|
try:
|
|
async with sse_stream_writer:
|
|
# Define an async callback for sending events
|
|
async def send_event(event_message: EventMessage) -> None:
|
|
event_data = self._create_event_data(event_message)
|
|
await sse_stream_writer.send(event_data)
|
|
|
|
# Replay past events and get the stream ID
|
|
stream_id = await event_store.replay_events_after(last_event_id, send_event)
|
|
|
|
# If stream ID not in mapping, create it
|
|
if stream_id and stream_id not in self._request_streams:
|
|
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
|
|
msg_reader = self._request_streams[stream_id][1]
|
|
|
|
# Forward messages to SSE
|
|
async with msg_reader:
|
|
async for event_message in msg_reader:
|
|
event_data = self._create_event_data(event_message)
|
|
|
|
await sse_stream_writer.send(event_data)
|
|
except Exception:
|
|
logger.exception("Error in replay sender")
|
|
|
|
# Create and start EventSourceResponse
|
|
response = EventSourceResponse(
|
|
content=sse_stream_reader,
|
|
data_sender_callable=replay_sender,
|
|
headers=headers,
|
|
)
|
|
|
|
try:
|
|
await response(request.scope, request.receive, send)
|
|
except Exception:
|
|
logger.exception("Error in replay response")
|
|
finally:
|
|
await sse_stream_writer.aclose()
|
|
await sse_stream_reader.aclose()
|
|
|
|
except Exception:
|
|
logger.exception("Error replaying events")
|
|
response = self._create_error_response(
|
|
"Error replaying events",
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
INTERNAL_ERROR,
|
|
)
|
|
await response(request.scope, request.receive, send)
|
|
|
|
@asynccontextmanager
|
|
async def connect(
|
|
self,
|
|
) -> AsyncGenerator[
|
|
tuple[
|
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
MemoryObjectSendStream[SessionMessage],
|
|
],
|
|
None,
|
|
]:
|
|
"""Context manager that provides read and write streams for a connection.
|
|
|
|
Yields:
|
|
Tuple of (read_stream, write_stream) for bidirectional communication
|
|
"""
|
|
|
|
# Create the memory streams for this connection
|
|
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
|
|
|
# Store the streams
|
|
self._read_stream_writer = read_stream_writer
|
|
self._read_stream = read_stream
|
|
self._write_stream_reader = write_stream_reader
|
|
self._write_stream = write_stream
|
|
|
|
# Start a task group for message routing
|
|
async with anyio.create_task_group() as tg:
|
|
# Create a message router that distributes messages to request streams
|
|
async def message_router():
|
|
try:
|
|
async for session_message in write_stream_reader:
|
|
# Determine which request stream(s) should receive this message
|
|
message = session_message.message
|
|
target_request_id = None
|
|
# Check if this is a response
|
|
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
|
response_id = str(message.root.id)
|
|
# If this response is for an existing request stream,
|
|
# send it there
|
|
target_request_id = response_id
|
|
else:
|
|
# Extract related_request_id from meta if it exists
|
|
if (
|
|
session_message.metadata is not None
|
|
and isinstance(
|
|
session_message.metadata,
|
|
ServerMessageMetadata,
|
|
)
|
|
and session_message.metadata.related_request_id is not None
|
|
):
|
|
target_request_id = str(session_message.metadata.related_request_id)
|
|
|
|
request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
|
|
|
|
# Store the event if we have an event store,
|
|
# regardless of whether a client is connected
|
|
# messages will be replayed on the re-connect
|
|
event_id = None
|
|
if self._event_store:
|
|
event_id = await self._event_store.store_event(request_stream_id, message)
|
|
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
|
|
|
if request_stream_id in self._request_streams:
|
|
try:
|
|
# Send both the message and the event ID
|
|
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
|
|
except (
|
|
anyio.BrokenResourceError,
|
|
anyio.ClosedResourceError,
|
|
):
|
|
# Stream might be closed, remove from registry
|
|
self._request_streams.pop(request_stream_id, None)
|
|
else:
|
|
logging.debug(
|
|
f"""Request stream {request_stream_id} not found
|
|
for message. Still processing message as the client
|
|
might reconnect and replay."""
|
|
)
|
|
except Exception:
|
|
logger.exception("Error in message router")
|
|
|
|
# Start the message router
|
|
tg.start_soon(message_router)
|
|
|
|
try:
|
|
# Yield the streams for the caller to use
|
|
yield read_stream, write_stream
|
|
finally:
|
|
for stream_id in list(self._request_streams.keys()):
|
|
await self._clean_up_memory_streams(stream_id)
|
|
self._request_streams.clear()
|
|
|
|
# Clean up the read and write streams
|
|
try:
|
|
await read_stream_writer.aclose()
|
|
await read_stream.aclose()
|
|
await write_stream_reader.aclose()
|
|
await write_stream.aclose()
|
|
except Exception as e:
|
|
# During cleanup, we catch all exceptions since streams might be in various states
|
|
logger.debug(f"Error closing streams: {e}")
|