Skip to content

Commit 8cd0a64

Browse files
committed
Move ClientSession onto JSONRPCDispatcher and delete BaseSession
ClientSession keeps its public surface (constructor, typed methods, manual initialize, context-manager lifecycle) but now owns a JSONRPCDispatcher instead of inheriting the v1 BaseSession receive loop. Server-initiated requests are answered through the existing callbacks via the closed-union parse; notifications validate-or-drop and tee to message_handler; transport exceptions reach message_handler through the dispatcher's stream-exception observer. A from_dispatcher constructor accepts a pre-built dispatcher for in-process embedding. mcp.shared.session shrinks to the surviving names: the ProgressFnT re-export and a typing-only RequestResponder stub for MessageHandlerFnT annotations. Behavior changes (deliberate, to be covered in the migration guide): - request ids count from 1; the progress token follows - timeouts use the dispatcher error text and send notifications/cancelled, so a timed-out server handler is interrupted instead of running on - responses with unknown ids are ignored per spec instead of surfacing a RuntimeError to message_handler - a raising request callback is answered with code 0 and the exception text - notification callbacks run concurrently (no completion-before-response) Three interaction-suite divergence entries are resolved and deleted, and the server-to-client cancellation requirement is now pinned by a passing test.
1 parent ab46e5d commit 8cd0a64

14 files changed

Lines changed: 438 additions & 1132 deletions

File tree

src/mcp/client/session.py

Lines changed: 228 additions & 64 deletions
Large diffs are not rendered by default.

src/mcp/shared/_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
"""Request context for MCP handlers."""
1+
"""Request context for MCP client handlers."""
22

33
from dataclasses import dataclass
44
from typing import Any, Generic
55

66
from typing_extensions import TypeVar
77

8-
from mcp.shared.session import BaseSession
98
from mcp.types import RequestId, RequestParamsMeta
109

11-
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
10+
SessionT = TypeVar("SessionT", default=Any)
1211

1312

1413
@dataclass(kw_only=True)

src/mcp/shared/session.py

Lines changed: 20 additions & 475 deletions
Large diffs are not rendered by default.

tests/client/test_resource_cleanup.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

tests/client/test_session.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mcp.types import (
1818
INVALID_PARAMS,
1919
LATEST_PROTOCOL_VERSION,
20+
METHOD_NOT_FOUND,
2021
CallToolResult,
2122
Implementation,
2223
InitializedNotification,
@@ -751,8 +752,34 @@ async def test_receive_loop_answers_malformed_inbound_request_with_invalid_param
751752

752753

753754
@pytest.mark.anyio
754-
async def test_receive_loop_answers_invalid_params_when_sampling_callback_raises():
755-
"""Same boundary catches exceptions from the request handler itself."""
755+
async def test_receive_loop_answers_unknown_request_method_with_method_not_found():
756+
"""A server request whose method is not in the ServerRequest union gets -32601
757+
(METHOD_NOT_FOUND) on the wire, not a validation failure (-32602)."""
758+
async with raw_client_session() as (_session, to_client, from_client):
759+
await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown")))
760+
out = await from_client.receive()
761+
assert isinstance(out.message, JSONRPCError)
762+
assert out.message.id == 7
763+
assert out.message.error == types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown")
764+
765+
766+
@pytest.mark.anyio
767+
async def test_receive_loop_drops_unknown_notification_method_without_response():
768+
"""An unknown notification method is dropped silently: JSON-RPC forbids
769+
responses to notifications, and the receive loop keeps serving."""
770+
async with raw_client_session() as (_session, to_client, from_client):
771+
await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="x/unknown")))
772+
# The next wire output must be the answer to this follow-up ping,
773+
# proving the notification produced no response and the loop survived.
774+
await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")))
775+
out = await from_client.receive()
776+
assert isinstance(out.message, JSONRPCResponse)
777+
assert out.message.id == 1
778+
779+
780+
@pytest.mark.anyio
781+
async def test_raising_sampling_callback_answers_with_code_zero():
782+
"""A raising request callback is answered through the dispatcher's exception boundary."""
756783

757784
async def boom(ctx: object, params: object) -> types.CreateMessageResult:
758785
raise RuntimeError("sampling boom")
@@ -767,7 +794,7 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult:
767794
)
768795
out = await from_client.receive()
769796
assert isinstance(out.message, JSONRPCError)
770-
assert out.message.error.code == INVALID_PARAMS
797+
assert out.message.error == types.ErrorData(code=0, message="sampling boom")
771798

772799

773800
@pytest.mark.anyio
@@ -841,23 +868,68 @@ async def handler(msg: object) -> None:
841868

842869

843870
@pytest.mark.anyio
844-
async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture):
871+
async def test_progress_callback_exception_is_swallowed(caplog: pytest.LogCaptureFixture):
845872
delivered = anyio.Event()
846873

847874
async def boom(progress: float, total: float | None, message: str | None) -> None:
848875
raise RuntimeError("progress boom")
849876

850877
async def handler(msg: object) -> None:
851-
delivered.set()
878+
if isinstance(msg, types.ProgressNotification):
879+
delivered.set()
880+
881+
async with raw_client_session(message_handler=handler) as (session, to_client, from_client):
882+
async with anyio.create_task_group() as tg:
883+
884+
async def call() -> None:
885+
await session.send_request(types.PingRequest(), types.EmptyResult, progress_callback=boom)
886+
887+
tg.start_soon(call)
888+
request = await from_client.receive()
889+
assert isinstance(request.message, JSONRPCRequest)
890+
# The request id doubles as the progress token.
891+
params = {"progressToken": request.message.id, "progress": 0.5}
892+
await to_client.send(
893+
SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params))
894+
)
895+
# The progress notification also reaches the message handler; the
896+
# raising callback was swallowed and logged.
897+
await delivered.wait()
898+
await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={})))
899+
assert "progress callback raised" in caplog.text
852900

853-
async with raw_client_session(message_handler=handler) as (session, to_client, _):
854-
# Register the callback under a known token without sending a request.
855-
session._progress_callbacks[42] = boom # pyright: ignore[reportPrivateUsage]
856-
params = {"progressToken": 42, "progress": 0.5}
857-
await to_client.send(
858-
SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params))
859-
)
860-
# The progress notification also reaches the message handler after the
861-
# callback runs, so this fires once the callback's exception is handled.
862-
await delivered.wait()
863-
assert "Progress callback raised an exception" in caplog.text
901+
902+
@pytest.mark.anyio
903+
async def test_from_dispatcher_runs_over_direct_dispatch():
904+
"""A session built with from_dispatcher works without a stream pair (in-process embedding)."""
905+
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
906+
from mcp.shared.dispatcher import DispatchContext
907+
from mcp.shared.transport_context import TransportContext
908+
909+
client_side, server_side = create_direct_dispatcher_pair()
910+
911+
async def server_on_request(
912+
ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None
913+
) -> dict[str, object]:
914+
assert method == "ping"
915+
return {}
916+
917+
notified: list[str] = []
918+
919+
async def server_on_notify(
920+
ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None
921+
) -> None:
922+
notified.append(method)
923+
924+
session = ClientSession.from_dispatcher(client_side)
925+
results: list[types.EmptyResult] = []
926+
async with anyio.create_task_group() as tg:
927+
await tg.start(server_side.run, server_on_request, server_on_notify)
928+
async with session:
929+
results.append(await session.send_ping(meta=None))
930+
# related_request_id routing is JSON-RPC plumbing; on other
931+
# dispatchers the notification is sent without it.
932+
await session.send_notification(types.RootsListChangedNotification(), related_request_id=7)
933+
server_side.close()
934+
assert results == [types.EmptyResult()]
935+
assert notified == ["notifications/roots/list_changed"]

tests/interaction/_requirements.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,6 @@ def __post_init__(self) -> None:
287287
"A response that arrives after the sender issued notifications/cancelled is ignored; the "
288288
"request stays failed and no error is raised."
289289
),
290-
divergence=Divergence(
291-
note=(
292-
"A response whose id matches no in-flight request is delivered to the message handler "
293-
"as a RuntimeError rather than being silently ignored. The post-cancellation case is the "
294-
"same code path; tested in its unknown-id form because that is deterministic without the "
295-
"client-side cancellation API the SDK does not yet provide."
296-
),
297-
),
298290
),
299291
"protocol:cancel:server-survives": Requirement(
300292
source="sdk",
@@ -306,19 +298,6 @@ def __post_init__(self) -> None:
306298
"A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) "
307299
"cancels it, and the client stops processing the cancelled request."
308300
),
309-
divergence=Divergence(
310-
note=(
311-
"Abandoning a server-side send_request emits no cancellation notification, and the client "
312-
"could not act on one anyway: client callbacks run inline in the receive loop, so a "
313-
"cancellation is not even read until the callback has finished."
314-
),
315-
),
316-
deferred=(
317-
"Not implemented in the SDK: abandoning a server-side send_request emits no cancellation "
318-
"notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and "
319-
"the client could not act on one anyway because client callbacks run inline in the receive "
320-
"loop, so a cancellation would not even be read until the callback had already finished."
321-
),
322301
),
323302
"protocol:cancel:unknown-id-ignored": Requirement(
324303
source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling",
@@ -466,13 +445,6 @@ def __post_init__(self) -> None:
466445
"When a request times out, the sender issues notifications/cancelled for that request before "
467446
"failing the local call."
468447
),
469-
divergence=Divergence(
470-
note=(
471-
"Client seat only: the client raises locally and sends nothing on timeout, so the server keeps "
472-
"running the handler. The server seat conforms: a timed-out server-initiated request is followed "
473-
"by notifications/cancelled on the wire."
474-
),
475-
),
476448
),
477449
"protocol:timeout:session-survives": Requirement(
478450
source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts",

tests/interaction/lowlevel/test_cancellation.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from inline_snapshot import snapshot
1212

1313
from mcp import MCPError, types
14-
from mcp.client import ClientSession
14+
from mcp.client import ClientRequestContext, ClientSession
1515
from mcp.server import Server, ServerRequestContext
1616
from mcp.shared.memory import MessageStream, create_client_server_memory_streams
1717
from mcp.shared.message import SessionMessage
@@ -155,14 +155,70 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
155155
assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")]))
156156

157157

158+
@requirement("protocol:cancel:server-to-client")
159+
async def test_abandoned_server_request_cancels_the_client_callback(connect: Connect) -> None:
160+
"""A server that abandons a sampling request cancels it, interrupting the client's callback.
161+
162+
The handler gives up on its sampling request by cancelling the scope around it; the courtesy
163+
notifications/cancelled that follows interrupts the client's sampling callback mid-await.
164+
"""
165+
callback_started = anyio.Event()
166+
callback_cancelled = anyio.Event()
167+
168+
async def sampling_callback(
169+
context: ClientRequestContext, params: types.CreateMessageRequestParams
170+
) -> types.CreateMessageResult:
171+
callback_started.set()
172+
try:
173+
await anyio.Event().wait() # blocks until the cancellation interrupts it
174+
except anyio.get_cancelled_exc_class():
175+
callback_cancelled.set()
176+
raise
177+
raise NotImplementedError # unreachable
178+
179+
async def list_tools(
180+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
181+
) -> types.ListToolsResult:
182+
return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})])
183+
184+
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
185+
assert params.name == "impatient"
186+
request = types.CreateMessageRequest(
187+
params=types.CreateMessageRequestParams(
188+
messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))],
189+
max_tokens=8,
190+
)
191+
)
192+
async with anyio.create_task_group() as abandon_scope:
193+
194+
async def sample() -> None:
195+
await ctx.session.send_request(request, types.CreateMessageResult)
196+
raise NotImplementedError # unreachable: the scope is cancelled
197+
198+
abandon_scope.start_soon(sample)
199+
await callback_started.wait()
200+
abandon_scope.cancel_scope.cancel()
201+
with anyio.fail_after(5):
202+
await callback_cancelled.wait()
203+
return CallToolResult(content=[TextContent(text="abandoned")])
204+
205+
server = Server("abandoner", on_list_tools=list_tools, on_call_tool=call_tool)
206+
207+
async with connect(server, sampling_callback=sampling_callback) as client:
208+
result = await client.call_tool("impatient", {})
209+
210+
assert result == snapshot(CallToolResult(content=[TextContent(text="abandoned")]))
211+
assert callback_cancelled.is_set()
212+
213+
158214
@requirement("protocol:cancel:late-response-ignored")
159-
async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None:
160-
"""A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError.
215+
async def test_a_response_for_an_unknown_request_id_is_ignored() -> None:
216+
"""A response whose id matches no in-flight request is ignored, as the spec asks.
161217
162218
The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation;
163219
that is the same client-side code path as any response with an unknown id, and that form is
164-
deterministic to test without depending on the cancellation API the SDK does not yet provide.
165-
See the divergence note on the requirement.
220+
deterministic to test without depending on a client-side cancellation API. Nothing reaches
221+
the message handler and the session keeps serving.
166222
167223
A real Server cannot be made to answer with a fabricated id, so the test plays the server's
168224
side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The
@@ -228,7 +284,6 @@ async def message_handler(message: IncomingMessage) -> None:
228284
pong = await session.send_request(PingRequest(), EmptyResult)
229285

230286
assert pong == snapshot(EmptyResult())
231-
assert len(incoming) == 1
232-
assert isinstance(incoming[0], RuntimeError)
233-
# The full message embeds the response object's repr; only the prefix is stable.
234-
assert str(incoming[0]).startswith("Received response with an unknown request ID:")
287+
# The fabricated response was dropped silently: the ping after it still
288+
# round-tripped, and nothing was surfaced to the message handler.
289+
assert incoming == []

tests/interaction/lowlevel/test_logging.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Logging interactions against the low-level Server, driven through the public Client API.
22
33
Notification ordering: the in-memory transport delivers every server-to-client message on one
4-
ordered stream, and the client's receive loop dispatches each incoming message to completion
5-
before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds
6-
only for messages that carry a ``related_request_id`` (they ride the originating request's POST
7-
stream); without it the message routes to the standalone GET stream and may arrive after the
8-
response. These tests pass ``related_request_id`` so they can collect into a plain list and
9-
assert after the request completes on every transport leg -- no events, no waiting.
4+
ordered stream, and the client starts notification callbacks in arrival order. Callbacks run
5+
concurrently with the rest of the session (no completion-before-response guarantee), but a
6+
callback with no internal awaits runs to completion as soon as it starts, which keeps
7+
plain-list collection deterministic here. Over streamable HTTP the ordered single-stream
8+
guarantee holds only for messages that carry a ``related_request_id`` (they ride the
9+
originating request's POST stream); without it the message routes to the standalone GET stream
10+
and may arrive after the response. These tests pass ``related_request_id`` and use await-free
11+
callbacks so they can collect into a plain list and assert after the request completes on
12+
every transport leg -- no events, no waiting.
1013
"""
1114

1215
import pytest

tests/interaction/lowlevel/test_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N
8787
async with connect(server) as client:
8888
result = await client.call_tool("inspect", {}, progress_callback=ignore)
8989

90-
# The token is the request id of the tools/call request itself (initialize is request 0).
91-
assert result == snapshot(CallToolResult(content=[TextContent(text="1")]))
90+
# The token is the request id of the tools/call request itself (initialize is request 1).
91+
assert result == snapshot(CallToolResult(content=[TextContent(text="2")]))
9292

9393

9494
@requirement("protocol:progress:no-token")

0 commit comments

Comments
 (0)