Skip to content

Commit d957771

Browse files
authored
Package Update (#107)
* add examples * move retail_dataset.jsonl * add data * change again * moving tau2 to be inside package * import error * download db from internet * Fix formatting and push database download changes * Add policy files to download and clean up package data * Add vendor.tau2 policy.md files to package data * add airline * Include all .md files in vendor.tau2 data to fix simulation_guidelines errors * fixing num_runs
1 parent 871675e commit d957771

File tree

18 files changed

+3969
-15
lines changed

18 files changed

+3969
-15
lines changed

eval_protocol/benchmarks/data/airline_dataset.jsonl

Lines changed: 50 additions & 0 deletions
Large diffs are not rendered by default.

eval_protocol/benchmarks/data/retail_dataset.jsonl

Lines changed: 114 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
"""
2+
Pytest test for tau bench airline evaluation using the evaluation_test decorator.
3+
4+
This test demonstrates how to use tau bench environments within the pytest framework,
5+
similar to the test_entire_airline_dataset test but integrated with the pytest evaluation system.
6+
"""
7+
8+
import json
9+
from datetime import datetime
10+
from pathlib import Path
11+
from typing import Any, Dict, List
12+
13+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
14+
from eval_protocol.pytest import evaluation_test, ExceptionHandlerConfig
15+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
16+
import litellm
17+
from vendor.tau2.data_model.message import (
18+
AssistantMessage,
19+
SystemMessage,
20+
ToolCall,
21+
ToolMessage,
22+
UserMessage,
23+
)
24+
from vendor.tau2.data_model.tasks import Action, EvaluationCriteria, RewardType, Task, UserScenario
25+
from vendor.tau2.evaluator.evaluator import EnvironmentEvaluator
26+
from vendor.tau2.evaluator.evaluator_action import ActionEvaluator
27+
from vendor.tau2.evaluator.evaluator_communicate import CommunicateEvaluator
28+
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
29+
from vendor.tau2.registry import registry
30+
from eval_protocol.mcp_servers.tau2 import get_server_script_path, get_system_prompt
31+
32+
33+
def _ensure_airline_database():
34+
"""Ensure airline database exists, downloading if necessary."""
35+
import urllib.request
36+
from pathlib import Path
37+
38+
# Get the vendor/tau2/data directory path
39+
try:
40+
from vendor.tau2.utils.utils import DATA_DIR
41+
42+
domains_dir = DATA_DIR / "domains"
43+
except ImportError:
44+
# Fallback: find vendor/tau2 relative to this file
45+
vendor_tau2 = Path(__file__).parent.parent.parent / "vendor" / "tau2"
46+
domains_dir = vendor_tau2 / "data" / "domains"
47+
48+
# Only download airline database for this test
49+
airline_db_path = domains_dir / "airline" / "db.json"
50+
if not airline_db_path.exists():
51+
print(f"📥 Downloading airline database to {airline_db_path}...")
52+
airline_db_path.parent.mkdir(parents=True, exist_ok=True)
53+
try:
54+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json"
55+
urllib.request.urlretrieve(url, airline_db_path)
56+
print(f"✅ Downloaded airline database ({airline_db_path.stat().st_size:,} bytes)")
57+
except Exception as e:
58+
print(f"❌ Failed to download airline database: {e}")
59+
raise
60+
61+
62+
# Ensure airline database is available before test runs
63+
_ensure_airline_database()
64+
65+
66+
def _get_airline_dataset_path() -> str:
67+
"""Get the airline dataset file path."""
68+
return str(Path(__file__).parent / "data" / "airline_dataset.jsonl")
69+
70+
71+
def _get_server_script_path() -> str:
72+
"""Get the tau2 mcp server script path."""
73+
from eval_protocol.mcp_servers.tau2 import get_server_script_path
74+
75+
return get_server_script_path()
76+
77+
78+
def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
79+
"""
80+
Convert entries from airline dataset to EvaluationRow objects.
81+
"""
82+
rows = []
83+
# Load system prompt from file so we can change it in one place
84+
from eval_protocol.mcp_servers.tau2 import get_system_prompt
85+
86+
domain = data[0]["environment_context"]["domain"]
87+
system_prompt = get_system_prompt(domain)
88+
89+
for row in data:
90+
eval_row = EvaluationRow(
91+
messages=[Message(role="system", content=system_prompt)],
92+
input_metadata=InputMetadata(
93+
row_id=row["id"],
94+
dataset_info={
95+
"environment_context": row["environment_context"],
96+
"user_simulation": row["user_simulation"],
97+
"evaluation_criteria": row["evaluation_criteria"],
98+
"user_prompt_template": row["user_prompt_template"],
99+
},
100+
),
101+
)
102+
103+
rows.append(eval_row)
104+
105+
return rows
106+
107+
108+
@evaluation_test(
109+
input_dataset=[_get_airline_dataset_path()],
110+
dataset_adapter=tau_bench_airline_to_evaluation_row,
111+
completion_params=[
112+
{
113+
"temperature": 0.8,
114+
"max_tokens": 4096,
115+
"extra_body": {"reasoning_effort": "medium"},
116+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
117+
}
118+
],
119+
rollout_processor=MCPGymRolloutProcessor(),
120+
rollout_processor_kwargs={"domain": "airline"},
121+
passed_threshold={"success": 0.4, "standard_error": 0.02},
122+
num_runs=8,
123+
mode="pointwise",
124+
max_concurrent_rollouts=50,
125+
server_script_path=_get_server_script_path(),
126+
exception_handler_config=ExceptionHandlerConfig(
127+
retryable_exceptions={
128+
litellm.RateLimitError,
129+
litellm.APIConnectionError,
130+
}
131+
),
132+
)
133+
def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
134+
"""
135+
Test tau bench airline evaluation using the pytest framework.
136+
137+
This test now uses the tau_bench_airline_reward function which automatically
138+
extracts evaluation criteria from dataset entries. No wrapper needed!
139+
140+
Args:
141+
row: EvaluationRow object from tau bench airline dataset after rollout
142+
143+
Returns:
144+
EvaluationRow with tau2 evaluation results
145+
"""
146+
messages = row.messages
147+
148+
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
149+
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
150+
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
151+
152+
nl_assertions = evaluation_criteria.get("nl_assertions", [])
153+
communicate_info = evaluation_criteria.get("communicate_info", [])
154+
actions = evaluation_criteria.get("actions", [])
155+
156+
# Convert Message objects directly to tau2-bench message objects
157+
trajectory_objects = []
158+
for msg in messages:
159+
role = msg.role
160+
content = msg.content
161+
162+
if role == "system":
163+
trajectory_objects.append(SystemMessage(role=role, content=content))
164+
elif role == "assistant":
165+
tau2_tool_calls = []
166+
if msg.tool_calls:
167+
for tool_call in msg.tool_calls:
168+
arguments = json.loads(tool_call.function.arguments)
169+
tau2_tool_call = ToolCall(
170+
id=tool_call.id,
171+
name=tool_call.function.name,
172+
arguments=arguments,
173+
)
174+
tau2_tool_calls.append(tau2_tool_call)
175+
176+
trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
177+
elif role == "user":
178+
trajectory_objects.append(UserMessage(role=role, content=content))
179+
elif role == "tool":
180+
tool_id = msg.tool_call_id
181+
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
182+
183+
reward = 1.0
184+
185+
evaluation_criteria = EvaluationCriteria(
186+
nl_assertions=nl_assertions,
187+
communicate_info=communicate_info,
188+
actions=actions,
189+
reward_basis=[ # Use this to adjust how to calculate reward. Tau2-bench uses DB and COMMUNICATE by default for airline tasks.
190+
RewardType.DB,
191+
RewardType.COMMUNICATE,
192+
],
193+
)
194+
195+
task = Task(
196+
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
197+
) # id and user_scenario are required for the Task type but not used in calculating reward
198+
199+
if RewardType.DB in task.evaluation_criteria.reward_basis:
200+
env_reward_info = EnvironmentEvaluator.calculate_reward(
201+
environment_constructor=registry.get_env_constructor("airline"),
202+
task=task,
203+
full_trajectory=trajectory_objects,
204+
)
205+
if RewardType.ACTION in task.evaluation_criteria.reward_basis:
206+
action_reward_info = ActionEvaluator.calculate_reward(
207+
task=task,
208+
full_trajectory=trajectory_objects,
209+
)
210+
if RewardType.COMMUNICATE in task.evaluation_criteria.reward_basis:
211+
communicate_reward_info = CommunicateEvaluator.calculate_reward(
212+
task=task,
213+
full_trajectory=trajectory_objects,
214+
)
215+
if RewardType.NL_ASSERTION in task.evaluation_criteria.reward_basis:
216+
nl_reward_info = NLAssertionsEvaluator.calculate_reward(
217+
task=task,
218+
full_trajectory=trajectory_objects,
219+
)
220+
221+
reward = 1.0
222+
env_bases = {RewardType.DB, RewardType.ENV_ASSERTION}
223+
action_bases = {RewardType.ACTION}
224+
nl_bases = {RewardType.NL_ASSERTION}
225+
comm_bases = {RewardType.COMMUNICATE}
226+
task_reward_basis = set(task.evaluation_criteria.reward_basis)
227+
228+
reward_breakdown = {}
229+
if task_reward_basis & env_bases:
230+
if env_reward_info.reward_breakdown is not None:
231+
reward_breakdown.update(env_reward_info.reward_breakdown)
232+
reward *= env_reward_info.reward
233+
if task_reward_basis & action_bases:
234+
if action_reward_info.reward_breakdown is not None:
235+
reward_breakdown.update(action_reward_info.reward_breakdown)
236+
reward *= action_reward_info.reward
237+
if task_reward_basis & nl_bases:
238+
if nl_reward_info.reward_breakdown is not None:
239+
reward_breakdown.update(nl_reward_info.reward_breakdown)
240+
reward *= nl_reward_info.reward
241+
if task_reward_basis & comm_bases:
242+
if communicate_reward_info.reward_breakdown is not None:
243+
reward_breakdown.update(communicate_reward_info.reward_breakdown)
244+
reward *= communicate_reward_info.reward
245+
246+
# Generate reason showing only failed components
247+
failed_reasons = []
248+
249+
if task_reward_basis & env_bases and env_reward_info.reward == 0:
250+
failed_reasons.append("❌ Environment/DB check failed")
251+
252+
if task_reward_basis & action_bases and action_reward_info.reward == 0:
253+
failed_actions = []
254+
if hasattr(action_reward_info, "action_checks") and action_reward_info.action_checks:
255+
failed_actions = [
256+
f"{ac.action.name}({ac.action.arguments})"
257+
for ac in action_reward_info.action_checks
258+
if not ac.action_match
259+
]
260+
if failed_actions:
261+
failed_reasons.append(f"❌ Failed actions: {failed_actions}")
262+
else:
263+
failed_reasons.append("❌ Actions failed")
264+
265+
if task_reward_basis & nl_bases and nl_reward_info.reward == 0:
266+
failed_nl = []
267+
if hasattr(nl_reward_info, "nl_assertions") and nl_reward_info.nl_assertions:
268+
failed_nl = [nla.nl_assertion for nla in nl_reward_info.nl_assertions if not nla.met]
269+
if failed_nl:
270+
failed_reasons.append(f"❌ Failed NL assertions: {failed_nl}")
271+
else:
272+
failed_reasons.append("❌ NL Assertions failed")
273+
274+
if task_reward_basis & comm_bases and communicate_reward_info.reward == 0:
275+
failed_comm = []
276+
if hasattr(communicate_reward_info, "communicate_checks") and communicate_reward_info.communicate_checks:
277+
failed_comm = [cc.info for cc in communicate_reward_info.communicate_checks if not cc.met]
278+
if failed_comm:
279+
failed_reasons.append(f"❌ Failed communication: {failed_comm}")
280+
else:
281+
failed_reasons.append("❌ Communication failed")
282+
283+
# If everything passed, show success
284+
reason = "\n".join(failed_reasons) if failed_reasons else "✅ All checks passed"
285+
286+
row.evaluation_result = EvaluateResult(
287+
score=reward,
288+
reason=reason,
289+
metrics={},
290+
)
291+
return row

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,55 @@
2727
from vendor.tau2.evaluator.evaluator_communicate import CommunicateEvaluator
2828
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
2929
from vendor.tau2.registry import registry
30+
from eval_protocol.mcp_servers.tau2 import get_server_script_path, get_system_prompt
3031

3132

32-
def _get_retail_dataset_path() -> str:
33-
"""Get the retail dataset file path."""
34-
return str(Path(__file__).parent.parent.parent / "tests" / "pytest" / "data" / "retail_dataset.jsonl")
33+
def _ensure_retail_database():
34+
"""Ensure retail database exists, downloading if necessary."""
35+
import urllib.request
36+
from pathlib import Path
37+
38+
# Get the vendor/tau2/data directory path
39+
try:
40+
from vendor.tau2.utils.utils import DATA_DIR
3541

42+
domains_dir = DATA_DIR / "domains"
43+
except ImportError:
44+
# Fallback: find vendor/tau2 relative to this file
45+
vendor_tau2 = Path(__file__).parent.parent.parent / "vendor" / "tau2"
46+
domains_dir = vendor_tau2 / "data" / "domains"
3647

37-
def _get_server_script_path() -> str:
38-
"""Get the tau2 mcp server script path."""
39-
return str(Path(__file__).parent.parent.parent / "examples" / "tau2_mcp" / "server.py")
48+
# Only download retail database for this test
49+
retail_db_path = domains_dir / "retail" / "db.json"
50+
if not retail_db_path.exists():
51+
print(f"📥 Downloading retail database to {retail_db_path}...")
52+
retail_db_path.parent.mkdir(parents=True, exist_ok=True)
53+
try:
54+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/retail/db.json"
55+
urllib.request.urlretrieve(url, retail_db_path)
56+
print(f"✅ Downloaded retail database ({retail_db_path.stat().st_size:,} bytes)")
57+
except Exception as e:
58+
print(f"❌ Failed to download retail database: {e}")
59+
raise
60+
61+
62+
# Ensure retail database is available before test runs
63+
_ensure_retail_database()
64+
65+
66+
def _get_retail_dataset_path() -> str:
67+
"""Get the retail dataset file path."""
68+
return str(Path(__file__).parent / "data" / "retail_dataset.jsonl")
4069

4170

4271
def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
4372
"""
4473
Convert entries from retail dataset to EvaluationRow objects.
4574
"""
4675
rows = []
47-
test_dir = Path(__file__).parent.parent.parent / "examples" / "tau2_mcp" / "tests"
48-
4976
# Load system prompt from file so we can change it in one place
5077
domain = data[0]["environment_context"]["domain"]
51-
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"
52-
53-
with open(prompt_file, "r") as f:
54-
system_prompt = f.read().strip()
78+
system_prompt = get_system_prompt(domain)
5579

5680
for row in data:
5781
eval_row = EvaluationRow(
@@ -87,7 +111,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
87111
num_runs=8,
88112
mode="pointwise",
89113
max_concurrent_rollouts=50,
90-
server_script_path=_get_server_script_path(),
114+
server_script_path=get_server_script_path(),
91115
exception_handler_config=ExceptionHandlerConfig(
92116
retryable_exceptions={
93117
litellm.RateLimitError,

eval_protocol/mcp_servers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)