Skip to content

Commit ffee4e0

Browse files
author
Dylan Huang
committed
Merge branch 'main' into aggregated-metrics-part-2
# Conflicts: # vite-app/src/components/EvaluationTable.tsx
2 parents b793de7 + 33c52a8 commit ffee4e0

34 files changed

+309
-351
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import hashlib
1010
import json
1111
import logging
12+
import time
1213
from contextlib import AsyncExitStack
1314
from typing import Any, Dict, List, Optional, Tuple
1415

16+
import httpx
1517
from mcp.client.session import ClientSession
1618
from mcp.client.streamable_http import streamablehttp_client
19+
from mcp.types import Implementation
1720

1821
from ...types import MCPSession
19-
from mcp.types import Implementation
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -109,15 +111,13 @@ async def reset_session(self, session: MCPSession) -> None:
109111
"""
110112
Clean session data in remote mcp server for the given session
111113
"""
112-
import httpx
113-
114114
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
115115
url = f"{base_url}/control/reset_session"
116116

117117
headers = {"mcp-session-id": session.session_id}
118118
body = {"seed": session.seed}
119119

120-
timeout = httpx.Timeout(3.0)
120+
timeout = httpx.Timeout(15.0)
121121
async with httpx.AsyncClient(timeout=timeout) as client:
122122
resp = await client.post(url, headers=headers, json=body)
123123
resp.raise_for_status()
@@ -202,8 +202,6 @@ async def get_initial_state(self, session: MCPSession) -> Any:
202202
initial_observation = None
203203

204204
try:
205-
import httpx
206-
207205
# Extract base URL and session ID from the MCP session
208206
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
209207
session_id = session.session_id
@@ -459,9 +457,6 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
459457
control_plane_info = {}
460458

461459
try:
462-
# Query control plane endpoints following the new architecture
463-
import httpx
464-
465460
# Extract base URL and session ID from the MCP session
466461
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
467462
# Use the session ID from the established MCP session
@@ -544,10 +539,10 @@ async def close_session(self, session: MCPSession) -> None:
544539
await session._exit_stack.aclose()
545540
except asyncio.CancelledError:
546541
# Handle cancellation gracefully (especially important for Python 3.12)
547-
logger.debug(f"Session {session.session_id} close was cancelled")
542+
logger.error(f"Session {session.session_id} close was cancelled")
548543
except Exception as e:
549544
# Hitting this error, probably because of use of threads: "Attempted to exit cancel scope in a different task than it was entered in"
550-
logger.debug(f"Error closing session {session.session_id}: {e}")
545+
logger.error(f"Error closing session {session.session_id}: {e}")
551546
finally:
552547
session._exit_stack = None
553548
session._mcp_session = None

eval_protocol/mcp/execution/base_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
220220
return mcp_tool_calls, usage_stats
221221
else:
222222
# No tool calls in response - this is normal when episode ends or LLM provides only text
223-
logger.info(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
223+
logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
224224
return [
225225
MCPToolCall(
226226
tool_name="_no_tool_call",

eval_protocol/mcp/execution/manager.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import os
1212
import threading
1313
import time
14-
from concurrent.futures import ThreadPoolExecutor, as_completed
15-
from dataclasses import asdict, dataclass
14+
from dataclasses import asdict
1615
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1716

1817
from openai.types import CompletionUsage
@@ -98,10 +97,12 @@ async def execute_rollouts(
9897

9998
async def _execute_with_semaphore(idx):
10099
async with semaphore:
101-
return await self._execute_rollout(
100+
result = await self._execute_rollout(
102101
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
103102
)
104103

104+
return result
105+
105106
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
106107
# exceptions will be try catched inside single _execute_rollout
107108
trajectories = await asyncio.gather(*tasks)
@@ -113,9 +114,6 @@ async def _execute_with_semaphore(idx):
113114

114115
shared_tool_schema = envs.tool_schemas
115116

116-
# Clean up
117-
await envs.close()
118-
119117
# Enhanced reporting with control plane info
120118
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
121119
terminated_by_control_plane = sum(
@@ -176,8 +174,11 @@ async def _execute_with_semaphore(idx):
176174
TerminationReason.USER_STOP,
177175
}:
178176
evaluation_rows[idx].rollout_status.status = "finished"
179-
elif trajectory.termination_reason == TerminationReason.MAX_STEPS:
177+
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
180178
evaluation_rows[idx].rollout_status.status = "stopped"
179+
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
180+
"termination_reason", trajectory.termination_reason
181+
)
181182
else:
182183
evaluation_rows[idx].rollout_status.status = "error"
183184
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
@@ -227,6 +228,7 @@ async def _execute_rollout(
227228
"total_tokens": 0,
228229
},
229230
)
231+
failure_reason = None
230232
try:
231233
current_observation, tool_schema = await envs.reset(session)
232234
system_prompt = dataset_row.system_prompt
@@ -248,7 +250,7 @@ async def _execute_rollout(
248250

249251
# Get initial messages in tau2-bench format for user simulator
250252
user_simulator_state = user_simulator.get_init_state()
251-
user_message, user_simulator_state = user_simulator.generate_next_message(
253+
user_message, user_simulator_state = await user_simulator.generate_next_message(
252254
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
253255
user_simulator_state,
254256
)
@@ -280,7 +282,7 @@ async def _execute_rollout(
280282
# Last message was agent, simulated user response
281283
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
282284
# Generate user response using the simulator
283-
user_message, user_simulator_state = user_simulator.generate_next_message(
285+
user_message, user_simulator_state = await user_simulator.generate_next_message(
284286
user_simulator_messages[-1], user_simulator_state
285287
)
286288
user_content = user_message.content if user_message.content else ""
@@ -312,8 +314,7 @@ async def _execute_rollout(
312314
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
313315
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
314316
trajectory.terminated = True
315-
trajectory.termination_reason = TerminationReason.ERROR
316-
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
317+
trajectory.termination_reason = TerminationReason.INTERRUPTED
317318
break
318319

319320
# Execute each tool call sequentially
@@ -467,11 +468,26 @@ async def _execute_rollout(
467468
logger.info(
468469
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
469470
)
471+
472+
except asyncio.CancelledError:
473+
logger.error(f"🚨 AsyncIO Cancel Error in roll out {rollout_idx}", exc_info=True)
474+
failure_reason = "asyncio context cancelled"
470475
except Exception as e:
471476
logger.error(f"🚨 Error in rollout {rollout_idx}: {e}", exc_info=True)
472-
trajectory.terminated = True
473-
trajectory.termination_reason = TerminationReason.ERROR
474-
trajectory.control_plane_summary.update({"error_message": str(e)})
477+
failure_reason = str(e)
478+
finally:
479+
if failure_reason:
480+
trajectory.terminated = True
481+
trajectory.termination_reason = TerminationReason.ERROR
482+
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
483+
try:
484+
await envs.connection_manager.reset_session(session)
485+
except:
486+
logger.error(f"Error resetting session {session.session_id}")
487+
try:
488+
await envs.connection_manager.close_session(session)
489+
except:
490+
logger.error(f"Error closing session {session.session_id}")
475491
return trajectory
476492

477493
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:

0 commit comments

Comments
 (0)