Skip to content

Commit 7d507ca

Browse files
committed
fix: fail fast when session is not started
1 parent 47bbab3 commit 7d507ca

2 files changed

Lines changed: 32 additions & 3 deletions

File tree

src/mcp/client/session.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
156156
self._initialize_result: types.InitializeResult | None = None
157157
self._task_group: anyio.abc.TaskGroup | None = None
158+
self._entered = False
158159
if dispatcher is not None:
159160
if read_stream is not None or write_stream is not None:
160161
raise ValueError("pass read_stream/write_stream or dispatcher, not both")
@@ -168,6 +169,8 @@ def __init__(
168169
)
169170

170171
async def __aenter__(self) -> Self:
172+
if self._entered:
173+
raise RuntimeError("Session is already running")
171174
self._task_group = anyio.create_task_group()
172175
await self._task_group.__aenter__()
173176
try:
@@ -184,6 +187,7 @@ async def __aenter__(self) -> Self:
184187
task_group.cancel_scope.shield = True
185188
await task_group.__aexit__(None, None, None)
186189
raise
190+
self._entered = True
187191
return self
188192

189193
async def __aexit__(
@@ -195,9 +199,12 @@ async def __aexit__(
195199
# Exit must not block: cancel the dispatcher and in-flight callbacks.
196200
assert self._task_group is not None
197201
self._task_group.cancel_scope.cancel()
198-
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
199-
await resync_tracer()
200-
return result
202+
try:
203+
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
204+
await resync_tracer()
205+
return result
206+
finally:
207+
self._entered = False
201208

202209
async def send_request(
203210
self,
@@ -216,6 +223,11 @@ async def send_request(
216223
MCPError: Error response, read timeout, or connection closed.
217224
RuntimeError: Called before entering the context manager.
218225
"""
226+
if self._task_group is None:
227+
raise RuntimeError(
228+
"Session is not running. Use it as an async context manager "
229+
"(e.g. `async with ClientSession(...) as session:`)."
230+
)
219231
data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
220232
method: str = data["method"]
221233
opts: CallOptions = {}

tests/client/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ async def message_handler( # pragma: no cover
149149
assert isinstance(initialized_notification, InitializedNotification)
150150

151151

152+
@pytest.mark.anyio
153+
async def test_client_session_requires_context_manager():
154+
client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
155+
_server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
156+
157+
async with (
158+
client_to_server_send,
159+
_client_to_server_receive,
160+
_server_to_client_send,
161+
server_to_client_receive,
162+
):
163+
session = ClientSession(server_to_client_receive, client_to_server_send)
164+
165+
with pytest.raises(RuntimeError, match="async context manager"):
166+
await session.initialize()
167+
168+
152169
@pytest.mark.anyio
153170
async def test_client_session_custom_client_info():
154171
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)

0 commit comments

Comments
 (0)