Skip to content

Commit 501002e

Browse files
committed
Add an on_stream_exception observer to the dispatcher
Transports yield Exception items on the read stream for connection faults and parse errors; the dispatcher debug-logged and dropped them. An optional observer now receives them (awaited in the read loop, contained so a raising observer costs the item, not the connection). Unset keeps the old behavior.
1 parent 745ed69 commit 501002e

2 files changed

Lines changed: 76 additions & 5 deletions

File tree

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def __init__(
245245
peer_cancel_mode: PeerCancelMode = "interrupt",
246246
raise_handler_exceptions: bool = False,
247247
inline_methods: frozenset[str] = frozenset(),
248+
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
248249
) -> None: ...
249250
@overload
250251
def __init__(
@@ -256,6 +257,7 @@ def __init__(
256257
peer_cancel_mode: PeerCancelMode = "interrupt",
257258
raise_handler_exceptions: bool = False,
258259
inline_methods: frozenset[str] = frozenset(),
260+
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
259261
) -> None: ...
260262
def __init__(
261263
self,
@@ -266,6 +268,7 @@ def __init__(
266268
peer_cancel_mode: PeerCancelMode = "interrupt",
267269
raise_handler_exceptions: bool = False,
268270
inline_methods: frozenset[str] = frozenset(),
271+
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
269272
) -> None:
270273
self._read_stream = read_stream
271274
self._write_stream = write_stream
@@ -287,6 +290,11 @@ def __init__(
287290
# while inline will deadlock because the parked read loop cannot dequeue
288291
# the response.
289292
self._inline_methods = inline_methods
293+
# Observer for Exception items the transport yields on the read stream
294+
# (SSE/streamable-HTTP connection faults, stdio parse errors). Without
295+
# it they are debug-logged and dropped. Awaited in the read loop and
296+
# contained: a raising observer costs the item, not the connection.
297+
self._on_stream_exception = on_stream_exception
290298

291299
self._next_id = 0
292300
self._pending: dict[RequestId, _Pending] = {}
@@ -482,13 +490,19 @@ async def _dispatch(
482490
) -> None:
483491
"""Route one inbound item.
484492
485-
Everything here is `send_nowait` or `_spawn`; the only `await` is for
486-
`inline_methods` requests, which deliberately block dequeuing until
487-
handled. Any other `await` would let one slow message head-of-line
488-
block the entire read loop.
493+
Everything here is `send_nowait` or `_spawn`; the only `await`s are
494+
`inline_methods` requests and the `on_stream_exception` observer,
495+
which deliberately block dequeuing until handled. Any other `await`
496+
would let one slow message head-of-line block the entire read loop.
489497
"""
490498
if isinstance(item, Exception):
491-
logger.debug("transport yielded exception: %r", item)
499+
if self._on_stream_exception is None:
500+
logger.debug("transport yielded exception: %r", item)
501+
return
502+
try:
503+
await self._on_stream_exception(item)
504+
except Exception:
505+
logger.exception("on_stream_exception observer raised")
492506
return
493507
metadata = item.metadata
494508
msg = item.message

tests/shared/test_jsonrpc_dispatcher.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,63 @@ async def test_transport_exception_in_read_stream_is_logged_and_dropped():
912912
s.close()
913913

914914

915+
@pytest.mark.anyio
916+
async def test_on_stream_exception_observes_transport_exceptions():
917+
"""With an observer set, Exception items reach it instead of being dropped; the loop stays healthy."""
918+
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
919+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
920+
921+
seen: list[Exception] = []
922+
923+
async def observe(exc: Exception) -> None:
924+
seen.append(exc)
925+
926+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe)
927+
on_request, on_notify = echo_handlers(Recorder())
928+
hiccup = ValueError("transport hiccup")
929+
try:
930+
async with anyio.create_task_group() as tg:
931+
await tg.start(server.run, on_request, on_notify)
932+
await c2s_send.send(hiccup)
933+
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)))
934+
with anyio.fail_after(5):
935+
resp = await s2c_recv.receive()
936+
assert isinstance(resp, SessionMessage)
937+
assert isinstance(resp.message, JSONRPCResponse)
938+
tg.cancel_scope.cancel()
939+
finally:
940+
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
941+
s.close()
942+
assert seen == [hiccup]
943+
944+
945+
@pytest.mark.anyio
946+
async def test_on_stream_exception_observer_raising_is_contained(caplog: pytest.LogCaptureFixture):
947+
"""A raising observer costs the item, not the connection: it runs in the read loop itself."""
948+
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
949+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
950+
951+
async def observe(exc: Exception) -> None:
952+
raise RuntimeError("observer boom")
953+
954+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe)
955+
on_request, on_notify = echo_handlers(Recorder())
956+
try:
957+
async with anyio.create_task_group() as tg:
958+
await tg.start(server.run, on_request, on_notify)
959+
await c2s_send.send(ValueError("transport hiccup"))
960+
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)))
961+
with anyio.fail_after(5):
962+
resp = await s2c_recv.receive()
963+
assert isinstance(resp, SessionMessage)
964+
assert isinstance(resp.message, JSONRPCResponse)
965+
tg.cancel_scope.cancel()
966+
finally:
967+
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
968+
s.close()
969+
assert "on_stream_exception observer raised" in caplog.text
970+
971+
915972
@pytest.mark.anyio
916973
async def test_progress_notification_for_unknown_token_falls_through_to_on_notify():
917974
async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec):

0 commit comments

Comments
 (0)