Skip to content

Commit c0e3ed3

Browse files
Benny ChenBenny Chen
authored andcommitted
fix a few more
1 parent cdf92b5 commit c0e3ed3

11 files changed

Lines changed: 39 additions & 17 deletions

File tree

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
424424

425425
@evaluation_test(
426426
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
427-
# Provide a flat list per run (Sequence[InputMessagesParam]) to match signature
428-
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
427+
# Wrap dataset messages in an extra list to match Sequence[list[InputMessagesParam]]
428+
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
429429
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
430430
rollout_processor=SingleTurnRolloutProcessor(),
431431
aggregation_method="mean",
@@ -468,7 +468,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
468468

469469
@evaluation_test(
470470
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
471-
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
471+
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
472472
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
473473
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
474474
aggregation_method="mean",
@@ -511,7 +511,7 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
511511

512512
@evaluation_test(
513513
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
514-
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
514+
input_messages=[[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS]],
515515
rollout_processor_kwargs={"extra_body": {"reasoning_effort": "low"}},
516516
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEREFORMAT_ROWS),
517517
aggregation_method="mean",

eval_protocol/benchmarks/test_tau_bench_airline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
147147
messages = row.messages
148148

149149
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
150-
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
150+
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
151151
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
152152

153153
nl_assertions = evaluation_criteria.get("nl_assertions", [])

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
137137
messages = row.messages
138138

139139
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
140-
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
140+
dataset_info = (row.input_metadata.dataset_info or {}) if row.input_metadata else {}
141141
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
142142

143143
nl_assertions = evaluation_criteria.get("nl_assertions", [])

eval_protocol/execution/pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
8787

8888
try:
8989
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
90+
assert self.mcp_intermediary_client is not None
9091
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
9192

9293
if init_response.get("error"):
@@ -109,6 +110,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
109110
current_instance_id = inst_info_dict.get("instance_id")
110111
if not current_instance_id:
111112
continue
113+
assert self.mcp_intermediary_client is not None
112114
list_tools_result = await self.mcp_intermediary_client.list_backend_tools(
113115
rk_session_id=rk_session_id,
114116
instance_id=current_instance_id,
@@ -130,6 +132,7 @@ async def _discover_tools_for_sample(self, sample_id: str, mcp_backend_ref: str)
130132
if rk_session_id and self.mcp_intermediary_client:
131133
logger.info(f"Sample {sample_id}: Cleaning up tool discovery session '{rk_session_id}'.")
132134
try:
135+
assert self.mcp_intermediary_client is not None
133136
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
134137
except Exception as e_cl:
135138
logger.error(
@@ -276,6 +279,7 @@ async def _execute_mcp_agent_rollout(
276279

277280
try:
278281
backend_requests = [{"backend_name_ref": mcp_backend_ref, "num_instances": 1}]
282+
assert self.mcp_intermediary_client is not None
279283
init_response = await self.mcp_intermediary_client.initialize_session(backend_requests)
280284
if init_response.get("error"):
281285
raise RuntimeError(
@@ -331,6 +335,7 @@ async def _execute_mcp_agent_rollout(
331335
if not isinstance(tool_args_dict, dict):
332336
raise ValueError("Args not dict")
333337

338+
assert self.mcp_intermediary_client is not None
334339
exec_result = await self.mcp_intermediary_client.call_backend_tool(
335340
rk_session_id=rk_session_id,
336341
instance_id=primary_instance_id_for_agent_actions,
@@ -405,6 +410,7 @@ async def _execute_mcp_agent_rollout(
405410
state_capture_tool = self.cfg.agent.get("state_capture_tool")
406411
if state_capture_tool:
407412
state_capture_args = dict(self.cfg.agent.get("state_capture_args", OmegaConf.create({})))
413+
assert self.mcp_intermediary_client is not None
408414
final_filesystem_state_from_mcp = await self.mcp_intermediary_client.call_backend_tool(
409415
rk_session_id=rk_session_id,
410416
instance_id=primary_instance_id_for_agent_actions,
@@ -432,6 +438,7 @@ async def _execute_mcp_agent_rollout(
432438
}
433439
finally:
434440
if rk_session_id and self.mcp_intermediary_client:
441+
assert self.mcp_intermediary_client is not None
435442
await self.mcp_intermediary_client.cleanup_session(rk_session_id)
436443

437444
async def _process_single_sample(

eval_protocol/integrations/braintrust.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def scorer(input_val: Any, output: Any, expected: Any) -> float:
4848
ground_truth = None
4949
if expected is not None:
5050
ground_truth = [Message(role="assistant", content=str(expected))]
51-
result = reward_fn(messages=messages, ground_truth=ground_truth)
51+
result = reward_fn(messages, ground_truth)
5252
return float(result.score)
5353

5454
return scorer

eval_protocol/mcp/execution/manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,15 @@ def extract_text_content(msg_dict):
281281
# Generate user response using the simulator
282282
# Pass the assistant message content to drive the simulated user's next response
283283
last_assistant = user_simulator_messages[-1]
284+
# Convert last assistant message into a valid user input message for simulator
285+
from vendor.tau2.data_model.message import UserMessage as TauUserMessage
286+
287+
converted_user_prompt = (
288+
last_assistant.content if getattr(last_assistant, "content", None) else ""
289+
)
290+
converted_message = TauUserMessage(role="user", content=converted_user_prompt)
284291
user_message, user_simulator_state = await user_simulator.generate_next_message(
285-
last_assistant,
292+
converted_message,
286293
user_simulator_state,
287294
)
288295
user_content = user_message.content if user_message.content else ""

eval_protocol/mcp/simulation_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _discover_and_register_resources(self):
288288
if discovered_resources:
289289

290290
@self.app.read_resource()
291-
async def read_resource(uri: str):
291+
async def read_resource(uri: AnyUrl):
292292
# Get the current request context
293293
ctx = self.app.request_context
294294

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,17 @@ def __init__(self, content: str):
5656

5757
# Resolve the appropriate async invoke function
5858
if hasattr(target, "graph") and hasattr(target.graph, "ainvoke"):
59-
invoke_fn = target.graph.ainvoke
59+
60+
async def _invoke_graph(payload):
61+
return await target.graph.ainvoke(payload) # type: ignore[attr-defined]
62+
63+
invoke_fn = _invoke_graph
6064
elif hasattr(target, "ainvoke"):
61-
invoke_fn = target.ainvoke
65+
66+
async def _invoke_direct(payload):
67+
return await target.ainvoke(payload) # type: ignore[attr-defined]
68+
69+
invoke_fn = _invoke_direct
6270
elif callable(target):
6371

6472
async def _invoke_wrapper(payload):

eval_protocol/rewards/function_calling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def schema_jaccard_reward(
451451
DeprecationWarning,
452452
stacklevel=2,
453453
)
454-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
454+
return exact_tool_match_reward(messages, ground_truth, **kwargs)
455455

456456

457457
@reward_function
@@ -493,7 +493,7 @@ def llm_judge_reward(
493493
DeprecationWarning,
494494
stacklevel=2,
495495
)
496-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
496+
return exact_tool_match_reward(messages, ground_truth, **kwargs)
497497

498498

499499
@reward_function
@@ -537,7 +537,7 @@ def composite_function_call_reward(
537537
DeprecationWarning,
538538
stacklevel=2,
539539
)
540-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
540+
return exact_tool_match_reward(messages, ground_truth, **kwargs)
541541

542542

543543
# JSON schema reward functions have been moved to json_schema.py module

eval_protocol/rewards/json_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ def json_schema_reward_with_llm_judge(
290290
normalized_weights = {k: v / total_weight for k, v in weights.items()}
291291

292292
schema_result = json_schema_reward(
293-
messages=messages,
294-
ground_truth=ground_truth,
293+
messages,
294+
ground_truth,
295295
json_content=json_content,
296296
expected_schema=expected_schema,
297297
**kwargs,

0 commit comments

Comments
 (0)