Skip to content

Commit 073d99a

Browse files
committed
run on local to double check
1 parent 1165ff1 commit 073d99a

File tree

13 files changed

+269
-313
lines changed

13 files changed

+269
-313
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 93 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import hashlib
1010
import json
1111
import logging
12+
import time
1213
from contextlib import AsyncExitStack
1314
from typing import Any, Dict, List, Optional, Tuple
1415

@@ -27,9 +28,6 @@ class MCPConnectionManager:
2728
def __init__(self):
2829
self._tools_cache: Dict[str, List[Dict]] = {}
2930
self._tools_cache_lock = asyncio.Lock()
30-
# Shared HTTP client for control plane requests with high connection limits
31-
self._shared_client: Optional[httpx.AsyncClient] = None
32-
self._client_lock = asyncio.Lock()
3331

3432
async def initialize_session(self, session: MCPSession) -> None:
3533
"""
@@ -147,8 +145,6 @@ async def reset_session(self, session: MCPSession) -> None:
147145
"""
148146
Clean session data in remote mcp server for the given session
149147
"""
150-
import httpx
151-
152148
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
153149
url = f"{base_url}/control/reset_session"
154150

@@ -177,16 +173,23 @@ async def discover_tools(self, session: MCPSession) -> List[Dict]:
177173

178174
cache_key = session.base_url
179175

180-
# Check cache first (should be pre-warmed during initialization)
176+
# Fast path: Check cache first without lock (safe for reads)
177+
if cache_key in self._tools_cache:
178+
cached_tools = self._tools_cache[cache_key]
179+
logger.debug(f"Using cached tools for session {session.session_id} ({len(cached_tools)} tools)")
180+
return cached_tools
181+
182+
# Slow path: Cache miss - use lock only for writing
181183
async with self._tools_cache_lock:
184+
# Double-check pattern: another task might have cached it while we waited
182185
if cache_key in self._tools_cache:
183186
cached_tools = self._tools_cache[cache_key]
184187
logger.debug(f"Using cached tools for session {session.session_id} ({len(cached_tools)} tools)")
185188
return cached_tools
186189

187-
# Fallback: if cache miss (shouldn't happen with pre-warming), fetch directly
188-
logger.warning(f"Cache miss for {cache_key} - this shouldn't happen with pre-warming")
189-
mcp_session = session._mcp_session
190+
# Fallback: if cache miss (shouldn't happen with pre-warming), fetch directly
191+
logger.warning(f"Cache miss for {cache_key} - this shouldn't happen with pre-warming")
192+
mcp_session = session._mcp_session
190193

191194
tools_response = await mcp_session.list_tools()
192195
tools = tools_response.tools if hasattr(tools_response, "tools") else []
@@ -233,74 +236,118 @@ async def get_initial_state(self, session: MCPSession) -> Any:
233236
Returns:
234237
Initial observation/state
235238
"""
239+
method_start = time.time()
240+
session_id_short = session.session_id[:8] if len(session.session_id) > 8 else session.session_id
241+
logger.info(f"### 🌟 GET_INITIAL_STATE_START: timestamp: {method_start}, session_id: {session_id_short}...")
242+
236243
if not session._mcp_session:
244+
logger.error(f"### ❌ SESSION_NOT_INITIALIZED: session_id: {session_id_short}")
237245
raise RuntimeError("Session not initialized")
238246

239247
# Try to get initial state from control plane endpoint first
240248
initial_observation = None
241249

242250
try:
243-
import httpx
244-
245251
# Extract base URL and session ID from the MCP session
252+
url_extract_start = time.time()
253+
logger.info(
254+
f"### 🔍 URL_EXTRACT_START: timestamp: {url_extract_start}, elapsed: {url_extract_start - method_start:.6f}s, session_id: {session_id_short}..."
255+
)
256+
246257
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
247258
session_id = session.session_id
248259

260+
url_extract_end = time.time()
261+
logger.info(
262+
f"### 🔍 URL_EXTRACT_END: timestamp: {url_extract_end}, elapsed: {url_extract_end - method_start:.6f}s, duration: {url_extract_end - url_extract_start:.6f}s, base_url: {base_url}, session_id: {session_id_short}..."
263+
)
264+
249265
if session_id:
266+
headers_start = time.time()
267+
logger.info(
268+
f"### 🔍 HEADERS_CREATE_START: timestamp: {headers_start}, elapsed: {headers_start - method_start:.6f}s, session_id: {session_id_short}..."
269+
)
270+
250271
headers = {"mcp-session-id": session_id}
251272

273+
headers_end = time.time()
274+
logger.info(
275+
f"### 🔍 HEADERS_CREATE_END: timestamp: {headers_end}, elapsed: {headers_end - method_start:.6f}s, duration: {headers_end - headers_start:.6f}s, session_id: {session_id_short}..."
276+
)
277+
252278
# Query initial state endpoint
253279
try:
280+
timeout_start = time.time()
281+
logger.info(
282+
f"### 🔍 TIMEOUT_CONFIG_START: timestamp: {timeout_start}, elapsed: {timeout_start - method_start:.6f}s, session_id: {session_id_short}..."
283+
)
284+
254285
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
255286
# (50+ concurrent sessions need more time for initial state setup)
256287
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
257288

258-
# TIMING: Get shared client
259-
client_start = __import__("time").time()
260-
client = await self._get_shared_client(timeout)
261-
client_time = __import__("time").time() - client_start
289+
timeout_end = time.time()
262290
logger.info(
263-
f"DEBUG_CLIENT: Getting shared client took {client_time:.3f}s for {session.session_id}"
291+
f"### 🔍 TIMEOUT_CONFIG_END: timestamp: {timeout_end}, elapsed: {timeout_end - method_start:.6f}s, duration: {timeout_end - timeout_start:.6f}s, timeout: {timeout}s, session_id: {session_id_short}..."
264292
)
265293

294+
# TIMING: Get shared client
295+
# client = await self._get_shared_client(timeout)
296+
266297
# TIMING: HTTP request with shared client
267-
request_start = __import__("time").time()
268-
initial_state_response = await client.get(
269-
f"{base_url}/control/initial_state",
270-
headers=headers,
271-
timeout=timeout,
298+
request_start = time.time()
299+
logger.info(
300+
f"### 🌐 HTTP_REQUEST_START: timestamp: {request_start}, elapsed: {request_start - method_start:.6f}s, url: {base_url}/control/initial_state, session_id: {session_id_short}..."
272301
)
273-
request_time = __import__("time").time() - request_start
274-
logger.info(f"DEBUG_REQUEST: HTTP request took {request_time:.3f}s for {session.session_id}")
275-
if initial_state_response.status_code == 200:
276-
initial_observation = initial_state_response.json()
277-
logger.info(
278-
f"Session {session.session_id}: ✅ Successfully fetched session-aware initial state from control plane endpoint"
302+
303+
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
304+
305+
async with httpx.AsyncClient(timeout=timeout) as client:
306+
initial_state_response = await client.get(
307+
f"{base_url}/control/initial_state",
308+
headers=headers,
309+
timeout=timeout,
279310
)
280-
else:
281-
logger.warning(
282-
f"Control plane initial state endpoint returned {initial_state_response.status_code}"
311+
request_time = time.time() - request_start
312+
313+
request_end = time.time()
314+
logger.info(
315+
f"### 🌐 HTTP_REQUEST_END: timestamp: {request_end}, elapsed: {request_end - method_start:.6f}s, duration: {request_time:.6f}s, status_code: {initial_state_response.status_code}, session_id: {session_id_short}..."
283316
)
317+
318+
if initial_state_response.status_code == 200:
319+
initial_observation = initial_state_response.json()
320+
success_end = time.time()
321+
logger.info(
322+
f"### ✅ RETURN: timestamp: {success_end}, total_duration: {success_end - method_start:.6f}s, session_id: {session_id_short}..."
323+
)
324+
# return initial_observation
325+
else:
326+
error_time = time.time()
327+
logger.warning(
328+
f"### ⚠️ HTTP_ERROR_RESPONSE: timestamp: {error_time}, elapsed: {error_time - method_start:.6f}s, status_code: {initial_state_response.status_code}, session_id: {session_id_short}"
329+
)
284330
except httpx.TimeoutException:
285-
logger.warning(f"Control plane initial state endpoint timed out after {timeout}s")
331+
timeout_error_time = time.time()
332+
logger.warning(
333+
f"### ⏰ HTTP_TIMEOUT: timestamp: {timeout_error_time}, elapsed: {timeout_error_time - method_start:.6f}s, timeout: {timeout}s, session_id: {session_id_short}"
334+
)
286335
except Exception as e:
287-
logger.warning(f"Failed to query initial state endpoint: {e}")
336+
http_error_time = time.time()
337+
logger.warning(
338+
f"### ❌ HTTP_ERROR: timestamp: {http_error_time}, elapsed: {http_error_time - method_start:.6f}s, error: {str(e)}, session_id: {session_id_short}"
339+
)
288340

289341
except Exception as e:
290-
logger.warning(f"Failed to query control plane initial state endpoint: {e}")
291-
292-
# Fallback to MCP resource if control plane endpoint fails (backward compatibility)
293-
if initial_observation is None:
294-
logger.debug(f"Session {session.session_id}: Falling back to MCP resource for initial state")
295-
initial_observation = await self._get_initial_state_from_mcp_resource(session)
296-
297-
# Ensure we have some observation
298-
if initial_observation is None:
299-
logger.debug(f"Session {session.session_id}: Using default initial state")
300-
initial_observation = {
301-
"observation": "default_initial_state",
302-
"session_id": session.session_id,
303-
}
342+
general_error_time = time.time()
343+
logger.warning(
344+
f"### ❌ GENERAL_ERROR: timestamp: {general_error_time}, elapsed: {general_error_time - method_start:.6f}s, error: {str(e)}, session_id: {session_id_short}"
345+
)
346+
347+
method_end = time.time()
348+
logger.info(
349+
f"### 🔴 GET_INITIAL_STATE_END: timestamp: {method_end}, total_duration: {method_end - method_start:.6f}s, session_id: {session_id_short}..."
350+
)
304351

305352
return initial_observation
306353

@@ -509,9 +556,6 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
509556
control_plane_info = {}
510557

511558
try:
512-
# Query control plane endpoints following the new architecture
513-
import httpx
514-
515559
# Extract base URL and session ID from the MCP session
516560
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
517561
# Use the session ID from the established MCP session
@@ -601,47 +645,3 @@ async def close_session(self, session: MCPSession) -> None:
601645
finally:
602646
session._exit_stack = None
603647
session._mcp_session = None
604-
605-
async def _get_shared_client(self, timeout: float) -> httpx.AsyncClient:
606-
"""
607-
Get or create a shared HTTP client with high connection limits for concurrent requests.
608-
609-
Args:
610-
timeout: Timeout for requests
611-
612-
Returns:
613-
Shared httpx.AsyncClient instance
614-
"""
615-
# Fast path: if client exists and is not closed, return it immediately
616-
if self._shared_client is not None and not self._shared_client.is_closed:
617-
return self._shared_client
618-
619-
# Slow path: need to create client (use lock only for creation)
620-
async with self._client_lock:
621-
# Double-check pattern: another task might have created it while we waited
622-
if self._shared_client is None or self._shared_client.is_closed:
623-
# Create HTTP client with high connection limits for concurrent initial state requests
624-
limits = httpx.Limits(
625-
max_keepalive_connections=None, # Unlimited keep-alive connections
626-
max_connections=None, # Unlimited total connection pool size
627-
keepalive_expiry=30.0, # Keep connections alive for 30s
628-
)
629-
630-
self._shared_client = httpx.AsyncClient(
631-
timeout=timeout,
632-
limits=limits,
633-
# Enable connection pooling and keep-alive
634-
http2=False, # Disable HTTP/2 for better connection pooling with many concurrent requests
635-
)
636-
logger.info(
637-
"Created shared HTTP client with unlimited connection limits for MCP control plane requests"
638-
)
639-
640-
return self._shared_client
641-
642-
async def close_shared_client(self):
643-
"""Close the shared HTTP client when shutting down."""
644-
async with self._client_lock:
645-
if self._shared_client and not self._shared_client.is_closed:
646-
await self._shared_client.aclose()
647-
self._shared_client = None

eval_protocol/mcp/execution/manager.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ async def _execute_rollout(
191191
dataset_row = envs.dataset_rows[rollout_idx]
192192
rollout_start = time.time()
193193
elapsed_from_main_start = rollout_start - start_time
194-
logger.info(
195-
f"DEBUG4. Starting rollout {dataset_row.id} at {datetime.fromtimestamp(rollout_start).strftime('%H:%M:%S.%f')[:-3]} (+{elapsed_from_main_start:.3f}s from start)"
196-
)
194+
logger.info(f"DEBUG4. Starting rollout {dataset_row.id} at {rollout_start}")
197195

198196
# Initialize trajectory
199197
trajectory = Trajectory(
@@ -219,7 +217,7 @@ async def _execute_rollout(
219217
temp_start = time.time()
220218
current_observation, tool_schema = await envs.reset(session)
221219
logger.info(
222-
f"DEBUG6: User simulator get_init_state took {time.time() - temp_start:.3f}s for {session.session_id}"
220+
f"DEBUG6: User simulator get_init_state took {time.time() - temp_start:.3f}s for {session.session_id}, started at {temp_start}"
223221
)
224222
system_prompt = dataset_row.system_prompt
225223

@@ -240,7 +238,7 @@ async def _execute_rollout(
240238

241239
# Get initial messages in tau2-bench format for user simulator
242240
user_simulator_state = user_simulator.get_init_state()
243-
user_message, user_simulator_state = user_simulator.generate_next_message(
241+
user_message, user_simulator_state = await user_simulator.generate_next_message(
244242
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
245243
user_simulator_state,
246244
)
@@ -277,11 +275,11 @@ async def _execute_rollout(
277275
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
278276
# Generate user response using the simulator
279277
temp_start1 = time.time()
280-
user_message, user_simulator_state = user_simulator.generate_next_message(
278+
user_message, user_simulator_state = await user_simulator.generate_next_message(
281279
user_simulator_messages[-1], user_simulator_state
282280
)
283281
logger.info(
284-
f"DEBUG8: User simulator generate_next_message took {time.time() - temp_start1:.3f}s for {dataset_row.id}"
282+
f"DEBUG8: User simulator generate_next_message took {time.time() - temp_start1:.3f}s for {dataset_row.id}, started at {temp_start1}"
285283
)
286284
user_content = user_message.content if user_message.content else ""
287285

@@ -297,7 +295,9 @@ async def _execute_rollout(
297295
while not turn_completed and not trajectory.terminated:
298296
temp_start2 = time.time()
299297
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
300-
logger.info(f"DEBUG9: Policy took {time.time() - temp_start2:.3f}s for {dataset_row.id}")
298+
logger.info(
299+
f"DEBUG9: Policy took {time.time() - temp_start2:.3f}s for {dataset_row.id}, started at {temp_start2}"
300+
)
301301

302302
# If no tool call is generated, turn is finished
303303
if len(tool_calls) == 1:
@@ -316,7 +316,9 @@ async def _execute_rollout(
316316
# Execute tool call for this environment
317317
temp_start3 = time.time()
318318
observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call)
319-
logger.info(f"DEBUG10: Env step took {time.time() - temp_start3:.3f}s for {dataset_row.id}")
319+
logger.info(
320+
f"DEBUG10: Env step took {time.time() - temp_start3:.3f}s for {dataset_row.id}, started at {temp_start3}"
321+
)
320322

321323
tool_response = envs.format_tool_response(observation)
322324

@@ -464,9 +466,7 @@ async def _execute_rollout(
464466
logger.info(
465467
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
466468
)
467-
logger.info(
468-
f"DEBUG11: Rollout {dataset_row.id} completed at {datetime.fromtimestamp(time.time()).strftime('%H:%M:%S.%f')[:-3]} (+{time.time() - rollout_start:.3f}s from start)"
469-
)
469+
logger.info(f"DEBUG11: Rollout {dataset_row.id} completed at {time.time()}, started at {rollout_start}")
470470
return trajectory
471471

472472
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:

0 commit comments

Comments
 (0)