Skip to content

Commit 79b75ec

Browse files
committed
test(client): cover session group cleanup branches
1 parent be95105 commit 79b75ec

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self._session_exit_stacks = {}
148148
self._component_name_hook = component_name_hook
149149

150-
async def __aenter__(self) -> Self: # pragma: no cover
150+
async def __aenter__(self) -> Self:
151151
# Enter the exit stack only if we created it ourselves
152152
if self._owns_exit_stack:
153153
await self._exit_stack.__aenter__()
@@ -158,7 +158,7 @@ async def __aexit__(
158158
_exc_type: type[BaseException] | None,
159159
_exc_val: BaseException | None,
160160
_exc_tb: TracebackType | None,
161-
) -> bool | None: # pragma: no cover
161+
) -> bool | None:
162162
"""Closes session exit stacks and main exit stack upon completion."""
163163

164164
# Only close the main exit stack if we created it
@@ -323,7 +323,7 @@ async def _establish_session(
323323
await self._exit_stack.enter_async_context(session_stack)
324324

325325
return result.server_info, session
326-
except Exception: # pragma: no cover
326+
except Exception:
327327
# If anything during this setup fails, ensure the session-specific
328328
# stack is closed.
329329
await session_stack.aclose()

tests/client/test_session_group.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,24 @@ async def test_client_session_group_disconnect_non_existent_server():
278278
await group.disconnect_from_server(session)
279279

280280

281+
@pytest.mark.anyio
282+
async def test_client_session_group_context_manager_with_provided_exit_stack():
283+
"""Provided exit stacks are not entered or closed by the session group."""
284+
provided_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
285+
session_stack = mock.AsyncMock(spec=contextlib.AsyncExitStack)
286+
session = mock.Mock(spec=mcp.ClientSession)
287+
288+
group = ClientSessionGroup(exit_stack=provided_stack)
289+
group._session_exit_stacks[session] = session_stack
290+
291+
assert await group.__aenter__() is group
292+
await group.__aexit__(None, None, None)
293+
294+
provided_stack.__aenter__.assert_not_awaited()
295+
provided_stack.aclose.assert_not_awaited()
296+
session_stack.aclose.assert_awaited_once()
297+
298+
281299
# TODO(Marcelo): This is horrible. We should drop this test.
282300
@pytest.mark.anyio
283301
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)