Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook

async def __aenter__(self) -> Self: # pragma: no cover
async def __aenter__(self) -> Self:
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
Expand All @@ -158,7 +158,7 @@ async def __aexit__(
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
) -> bool | None:
"""Closes session exit stacks and main exit stack upon completion."""

# Only close the main exit stack if we created it
Expand Down Expand Up @@ -323,7 +323,7 @@ async def _establish_session(
await self._exit_stack.enter_async_context(session_stack)

return result.server_info, session
except Exception: # pragma: no cover
except Exception:
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
Expand Down
30 changes: 23 additions & 7 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,17 +468,33 @@ async def _handle_message(session_message: SessionMessage) -> None:
read_stream_writer=read_stream_writer,
)

async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)

# If this is a request, start a new task to handle it
if isinstance(message, JSONRPCRequest):
request_id = message.id

async def handle_request_async() -> None:
try:
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
except httpx.HTTPError as exc:
logger.exception("Transport error handling request")
error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError):
await read_stream_writer.send(error_msg)

tg.start_soon(handle_request_async)
else:
await handle_request_async()

async def handle_notification_async() -> None:
try:
await self._handle_post_request(ctx)
except httpx.HTTPError:
logger.debug("Transport error handling notification", exc_info=True)

await handle_notification_async()

async for session_message in write_stream_reader:
sender_ctx = write_stream_reader.last_context
Expand Down
57 changes: 57 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,36 @@ async def test_client_session_group_disconnect_non_existent_server():
await group.disconnect_from_server(session)


@pytest.mark.anyio
async def test_client_session_group_context_manager_with_provided_exit_stack():
"""Provided exit stacks are not entered or closed by the session group."""
provided_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
session = mock.Mock(spec=mcp.ClientSession)

group = ClientSessionGroup(exit_stack=provided_stack)
group._session_exit_stacks[session] = session_stack

assert await group.__aenter__() is group
await group.__aexit__(None, None, None)

provided_stack.__aenter__.assert_not_awaited()
provided_stack.aclose.assert_not_awaited()
session_stack.aclose.assert_awaited_once()


@pytest.mark.anyio
async def test_client_session_group_context_manager_with_owned_exit_stack():
"""Owned exit stacks are entered and closed by the session group."""
session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
session = mock.Mock(spec=mcp.ClientSession)

async with ClientSessionGroup() as group:
group._session_exit_stacks[session] = session_stack

session_stack.aclose.assert_awaited_once()


# TODO(Marcelo): This is horrible. We should drop this test.
@pytest.mark.anyio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -385,3 +415,30 @@ async def test_client_session_group_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.server_info
assert returned_session is mock_entered_session


@pytest.mark.anyio
async def test_client_session_group_establish_session_closes_stack_on_initialize_error():
group_exit_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
mock_read_stream = mock.AsyncMock(name="Read")
mock_write_stream = mock.AsyncMock(name="Write")
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_session.initialize.side_effect = RuntimeError("initialize failed")
session_stack.enter_async_context.side_effect = [
(mock_read_stream, mock_write_stream),
mock_session,
]

group = ClientSessionGroup(exit_stack=group_exit_stack)

with (
mock.patch("mcp.client.session_group.contextlib.AsyncExitStack", return_value=session_stack),
mock.patch("mcp.client.session_group.mcp.stdio_client", return_value=mock.AsyncMock()),
mock.patch("mcp.client.session_group.mcp.ClientSession", return_value=mock.AsyncMock()),
pytest.raises(RuntimeError, match="initialize failed"),
):
await group._establish_session(StdioServerParameters(command="test"), ClientSessionParameters())

session_stack.aclose.assert_awaited_once()
group_exit_stack.enter_async_context.assert_not_awaited()
102 changes: 102 additions & 0 deletions tests/issues/test_915_streamable_http_unreachable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
from typing import cast

import anyio
import httpx
import pytest

from mcp import ClientSession
from mcp.client.session_group import ClientSessionGroup, StreamableHttpParameters
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.exceptions import MCPError
from mcp.types import LATEST_PROTOCOL_VERSION, RootsListChangedNotification

pytestmark = pytest.mark.anyio


def _contains_cancel_scope_error(exc: BaseException) -> bool:
if isinstance(exc, RuntimeError) and "Attempted to exit cancel scope" in str(exc):
return True

raw_grouped_exceptions = getattr(exc, "exceptions", ())
if isinstance(raw_grouped_exceptions, tuple) and raw_grouped_exceptions:
grouped_exceptions = cast(tuple[BaseException, ...], raw_grouped_exceptions)
return any(_contains_cancel_scope_error(inner) for inner in grouped_exceptions)

return any(_contains_cancel_scope_error(inner) for inner in (exc.__cause__, exc.__context__) if inner is not None)


def test_contains_cancel_scope_error_follows_exception_tree() -> None:
cancel_scope_error = RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
wrapped = RuntimeError("wrapped")
wrapped.__cause__ = cancel_scope_error

assert _contains_cancel_scope_error(wrapped)


def test_contains_cancel_scope_error_follows_grouped_exceptions() -> None:
cancel_scope_error = RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")

class DummyGroup(Exception):
def __init__(self) -> None:
self.exceptions = (cancel_scope_error,)

assert _contains_cancel_scope_error(DummyGroup())


async def test_session_group_streamable_http_connect_error_is_catchable(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def raise_connect_error(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("server unavailable", request=request)

def mock_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
return httpx.AsyncClient(
auth=auth,
headers=headers,
timeout=timeout,
transport=httpx.MockTransport(raise_connect_error),
)

monkeypatch.setattr("mcp.client.session_group.create_mcp_http_client", mock_http_client)

async with ClientSessionGroup() as group:
with anyio.fail_after(5), pytest.raises(MCPError) as exc_info:
await group.connect_to_server(StreamableHttpParameters(url="http://example.invalid/mcp"))

assert "Transport error: server unavailable" in exc_info.value.error.message
assert not _contains_cancel_scope_error(exc_info.value)


async def test_streamable_http_notification_transport_error_does_not_crash() -> None:
async def handle_request(request: httpx.Request) -> httpx.Response:
data = json.loads(request.content)
if data.get("method") == "initialize":
return httpx.Response(
200,
headers={"content-type": "application/json"},
json={
"jsonrpc": "2.0",
"id": data["id"],
"result": {
"protocolVersion": LATEST_PROTOCOL_VERSION,
"capabilities": {},
"serverInfo": {"name": "mock-server", "version": "1.0.0"},
},
},
)

raise httpx.ConnectError("notification failed", request=request)

async with (
httpx.AsyncClient(transport=httpx.MockTransport(handle_request)) as http_client,
streamable_http_client("http://example.invalid/mcp", http_client=http_client) as (read_stream, write_stream),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()
await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed"))
await anyio.sleep(0)
Loading