|
22 | 22 | _outbound_metadata, |
23 | 23 | _Pending, |
24 | 24 | ) |
25 | | -from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage |
| 25 | +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage |
26 | 26 | from mcp.shared.transport_context import TransportContext |
27 | 27 | from mcp.types import ( |
28 | 28 | CONNECTION_CLOSED, |
@@ -274,6 +274,77 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> |
274 | 274 | s.close() |
275 | 275 |
|
276 | 276 |
|
| 277 | +@pytest.mark.anyio |
| 278 | +async def test_ctx_message_metadata_carries_inbound_request_metadata(): |
| 279 | + """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" |
| 280 | + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) |
| 281 | + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) |
| 282 | + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) |
| 283 | + metadata = ServerMessageMetadata(request_context="request-scoped-data") |
| 284 | + seen: list[MessageMetadata] = [] |
| 285 | + |
| 286 | + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: |
| 287 | + seen.append(ctx.message_metadata) |
| 288 | + return {} |
| 289 | + |
| 290 | + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: |
| 291 | + raise NotImplementedError |
| 292 | + |
| 293 | + try: |
| 294 | + async with anyio.create_task_group() as tg: |
| 295 | + await tg.start(server.run, on_request, on_notify) |
| 296 | + await c2s_send.send( |
| 297 | + SessionMessage( |
| 298 | + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params=None), |
| 299 | + metadata=metadata, |
| 300 | + ) |
| 301 | + ) |
| 302 | + with anyio.fail_after(5): |
| 303 | + await s2c_recv.receive() # response sent ⇒ the handler has run |
| 304 | + tg.cancel_scope.cancel() |
| 305 | + finally: |
| 306 | + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): |
| 307 | + s.close() |
| 308 | + assert len(seen) == 1 |
| 309 | + assert seen[0] is metadata # the exact object, passed through verbatim |
| 310 | + |
| 311 | + |
| 312 | +@pytest.mark.anyio |
| 313 | +async def test_ctx_message_metadata_carries_inbound_notification_metadata(): |
| 314 | + """Notifications get the same metadata pass-through as requests.""" |
| 315 | + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) |
| 316 | + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) |
| 317 | + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) |
| 318 | + metadata = ServerMessageMetadata(request_context="request-scoped-data") |
| 319 | + seen: list[MessageMetadata] = [] |
| 320 | + notified = anyio.Event() |
| 321 | + |
| 322 | + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: |
| 323 | + raise NotImplementedError |
| 324 | + |
| 325 | + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: |
| 326 | + seen.append(ctx.message_metadata) |
| 327 | + notified.set() |
| 328 | + |
| 329 | + try: |
| 330 | + async with anyio.create_task_group() as tg: |
| 331 | + await tg.start(server.run, on_request, on_notify) |
| 332 | + await c2s_send.send( |
| 333 | + SessionMessage( |
| 334 | + message=JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params=None), |
| 335 | + metadata=metadata, |
| 336 | + ) |
| 337 | + ) |
| 338 | + with anyio.fail_after(5): |
| 339 | + await notified.wait() |
| 340 | + tg.cancel_scope.cancel() |
| 341 | + finally: |
| 342 | + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): |
| 343 | + s.close() |
| 344 | + assert len(seen) == 1 |
| 345 | + assert seen[0] is metadata |
| 346 | + |
| 347 | + |
277 | 348 | @pytest.mark.anyio |
278 | 349 | async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): |
279 | 350 | received: list[tuple[float, float | None, str | None]] = [] |
|
0 commit comments