Skip to content

Commit 28d8b3e

Browse files
benjibcBenny Chen
andauthored
revert mcp gym issue (#148)
* revert mcp gym issue * also fix the type error --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent caf93cf commit 28d8b3e

File tree

3 files changed

+108
-31
lines changed

3 files changed

+108
-31
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -441,34 +441,24 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
441441
# Extract data plane results (observation only)
442442
if tool_result.content and len(tool_result.content) > 0:
443443
content = tool_result.content[0]
444-
# Safely attempt to read a "text" attribute if present across content types
445-
text_attr = getattr(content, "text", None)
446-
if isinstance(text_attr, str):
447-
content_text = text_attr
448-
elif isinstance(text_attr, list):
449-
# text can also be an array of parts with optional .text fields
450-
content_text = "".join([getattr(p, "text", "") for p in text_attr])
451-
else:
452-
content_text = None
453-
454-
if isinstance(content_text, str):
444+
if hasattr(content, "text"):
455445
# Fix: Handle empty or invalid JSON responses gracefully
456-
if content_text.strip() == "":
446+
if not content.text or content.text.strip() == "":
457447
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
458448
observation = {
459449
"observation": "empty_response",
460450
"session_id": session.session_id,
461451
}
462452
else:
463453
try:
464-
observation = json.loads(content_text)
454+
observation = json.loads(content.text)
465455
except json.JSONDecodeError as e:
466456
logger.warning(
467-
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}"
457+
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
468458
)
469459
# Create a structured response from the raw text
470460
observation = {
471-
"observation": content_text,
461+
"observation": content.text,
472462
"session_id": session.session_id,
473463
"error": "invalid_json_response",
474464
}

eval_protocol/mcp/mcpgym.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from concurrent.futures import ThreadPoolExecutor
2525
from datetime import date, datetime
2626
from enum import Enum
27-
from typing import Any, Callable, Dict, Optional, Tuple
27+
from typing import Any, Callable, Dict, Optional, Tuple, Literal, cast
2828

2929
import uvicorn
3030
from mcp.server.fastmcp import Context, FastMCP
@@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str:
146146
print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}")
147147

148148
# Use stable session ID based on client info (following simulation_server.py pattern)
149-
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
150-
client_params = ctx.session.client_params
149+
if hasattr(ctx, "session"):
150+
client_params = getattr(ctx.session, "client_params", None)
151151
print(f"🔍 _get_session_id: client_params type: {type(client_params)}")
152-
print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")
153-
154-
if hasattr(client_params, "clientInfo"):
155-
client_info = client_params.clientInfo
152+
if client_params is not None and hasattr(client_params, "clientInfo"):
153+
client_info = getattr(client_params, "clientInfo", None)
156154
print(f"🔍 _get_session_id: client_info: {client_info}")
157-
print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")
158155

159-
if client_info and hasattr(client_info, "_extra"):
160-
extra_data = client_info._extra
156+
if client_info is not None:
157+
# Access private _extra with a cast to satisfy type checker
158+
extra_data = cast(Any, getattr(client_info, "_extra", None))
161159
print(f"🔍 _get_session_id: extra_data: {extra_data}")
162160
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
163161

164-
if extra_data and isinstance(extra_data, dict):
162+
if isinstance(extra_data, dict):
165163
# use the client generated session id
166164
if "session_id" in extra_data:
167165
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
@@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str:
181179
"config": config_value,
182180
"dataset_row_id": dataset_row_id_value,
183181
"model_id": model_id_value,
184-
"name": client_info.name,
185-
"version": client_info.version,
182+
"name": getattr(client_info, "name", None),
183+
"version": getattr(client_info, "version", None),
186184
}
187185

188186
print(f"🔍 _get_session_id: stable_data: {stable_data}")
@@ -205,6 +203,15 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
205203
"""
206204
session_id = self._get_session_id(ctx)
207205
print(f"🔍 _get_or_create_session: session_id: {session_id}")
206+
if session_id not in self.sessions:
207+
env, obs, info = self._new_env(seed=None)
208+
with self.session_lock:
209+
self.sessions[session_id] = {
210+
"env": env,
211+
"obs": obs,
212+
"session_data": {},
213+
"session_id": session_id,
214+
}
208215
return self.sessions[session_id]
209216

210217
def _register_session_reset_endpoint(self):
@@ -400,6 +407,15 @@ def _execute_session_environment_step(self, session_id: str, action: Any) -> Dic
400407
Returns:
401408
Data plane response (observation only, no rewards)
402409
"""
410+
if session_id not in self.sessions:
411+
env, obs, info = self._new_env(seed=None)
412+
with self.session_lock:
413+
self.sessions[session_id] = {
414+
"env": env,
415+
"obs": obs,
416+
"session_data": {},
417+
"session_id": session_id,
418+
}
403419
session_data = self.sessions[session_id]
404420
env = session_data["env"]
405421

@@ -558,7 +574,8 @@ async def run_with_high_concurrency():
558574
if not kwargs.get("redirect_slashes", True) and hasattr(starlette_app, "router"):
559575
starlette_app.router.redirect_slashes = False
560576

561-
starlette_app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
577+
# Add middleware with proper type cast to satisfy basedpyright
578+
starlette_app.add_middleware(cast(Any, ProxyHeadersMiddleware), trusted_hosts="*")
562579

563580
config = uvicorn.Config(
564581
starlette_app,
@@ -580,7 +597,15 @@ async def run_with_high_concurrency():
580597
asyncio.run(run_with_high_concurrency())
581598
else:
582599
# Use default FastMCP run for other transports
583-
self.mcp.run(transport=transport, **kwargs)
600+
# Constrain transport to the allowed literal values for type checker
601+
allowed_transport: Literal["stdio", "sse", "streamable-http"]
602+
if transport in ("stdio", "sse", "streamable-http"):
603+
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], transport)
604+
else:
605+
# Default to streamable-http if unknown
606+
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], "streamable-http")
607+
608+
self.mcp.run(transport=allowed_transport, **kwargs)
584609

585610
def _to_json_serializable(self, obj: Any) -> Any:
586611
"""Convert any object to JSON-serializable format.
@@ -607,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
607632

608633
# Handle dataclasses
609634
elif dataclasses.is_dataclass(obj):
610-
return dataclasses.asdict(obj)
635+
# Cast for type checker because protocol uses ClassVar on __dataclass_fields__
636+
return dataclasses.asdict(cast(Any, obj))
611637

612638
# Handle dictionaries
613639
elif isinstance(obj, dict):
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Regression test: ensure MCP-Gym auto-creates a session on first tool call
3+
without requiring a prior initial state fetch, and returns JSON.
4+
"""
5+
6+
import time
7+
from multiprocessing import Process
8+
9+
import httpx
10+
import pytest
11+
12+
from eval_protocol.mcp.client.connection import MCPConnectionManager
13+
from eval_protocol.types import MCPSession
14+
15+
16+
def _run_airline_server():
17+
import os
18+
19+
os.environ["PORT"] = "9780"
20+
from eval_protocol.mcp_servers.tau2.tau2_mcp import AirlineDomainMcp
21+
22+
server = AirlineDomainMcp(seed=None)
23+
server.run(transport="streamable-http")
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_tool_call_returns_json_without_prior_initial_state():
28+
proc = Process(target=_run_airline_server, daemon=True)
29+
proc.start()
30+
31+
try:
32+
base_url = "http://127.0.0.1:9780/mcp"
33+
client = httpx.Client(timeout=1.0)
34+
deadline = time.time() + 20
35+
while time.time() < deadline:
36+
try:
37+
r = client.get(base_url)
38+
if r.status_code in (200, 307, 406):
39+
break
40+
except Exception:
41+
pass
42+
time.sleep(0.2)
43+
else:
44+
pytest.fail("Server did not start on port 9780 in time")
45+
46+
session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model")
47+
48+
mgr = MCPConnectionManager()
49+
await mgr.initialize_session(session)
50+
await mgr.discover_tools(session)
51+
52+
observation, reward, done, info = await mgr.call_tool(session, "list_all_airports", {})
53+
54+
assert isinstance(observation, dict), f"Expected JSON dict, got: {type(observation)} {observation}"
55+
assert observation.get("error") != "invalid_json_response"
56+
57+
await mgr.reset_session(session)
58+
await mgr.close_session(session)
59+
finally:
60+
proc.terminate()
61+
proc.join(timeout=5)

0 commit comments

Comments
 (0)