Skip to content

Commit ca9e87f

Browse files
committed
Add courtesy-cancel controls and bound shielded dispatcher writes
- New CallOptions key cancel_on_abandon (default true): abandoning a request (timeout or caller cancellation) sends notifications/cancelled unless the caller opted out or the request carries resumption hints - Bound the two shielded cancellation-path writes with a 5s deadline so a wedged transport write cannot hang shutdown or a cancelled caller - Capitalize the connection-closed fan-out message ("Connection closed") - Pin the server-seat timeout contract in the interaction suite: a timed-out server-initiated request is followed by notifications/cancelled
1 parent 1e21814 commit ca9e87f

5 files changed

Lines changed: 309 additions & 12 deletions

File tree

src/mcp/shared/dispatcher.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ class CallOptions(TypedDict, total=False):
5555
timeout: float
5656
"""Seconds to wait for a result before raising and sending `notifications/cancelled`."""
5757

58+
cancel_on_abandon: bool
59+
"""Whether abandoning this request sends `notifications/cancelled` to the peer.
60+
61+
A request is abandoned when its `timeout` elapses or the caller's scope is
62+
cancelled while awaiting the response. Defaults to `True`. Set `False` for
63+
requests the protocol forbids cancelling, such as `initialize`. The
64+
notification is also suppressed when resumption hints are present: the
65+
caller intends to resume the request, so the peer's work must keep running.
66+
"""
67+
5868
on_progress: ProgressFnT
5969
"""Receive `notifications/progress` updates for this request."""
6070

@@ -97,8 +107,8 @@ async def send_raw_request(
97107
) -> dict[str, Any]:
98108
"""Send a request and await its raw result dict.
99109
100-
`opts` carries per-call `timeout` / `on_progress` / resumption hints;
101-
see `CallOptions`.
110+
`opts` carries per-call `timeout` / `on_progress` / abandon-cancellation
111+
/ resumption hints; see `CallOptions`.
102112
103113
Raises:
104114
MCPError: If the peer responded with an error, or the handler

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@
6363

6464
logger = logging.getLogger(__name__)
6565

66+
_SHIELDED_WRITE_TIMEOUT: float = 5
67+
"""Bound for the shielded courtesy writes on the cancellation paths.
68+
69+
Those writes run inside a shield because the surrounding scope is already
70+
cancelled; without a bound, a wedged transport write would turn the shield
71+
into an uncancellable hang (and block shutdown indefinitely)."""
72+
6673
TransportT = TypeVar("TransportT", bound=TransportContext)
6774

6875
PeerCancelMode = Literal["interrupt", "signal"]
@@ -323,6 +330,18 @@ async def send_raw_request(
323330
pending = _Pending(send=send, receive=receive, on_progress=on_progress)
324331
self._pending[request_id] = pending
325332

333+
# An abandoned request (timeout elapsed, or the caller's scope was
334+
# cancelled while awaiting the response) sends a courtesy
335+
# `notifications/cancelled` so the peer can stop work - unless the
336+
# caller opted out (`initialize`, which the spec forbids cancelling),
337+
# or the request carries resumption hints (the caller intends to
338+
# resume it, so the peer's work must keep running).
339+
cancel_on_abandon = (
340+
opts.get("cancel_on_abandon", True)
341+
and opts.get("resumption_token") is None
342+
and opts.get("on_resumption_token") is None
343+
)
344+
326345
metadata = _outbound_metadata(_related_request_id, opts)
327346
target = out_params.get("name")
328347
span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}"
@@ -348,14 +367,16 @@ async def send_raw_request(
348367
# Spec-recommended courtesy: tell the peer we've given up so it can
349368
# stop work and free resources. v1's BaseSession.send_request does
350369
# NOT do this; it's new behaviour.
351-
await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id)
370+
if cancel_on_abandon:
371+
await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id)
352372
raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None
353373
except anyio.get_cancelled_exc_class():
354374
# Our caller's scope was cancelled. We're already inside a cancelled
355-
# scope, so any bare `await` here re-raises immediately - shield to
356-
# let the courtesy cancel notification go out before we propagate.
357-
with anyio.CancelScope(shield=True):
358-
await self._cancel_outbound(request_id, "caller cancelled", _related_request_id)
375+
# scope, so any bare `await` here re-raises immediately - shield
376+
# (bounded) to let the courtesy cancel go out before we propagate.
377+
if cancel_on_abandon:
378+
with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True):
379+
await self._cancel_outbound(request_id, "caller cancelled", _related_request_id)
359380
raise
360381
finally:
361382
# Always remove the waiter, even on cancel/timeout, so a late
@@ -635,7 +656,7 @@ def _fan_out_closed(self) -> None:
635656
Synchronous (uses `send_nowait`) because it's called from `finally`
636657
which may be inside a cancelled scope. Idempotent.
637658
"""
638-
closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed")
659+
closed = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
639660
for pending in self._pending.values():
640661
try:
641662
pending.send.send_nowait(closed)
@@ -681,8 +702,9 @@ async def _handle_request(
681702
await self._write_error(req.id, ErrorData(code=0, message="Request cancelled"))
682703
except anyio.get_cancelled_exc_class():
683704
# Outer-cancel: run()'s task group is shutting down. Any bare
684-
# `await` here re-raises immediately, so shield the courtesy write.
685-
with anyio.CancelScope(shield=True):
705+
# `await` here re-raises immediately, so shield (bounded) the
706+
# courtesy write.
707+
with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True):
686708
await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled"))
687709
raise
688710
except MCPError as e:

tests/interaction/_requirements.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,9 @@ def __post_init__(self) -> None:
468468
),
469469
divergence=Divergence(
470470
note=(
471-
"The client only raises locally and sends nothing on timeout, so the server keeps running the handler."
471+
"Client seat only: the client raises locally and sends nothing on timeout, so the server keeps "
472+
"running the handler. The server seat conforms: a timed-out server-initiated request is followed "
473+
"by notifications/cancelled on the wire."
472474
),
473475
),
474476
),

tests/interaction/lowlevel/test_timeouts.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
from trio.testing import MockClock
1414

1515
from mcp import MCPError, types
16+
from mcp.client import ClientRequestContext
17+
from mcp.client._memory import InMemoryTransport
1618
from mcp.client.client import Client
1719
from mcp.server import Server, ServerRequestContext
18-
from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent
20+
from mcp.shared.message import SessionMessage
21+
from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, JSONRPCNotification, TextContent
22+
from tests.interaction._helpers import RecordingTransport
1923
from tests.interaction._requirements import requirement
2024

2125
pytestmark = pytest.mark.anyio
@@ -56,6 +60,68 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
5660
)
5761

5862

63+
@requirement("protocol:timeout:basic")
64+
@requirement("protocol:timeout:sends-cancellation")
65+
async def test_server_request_timeout_sends_cancellation_to_the_client() -> None:
66+
"""A server-initiated request that times out fails server-side and cancels the client's work.
67+
68+
The server seat conforms to the spec's timeout guidance: the handler's timed-out sampling
69+
request is followed by notifications/cancelled on the wire. The client's sampling callback
70+
blocks until the server has already given up, then answers; the late response is discarded
71+
and the tool call still completes.
72+
"""
73+
release = anyio.Event()
74+
callback_started = anyio.Event()
75+
errors: list[ErrorData] = []
76+
77+
async def list_tools(
78+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
79+
) -> types.ListToolsResult:
80+
return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})])
81+
82+
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
83+
assert params.name == "impatient"
84+
request = types.CreateMessageRequest(
85+
params=types.CreateMessageRequestParams(
86+
messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))],
87+
max_tokens=8,
88+
)
89+
)
90+
with pytest.raises(MCPError) as exc_info:
91+
await ctx.session.send_request(request, types.CreateMessageResult, request_read_timeout_seconds=0.000001)
92+
errors.append(exc_info.value.error)
93+
release.set()
94+
return CallToolResult(content=[TextContent(text="gave up")])
95+
96+
server = Server("impatient", on_list_tools=list_tools, on_call_tool=call_tool)
97+
recording = RecordingTransport(InMemoryTransport(server))
98+
99+
async def sampling_callback(
100+
context: ClientRequestContext, params: types.CreateMessageRequestParams
101+
) -> types.CreateMessageResult:
102+
callback_started.set()
103+
await release.wait()
104+
return types.CreateMessageResult(role="assistant", content=TextContent(text="too late"), model="test-model")
105+
106+
async with Client(recording, sampling_callback=sampling_callback) as client:
107+
result = await client.call_tool("impatient", {})
108+
109+
assert result == snapshot(CallToolResult(content=[TextContent(text="gave up")]))
110+
assert callback_started.is_set()
111+
assert errors == snapshot([ErrorData(code=REQUEST_TIMEOUT, message="Request 'sampling/createMessage' timed out")])
112+
cancellations = [
113+
item.message
114+
for item in recording.received
115+
if isinstance(item, SessionMessage)
116+
and isinstance(item.message, JSONRPCNotification)
117+
and item.message.method == "notifications/cancelled"
118+
]
119+
# The cancel names the sampling request (the server's first outbound request) and the reason.
120+
assert [notification.params for notification in cancellations] == snapshot(
121+
[{"requestId": 1, "reason": "timed out after 1e-06s"}]
122+
)
123+
124+
59125
@requirement("protocol:timeout:session-survives")
60126
async def test_session_serves_requests_after_timeout() -> None:
61127
"""A timed-out request does not poison the session: the next request succeeds."""

0 commit comments

Comments
 (0)