Skip to content

Commit 13a8506

Browse files
committed
debug
1 parent 58eba6e commit 13a8506

File tree

11 files changed

+30
-341
lines changed

11 files changed

+30
-341
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
8888
"""
8989
cache_key = session.base_url
9090

91-
# Fast path: if cache already exists, return immediately (no lock)
92-
if cache_key in self._tools_cache:
93-
logger.debug(f"Tools cache already exists for {cache_key}")
94-
return
95-
96-
# Slow path: need to create cache (use lock only for creation)
9791
async with self._tools_cache_lock:
9892
# Only fetch tools if not already cached for this base_url
9993
if cache_key not in self._tools_cache:
@@ -123,7 +117,7 @@ async def reset_session(self, session: MCPSession) -> None:
123117
headers = {"mcp-session-id": session.session_id}
124118
body = {"seed": session.seed}
125119

126-
timeout = httpx.Timeout(3.0)
120+
timeout = httpx.Timeout(15.0)
127121
async with httpx.AsyncClient(timeout=timeout) as client:
128122
resp = await client.post(url, headers=headers, json=body)
129123
resp.raise_for_status()
@@ -145,23 +139,16 @@ async def discover_tools(self, session: MCPSession) -> List[Dict]:
145139

146140
cache_key = session.base_url
147141

148-
# Fast path: Check cache first without lock (safe for reads)
149-
if cache_key in self._tools_cache:
150-
cached_tools = self._tools_cache[cache_key]
151-
logger.debug(f"Using cached tools for session {session.session_id} ({len(cached_tools)} tools)")
152-
return cached_tools
153-
154-
# Slow path: Cache miss - use lock only for writing
142+
# Check cache first (should be pre-warmed during initialization)
155143
async with self._tools_cache_lock:
156-
# Double-check pattern: another task might have cached it while we waited
157144
if cache_key in self._tools_cache:
158145
cached_tools = self._tools_cache[cache_key]
159146
logger.debug(f"Using cached tools for session {session.session_id} ({len(cached_tools)} tools)")
160147
return cached_tools
161148

162-
# Fallback: if cache miss (shouldn't happen with pre-warming), fetch directly
163-
logger.warning(f"Cache miss for {cache_key} - this shouldn't happen with pre-warming")
164-
mcp_session = session._mcp_session
149+
# Fallback: if cache miss (shouldn't happen with pre-warming), fetch directly
150+
logger.warning(f"Cache miss for {cache_key} - this shouldn't happen with pre-warming")
151+
mcp_session = session._mcp_session
165152

166153
tools_response = await mcp_session.list_tools()
167154
tools = tools_response.tools if hasattr(tools_response, "tools") else []
@@ -213,7 +200,6 @@ async def get_initial_state(self, session: MCPSession) -> Any:
213200
logger.info(f"### 🌟 GET_INITIAL_STATE_START: timestamp: {method_start}, session_id: {session_id_short}...")
214201

215202
if not session._mcp_session:
216-
logger.error(f"### ❌ SESSION_NOT_INITIALIZED: session_id: {session_id_short}")
217203
raise RuntimeError("Session not initialized")
218204

219205
# Try to get initial state from control plane endpoint first

eval_protocol/mcp/execution/manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,13 @@ async def _execute_rollout(
288288
)
289289
user_content = user_message.content if user_message.content else ""
290290

291-
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
292-
conversation_history.append({"role": "user", "content": user_prompt})
291+
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
292+
conversation_history.append({"role": "user", "content": user_prompt})
293293

294-
# Check if user simulator signaled termination
295-
if UserSimulator.is_stop(user_message):
296-
trajectory.terminated = True
297-
trajectory.termination_reason = TerminationReason.USER_STOP
294+
# Check if user simulator signaled termination
295+
if UserSimulator.is_stop(user_message):
296+
trajectory.terminated = True
297+
trajectory.termination_reason = TerminationReason.USER_STOP
298298

299299
# In each turn: keep looping until assistant is ready to provide final response
300300
while not turn_completed and not trajectory.terminated:

eval_protocol/mcp/mcpgym.py

Lines changed: 16 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
from typing import Any, Callable, Dict, Optional, Tuple
2929

3030
import uvicorn
31-
32-
# from mcp.server.fastmcp import Context, FastMCP
33-
from fastmcp import Context, FastMCP
31+
from mcp.server.fastmcp import Context, FastMCP
3432
from pydantic import BaseModel
3533
from starlette.requests import Request
3634
from starlette.responses import JSONResponse
@@ -104,8 +102,11 @@ def __init__(
104102
self.adapter = adapter
105103

106104
# Create FastMCP server
107-
self.mcp = FastMCP(name=server_name)
108-
105+
self.mcp = FastMCP(
106+
server_name,
107+
host="0.0.0.0",
108+
port=int(os.environ.get("PORT", 8000)),
109+
)
109110
# Store host and port for later use in run() method
110111
self.host = "0.0.0.0"
111112
self.port = int(os.environ.get("PORT", 8000))
@@ -129,6 +130,7 @@ def __init__(
129130

130131
self.pool = ThreadPoolExecutor(max_workers=max_workers)
131132

133+
# Reset with seed if provided
132134
self.env, self.obs, _info = self._new_env(seed=seed)
133135

134136
# Register tools and control plane endpoints
@@ -220,7 +222,8 @@ async def reset_session_endpoint(request: Request) -> JSONResponse:
220222
if not session_id:
221223
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
222224
if session_id in self.sessions:
223-
env, obs, _ = self._new_env(seed=seed)
225+
loop = asyncio.get_running_loop()
226+
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, seed)
224227
with self.session_lock:
225228
self.sessions[session_id] = {
226229
"env": env,
@@ -269,17 +272,10 @@ async def endpoint_handler(request: Request) -> JSONResponse:
269272
{"error": f"Session {session_id} not found"},
270273
status_code=404,
271274
)
272-
start_time = time.time()
273-
logger.info(
274-
f"### 🔍 NEW_ENV_START: timestamp: {start_time}, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
275-
)
275+
276276
loop = asyncio.get_running_loop()
277277
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, None)
278-
# env, obs, info = self._new_env(None)
279-
end_time = time.time()
280-
logger.info(
281-
f"### 🔍 NEW_ENV_END: timestamp: {end_time}, elapsed: {end_time - start_time:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
282-
)
278+
283279
# Initialize session state with extracted seed from session ID
284280
session_data = {
285281
"env": env,
@@ -294,6 +290,7 @@ async def endpoint_handler(request: Request) -> JSONResponse:
294290
result = await func(session_data=session_data)
295291
else:
296292
result = func(session_data=session_data)
293+
297294
return JSONResponse(result)
298295

299296
except Exception as e:
@@ -484,78 +481,26 @@ def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
484481
@control_plane_endpoint("/control/initial_state")
485482
async def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
486483
"""Get initial state for this session."""
487-
endpoint_start = time.time()
488484
session_id = session_data.get("session_id", "unknown")
489-
logger.info(
490-
f"### 🌟 ENDPOINT_START: get_initial_state_endpoint, timestamp: {endpoint_start}, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
491-
)
492-
493-
env_check_start = time.time()
494-
logger.info(
495-
f"### 🔍 ENV_CHECK_START: timestamp: {env_check_start}, elapsed: {env_check_start - endpoint_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
496-
)
497-
498485
env = session_data.get("env")
499486
obs = session_data.get("obs")
500-
501-
env_check_end = time.time()
502-
logger.info(
503-
f"### 🔍 ENV_CHECK_END: timestamp: {env_check_end}, elapsed: {env_check_end - endpoint_start:.6f}s, duration: {env_check_end - env_check_start:.6f}s, env: {env is not None}, obs: {obs is not None}, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
504-
)
505-
506487
if env and obs is not None:
507-
format_start = time.time()
508-
logger.info(
509-
f"### 🔄 FORMAT_OBS_START: timestamp: {format_start}, elapsed: {format_start - endpoint_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
510-
)
511-
512488
try:
513489
formatted_obs = self.format_observation(obs, env)
514-
515-
format_end = time.time()
516-
logger.info(
517-
f"### 🔄 FORMAT_OBS_END: timestamp: {format_end}, elapsed: {format_end - endpoint_start:.6f}s, duration: {format_end - format_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
518-
)
519-
520-
endpoint_end = time.time()
521-
logger.info(
522-
f"### ✅ ENDPOINT_SUCCESS_END: timestamp: {endpoint_end}, total_duration: {endpoint_end - endpoint_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
523-
)
524-
525490
return formatted_obs
526491
except Exception as e:
527-
error_time = time.time()
528-
logger.error(
529-
f"### ❌ FORMAT_OBS_ERROR: timestamp: {error_time}, elapsed: {error_time - endpoint_start:.6f}s, error: {str(e)}, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
530-
)
531-
492+
logger.error(f"❌ Error in format_observation: {e}")
532493
return {
533494
"error": f"Failed to format observation: {str(e)}",
534495
"observation_type": str(type(obs)),
535496
"session_id": session_data.get("session_id", "unknown"),
536497
}
537498
else:
538-
fallback_start = time.time()
539-
logger.info(
540-
f"### 🔄 FALLBACK_START: timestamp: {fallback_start}, elapsed: {fallback_start - endpoint_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
541-
)
542-
543499
# Fallback if session data is not available
544500
result = {
545501
"observation": "session_not_initialized",
546502
"session_id": session_data.get("session_id", "unknown"),
547503
}
548-
549-
fallback_end = time.time()
550-
logger.info(
551-
f"### 🔄 FALLBACK_END: timestamp: {fallback_end}, elapsed: {fallback_end - endpoint_start:.6f}s, duration: {fallback_end - fallback_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
552-
)
553-
554-
endpoint_end = time.time()
555-
logger.info(
556-
f"### ✅ ENDPOINT_FALLBACK_END: timestamp: {endpoint_end}, total_duration: {endpoint_end - endpoint_start:.6f}s, session_id: {session_id[:8] if len(session_id) > 8 else session_id}..."
557-
)
558-
559504
return result
560505

561506
def _get_session_control_plane_from_data(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -623,9 +568,9 @@ async def run_with_high_concurrency():
623568

624569
config = uvicorn.Config(
625570
starlette_app,
626-
host=self.host,
627-
port=self.port,
628-
log_level="info", # Use default log level instead of accessing settings
571+
host=self.mcp.settings.host,
572+
port=self.mcp.settings.port,
573+
log_level=self.mcp.settings.log_level.lower(), # Use default log level instead of accessing settings
629574
proxy_headers=True,
630575
forwarded_allow_ips="*",
631576
# HIGH CONCURRENCY SETTINGS

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,11 @@ async def default_mcp_gym_rollout_processor(
224224
)
225225

226226
# Create MCP environments directly from evaluation_rows
227-
print("DEBUG1", time.time())
228227
envs = await ep.make(
229228
"http://localhost:9700/mcp/",
230229
evaluation_rows=rows,
231230
model_id=policy.model_id,
232231
)
233-
print("DEBUG2", time.time())
234-
print("max_concurrent_rollouts", config.max_concurrent_rollouts)
235232

236233
# Run rollout with environments and policy
237234
evaluation_rows = await ep.rollout(

examples/frozen_lake_mcp/frozen_lake_mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
from typing import Any, Dict, Optional
2121

22-
from fastmcp import Context
2322
from frozen_lake_adapter import FrozenLakeAdapter
23+
from mcp.server.fastmcp import Context
2424

2525
from eval_protocol.mcp import McpGym
2626
from eval_protocol.mcp.mcpgym import control_plane_endpoint

examples/tau2_mcp/airplane_environment/airline_environment.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):
3737
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
3838
"""Reset the environment to initial state"""
3939
logger.info("🔄 Resetting airline environment - reloading database from disk")
40-
start_time = time.time()
4140
self.db = FlightDB.load(AIRLINE_DB_PATH)
4241
self.airline_tools = AirlineTools(self.db)
4342

44-
end_time = time.time()
45-
logger.info(f"11RESET TOOK {end_time - start_time:.2f} seconds, called at {start_time}")
46-
4743
return {}, {}
4844

4945
def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:

examples/tau2_mcp/tau2_mcp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from typing import Annotated, Any, Dict, List, Optional
1313

1414
from airplane_environment.airline_environment import AirlineEnvironment
15-
16-
# from mcp.server.fastmcp import Context
17-
from fastmcp import Context
15+
from mcp.server.fastmcp import Context
1816
from mock_environment.mock_environment import MockEnvironment
1917
from pydantic import Field
2018
from retail_environment.retail_environment import RetailEnvironment

monitor_connections.sh

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)