59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Generic
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from mcp.shared.context import LifespanContextT, RequestContext
|
|
from mcp.shared.session import (
|
|
BaseSession,
|
|
ReceiveNotificationT,
|
|
ReceiveRequestT,
|
|
SendNotificationT,
|
|
SendRequestT,
|
|
SendResultT,
|
|
)
|
|
from mcp.types import ProgressToken
|
|
|
|
|
|
class Progress(BaseModel):
|
|
progress: float
|
|
total: float | None
|
|
|
|
|
|
@dataclass
|
|
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
|
|
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
|
|
progress_token: ProgressToken
|
|
total: float | None
|
|
current: float = field(default=0.0, init=False)
|
|
|
|
async def progress(self, amount: float, message: str | None = None) -> None:
|
|
self.current += amount
|
|
|
|
await self.session.send_progress_notification(
|
|
self.progress_token, self.current, total=self.total, message=message
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def progress(
|
|
ctx: RequestContext[
|
|
BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
|
LifespanContextT,
|
|
],
|
|
total: float | None = None,
|
|
) -> Generator[
|
|
ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
|
None,
|
|
]:
|
|
if ctx.meta is None or ctx.meta.progressToken is None:
|
|
raise ValueError("No progress token provided")
|
|
|
|
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
|
try:
|
|
yield progress_ctx
|
|
finally:
|
|
pass
|