Skip to content

Commit 820440d

Browse files
committed
add airline
1 parent 02f834b commit 820440d

File tree

2 files changed

+309
-25
lines changed

2 files changed

+309
-25
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
"""
2+
Pytest test for tau bench airline evaluation using the evaluation_test decorator.
3+
4+
This test demonstrates how to use tau bench environments within the pytest framework,
5+
similar to the test_entire_airline_dataset test but integrated with the pytest evaluation system.
6+
"""
7+
8+
import json
9+
from datetime import datetime
10+
from pathlib import Path
11+
from typing import Any, Dict, List
12+
13+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
14+
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
15+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
16+
import litellm
17+
from vendor.tau2.data_model.message import (
18+
AssistantMessage,
19+
SystemMessage,
20+
ToolCall,
21+
ToolMessage,
22+
UserMessage,
23+
)
24+
from vendor.tau2.data_model.tasks import Action, EvaluationCriteria, RewardType, Task, UserScenario
25+
from vendor.tau2.evaluator.evaluator import EnvironmentEvaluator
26+
from vendor.tau2.evaluator.evaluator_action import ActionEvaluator
27+
from vendor.tau2.evaluator.evaluator_communicate import CommunicateEvaluator
28+
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
29+
from vendor.tau2.registry import registry
30+
from eval_protocol.mcp_servers.tau2 import get_server_script_path, get_system_prompt
31+
32+
33+
def _ensure_airline_database():
34+
"""Ensure airline database exists, downloading if necessary."""
35+
import urllib.request
36+
from pathlib import Path
37+
38+
# Get the vendor/tau2/data directory path
39+
try:
40+
from vendor.tau2.utils.utils import DATA_DIR
41+
42+
domains_dir = DATA_DIR / "domains"
43+
except ImportError:
44+
# Fallback: find vendor/tau2 relative to this file
45+
vendor_tau2 = Path(__file__).parent.parent.parent / "vendor" / "tau2"
46+
domains_dir = vendor_tau2 / "data" / "domains"
47+
48+
# Only download airline database for this test
49+
airline_db_path = domains_dir / "airline" / "db.json"
50+
if not airline_db_path.exists():
51+
print(f"📥 Downloading airline database to {airline_db_path}...")
52+
airline_db_path.parent.mkdir(parents=True, exist_ok=True)
53+
try:
54+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json"
55+
urllib.request.urlretrieve(url, airline_db_path)
56+
print(f"✅ Downloaded airline database ({airline_db_path.stat().st_size:,} bytes)")
57+
except Exception as e:
58+
print(f"❌ Failed to download airline database: {e}")
59+
raise
60+
61+
62+
# Ensure airline database is available before test runs
63+
_ensure_airline_database()
64+
65+
66+
def _get_airline_dataset_path() -> str:
67+
"""Get the airline dataset file path."""
68+
return str(Path(__file__).parent / "data" / "airline_dataset.jsonl")
69+
70+
71+
def _get_server_script_path() -> str:
72+
"""Get the tau2 mcp server script path."""
73+
from eval_protocol.mcp_servers.tau2 import get_server_script_path
74+
75+
return get_server_script_path()
76+
77+
78+
def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
79+
"""
80+
Convert entries from airline dataset to EvaluationRow objects.
81+
"""
82+
rows = []
83+
# Load system prompt from file so we can change it in one place
84+
from eval_protocol.mcp_servers.tau2 import get_system_prompt
85+
86+
domain = data[0]["environment_context"]["domain"]
87+
system_prompt = get_system_prompt(domain)
88+
89+
for row in data:
90+
eval_row = EvaluationRow(
91+
messages=[Message(role="system", content=system_prompt)],
92+
input_metadata=InputMetadata(
93+
row_id=row["id"],
94+
dataset_info={
95+
"environment_context": row["environment_context"],
96+
"user_simulation": row["user_simulation"],
97+
"evaluation_criteria": row["evaluation_criteria"],
98+
"user_prompt_template": row["user_prompt_template"],
99+
},
100+
),
101+
)
102+
103+
rows.append(eval_row)
104+
105+
return rows
106+
107+
108+
@evaluation_test(
109+
input_dataset=[_get_airline_dataset_path()],
110+
dataset_adapter=tau_bench_airline_to_evaluation_row,
111+
completion_params=[
112+
{
113+
"temperature": 0.8,
114+
"max_tokens": 4096,
115+
"extra_body": {"reasoning_effort": "medium"},
116+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
117+
}
118+
],
119+
rollout_processor=MCPGymRolloutProcessor(),
120+
rollout_processor_kwargs={"domain": "airline"},
121+
passed_threshold={"success": 0.4, "standard_error": 0.02},
122+
num_runs=8,
123+
mode="pointwise",
124+
max_concurrent_rollouts=50,
125+
server_script_path=_get_server_script_path(),
126+
exception_handler_config=ExceptionHandlerConfig(
127+
retryable_exceptions={
128+
litellm.RateLimitError,
129+
litellm.APIConnectionError,
130+
}
131+
),
132+
)
133+
def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
134+
"""
135+
Test tau bench airline evaluation using the pytest framework.
136+
137+
This test now uses the tau_bench_airline_reward function which automatically
138+
extracts evaluation criteria from dataset entries. No wrapper needed!
139+
140+
Args:
141+
row: EvaluationRow object from tau bench airline dataset after rollout
142+
143+
Returns:
144+
EvaluationRow with tau2 evaluation results
145+
"""
146+
messages = row.messages
147+
148+
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
149+
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
150+
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
151+
152+
nl_assertions = evaluation_criteria.get("nl_assertions", [])
153+
communicate_info = evaluation_criteria.get("communicate_info", [])
154+
actions = evaluation_criteria.get("actions", [])
155+
156+
# Convert Message objects directly to tau2-bench message objects
157+
trajectory_objects = []
158+
for msg in messages:
159+
role = msg.role
160+
content = msg.content
161+
162+
if role == "system":
163+
trajectory_objects.append(SystemMessage(role=role, content=content))
164+
elif role == "assistant":
165+
tau2_tool_calls = []
166+
if msg.tool_calls:
167+
for tool_call in msg.tool_calls:
168+
arguments = json.loads(tool_call.function.arguments)
169+
tau2_tool_call = ToolCall(
170+
id=tool_call.id,
171+
name=tool_call.function.name,
172+
arguments=arguments,
173+
)
174+
tau2_tool_calls.append(tau2_tool_call)
175+
176+
trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
177+
elif role == "user":
178+
trajectory_objects.append(UserMessage(role=role, content=content))
179+
elif role == "tool":
180+
tool_id = msg.tool_call_id
181+
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
182+
183+
reward = 1.0
184+
185+
evaluation_criteria = EvaluationCriteria(
186+
nl_assertions=nl_assertions,
187+
communicate_info=communicate_info,
188+
actions=actions,
189+
reward_basis=[ # Use this to adjust how to calculate reward. Tau2-bench uses DB and COMMUNICATE by default for airline tasks.
190+
RewardType.DB,
191+
RewardType.COMMUNICATE,
192+
],
193+
)
194+
195+
task = Task(
196+
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
197+
) # id and user_scenario are required for the Task type but not used in calculating reward
198+
199+
if RewardType.DB in task.evaluation_criteria.reward_basis:
200+
env_reward_info = EnvironmentEvaluator.calculate_reward(
201+
environment_constructor=registry.get_env_constructor("airline"),
202+
task=task,
203+
full_trajectory=trajectory_objects,
204+
)
205+
if RewardType.ACTION in task.evaluation_criteria.reward_basis:
206+
action_reward_info = ActionEvaluator.calculate_reward(
207+
task=task,
208+
full_trajectory=trajectory_objects,
209+
)
210+
if RewardType.COMMUNICATE in task.evaluation_criteria.reward_basis:
211+
communicate_reward_info = CommunicateEvaluator.calculate_reward(
212+
task=task,
213+
full_trajectory=trajectory_objects,
214+
)
215+
if RewardType.NL_ASSERTION in task.evaluation_criteria.reward_basis:
216+
nl_reward_info = NLAssertionsEvaluator.calculate_reward(
217+
task=task,
218+
full_trajectory=trajectory_objects,
219+
)
220+
221+
reward = 1.0
222+
env_bases = {RewardType.DB, RewardType.ENV_ASSERTION}
223+
action_bases = {RewardType.ACTION}
224+
nl_bases = {RewardType.NL_ASSERTION}
225+
comm_bases = {RewardType.COMMUNICATE}
226+
task_reward_basis = set(task.evaluation_criteria.reward_basis)
227+
228+
reward_breakdown = {}
229+
if task_reward_basis & env_bases:
230+
if env_reward_info.reward_breakdown is not None:
231+
reward_breakdown.update(env_reward_info.reward_breakdown)
232+
reward *= env_reward_info.reward
233+
if task_reward_basis & action_bases:
234+
if action_reward_info.reward_breakdown is not None:
235+
reward_breakdown.update(action_reward_info.reward_breakdown)
236+
reward *= action_reward_info.reward
237+
if task_reward_basis & nl_bases:
238+
if nl_reward_info.reward_breakdown is not None:
239+
reward_breakdown.update(nl_reward_info.reward_breakdown)
240+
reward *= nl_reward_info.reward
241+
if task_reward_basis & comm_bases:
242+
if communicate_reward_info.reward_breakdown is not None:
243+
reward_breakdown.update(communicate_reward_info.reward_breakdown)
244+
reward *= communicate_reward_info.reward
245+
246+
# Generate reason showing only failed components
247+
failed_reasons = []
248+
249+
if task_reward_basis & env_bases and env_reward_info.reward == 0:
250+
failed_reasons.append("❌ Environment/DB check failed")
251+
252+
if task_reward_basis & action_bases and action_reward_info.reward == 0:
253+
failed_actions = []
254+
if hasattr(action_reward_info, "action_checks") and action_reward_info.action_checks:
255+
failed_actions = [
256+
f"{ac.action.name}({ac.action.arguments})"
257+
for ac in action_reward_info.action_checks
258+
if not ac.action_match
259+
]
260+
if failed_actions:
261+
failed_reasons.append(f"❌ Failed actions: {failed_actions}")
262+
else:
263+
failed_reasons.append("❌ Actions failed")
264+
265+
if task_reward_basis & nl_bases and nl_reward_info.reward == 0:
266+
failed_nl = []
267+
if hasattr(nl_reward_info, "nl_assertions") and nl_reward_info.nl_assertions:
268+
failed_nl = [nla.nl_assertion for nla in nl_reward_info.nl_assertions if not nla.met]
269+
if failed_nl:
270+
failed_reasons.append(f"❌ Failed NL assertions: {failed_nl}")
271+
else:
272+
failed_reasons.append("❌ NL Assertions failed")
273+
274+
if task_reward_basis & comm_bases and communicate_reward_info.reward == 0:
275+
failed_comm = []
276+
if hasattr(communicate_reward_info, "communicate_checks") and communicate_reward_info.communicate_checks:
277+
failed_comm = [cc.info for cc in communicate_reward_info.communicate_checks if not cc.met]
278+
if failed_comm:
279+
failed_reasons.append(f"❌ Failed communication: {failed_comm}")
280+
else:
281+
failed_reasons.append("❌ Communication failed")
282+
283+
# If everything passed, show success
284+
reason = "\n".join(failed_reasons) if failed_reasons else "✅ All checks passed"
285+
286+
row.evaluation_result = EvaluateResult(
287+
score=reward,
288+
reason=reason,
289+
metrics={},
290+
)
291+
return row

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
from eval_protocol.mcp_servers.tau2 import get_server_script_path, get_system_prompt
3131

3232

33-
def _ensure_tau2_databases():
34-
"""Ensure tau2 database files exist, downloading if necessary."""
35-
import os
33+
def _ensure_retail_database():
34+
"""Ensure retail database exists, downloading if necessary."""
3635
import urllib.request
3736
from pathlib import Path
3837

@@ -46,28 +45,22 @@ def _ensure_tau2_databases():
4645
vendor_tau2 = Path(__file__).parent.parent.parent / "vendor" / "tau2"
4746
domains_dir = vendor_tau2 / "data" / "domains"
4847

49-
# Database files to download
50-
databases = {
51-
"retail/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/40f46d3540dc95aca145ddecb0464fdd9a1e8c15/data/tau2/domains/retail/db.json",
52-
"airline/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json",
53-
"mock/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/mock/db.json",
54-
}
55-
56-
for rel_path, url in databases.items():
57-
file_path = domains_dir / rel_path
58-
if not file_path.exists():
59-
print(f"📥 Downloading {rel_path} to {file_path}...")
60-
file_path.parent.mkdir(parents=True, exist_ok=True)
61-
try:
62-
urllib.request.urlretrieve(url, file_path)
63-
print(f"✅ Downloaded {rel_path} ({file_path.stat().st_size:,} bytes)")
64-
except Exception as e:
65-
print(f"❌ Failed to download {rel_path}: {e}")
66-
raise
67-
68-
69-
# Ensure databases are available before test runs
70-
_ensure_tau2_databases()
48+
# Only download retail database for this test
49+
retail_db_path = domains_dir / "retail" / "db.json"
50+
if not retail_db_path.exists():
51+
print(f"📥 Downloading retail database to {retail_db_path}...")
52+
retail_db_path.parent.mkdir(parents=True, exist_ok=True)
53+
try:
54+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/retail/db.json"
55+
urllib.request.urlretrieve(url, retail_db_path)
56+
print(f"✅ Downloaded retail database ({retail_db_path.stat().st_size:,} bytes)")
57+
except Exception as e:
58+
print(f"❌ Failed to download retail database: {e}")
59+
raise
60+
61+
62+
# Ensure retail database is available before test runs
63+
_ensure_retail_database()
7164

7265

7366
def _get_retail_dataset_path() -> str:

0 commit comments

Comments
 (0)