Skip to content

Commit a6adeb5

Browse files
committed
changed tests
1 parent d3c4007 commit a6adeb5

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

eval_protocol/mcp_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ async def rollout(
272272
execution_manager = ExecutionManager()
273273

274274
return await execution_manager.execute_rollouts(
275-
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts
275+
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
276276
)
277277

278278

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ async def default_mcp_gym_rollout_processor(rows: List[EvaluationRow], config: R
219219
evaluation_rows = await ep.rollout(
220220
envs,
221221
policy=policy,
222+
evaluation_rows=rows,
222223
steps=config.steps,
223224
max_concurrent_rollouts=config.max_concurrent_rollouts
224225
)

tests/pytest/test_tau_bench_airline.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
6363
@evaluation_test(
6464
input_dataset=["tests/pytest/data/airline_dataset.jsonl"],
6565
dataset_adapter=tau_bench_airline_to_evaluation_row,
66-
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
67-
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
66+
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
67+
rollout_input_params=[{"temperature": 0.8, "max_tokens": 4096}],
6868
rollout_processor=default_mcp_gym_rollout_processor,
6969
threshold_of_success=0.4,
7070
num_runs=1,
7171
mode="pointwise",
72-
max_concurrent_rollouts=32,
72+
max_concurrent_rollouts=16,
7373
server_script_path="examples/tau2_mcp/server.py",
7474
)
7575
def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
@@ -80,12 +80,10 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
8080
extracts evaluation criteria from dataset entries. No wrapper needed!
8181
8282
Args:
83-
input_dataset: List of EvaluationRow objects from tau bench airline dataset
84-
input_params: Model parameters (temperature, max_tokens, etc.)
85-
model: Model identifier
83+
row: EvaluationRow object from tau bench airline dataset after rollout
8684
8785
Returns:
88-
List of evaluated EvaluationRow objects with scores and feedback
86+
EvaluationRow with tau2 evaluation results
8987
"""
9088
messages = row.messages
9189

@@ -131,9 +129,7 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
131129
communicate_info=communicate_info,
132130
actions=actions,
133131
reward_basis=[
134-
RewardType.NL_ASSERTION,
135132
RewardType.DB,
136-
RewardType.COMMUNICATE,
137133
RewardType.ACTION,
138134
],
139135
)

0 commit comments

Comments
 (0)