diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..57577c662 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -147,7 +147,7 @@ def __init__( self._session_exit_stacks = {} self._component_name_hook = component_name_hook - async def __aenter__(self) -> Self: # pragma: no cover + async def __aenter__(self) -> Self: # Enter the exit stack only if we created it ourselves if self._owns_exit_stack: await self._exit_stack.__aenter__() @@ -158,7 +158,7 @@ async def __aexit__( _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None, - ) -> bool | None: # pragma: no cover + ) -> bool | None: """Closes session exit stacks and main exit stack upon completion.""" # Only close the main exit stack if we created it @@ -323,7 +323,7 @@ async def _establish_session( await self._exit_stack.enter_async_context(session_stack) return result.server_info, session - except Exception: # pragma: no cover + except Exception: # If anything during this setup fails, ensure the session-specific # stack is closed. await session_stack.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index b5950d3b5..8d84e8b1b 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -468,17 +468,33 @@ async def _handle_message(session_message: SessionMessage) -> None: read_stream_writer=read_stream_writer, ) - async def handle_request_async(): - if is_resumption: - await self._handle_resumption_request(ctx) - else: - await self._handle_post_request(ctx) - # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): + request_id = message.id + + async def handle_request_async() -> None: + try: + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + except httpx.HTTPError as exc: + logger.exception("Transport error handling request") + error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data)) + with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError): + await read_stream_writer.send(error_msg) + tg.start_soon(handle_request_async) else: - await handle_request_async() + + async def handle_notification_async() -> None: + try: + await self._handle_post_request(ctx) + except httpx.HTTPError: + logger.debug("Transport error handling notification", exc_info=True) + + await handle_notification_async() async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..0e74c036f 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -278,6 +278,36 @@ async def test_client_session_group_disconnect_non_existent_server(): await group.disconnect_from_server(session) +@pytest.mark.anyio +async def test_client_session_group_context_manager_with_provided_exit_stack(): + """Provided exit stacks are not entered or closed by the session group.""" + provided_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack) + session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack) + session = mock.Mock(spec=mcp.ClientSession) + + group = ClientSessionGroup(exit_stack=provided_stack) + group._session_exit_stacks[session] = session_stack + + assert await group.__aenter__() is group + await group.__aexit__(None, None, None) + + provided_stack.__aenter__.assert_not_awaited() + provided_stack.aclose.assert_not_awaited() + session_stack.aclose.assert_awaited_once() + + +@pytest.mark.anyio +async def test_client_session_group_context_manager_with_owned_exit_stack(): + """Owned exit stacks are entered and closed by the session group.""" + session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack) + session = mock.Mock(spec=mcp.ClientSession) + + async with ClientSessionGroup() as group: + group._session_exit_stacks[session] = session_stack + + session_stack.aclose.assert_awaited_once() + + # TODO(Marcelo): This is horrible. We should drop this test. @pytest.mark.anyio @pytest.mark.parametrize( @@ -385,3 +415,30 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +async def test_client_session_group_establish_session_closes_stack_on_initialize_error(): + group_exit_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack) + session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack) + mock_read_stream = mock.AsyncMock(name="Read") + mock_write_stream = mock.AsyncMock(name="Write") + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_session.initialize.side_effect = RuntimeError("initialize failed") + session_stack.enter_async_context.side_effect = [ + (mock_read_stream, mock_write_stream), + mock_session, + ] + + group = ClientSessionGroup(exit_stack=group_exit_stack) + + with ( + mock.patch("mcp.client.session_group.contextlib.AsyncExitStack", return_value=session_stack), + mock.patch("mcp.client.session_group.mcp.stdio_client", return_value=mock.AsyncMock()), + mock.patch("mcp.client.session_group.mcp.ClientSession", return_value=mock.AsyncMock()), + pytest.raises(RuntimeError, match="initialize failed"), + ): + await group._establish_session(StdioServerParameters(command="test"), ClientSessionParameters()) + + session_stack.aclose.assert_awaited_once() + group_exit_stack.enter_async_context.assert_not_awaited() diff --git a/tests/issues/test_915_streamable_http_unreachable.py b/tests/issues/test_915_streamable_http_unreachable.py new file mode 100644 index 000000000..80170ba9d --- /dev/null +++ b/tests/issues/test_915_streamable_http_unreachable.py @@ -0,0 +1,102 @@ +import json +from typing import cast + +import anyio +import httpx +import pytest + +from mcp import ClientSession +from mcp.client.session_group import ClientSessionGroup, StreamableHttpParameters +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.exceptions import MCPError +from mcp.types import LATEST_PROTOCOL_VERSION, RootsListChangedNotification + +pytestmark = pytest.mark.anyio + + +def _contains_cancel_scope_error(exc: BaseException) -> bool: + if isinstance(exc, RuntimeError) and "Attempted to exit cancel scope" in str(exc): + return True + + raw_grouped_exceptions = getattr(exc, "exceptions", ()) + if isinstance(raw_grouped_exceptions, tuple) and raw_grouped_exceptions: + grouped_exceptions = cast(tuple[BaseException, ...], raw_grouped_exceptions) + return any(_contains_cancel_scope_error(inner) for inner in grouped_exceptions) + + return any(_contains_cancel_scope_error(inner) for inner in (exc.__cause__, exc.__context__) if inner is not None) + + +def test_contains_cancel_scope_error_follows_exception_tree() -> None: + cancel_scope_error = RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + wrapped = RuntimeError("wrapped") + wrapped.__cause__ = cancel_scope_error + + assert _contains_cancel_scope_error(wrapped) + + +def test_contains_cancel_scope_error_follows_grouped_exceptions() -> None: + cancel_scope_error = RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + + class DummyGroup(Exception): + def __init__(self) -> None: + self.exceptions = (cancel_scope_error,) + + assert _contains_cancel_scope_error(DummyGroup()) + + +async def test_session_group_streamable_http_connect_error_is_catchable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def raise_connect_error(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("server unavailable", request=request) + + def mock_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + auth=auth, + headers=headers, + timeout=timeout, + transport=httpx.MockTransport(raise_connect_error), + ) + + monkeypatch.setattr("mcp.client.session_group.create_mcp_http_client", mock_http_client) + + async with ClientSessionGroup() as group: + with anyio.fail_after(5), pytest.raises(MCPError) as exc_info: + await group.connect_to_server(StreamableHttpParameters(url="http://example.invalid/mcp")) + + assert "Transport error: server unavailable" in exc_info.value.error.message + assert not _contains_cancel_scope_error(exc_info.value) + + +async def test_streamable_http_notification_transport_error_does_not_crash() -> None: + async def handle_request(request: httpx.Request) -> httpx.Response: + data = json.loads(request.content) + if data.get("method") == "initialize": + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + json={ + "jsonrpc": "2.0", + "id": data["id"], + "result": { + "protocolVersion": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "serverInfo": {"name": "mock-server", "version": "1.0.0"}, + }, + }, + ) + + raise httpx.ConnectError("notification failed", request=request) + + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handle_request)) as http_client, + streamable_http_client("http://example.invalid/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) + await anyio.sleep(0)