Skip to content

Commit 3214be5

Browse files
committed
fix(stdio): drain responses after stdin EOF
1 parent 5d82649 commit 3214be5

5 files changed

Lines changed: 140 additions & 30 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ async def run(
416416
# the next request (spec says SHOULD NOT, not MUST NOT) sees
417417
# the initialized state instead of failing the init-gate.
418418
inline_methods=frozenset({"initialize"}),
419+
close_write_stream_on_read_close=False,
419420
)
420421
runner = ServerRunner(
421422
server=self,

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import contextvars
2424
import logging
2525
from collections.abc import Awaitable, Callable, Mapping
26+
from contextlib import AsyncExitStack
2627
from dataclasses import dataclass, field
2728
from typing import Any, Generic, Literal, TypeVar, cast, overload
2829

@@ -226,6 +227,7 @@ def __init__(
226227
peer_cancel_mode: PeerCancelMode = "interrupt",
227228
raise_handler_exceptions: bool = False,
228229
inline_methods: frozenset[str] = frozenset(),
230+
close_write_stream_on_read_close: bool = True,
229231
) -> None: ...
230232
@overload
231233
def __init__(
@@ -237,6 +239,7 @@ def __init__(
237239
peer_cancel_mode: PeerCancelMode = "interrupt",
238240
raise_handler_exceptions: bool = False,
239241
inline_methods: frozenset[str] = frozenset(),
242+
close_write_stream_on_read_close: bool = True,
240243
) -> None: ...
241244
def __init__(
242245
self,
@@ -247,6 +250,7 @@ def __init__(
247250
peer_cancel_mode: PeerCancelMode = "interrupt",
248251
raise_handler_exceptions: bool = False,
249252
inline_methods: frozenset[str] = frozenset(),
253+
close_write_stream_on_read_close: bool = True,
250254
) -> None:
251255
self._read_stream = read_stream
252256
self._write_stream = write_stream
@@ -259,6 +263,7 @@ def __init__(
259263
)
260264
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
261265
self._raise_handler_exceptions = raise_handler_exceptions
266+
self._close_write_stream_on_read_close = close_write_stream_on_read_close
262267
# Request methods handled inline in the read loop (awaited before the
263268
# next message is dequeued) instead of spawned concurrently. Use for
264269
# methods whose side effects must be observable to the next message,
@@ -400,13 +405,17 @@ async def run(
400405
`await tg.start(dispatcher.run, ...)` resumes when `send_raw_request`
401406
is usable.
402407
"""
408+
normal_eof = False
403409
try:
404410
async with anyio.create_task_group() as tg:
405411
self._tg = tg
406412
self._running = True
407413
task_status.started()
408414
try:
409-
async with self._read_stream, self._write_stream:
415+
async with AsyncExitStack() as stack:
416+
await stack.enter_async_context(self._read_stream)
417+
if self._close_write_stream_on_read_close:
418+
await stack.enter_async_context(self._write_stream)
410419
try:
411420
async for item in self._read_stream:
412421
# Duck-typed: `_context_streams.ContextReceiveStream`
@@ -425,14 +434,13 @@ async def run(
425434
# (callers outside this task group) with CONNECTION_CLOSED.
426435
self._running = False
427436
self._fan_out_closed()
437+
normal_eof = True
428438
finally:
429-
# Transport closed: cancel in-flight handlers. Without this
430-
# the task-group join waits for them, and a handler that
431-
# outlives its caller (its request timed out client-side, or
432-
# the client disconnected mid-call) would keep `run()` from
433-
# returning forever. Same behaviour as `Server.run()` before
434-
# the dispatcher rework.
435-
tg.cancel_scope.cancel()
439+
if not normal_eof:
440+
# Transport closed abnormally: cancel in-flight handlers.
441+
# On normal EOF, let already-received handlers drain
442+
# their responses before the task group exits.
443+
tg.cancel_scope.cancel()
436444
finally:
437445
# Covers the cancel/crash paths where the inline fan-out above is
438446
# never reached. Idempotent.

src/mcp/shared/session.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,24 @@ def __init__(
148148
write_stream: WriteStream[SessionMessage],
149149
# If none, reading will never time out
150150
read_timeout_seconds: float | None = None,
151+
# When True, closing/EOF on the read stream closes the write stream too.
152+
#
153+
# For full-duplex transports (e.g., stdio), an input EOF can be a
154+
# half-close: the peer is done sending requests but still expects
155+
# responses on the output stream. In that case, callers may opt out so
156+
# in-flight handlers can drain their responses before shutdown.
157+
close_write_stream_on_read_close: bool = True,
151158
) -> None:
152159
self._read_stream = read_stream
153160
self._write_stream = write_stream
154161
self._response_streams = {}
155162
self._request_id = 0
156163
self._session_read_timeout_seconds = read_timeout_seconds
164+
self._close_write_stream_on_read_close = close_write_stream_on_read_close
157165
self._progress_callbacks = {}
158166
self._exit_stack = AsyncExitStack()
167+
self._exit_stack.push_async_callback(self._read_stream.aclose)
168+
self._exit_stack.push_async_callback(self._write_stream.aclose)
159169

160170
async def __aenter__(self) -> Self:
161171
self._task_group = anyio.create_task_group()
@@ -291,7 +301,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
291301
raise NotImplementedError
292302

293303
async def _receive_loop(self) -> None:
294-
async with self._read_stream, self._write_stream:
304+
async with AsyncExitStack() as stack:
305+
await stack.enter_async_context(self._read_stream)
306+
if self._close_write_stream_on_read_close:
307+
await stack.enter_async_context(self._write_stream)
295308
try:
296309

297310
async def _handle_session_message(message: SessionMessage) -> None:

tests/server/test_cancel_handling.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InitializeRequestParams,
2020
JSONRPCNotification,
2121
JSONRPCRequest,
22+
JSONRPCResponse,
2223
ListToolsResult,
2324
PaginatedRequestParams,
2425
TextContent,
@@ -100,29 +101,18 @@ async def first_request():
100101

101102

102103
@pytest.mark.anyio
103-
async def test_server_cancels_in_flight_handlers_on_transport_close():
104-
"""When the transport closes mid-request, server.run() must cancel in-flight
105-
handlers rather than join on them.
106-
107-
Without the cancel, the task group waits for the handler, which then tries
108-
to respond through a write stream that _receive_loop already closed,
109-
raising ClosedResourceError and crashing server.run() with exit code 1.
110-
111-
This drives server.run() with raw memory streams because InMemoryTransport
112-
wraps it in its own finally-cancel (_memory.py) which masks the bug.
113-
"""
104+
async def test_server_drains_in_flight_handlers_on_transport_read_eof():
105+
"""When the transport's read side hits EOF (e.g., stdio stdin closes), the
106+
server must drain already-started handlers so their responses reach the
107+
peer via the still-open write side."""
114108
handler_started = anyio.Event()
115-
handler_cancelled = anyio.Event()
109+
handler_allowed_to_finish = anyio.Event()
116110
server_run_returned = anyio.Event()
117111

118112
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
119113
handler_started.set()
120-
try:
121-
await anyio.sleep_forever()
122-
finally:
123-
handler_cancelled.set()
124-
# unreachable: sleep_forever only exits via cancellation
125-
raise AssertionError # pragma: no cover
114+
await handler_allowed_to_finish.wait()
115+
return CallToolResult(content=[TextContent(type="text", text="ok")])
126116

127117
server = Server("test", on_call_tool=handle_call_tool)
128118

@@ -167,9 +157,13 @@ async def run_server():
167157
# handler gets CancelledError, server.run() returns.
168158
await to_server.aclose()
169159

170-
await server_run_returned.wait()
160+
handler_allowed_to_finish.set()
161+
162+
response = await from_server.receive()
163+
assert isinstance(response.message, JSONRPCResponse)
164+
assert response.message.id == 2
171165

172-
assert handler_cancelled.is_set()
166+
await server_run_returned.wait()
173167

174168

175169
@pytest.mark.anyio

tests/server/test_stdio.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,28 @@
88
import anyio
99
import pytest
1010

11+
from mcp.server import Server, ServerRequestContext
1112
from mcp.server.mcpserver import MCPServer
1213
from mcp.server.stdio import stdio_server
1314
from mcp.shared.message import SessionMessage
14-
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
15+
from mcp.types import (
16+
LATEST_PROTOCOL_VERSION,
17+
CallToolRequestParams,
18+
CallToolResult,
19+
ClientCapabilities,
20+
Implementation,
21+
InitializeRequestParams,
22+
JSONRPCError,
23+
JSONRPCMessage,
24+
JSONRPCNotification,
25+
JSONRPCRequest,
26+
JSONRPCResponse,
27+
ListToolsResult,
28+
PaginatedRequestParams,
29+
TextContent,
30+
Tool,
31+
jsonrpc_message_adapter,
32+
)
1533

1634

1735
@pytest.mark.anyio
@@ -169,3 +187,79 @@ async def lifespan(server: MCPServer) -> AsyncIterator[None]:
169187
assert events == ["setup", "cleanup"]
170188
response = jsonrpc_message_adapter.validate_json(captured.getvalue().decode().strip())
171189
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})
190+
191+
192+
@pytest.mark.anyio
193+
async def test_stdio_server_drains_in_flight_responses_on_stdin_eof():
194+
"""When stdin reaches EOF (e.g., bash-redirected input), already-received
195+
requests must still be able to emit their responses on stdout."""
196+
stdin = io.StringIO()
197+
stdout = io.StringIO()
198+
199+
tool_started_count = 0
200+
both_tools_started = anyio.Event()
201+
allow_tools_to_finish = anyio.Event()
202+
203+
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
204+
return ListToolsResult(tools=[Tool(name="slow", description="test", input_schema={})])
205+
206+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
207+
nonlocal tool_started_count
208+
tool_started_count += 1
209+
if tool_started_count == 2:
210+
both_tools_started.set()
211+
await allow_tools_to_finish.wait()
212+
return CallToolResult(content=[TextContent(type="text", text="ok")])
213+
214+
server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool)
215+
216+
init_req = JSONRPCRequest(
217+
jsonrpc="2.0",
218+
id=0,
219+
method="initialize",
220+
params=InitializeRequestParams(
221+
protocol_version=LATEST_PROTOCOL_VERSION,
222+
capabilities=ClientCapabilities(),
223+
client_info=Implementation(name="test", version="1.0"),
224+
).model_dump(by_alias=True, mode="json", exclude_none=True),
225+
)
226+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
227+
call_1 = JSONRPCRequest(
228+
jsonrpc="2.0",
229+
id=1,
230+
method="tools/call",
231+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
232+
)
233+
call_2 = JSONRPCRequest(
234+
jsonrpc="2.0",
235+
id=2,
236+
method="tools/call",
237+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
238+
)
239+
240+
for message in (init_req, initialized, call_1, call_2):
241+
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
242+
stdin.seek(0)
243+
244+
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
245+
read_stream,
246+
write_stream,
247+
):
248+
with anyio.fail_after(5):
249+
async with anyio.create_task_group() as tg:
250+
tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options())
251+
await both_tools_started.wait()
252+
allow_tools_to_finish.set()
253+
254+
stdout.seek(0)
255+
ids: set[int | str] = set()
256+
for line in stdout.readlines():
257+
line = line.strip()
258+
if not line:
259+
continue
260+
message = jsonrpc_message_adapter.validate_json(line)
261+
if isinstance(message, JSONRPCResponse | JSONRPCError):
262+
assert message.id is not None
263+
ids.add(message.id)
264+
assert 1 in ids
265+
assert 2 in ids

0 commit comments

Comments
 (0)