Skip to content

Commit 28932e4

Browse files
committed
cleanup
1 parent 13a8506 commit 28932e4

5 files changed

Lines changed: 71 additions & 131 deletions

File tree

eval_protocol/mcp/client/connection.py

Lines changed: 18 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ async def get_initial_state(self, session: MCPSession) -> Any:
195195
Returns:
196196
Initial observation/state
197197
"""
198-
method_start = time.time()
199-
session_id_short = session.session_id[:8] if len(session.session_id) > 8 else session.session_id
200-
logger.info(f"### 🌟 GET_INITIAL_STATE_START: timestamp: {method_start}, session_id: {session_id_short}...")
201-
202198
if not session._mcp_session:
203199
raise RuntimeError("Session not initialized")
204200

@@ -207,105 +203,53 @@ async def get_initial_state(self, session: MCPSession) -> Any:
207203

208204
try:
209205
# Extract base URL and session ID from the MCP session
210-
url_extract_start = time.time()
211-
logger.info(
212-
f"### 🔍 URL_EXTRACT_START: timestamp: {url_extract_start}, elapsed: {url_extract_start - method_start:.6f}s, session_id: {session_id_short}..."
213-
)
214-
215206
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
216207
session_id = session.session_id
217208

218-
url_extract_end = time.time()
219-
logger.info(
220-
f"### 🔍 URL_EXTRACT_END: timestamp: {url_extract_end}, elapsed: {url_extract_end - method_start:.6f}s, duration: {url_extract_end - url_extract_start:.6f}s, base_url: {base_url}, session_id: {session_id_short}..."
221-
)
222-
223209
if session_id:
224-
headers_start = time.time()
225-
logger.info(
226-
f"### 🔍 HEADERS_CREATE_START: timestamp: {headers_start}, elapsed: {headers_start - method_start:.6f}s, session_id: {session_id_short}..."
227-
)
228-
229210
headers = {"mcp-session-id": session_id}
230211

231-
headers_end = time.time()
232-
logger.info(
233-
f"### 🔍 HEADERS_CREATE_END: timestamp: {headers_end}, elapsed: {headers_end - method_start:.6f}s, duration: {headers_end - headers_start:.6f}s, session_id: {session_id_short}..."
234-
)
235-
236212
# Query initial state endpoint
237213
try:
238-
timeout_start = time.time()
239-
logger.info(
240-
f"### 🔍 TIMEOUT_CONFIG_START: timestamp: {timeout_start}, elapsed: {timeout_start - method_start:.6f}s, session_id: {session_id_short}..."
241-
)
242-
243214
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
244215
# (50+ concurrent sessions need more time for initial state setup)
245216
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
246217

247-
timeout_end = time.time()
248-
logger.info(
249-
f"### 🔍 TIMEOUT_CONFIG_END: timestamp: {timeout_end}, elapsed: {timeout_end - method_start:.6f}s, duration: {timeout_end - timeout_start:.6f}s, timeout: {timeout}s, session_id: {session_id_short}..."
250-
)
251-
252-
# TIMING: Get shared client
253-
# client = await self._get_shared_client(timeout)
254-
255-
# TIMING: HTTP request with shared client
256-
request_start = time.time()
257-
logger.info(
258-
f"### 🌐 HTTP_REQUEST_START: timestamp: {request_start}, elapsed: {request_start - method_start:.6f}s, url: {base_url}/control/initial_state, session_id: {session_id_short}..."
259-
)
260-
261-
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
262-
263218
async with httpx.AsyncClient(timeout=timeout) as client:
264219
initial_state_response = await client.get(
265220
f"{base_url}/control/initial_state",
266221
headers=headers,
267222
timeout=timeout,
268223
)
269-
request_time = time.time() - request_start
270-
271-
request_end = time.time()
272-
logger.info(
273-
f"### 🌐 HTTP_REQUEST_END: timestamp: {request_end}, elapsed: {request_end - method_start:.6f}s, duration: {request_time:.6f}s, status_code: {initial_state_response.status_code}, session_id: {session_id_short}..."
274-
)
275224

276225
if initial_state_response.status_code == 200:
277226
initial_observation = initial_state_response.json()
278-
success_end = time.time()
279227
logger.info(
280-
f"### ✅ RETURN: timestamp: {success_end}, total_duration: {success_end - method_start:.6f}s, session_id: {session_id_short}..."
228+
f"Session {session.session_id}: ✅ Successfully fetched session-aware initial state from control plane endpoint"
281229
)
282-
# return initial_observation
283230
else:
284-
error_time = time.time()
285231
logger.warning(
286-
f"### ⚠️ HTTP_ERROR_RESPONSE: timestamp: {error_time}, elapsed: {error_time - method_start:.6f}s, status_code: {initial_state_response.status_code}, session_id: {session_id_short}"
232+
f"Control plane initial state endpoint returned {initial_state_response.status_code}"
287233
)
288234
except httpx.TimeoutException:
289-
timeout_error_time = time.time()
290-
logger.warning(
291-
f"### ⏰ HTTP_TIMEOUT: timestamp: {timeout_error_time}, elapsed: {timeout_error_time - method_start:.6f}s, timeout: {timeout}s, session_id: {session_id_short}"
292-
)
235+
logger.warning(f"Control plane initial state endpoint timed out after {timeout}s")
293236
except Exception as e:
294-
http_error_time = time.time()
295-
logger.warning(
296-
f"### ❌ HTTP_ERROR: timestamp: {http_error_time}, elapsed: {http_error_time - method_start:.6f}s, error: {str(e)}, session_id: {session_id_short}"
297-
)
298-
237+
logger.warning(f"Failed to query control plane initial state endpoint: {e}")
299238
except Exception as e:
300-
general_error_time = time.time()
301-
logger.warning(
302-
f"### ❌ GENERAL_ERROR: timestamp: {general_error_time}, elapsed: {general_error_time - method_start:.6f}s, error: {str(e)}, session_id: {session_id_short}"
303-
)
304-
305-
method_end = time.time()
306-
logger.info(
307-
f"### 🔴 GET_INITIAL_STATE_END: timestamp: {method_end}, total_duration: {method_end - method_start:.6f}s, session_id: {session_id_short}..."
308-
)
239+
logger.warning(f"Failed to query control plane initial state endpoint: {e}")
240+
241+
# Fallback to MCP resource if control plane endpoint fails (backward compatibility)
242+
if initial_observation is None:
243+
logger.debug(f"Session {session.session_id}: Falling back to MCP resource for initial state")
244+
initial_observation = await self._get_initial_state_from_mcp_resource(session)
245+
246+
# Ensure we have some observation
247+
if initial_observation is None:
248+
logger.debug(f"Session {session.session_id}: Using default initial state")
249+
initial_observation = {
250+
"observation": "default_initial_state",
251+
"session_id": session.session_id,
252+
}
309253

310254
return initial_observation
311255

eval_protocol/mcp/execution/manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,6 @@ async def _execute_rollout(
207207
"""
208208
session = envs.sessions[rollout_idx]
209209
dataset_row = envs.dataset_rows[rollout_idx]
210-
rollout_start = time.time()
211-
elapsed_from_main_start = rollout_start - start_time
212-
logger.info(f"DEBUG4. Starting rollout {dataset_row.id} at {rollout_start}")
213210

214211
# Initialize trajectory
215212
trajectory = Trajectory(

eval_protocol/mcp/execution/policy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from .base_policy import LLMBasePolicy
2020

2121
logger = logging.getLogger(__name__)
22-
litellm._turn_on_debug()
2322

2423

2524
class LiteLLMPolicy(LLMBasePolicy):

eval_protocol/mcp/mcpgym.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ def __init__(
107107
host="0.0.0.0",
108108
port=int(os.environ.get("PORT", 8000)),
109109
)
110-
# Store host and port for later use in run() method
111-
self.host = "0.0.0.0"
112-
self.port = int(os.environ.get("PORT", 8000))
113110

114111
# Multi-session support
115112
self.sessions = {} # session_id -> {"env": env, "obs": obs, "session_data": data}
@@ -497,11 +494,10 @@ async def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict
497494
}
498495
else:
499496
# Fallback if session data is not available
500-
result = {
497+
return {
501498
"observation": "session_not_initialized",
502499
"session_id": session_data.get("session_id", "unknown"),
503500
}
504-
return result
505501

506502
def _get_session_control_plane_from_data(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
507503
"""Extract control plane state from session data."""
@@ -570,7 +566,7 @@ async def run_with_high_concurrency():
570566
starlette_app,
571567
host=self.mcp.settings.host,
572568
port=self.mcp.settings.port,
573-
log_level=self.mcp.settings.log_level.lower(), # Use default log level instead of accessing settings
569+
log_level=self.mcp.settings.log_level.lower(),
574570
proxy_headers=True,
575571
forwarded_allow_ips="*",
576572
# HIGH CONCURRENCY SETTINGS

tests/pytest/test_tau_bench_airline.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
5858

5959
rows.append(eval_row)
6060

61-
return rows
61+
return rows[0:3]
6262

6363

6464
@evaluation_test(
@@ -139,23 +139,27 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
139139
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
140140
) # id and user_scenario are required for the Task type but not used in calculating reward
141141

142-
env_reward_info = EnvironmentEvaluator.calculate_reward(
143-
environment_constructor=registry.get_env_constructor("airline"),
144-
task=task,
145-
full_trajectory=trajectory_objects,
146-
)
147-
# action_reward_info = ActionEvaluator.calculate_reward(
148-
# task=task,
149-
# full_trajectory=trajectory_objects,
150-
# )
151-
communicate_reward_info = CommunicateEvaluator.calculate_reward(
152-
task=task,
153-
full_trajectory=trajectory_objects,
154-
)
155-
# nl_reward_info = NLAssertionsEvaluator.calculate_reward(
156-
# task=task,
157-
# full_trajectory=trajectory_objects,
158-
# )
142+
if RewardType.DB in task.evaluation_criteria.reward_basis:
143+
env_reward_info = EnvironmentEvaluator.calculate_reward(
144+
environment_constructor=registry.get_env_constructor("airline"),
145+
task=task,
146+
full_trajectory=trajectory_objects,
147+
)
148+
if RewardType.ACTION in task.evaluation_criteria.reward_basis:
149+
action_reward_info = ActionEvaluator.calculate_reward(
150+
task=task,
151+
full_trajectory=trajectory_objects,
152+
)
153+
if RewardType.COMMUNICATE in task.evaluation_criteria.reward_basis:
154+
communicate_reward_info = CommunicateEvaluator.calculate_reward(
155+
task=task,
156+
full_trajectory=trajectory_objects,
157+
)
158+
if RewardType.NL_ASSERTION in task.evaluation_criteria.reward_basis:
159+
nl_reward_info = NLAssertionsEvaluator.calculate_reward(
160+
task=task,
161+
full_trajectory=trajectory_objects,
162+
)
159163

160164
reward = 1.0
161165
env_bases = {RewardType.DB, RewardType.ENV_ASSERTION}
@@ -169,14 +173,14 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
169173
if env_reward_info.reward_breakdown is not None:
170174
reward_breakdown.update(env_reward_info.reward_breakdown)
171175
reward *= env_reward_info.reward
172-
# if task_reward_basis & action_bases:
173-
# if action_reward_info.reward_breakdown is not None:
174-
# reward_breakdown.update(action_reward_info.reward_breakdown)
175-
# reward *= action_reward_info.reward
176-
# if task_reward_basis & nl_bases:
177-
# if nl_reward_info.reward_breakdown is not None:
178-
# reward_breakdown.update(nl_reward_info.reward_breakdown)
179-
# reward *= nl_reward_info.reward
176+
if task_reward_basis & action_bases:
177+
if action_reward_info.reward_breakdown is not None:
178+
reward_breakdown.update(action_reward_info.reward_breakdown)
179+
reward *= action_reward_info.reward
180+
if task_reward_basis & nl_bases:
181+
if nl_reward_info.reward_breakdown is not None:
182+
reward_breakdown.update(nl_reward_info.reward_breakdown)
183+
reward *= nl_reward_info.reward
180184
if task_reward_basis & comm_bases:
181185
if communicate_reward_info.reward_breakdown is not None:
182186
reward_breakdown.update(communicate_reward_info.reward_breakdown)
@@ -188,27 +192,27 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
188192
if task_reward_basis & env_bases and env_reward_info.reward == 0:
189193
failed_reasons.append("❌ Environment/DB check failed")
190194

191-
# if task_reward_basis & action_bases and action_reward_info.reward == 0:
192-
# failed_actions = []
193-
# if hasattr(action_reward_info, "action_checks") and action_reward_info.action_checks:
194-
# failed_actions = [
195-
# f"{ac.action.name}({ac.action.arguments})"
196-
# for ac in action_reward_info.action_checks
197-
# if not ac.action_match
198-
# ]
199-
# if failed_actions:
200-
# failed_reasons.append(f"❌ Failed actions: {failed_actions}")
201-
# else:
202-
# failed_reasons.append("❌ Actions failed")
203-
204-
# if task_reward_basis & nl_bases and nl_reward_info.reward == 0:
205-
# failed_nl = []
206-
# if hasattr(nl_reward_info, "nl_assertions") and nl_reward_info.nl_assertions:
207-
# failed_nl = [nla.nl_assertion for nla in nl_reward_info.nl_assertions if not nla.met]
208-
# if failed_nl:
209-
# failed_reasons.append(f"❌ Failed NL assertions: {failed_nl}")
210-
# else:
211-
# failed_reasons.append("❌ NL Assertions failed")
195+
if task_reward_basis & action_bases and action_reward_info.reward == 0:
196+
failed_actions = []
197+
if hasattr(action_reward_info, "action_checks") and action_reward_info.action_checks:
198+
failed_actions = [
199+
f"{ac.action.name}({ac.action.arguments})"
200+
for ac in action_reward_info.action_checks
201+
if not ac.action_match
202+
]
203+
if failed_actions:
204+
failed_reasons.append(f"❌ Failed actions: {failed_actions}")
205+
else:
206+
failed_reasons.append("❌ Actions failed")
207+
208+
if task_reward_basis & nl_bases and nl_reward_info.reward == 0:
209+
failed_nl = []
210+
if hasattr(nl_reward_info, "nl_assertions") and nl_reward_info.nl_assertions:
211+
failed_nl = [nla.nl_assertion for nla in nl_reward_info.nl_assertions if not nla.met]
212+
if failed_nl:
213+
failed_reasons.append(f"❌ Failed NL assertions: {failed_nl}")
214+
else:
215+
failed_reasons.append("❌ NL Assertions failed")
212216

213217
if task_reward_basis & comm_bases and communicate_reward_info.reward == 0:
214218
failed_comm = []

0 commit comments

Comments
 (0)