Skip to content

Commit 16149d2

Browse files
authored
Debugging Async Slow Runs (#41)
* test * add error msg * current * MINIMAL REPRO * run on local to double check * debug * cleanup * small fix
1 parent e355931 commit 16149d2

File tree

17 files changed

+208
-245
lines changed

17 files changed

+208
-245
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import hashlib
1010
import json
1111
import logging
12+
import time
1213
from contextlib import AsyncExitStack
1314
from typing import Any, Dict, List, Optional, Tuple
1415

16+
import httpx
1517
from mcp.client.session import ClientSession
1618
from mcp.client.streamable_http import streamablehttp_client
19+
from mcp.types import Implementation
1720

1821
from ...types import MCPSession
19-
from mcp.types import Implementation
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -109,15 +111,13 @@ async def reset_session(self, session: MCPSession) -> None:
109111
"""
110112
Clean session data in remote mcp server for the given session
111113
"""
112-
import httpx
113-
114114
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
115115
url = f"{base_url}/control/reset_session"
116116

117117
headers = {"mcp-session-id": session.session_id}
118118
body = {"seed": session.seed}
119119

120-
timeout = httpx.Timeout(3.0)
120+
timeout = httpx.Timeout(15.0)
121121
async with httpx.AsyncClient(timeout=timeout) as client:
122122
resp = await client.post(url, headers=headers, json=body)
123123
resp.raise_for_status()
@@ -202,8 +202,6 @@ async def get_initial_state(self, session: MCPSession) -> Any:
202202
initial_observation = None
203203

204204
try:
205-
import httpx
206-
207205
# Extract base URL and session ID from the MCP session
208206
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
209207
session_id = session.session_id
@@ -459,9 +457,6 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
459457
control_plane_info = {}
460458

461459
try:
462-
# Query control plane endpoints following the new architecture
463-
import httpx
464-
465460
# Extract base URL and session ID from the MCP session
466461
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
467462
# Use the session ID from the established MCP session

eval_protocol/mcp/execution/manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import os
1212
import threading
1313
import time
14-
from concurrent.futures import ThreadPoolExecutor, as_completed
15-
from dataclasses import asdict, dataclass
14+
from dataclasses import asdict
1615
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1716

1817
from openai.types import CompletionUsage
@@ -248,7 +247,7 @@ async def _execute_rollout(
248247

249248
# Get initial messages in tau2-bench format for user simulator
250249
user_simulator_state = user_simulator.get_init_state()
251-
user_message, user_simulator_state = user_simulator.generate_next_message(
250+
user_message, user_simulator_state = await user_simulator.generate_next_message(
252251
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
253252
user_simulator_state,
254253
)
@@ -280,7 +279,7 @@ async def _execute_rollout(
280279
# Last message was agent, simulated user response
281280
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
282281
# Generate user response using the simulator
283-
user_message, user_simulator_state = user_simulator.generate_next_message(
282+
user_message, user_simulator_state = await user_simulator.generate_next_message(
284283
user_simulator_messages[-1], user_simulator_state
285284
)
286285
user_content = user_message.content if user_message.content else ""

eval_protocol/mcp/mcpgym.py

Lines changed: 61 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,22 @@
1313
"""
1414

1515
import asyncio
16+
import dataclasses
1617
import hashlib
1718
import inspect
1819
import json
1920
import logging
2021
import os
2122
import threading
2223
from abc import ABC, abstractmethod
24+
from concurrent.futures import ThreadPoolExecutor
25+
from datetime import date, datetime
26+
from enum import Enum
2327
from typing import Any, Callable, Dict, Optional, Tuple
2428

2529
import uvicorn
2630
from mcp.server.fastmcp import Context, FastMCP
31+
from pydantic import BaseModel
2732
from starlette.requests import Request
2833
from starlette.responses import JSONResponse
2934
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
@@ -75,14 +80,23 @@ class McpGym(ABC):
7580
- Environment Implementation: Single-process MCP server per environment
7681
"""
7782

78-
def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional[int] = None):
83+
def __init__(
84+
self,
85+
server_name: str,
86+
adapter: EnvironmentAdapter,
87+
seed: Optional[int] = None,
88+
max_workers: Optional[int] = None,
89+
):
7990
"""
8091
Initialize the MCP-Gym environment.
8192
8293
Args:
8394
server_name: Name for the MCP server
8495
adapter: Environment adapter instance
8596
seed: Optional seed for reproducible environments
97+
max_workers: Optional maximum number of worker threads for ThreadPoolExecutor.
98+
If None, uses ThreadPoolExecutor default (min(32, (os.cpu_count() or 1) + 4))
99+
86100
"""
87101
self.adapter = adapter
88102

@@ -110,6 +124,8 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
110124
"total_reward": 0.0,
111125
}
112126

127+
self.pool = ThreadPoolExecutor(max_workers=max_workers)
128+
113129
# Reset with seed if provided
114130
self.env, self.obs, _info = self._new_env(seed=seed)
115131

@@ -189,49 +205,7 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
189205
"""
190206
session_id = self._get_session_id(ctx)
191207
print(f"🔍 _get_or_create_session: session_id: {session_id}")
192-
193-
with self.session_lock:
194-
if session_id not in self.sessions:
195-
print(f"🔍 _get_or_create_session: Creating new session for {session_id}")
196-
# Extract seed from context using proper FastMCP pattern
197-
seed = None
198-
config = self._get_default_config()
199-
print(f"🔍 _get_or_create_session: default_config: {config}")
200-
201-
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
202-
client_params = ctx.session.client_params
203-
if hasattr(client_params, "clientInfo"):
204-
client_info = client_params.clientInfo
205-
if client_info and hasattr(client_info, "_extra"):
206-
extra_data = client_info._extra
207-
print(f"🔍 _get_or_create_session: extra_data in session creation: {extra_data}")
208-
if extra_data and isinstance(extra_data, dict):
209-
# Extract seed from client info
210-
seed = extra_data.get("seed")
211-
print(f"🌱 Extracted seed from client_info: {seed} (type: {type(seed)})")
212-
# Update config with any additional options
213-
if "config" in extra_data:
214-
config.update(extra_data["config"])
215-
print(f"🔍 _get_or_create_session: updated config: {config}")
216-
217-
print(f"🔍 _get_or_create_session: About to create environment with seed: {seed}")
218-
219-
env, obs, info = self._new_env(seed=seed)
220-
print(f"🔍 _get_or_create_session: environment created with obs: {obs}, info: {info}")
221-
222-
# Initialize session state
223-
self.sessions[session_id] = {
224-
"env": env,
225-
"obs": obs,
226-
"session_data": {}, # Subclasses can store additional data here
227-
"session_id": session_id,
228-
}
229-
230-
print(f"🎮 Created new session {session_id[:16]}... with seed {seed}, initial obs: {obs}")
231-
else:
232-
print(f"🔍 _get_or_create_session: Returning existing session {session_id}")
233-
234-
return self.sessions[session_id]
208+
return self.sessions[session_id]
235209

236210
def _register_session_reset_endpoint(self):
237211

@@ -243,16 +217,17 @@ async def reset_session_endpoint(request: Request) -> JSONResponse:
243217
print(f"🔍 _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}")
244218
if not session_id:
245219
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
246-
with self.session_lock:
247-
if session_id in self.sessions:
248-
env, obs, _ = self._new_env(seed=seed)
220+
if session_id in self.sessions:
221+
loop = asyncio.get_running_loop()
222+
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, seed)
223+
with self.session_lock:
249224
self.sessions[session_id] = {
250225
"env": env,
251226
"obs": obs,
252227
"session_data": {},
253228
"session_id": session_id,
254229
}
255-
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
230+
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
256231
return JSONResponse({"message": "Session reset successfully"})
257232

258233
def _discover_and_register_control_plane_endpoints(self):
@@ -286,29 +261,27 @@ async def endpoint_handler(request: Request) -> JSONResponse:
286261
)
287262

288263
# Get or create session data
264+
session_data = self.sessions.get(session_id)
265+
if not session_data:
266+
if func.__name__ != "get_initial_state_endpoint":
267+
return JSONResponse(
268+
{"error": f"Session {session_id} not found"},
269+
status_code=404,
270+
)
271+
272+
loop = asyncio.get_running_loop()
273+
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, None)
274+
275+
# Initialize session state with extracted seed from session ID
276+
session_data = {
277+
"env": env,
278+
"obs": obs,
279+
"session_data": {}, # Subclasses can store additional data here
280+
"session_id": session_id,
281+
}
289282
with self.session_lock:
290-
session_data = self.sessions.get(session_id)
291-
if not session_data:
292-
# For initial state endpoint, we need to create the session
293-
# based on the session ID and available information
294-
if func.__name__ == "get_initial_state_endpoint":
295-
env, obs, info = self._new_env(seed=None)
296-
# Initialize session state with extracted seed from session ID
297-
session_data = {
298-
"env": env,
299-
"obs": obs,
300-
"session_data": {}, # Subclasses can store additional data here
301-
"session_id": session_id,
302-
}
303-
# Store the session
304-
self.sessions[session_id] = session_data
305-
else:
306-
return JSONResponse(
307-
{"error": f"Session {session_id} not found"},
308-
status_code=404,
309-
)
310-
311-
# Call the endpoint function with session data
283+
self.sessions[session_id] = session_data
284+
312285
if inspect.iscoroutinefunction(func):
313286
result = await func(session_data=session_data)
314287
else:
@@ -356,22 +329,21 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool
356329

357330
def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
358331
"""Get or create control plane state for a specific session."""
359-
with self.session_lock:
360-
if session_id not in self.sessions:
361-
return {}
362-
363-
session_data = self.sessions[session_id]
364-
if "control_plane" not in session_data["session_data"]:
365-
session_data["session_data"]["control_plane"] = {
366-
"reward": 0.0,
367-
"terminated": False,
368-
"truncated": False,
369-
"info": {},
370-
"step_count": 0,
371-
"total_reward": 0.0,
372-
}
332+
if session_id not in self.sessions:
333+
raise Exception(f"Session {session_id} not found")
334+
335+
session_data = self.sessions[session_id]
336+
if "control_plane" not in session_data["session_data"]:
337+
session_data["session_data"]["control_plane"] = {
338+
"reward": 0.0,
339+
"terminated": False,
340+
"truncated": False,
341+
"info": {},
342+
"step_count": 0,
343+
"total_reward": 0.0,
344+
}
373345

374-
return session_data["session_data"]["control_plane"]
346+
return session_data["session_data"]["control_plane"]
375347

376348
def _update_session_control_plane(
377349
self,
@@ -396,13 +368,6 @@ def _update_session_control_plane(
396368
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}"
397369
)
398370

399-
def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:
400-
"""Get control plane state for a specific session (for rollout system)."""
401-
with self.session_lock:
402-
if session_id in self.sessions:
403-
return self._get_or_create_session_control_plane(session_id).copy()
404-
return None
405-
406371
def _execute_environment_step(self, action_int: int) -> Dict[str, Any]:
407372
"""
408373
Execute environment step and update control plane (single session).
@@ -510,11 +475,11 @@ def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
510475
return control_plane.get("info", {})
511476

512477
@control_plane_endpoint("/control/initial_state")
513-
def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
478+
async def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
514479
"""Get initial state for this session."""
480+
session_id = session_data.get("session_id", "unknown")
515481
env = session_data.get("env")
516482
obs = session_data.get("obs")
517-
518483
if env and obs is not None:
519484
try:
520485
formatted_obs = self.format_observation(obs, env)
@@ -604,8 +569,8 @@ async def run_with_high_concurrency():
604569
proxy_headers=True,
605570
forwarded_allow_ips="*",
606571
# HIGH CONCURRENCY SETTINGS
607-
limit_concurrency=200, # Increase for HTTP endpoints + MCP
608-
limit_max_requests=100000, # Higher request limit
572+
limit_concurrency=None, # Increase for HTTP endpoints + MCP
573+
limit_max_requests=None, # Higher request limit
609574
timeout_keep_alive=120, # Longer keep-alive for control plane
610575
timeout_notify=180,
611576
h11_max_incomplete_event_size=4 * 1024 * 1024, # Handle larger events
@@ -624,11 +589,6 @@ def _to_json_serializable(self, obj: Any) -> Any:
624589
Handles Pydantic models, dataclasses, lists, dicts, and primitive types.
625590
This is a utility method that can be used by format_observation implementations.
626591
"""
627-
import dataclasses
628-
from datetime import date, datetime
629-
from enum import Enum
630-
631-
from pydantic import BaseModel
632592

633593
# Handle None and primitive types
634594
if obj is None or isinstance(obj, (str, int, float, bool)):

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ def start(self) -> None:
4242
if self.process:
4343
return
4444

45+
try:
46+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
47+
s.settimeout(1)
48+
result = s.connect_ex(("localhost", self.port))
49+
if result == 0:
50+
raise RuntimeError(
51+
f"Port {self.port} is already in use! Please use a different port or kill the process using it."
52+
)
53+
except socket.error:
54+
pass
55+
4556
# Set environment for server
4657
env = os.environ.copy()
4758
env["PORT"] = str(self.port)

examples/blackjack_mcp/blackjack_mcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ class BlackjackMcp(McpGym):
3939
- Multi-session support with session-based control plane state
4040
"""
4141

42-
def __init__(self, seed: Optional[int] = None):
42+
def __init__(self, seed: Optional[int] = None, **kwargs):
4343
"""Initialize Blackjack MCP-Gym environment."""
4444
adapter = BlackjackAdapter()
45-
super().__init__("Blackjack-v1", adapter, seed)
45+
super().__init__("Blackjack-v1", adapter, seed, **kwargs)
4646

4747
# Multi-session support is now handled by the base class
4848

examples/cliff_walking_mcp/cliff_walking_mcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class CliffWalkingMcp(McpGym):
3838
- Multi-session support with session-based control plane state
3939
"""
4040

41-
def __init__(self, seed: Optional[int] = None):
41+
def __init__(self, seed: Optional[int] = None, **kwargs):
4242
"""Initialize Cliff Walking MCP-Gym environment."""
4343
adapter = CliffWalkingAdapter()
44-
super().__init__("CliffWalking-v1", adapter, seed)
44+
super().__init__("CliffWalking-v1", adapter, seed, **kwargs)
4545

4646
# Multi-session support is now handled by the base class
4747

0 commit comments

Comments
 (0)