11from __future__ import annotations
22
33import logging
4+ from collections .abc import Mapping
5+ from types import TracebackType
46from typing import Any , Protocol
57
8+ import anyio
9+ import anyio .abc
610import anyio .lowlevel
7- from pydantic import TypeAdapter
11+ from pydantic import BaseModel , TypeAdapter , ValidationError
12+ from typing_extensions import Self , TypeVar
813
914from mcp import types
1015from mcp .client ._transport import ReadStream , WriteStream
16+ from mcp .shared ._compat import resync_tracer
1117from mcp .shared ._context import RequestContext
12- from mcp .shared .message import SessionMessage
13- from mcp .shared .session import BaseSession , ProgressFnT , RequestResponder
18+ from mcp .shared .dispatcher import CallOptions , DispatchContext , Dispatcher
19+ from mcp .shared .exceptions import MCPError
20+ from mcp .shared .jsonrpc_dispatcher import JSONRPCDispatcher
21+ from mcp .shared .message import ClientMessageMetadata , MessageMetadata , ServerMessageMetadata , SessionMessage
22+ from mcp .shared .session import ProgressFnT , RequestResponder
23+ from mcp .shared .transport_context import TransportContext
1424from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
1525from mcp .types ._types import RequestParamsMeta
1626
1727DEFAULT_CLIENT_INFO = types .Implementation (name = "mcp" , version = "0.1.0" )
1828
1929logger = logging .getLogger ("client" )
2030
31+ ReceiveResultT = TypeVar ("ReceiveResultT" , bound = BaseModel )
32+
2133
2234class SamplingFnT (Protocol ):
2335 async def __call__ (
@@ -96,15 +108,16 @@ async def _default_logging_callback(
96108ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (types .ClientResult | types .ErrorData )
97109
98110
99- class ClientSession (
100- BaseSession [
101- types .ClientRequest ,
102- types .ClientNotification ,
103- types .ClientResult ,
104- types .ServerRequest ,
105- types .ServerNotification ,
106- ]
107- ):
111+ class ClientSession :
112+ """Client half of an MCP connection, running on `JSONRPCDispatcher`.
113+
114+ Construct it over a transport's stream pair, enter it as an async context
115+ manager, then call `initialize()`. The receive loop, request correlation,
116+ and per-request concurrency live in the dispatcher; this class owns the
117+ MCP type layer: typed requests, the initialize handshake, and routing
118+ server-initiated traffic to the constructor callbacks.
119+ """
120+
108121 def __init__ (
109122 self ,
110123 read_stream : ReadStream [SessionMessage | Exception ],
@@ -119,7 +132,70 @@ def __init__(
119132 * ,
120133 sampling_capabilities : types .SamplingCapability | None = None ,
121134 ) -> None :
122- super ().__init__ (read_stream , write_stream , read_timeout_seconds = read_timeout_seconds )
135+ self ._init_state (
136+ read_timeout_seconds = read_timeout_seconds ,
137+ sampling_callback = sampling_callback ,
138+ elicitation_callback = elicitation_callback ,
139+ list_roots_callback = list_roots_callback ,
140+ logging_callback = logging_callback ,
141+ message_handler = message_handler ,
142+ client_info = client_info ,
143+ sampling_capabilities = sampling_capabilities ,
144+ )
145+ # Built here (inert until run() starts in __aenter__) so notifications
146+ # can be sent before entering the context manager, as before.
147+ self ._dispatcher : Dispatcher [Any ] = JSONRPCDispatcher (
148+ read_stream , write_stream , on_stream_exception = self ._on_stream_exception
149+ )
150+
151+ @classmethod
152+ def from_dispatcher (
153+ cls ,
154+ dispatcher : Dispatcher [Any ],
155+ * ,
156+ read_timeout_seconds : float | None = None ,
157+ sampling_callback : SamplingFnT | None = None ,
158+ elicitation_callback : ElicitationFnT | None = None ,
159+ list_roots_callback : ListRootsFnT | None = None ,
160+ logging_callback : LoggingFnT | None = None ,
161+ message_handler : MessageHandlerFnT | None = None ,
162+ client_info : types .Implementation | None = None ,
163+ sampling_capabilities : types .SamplingCapability | None = None ,
164+ ) -> Self :
165+ """Build a session over a pre-built dispatcher instead of a stream pair.
166+
167+ For embedding a server in-process (`DirectDispatcher`) or transports
168+ that construct their own dispatcher. Transport-level `Exception` items
169+ reach `message_handler` only on the stream constructor, where the
170+ session wires the dispatcher's `on_stream_exception` itself.
171+ """
172+ self = cls .__new__ (cls )
173+ self ._init_state (
174+ read_timeout_seconds = read_timeout_seconds ,
175+ sampling_callback = sampling_callback ,
176+ elicitation_callback = elicitation_callback ,
177+ list_roots_callback = list_roots_callback ,
178+ logging_callback = logging_callback ,
179+ message_handler = message_handler ,
180+ client_info = client_info ,
181+ sampling_capabilities = sampling_capabilities ,
182+ )
183+ self ._dispatcher = dispatcher
184+ return self
185+
186+ def _init_state (
187+ self ,
188+ * ,
189+ read_timeout_seconds : float | None ,
190+ sampling_callback : SamplingFnT | None ,
191+ elicitation_callback : ElicitationFnT | None ,
192+ list_roots_callback : ListRootsFnT | None ,
193+ logging_callback : LoggingFnT | None ,
194+ message_handler : MessageHandlerFnT | None ,
195+ client_info : types .Implementation | None ,
196+ sampling_capabilities : types .SamplingCapability | None ,
197+ ) -> None :
198+ self ._session_read_timeout_seconds = read_timeout_seconds
123199 self ._client_info = client_info or DEFAULT_CLIENT_INFO
124200 self ._sampling_callback = sampling_callback or _default_sampling_callback
125201 self ._sampling_capabilities = sampling_capabilities
@@ -129,14 +205,90 @@ def __init__(
129205 self ._message_handler = message_handler or _default_message_handler
130206 self ._tool_output_schemas : dict [str , dict [str , Any ] | None ] = {}
131207 self ._initialize_result : types .InitializeResult | None = None
208+ self ._task_group : anyio .abc .TaskGroup | None = None
132209
133- @property
134- def _receive_request_adapter (self ) -> TypeAdapter [types .ServerRequest ]:
135- return types .server_request_adapter
210+ async def __aenter__ (self ) -> Self :
211+ self ._task_group = anyio .create_task_group ()
212+ await self ._task_group .__aenter__ ()
213+ await self ._task_group .start (self ._dispatcher .run , self ._on_request , self ._on_notify )
214+ return self
136215
137- @property
138- def _receive_notification_adapter (self ) -> TypeAdapter [types .ServerNotification ]:
139- return types .server_notification_adapter
216+ async def __aexit__ (
217+ self ,
218+ exc_type : type [BaseException ] | None ,
219+ exc_val : BaseException | None ,
220+ exc_tb : TracebackType | None ,
221+ ) -> bool | None :
222+ # Exit must not block: cancel the dispatcher and any in-flight
223+ # callbacks rather than waiting for them.
224+ assert self ._task_group is not None
225+ self ._task_group .cancel_scope .cancel ()
226+ result = await self ._task_group .__aexit__ (exc_type , exc_val , exc_tb )
227+ await resync_tracer ()
228+ return result
229+
230+ async def send_request (
231+ self ,
232+ request : types .ClientRequest ,
233+ result_type : type [ReceiveResultT ],
234+ request_read_timeout_seconds : float | None = None ,
235+ metadata : MessageMetadata = None ,
236+ progress_callback : ProgressFnT | None = None ,
237+ ) -> ReceiveResultT :
238+ """Send a request and wait for its typed result.
239+
240+ A per-request read timeout takes precedence over the session-level
241+ one. `metadata` carries transport hints: `ClientMessageMetadata`
242+ resumption fields (streamable HTTP), or a
243+ `ServerMessageMetadata.related_request_id` to route the message onto
244+ an originating request's stream.
245+
246+ Raises:
247+ MCPError: The server responded with an error, or the read timeout
248+ elapsed, or the connection closed while waiting.
249+ RuntimeError: Called before entering the context manager.
250+ """
251+ data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
252+ method : str = data ["method" ]
253+ opts : CallOptions = {}
254+ timeout = request_read_timeout_seconds or self ._session_read_timeout_seconds
255+ if timeout is not None :
256+ opts ["timeout" ] = timeout
257+ if progress_callback is not None :
258+ opts ["on_progress" ] = progress_callback
259+ related_request_id : types .RequestId | None = None
260+ if isinstance (metadata , ClientMessageMetadata ):
261+ if metadata .resumption_token is not None :
262+ opts ["resumption_token" ] = metadata .resumption_token
263+ if metadata .on_resumption_token_update is not None :
264+ opts ["on_resumption_token" ] = metadata .on_resumption_token_update
265+ elif isinstance (metadata , ServerMessageMetadata ):
266+ related_request_id = metadata .related_request_id
267+ if method == "initialize" :
268+ # The spec forbids cancelling initialize; opt out of the
269+ # dispatcher's courtesy cancel-on-abandon.
270+ opts ["cancel_on_abandon" ] = False
271+ if related_request_id is not None and isinstance (self ._dispatcher , JSONRPCDispatcher ):
272+ # Related-request routing is JSON-RPC stream plumbing; other
273+ # dispatchers have no per-request streams to route onto.
274+ raw = await self ._dispatcher .send_raw_request (
275+ method , data .get ("params" ), opts , _related_request_id = related_request_id
276+ )
277+ else :
278+ raw = await self ._dispatcher .send_raw_request (method , data .get ("params" ), opts )
279+ return result_type .model_validate (raw , by_name = False )
280+
281+ async def send_notification (
282+ self ,
283+ notification : types .ClientNotification ,
284+ related_request_id : types .RequestId | None = None ,
285+ ) -> None :
286+ """Send a one-way notification. Usable before entering the context manager."""
287+ data = notification .model_dump (by_alias = True , mode = "json" , exclude_none = True )
288+ if related_request_id and isinstance (self ._dispatcher , JSONRPCDispatcher ):
289+ await self ._dispatcher .notify (data ["method" ], data .get ("params" ), _related_request_id = related_request_id )
290+ else :
291+ await self ._dispatcher .notify (data ["method" ], data .get ("params" ))
140292
141293 async def initialize (self ) -> types .InitializeResult :
142294 sampling = (
@@ -385,49 +537,59 @@ async def send_roots_list_changed(self) -> None:
385537 """Send a roots/list_changed notification."""
386538 await self .send_notification (types .RootsListChangedNotification ())
387539
388- async def _received_request (self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]) -> None :
389- ctx = RequestContext [ClientSession ](request_id = responder .request_id , meta = responder .request_meta , session = self )
390-
391- match responder .request :
392- case types .CreateMessageRequest (params = params ):
393- with responder :
394- response = await self ._sampling_callback (ctx , params )
395- client_response = ClientResponse .validate_python (response )
396- await responder .respond (client_response )
540+ async def _on_request (
541+ self , dctx : DispatchContext [TransportContext ], method : str , params : Mapping [str , Any ] | None
542+ ) -> dict [str , Any ]:
543+ """Answer a server-initiated request via the registered callbacks.
397544
398- case types .ElicitRequest (params = params ):
399- with responder :
400- response = await self ._elicitation_callback (ctx , params )
401- client_response = ClientResponse .validate_python (response )
402- await responder .respond (client_response )
545+ Validation failures (unknown method or malformed params) raise
546+ `ValidationError`, which the dispatcher answers with INVALID_PARAMS;
547+ an `ErrorData` returned by a callback becomes the error response.
548+ """
549+ payload : dict [str , Any ] = {"method" : method }
550+ if params is not None :
551+ payload ["params" ] = dict (params )
552+ request = types .server_request_adapter .validate_python (payload , by_name = False )
403553
554+ ctx = RequestContext [ClientSession ](
555+ request_id = dctx .request_id , meta = request .params .meta if request .params else None , session = self
556+ )
557+ response : types .ClientResult | types .ErrorData
558+ match request :
559+ case types .CreateMessageRequest (params = sampling_params ):
560+ response = await self ._sampling_callback (ctx , sampling_params )
561+ case types .ElicitRequest (params = elicit_params ):
562+ response = await self ._elicitation_callback (ctx , elicit_params )
404563 case types .ListRootsRequest ():
405- with responder :
406- response = await self ._list_roots_callback (ctx )
407- client_response = ClientResponse .validate_python (response )
408- await responder .respond (client_response )
409-
564+ response = await self ._list_roots_callback (ctx )
410565 case types .PingRequest (): # pragma: no branch
411- with responder :
412- await responder .respond (types .EmptyResult ())
413-
414- async def _handle_incoming (
415- self ,
416- req : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
566+ response = types .EmptyResult ()
567+ client_response = ClientResponse .validate_python (response )
568+ if isinstance (client_response , types .ErrorData ):
569+ raise MCPError .from_error_data (client_response )
570+ return client_response .model_dump (by_alias = True , mode = "json" , exclude_none = True )
571+
572+ async def _on_notify (
573+ self , dctx : DispatchContext [TransportContext ], method : str , params : Mapping [str , Any ] | None
417574 ) -> None :
418- """Handle incoming messages by forwarding to the message handler."""
419- await self ._message_handler (req )
420-
421- async def _received_notification (self , notification : types .ServerNotification ) -> None :
422- """Handle notifications from the server."""
423- # Process specific notification types
424- match notification :
425- case types .LoggingMessageNotification (params = params ):
426- await self ._logging_callback (params )
427- case types .ElicitCompleteNotification (params = params ):
428- # Handle elicitation completion notification
429- # Clients MAY use this to retry requests or update UI
430- # The notification contains the elicitationId of the completed elicitation
431- pass
432- case _:
433- pass
575+ """Route a server notification: validate, run the typed callback, tee to message_handler."""
576+ payload : dict [str , Any ] = {"method" : method }
577+ if params is not None :
578+ payload ["params" ] = dict (params )
579+ try :
580+ notification = types .server_notification_adapter .validate_python (payload , by_name = False )
581+ except ValidationError :
582+ logger .warning ("Failed to validate notification: %s" , payload , exc_info = True )
583+ return
584+ if isinstance (notification , types .CancelledNotification ):
585+ # The dispatcher already applied the cancellation to the in-flight
586+ # request; message_handler never sees it, so handlers matching
587+ # exhaustively over ServerNotification need no arm for it.
588+ return
589+ if isinstance (notification , types .LoggingMessageNotification ):
590+ await self ._logging_callback (notification .params )
591+ await self ._message_handler (notification )
592+
593+ async def _on_stream_exception (self , exc : Exception ) -> None :
594+ """Forward transport-level faults (connection errors, parse errors) to message_handler."""
595+ await self ._message_handler (exc )
0 commit comments