diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index dda241035..20de2b8a4 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -155,6 +155,7 @@ def __init__( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None self._task_group: anyio.abc.TaskGroup | None = None + self._entered = False if dispatcher is not None: if read_stream is not None or write_stream is not None: raise ValueError("pass read_stream/write_stream or dispatcher, not both") @@ -168,6 +169,8 @@ def __init__( ) async def __aenter__(self) -> Self: + if self._entered: + raise RuntimeError("Session is already running") self._task_group = anyio.create_task_group() await self._task_group.__aenter__() try: @@ -184,6 +187,7 @@ async def __aenter__(self) -> Self: task_group.cancel_scope.shield = True await task_group.__aexit__(None, None, None) raise + self._entered = True return self async def __aexit__( @@ -195,9 +199,12 @@ async def __aexit__( # Exit must not block: cancel the dispatcher and in-flight callbacks. assert self._task_group is not None self._task_group.cancel_scope.cancel() - result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - await resync_tracer() - return result + try: + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + await resync_tracer() + return result + finally: + self._entered = False async def send_request( self, @@ -216,6 +223,11 @@ async def send_request( MCPError: Error response, read timeout, or connection closed. RuntimeError: Called before entering the context manager. """ + if self._task_group is None: + raise RuntimeError( + "Session is not running. Use it as an async context manager " + "(e.g. `async with ClientSession(...) as session:`)." + ) data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] opts: CallOptions = {} diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 0f68a066f..91b23e0b3 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -149,6 +149,43 @@ async def message_handler( # pragma: no cover assert isinstance(initialized_notification, InitializedNotification) +@pytest.mark.anyio +async def test_client_session_requires_context_manager(): + client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + _server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + _client_to_server_receive, + _server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + + with pytest.raises(RuntimeError, match="async context manager"): + await session.initialize() + + +@pytest.mark.anyio +async def test_client_session_reentry_raises_runtime_error(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + await session.__aenter__() + try: + with pytest.raises(RuntimeError, match="already running"): + await session.__aenter__() + finally: + await session.__aexit__(None, None, None) + + @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -1252,12 +1289,12 @@ async def server_on_notify( @pytest.mark.anyio async def test_dispatcher_keyword_send_request_before_enter_raises_runtimeerror(): - """The documented pre-enter RuntimeError holds for dispatcher= sessions too.""" + """The documented pre-enter RuntimeError holds before any dispatcher call.""" client_side, _server_side = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_side) with anyio.fail_after(5), pytest.raises(RuntimeError) as exc: await session.send_ping() - assert str(exc.value) == "DirectDispatcher.send_raw_request called before run()" + assert "async context manager" in str(exc.value) @pytest.mark.anyio