Skip to content

Commit d985c55

Browse files
committed
Move ClientSession onto JSONRPCDispatcher and delete BaseSession
ClientSession keeps its public surface (constructor, typed methods, manual initialize, context-manager lifecycle) but now owns a JSONRPCDispatcher instead of inheriting the v1 BaseSession receive loop. Server-initiated requests are answered through the existing callbacks via the closed-union parse; notifications validate-or-drop and tee to message_handler; transport exceptions reach message_handler through the dispatcher's stream-exception observer. A from_dispatcher constructor accepts a pre-built dispatcher for in-process embedding. mcp.shared.session shrinks to the surviving names: the ProgressFnT re-export and a typing-only RequestResponder stub for MessageHandlerFnT annotations. Behavior changes (deliberate, to be covered in the migration guide): - request ids count from 1; the progress token follows - timeouts use the dispatcher error text and send notifications/cancelled, so a timed-out server handler is interrupted instead of running on - responses with unknown ids are ignored per spec instead of surfacing a RuntimeError to message_handler - a raising request callback is answered with code 0 and the exception text - notification callbacks run concurrently (no completion-before-response) Three interaction-suite divergence entries are resolved and deleted, and the server-to-client cancellation requirement is now pinned by a passing test.
1 parent 2fe3682 commit d985c55

14 files changed

Lines changed: 405 additions & 1067 deletions

File tree

src/mcp/client/session.py

Lines changed: 222 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
from __future__ import annotations
22

33
import logging
4+
from collections.abc import Mapping
5+
from types import TracebackType
46
from typing import Any, Protocol
57

8+
import anyio
9+
import anyio.abc
610
import anyio.lowlevel
7-
from pydantic import TypeAdapter
11+
from pydantic import BaseModel, TypeAdapter, ValidationError
12+
from typing_extensions import Self, TypeVar
813

914
from mcp import types
1015
from mcp.client._transport import ReadStream, WriteStream
16+
from mcp.shared._compat import resync_tracer
1117
from mcp.shared._context import RequestContext
12-
from mcp.shared.message import SessionMessage
13-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
18+
from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher
19+
from mcp.shared.exceptions import MCPError
20+
from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher
21+
from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage
22+
from mcp.shared.session import ProgressFnT, RequestResponder
23+
from mcp.shared.transport_context import TransportContext
1424
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1525
from mcp.types._types import RequestParamsMeta
1626

1727
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
1828

1929
logger = logging.getLogger("client")
2030

31+
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
32+
2133

2234
class SamplingFnT(Protocol):
2335
async def __call__(
@@ -96,15 +108,16 @@ async def _default_logging_callback(
96108
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
97109

98110

99-
class ClientSession(
100-
BaseSession[
101-
types.ClientRequest,
102-
types.ClientNotification,
103-
types.ClientResult,
104-
types.ServerRequest,
105-
types.ServerNotification,
106-
]
107-
):
111+
class ClientSession:
112+
"""Client half of an MCP connection, running on `JSONRPCDispatcher`.
113+
114+
Construct it over a transport's stream pair, enter it as an async context
115+
manager, then call `initialize()`. The receive loop, request correlation,
116+
and per-request concurrency live in the dispatcher; this class owns the
117+
MCP type layer: typed requests, the initialize handshake, and routing
118+
server-initiated traffic to the constructor callbacks.
119+
"""
120+
108121
def __init__(
109122
self,
110123
read_stream: ReadStream[SessionMessage | Exception],
@@ -119,7 +132,70 @@ def __init__(
119132
*,
120133
sampling_capabilities: types.SamplingCapability | None = None,
121134
) -> None:
122-
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
135+
self._init_state(
136+
read_timeout_seconds=read_timeout_seconds,
137+
sampling_callback=sampling_callback,
138+
elicitation_callback=elicitation_callback,
139+
list_roots_callback=list_roots_callback,
140+
logging_callback=logging_callback,
141+
message_handler=message_handler,
142+
client_info=client_info,
143+
sampling_capabilities=sampling_capabilities,
144+
)
145+
# Built here (inert until run() starts in __aenter__) so notifications
146+
# can be sent before entering the context manager, as before.
147+
self._dispatcher: Dispatcher[Any] = JSONRPCDispatcher(
148+
read_stream, write_stream, on_stream_exception=self._on_stream_exception
149+
)
150+
151+
@classmethod
152+
def from_dispatcher(
153+
cls,
154+
dispatcher: Dispatcher[Any],
155+
*,
156+
read_timeout_seconds: float | None = None,
157+
sampling_callback: SamplingFnT | None = None,
158+
elicitation_callback: ElicitationFnT | None = None,
159+
list_roots_callback: ListRootsFnT | None = None,
160+
logging_callback: LoggingFnT | None = None,
161+
message_handler: MessageHandlerFnT | None = None,
162+
client_info: types.Implementation | None = None,
163+
sampling_capabilities: types.SamplingCapability | None = None,
164+
) -> Self:
165+
"""Build a session over a pre-built dispatcher instead of a stream pair.
166+
167+
For embedding a server in-process (`DirectDispatcher`) or transports
168+
that construct their own dispatcher. Transport-level `Exception` items
169+
reach `message_handler` only on the stream constructor, where the
170+
session wires the dispatcher's `on_stream_exception` itself.
171+
"""
172+
self = cls.__new__(cls)
173+
self._init_state(
174+
read_timeout_seconds=read_timeout_seconds,
175+
sampling_callback=sampling_callback,
176+
elicitation_callback=elicitation_callback,
177+
list_roots_callback=list_roots_callback,
178+
logging_callback=logging_callback,
179+
message_handler=message_handler,
180+
client_info=client_info,
181+
sampling_capabilities=sampling_capabilities,
182+
)
183+
self._dispatcher = dispatcher
184+
return self
185+
186+
def _init_state(
187+
self,
188+
*,
189+
read_timeout_seconds: float | None,
190+
sampling_callback: SamplingFnT | None,
191+
elicitation_callback: ElicitationFnT | None,
192+
list_roots_callback: ListRootsFnT | None,
193+
logging_callback: LoggingFnT | None,
194+
message_handler: MessageHandlerFnT | None,
195+
client_info: types.Implementation | None,
196+
sampling_capabilities: types.SamplingCapability | None,
197+
) -> None:
198+
self._session_read_timeout_seconds = read_timeout_seconds
123199
self._client_info = client_info or DEFAULT_CLIENT_INFO
124200
self._sampling_callback = sampling_callback or _default_sampling_callback
125201
self._sampling_capabilities = sampling_capabilities
@@ -129,14 +205,90 @@ def __init__(
129205
self._message_handler = message_handler or _default_message_handler
130206
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
131207
self._initialize_result: types.InitializeResult | None = None
208+
self._task_group: anyio.abc.TaskGroup | None = None
132209

133-
@property
134-
def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
135-
return types.server_request_adapter
210+
async def __aenter__(self) -> Self:
211+
self._task_group = anyio.create_task_group()
212+
await self._task_group.__aenter__()
213+
await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify)
214+
return self
136215

137-
@property
138-
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
139-
return types.server_notification_adapter
216+
async def __aexit__(
217+
self,
218+
exc_type: type[BaseException] | None,
219+
exc_val: BaseException | None,
220+
exc_tb: TracebackType | None,
221+
) -> bool | None:
222+
# Exit must not block: cancel the dispatcher and any in-flight
223+
# callbacks rather than waiting for them.
224+
assert self._task_group is not None
225+
self._task_group.cancel_scope.cancel()
226+
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
227+
await resync_tracer()
228+
return result
229+
230+
async def send_request(
231+
self,
232+
request: types.ClientRequest,
233+
result_type: type[ReceiveResultT],
234+
request_read_timeout_seconds: float | None = None,
235+
metadata: MessageMetadata = None,
236+
progress_callback: ProgressFnT | None = None,
237+
) -> ReceiveResultT:
238+
"""Send a request and wait for its typed result.
239+
240+
A per-request read timeout takes precedence over the session-level
241+
one. `metadata` carries transport hints: `ClientMessageMetadata`
242+
resumption fields (streamable HTTP), or a
243+
`ServerMessageMetadata.related_request_id` to route the message onto
244+
an originating request's stream.
245+
246+
Raises:
247+
MCPError: The server responded with an error, or the read timeout
248+
elapsed, or the connection closed while waiting.
249+
RuntimeError: Called before entering the context manager.
250+
"""
251+
data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
252+
method: str = data["method"]
253+
opts: CallOptions = {}
254+
timeout = request_read_timeout_seconds or self._session_read_timeout_seconds
255+
if timeout is not None:
256+
opts["timeout"] = timeout
257+
if progress_callback is not None:
258+
opts["on_progress"] = progress_callback
259+
related_request_id: types.RequestId | None = None
260+
if isinstance(metadata, ClientMessageMetadata):
261+
if metadata.resumption_token is not None:
262+
opts["resumption_token"] = metadata.resumption_token
263+
if metadata.on_resumption_token_update is not None:
264+
opts["on_resumption_token"] = metadata.on_resumption_token_update
265+
elif isinstance(metadata, ServerMessageMetadata):
266+
related_request_id = metadata.related_request_id
267+
if method == "initialize":
268+
# The spec forbids cancelling initialize; opt out of the
269+
# dispatcher's courtesy cancel-on-abandon.
270+
opts["cancel_on_abandon"] = False
271+
if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher):
272+
# Related-request routing is JSON-RPC stream plumbing; other
273+
# dispatchers have no per-request streams to route onto.
274+
raw = await self._dispatcher.send_raw_request(
275+
method, data.get("params"), opts, _related_request_id=related_request_id
276+
)
277+
else:
278+
raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts)
279+
return result_type.model_validate(raw, by_name=False)
280+
281+
async def send_notification(
282+
self,
283+
notification: types.ClientNotification,
284+
related_request_id: types.RequestId | None = None,
285+
) -> None:
286+
"""Send a one-way notification. Usable before entering the context manager."""
287+
data = notification.model_dump(by_alias=True, mode="json", exclude_none=True)
288+
if related_request_id and isinstance(self._dispatcher, JSONRPCDispatcher):
289+
await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id)
290+
else:
291+
await self._dispatcher.notify(data["method"], data.get("params"))
140292

141293
async def initialize(self) -> types.InitializeResult:
142294
sampling = (
@@ -385,49 +537,59 @@ async def send_roots_list_changed(self) -> None:
385537
"""Send a roots/list_changed notification."""
386538
await self.send_notification(types.RootsListChangedNotification())
387539

388-
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
389-
ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self)
390-
391-
match responder.request:
392-
case types.CreateMessageRequest(params=params):
393-
with responder:
394-
response = await self._sampling_callback(ctx, params)
395-
client_response = ClientResponse.validate_python(response)
396-
await responder.respond(client_response)
540+
async def _on_request(
541+
self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
542+
) -> dict[str, Any]:
543+
"""Answer a server-initiated request via the registered callbacks.
397544
398-
case types.ElicitRequest(params=params):
399-
with responder:
400-
response = await self._elicitation_callback(ctx, params)
401-
client_response = ClientResponse.validate_python(response)
402-
await responder.respond(client_response)
545+
Validation failures (unknown method or malformed params) raise
546+
`ValidationError`, which the dispatcher answers with INVALID_PARAMS;
547+
an `ErrorData` returned by a callback becomes the error response.
548+
"""
549+
payload: dict[str, Any] = {"method": method}
550+
if params is not None:
551+
payload["params"] = dict(params)
552+
request = types.server_request_adapter.validate_python(payload, by_name=False)
403553

554+
ctx = RequestContext[ClientSession](
555+
request_id=dctx.request_id, meta=request.params.meta if request.params else None, session=self
556+
)
557+
response: types.ClientResult | types.ErrorData
558+
match request:
559+
case types.CreateMessageRequest(params=sampling_params):
560+
response = await self._sampling_callback(ctx, sampling_params)
561+
case types.ElicitRequest(params=elicit_params):
562+
response = await self._elicitation_callback(ctx, elicit_params)
404563
case types.ListRootsRequest():
405-
with responder:
406-
response = await self._list_roots_callback(ctx)
407-
client_response = ClientResponse.validate_python(response)
408-
await responder.respond(client_response)
409-
564+
response = await self._list_roots_callback(ctx)
410565
case types.PingRequest(): # pragma: no branch
411-
with responder:
412-
await responder.respond(types.EmptyResult())
413-
414-
async def _handle_incoming(
415-
self,
416-
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
566+
response = types.EmptyResult()
567+
client_response = ClientResponse.validate_python(response)
568+
if isinstance(client_response, types.ErrorData):
569+
raise MCPError.from_error_data(client_response)
570+
return client_response.model_dump(by_alias=True, mode="json", exclude_none=True)
571+
572+
async def _on_notify(
573+
self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
417574
) -> None:
418-
"""Handle incoming messages by forwarding to the message handler."""
419-
await self._message_handler(req)
420-
421-
async def _received_notification(self, notification: types.ServerNotification) -> None:
422-
"""Handle notifications from the server."""
423-
# Process specific notification types
424-
match notification:
425-
case types.LoggingMessageNotification(params=params):
426-
await self._logging_callback(params)
427-
case types.ElicitCompleteNotification(params=params):
428-
# Handle elicitation completion notification
429-
# Clients MAY use this to retry requests or update UI
430-
# The notification contains the elicitationId of the completed elicitation
431-
pass
432-
case _:
433-
pass
575+
"""Route a server notification: validate, run the typed callback, tee to message_handler."""
576+
payload: dict[str, Any] = {"method": method}
577+
if params is not None:
578+
payload["params"] = dict(params)
579+
try:
580+
notification = types.server_notification_adapter.validate_python(payload, by_name=False)
581+
except ValidationError:
582+
logger.warning("Failed to validate notification: %s", payload, exc_info=True)
583+
return
584+
if isinstance(notification, types.CancelledNotification):
585+
# The dispatcher already applied the cancellation to the in-flight
586+
# request; message_handler never sees it, so handlers matching
587+
# exhaustively over ServerNotification need no arm for it.
588+
return
589+
if isinstance(notification, types.LoggingMessageNotification):
590+
await self._logging_callback(notification.params)
591+
await self._message_handler(notification)
592+
593+
async def _on_stream_exception(self, exc: Exception) -> None:
594+
"""Forward transport-level faults (connection errors, parse errors) to message_handler."""
595+
await self._message_handler(exc)

src/mcp/shared/_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
"""Request context for MCP handlers."""
1+
"""Request context for MCP client handlers."""
22

33
from dataclasses import dataclass
44
from typing import Any, Generic
55

66
from typing_extensions import TypeVar
77

8-
from mcp.shared.session import BaseSession
98
from mcp.types import RequestId, RequestParamsMeta
109

11-
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
10+
SessionT = TypeVar("SessionT", default=Any)
1211

1312

1413
@dataclass(kw_only=True)

0 commit comments

Comments
 (0)