1616from mcp .types import (
1717 CONNECTION_CLOSED ,
1818 INVALID_PARAMS ,
19+ CallToolResult ,
1920 CancelledNotification ,
2021 ClientNotification ,
2122 ClientRequest ,
2223 ClientResult ,
2324 ErrorData ,
25+ GetOperationPayloadRequest ,
26+ GetOperationPayloadResult ,
2427 JSONRPCError ,
2528 JSONRPCMessage ,
2629 JSONRPCNotification ,
@@ -177,6 +180,7 @@ class BaseSession(
177180 _request_id : int
178181 _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
179182 _progress_callbacks : dict [RequestId , ProgressFnT ]
183+ _operation_requests : dict [str , RequestId ]
180184
181185 def __init__ (
182186 self ,
@@ -196,6 +200,7 @@ def __init__(
196200 self ._session_read_timeout_seconds = read_timeout_seconds
197201 self ._in_flight = {}
198202 self ._progress_callbacks = {}
203+ self ._operation_requests = {}
199204 self ._exit_stack = AsyncExitStack ()
200205
201206 async def __aenter__ (self ) -> Self :
@@ -251,6 +256,7 @@ async def send_request(
251256 # Store the callback for this request
252257 self ._progress_callbacks [request_id ] = progress_callback
253258
259+ pop_progress : RequestId | None = request_id
254260 try :
255261 jsonrpc_request = JSONRPCRequest (
256262 jsonrpc = "2.0" ,
@@ -285,11 +291,28 @@ async def send_request(
285291 if isinstance (response_or_error , JSONRPCError ):
286292 raise McpError (response_or_error .error )
287293 else :
288- return result_type .model_validate (response_or_error .result )
294+ result = result_type .model_validate (response_or_error .result )
295+ if isinstance (result , CallToolResult ) and result .operation is not None :
296+ # Store mapping of operation token to request ID for async operations
297+ self ._operation_requests [result .operation .token ] = request_id
298+
299+ # Don't pop the progress function if we were given one
300+ pop_progress = None
301+ elif isinstance (request , GetOperationPayloadRequest ) and isinstance (result , GetOperationPayloadResult ):
302+ # Checked request and result to ensure no error
303+ operation_token = request .params .token
304+
305+ # Pop the progress function for the original request
306+ pop_progress = self ._operation_requests [operation_token ]
307+
308+ # Pop the token mapping since we know we won't need it anymore
309+ self ._operation_requests .pop (operation_token , None )
310+ return result
289311
290312 finally :
291313 self ._response_streams .pop (request_id , None )
292- self ._progress_callbacks .pop (request_id , None )
314+ if pop_progress :
315+ self ._progress_callbacks .pop (pop_progress , None )
293316 await response_stream .aclose ()
294317 await response_stream_reader .aclose ()
295318
0 commit comments