Skip to content

Commit 6f092d8

Browse files
committed
Fix formatting and push database download changes
1 parent 1a51eec commit 6f092d8

File tree

4 files changed

+43
-3
lines changed

4 files changed

+43
-3
lines changed

eval_protocol/benchmarks/test_tau_bench_retail.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,46 @@
3030
from eval_protocol.mcp_servers.tau2 import get_server_script_path, get_system_prompt
3131

3232

33+
def _ensure_tau2_databases():
34+
"""Ensure tau2 database files exist, downloading if necessary."""
35+
import os
36+
import urllib.request
37+
from pathlib import Path
38+
39+
# Get the vendor/tau2/data directory path
40+
try:
41+
from vendor.tau2.utils.utils import DATA_DIR
42+
43+
domains_dir = DATA_DIR / "domains"
44+
except ImportError:
45+
# Fallback: find vendor/tau2 relative to this file
46+
vendor_tau2 = Path(__file__).parent.parent.parent / "vendor" / "tau2"
47+
domains_dir = vendor_tau2 / "data" / "domains"
48+
49+
# Database files to download
50+
databases = {
51+
"retail/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/40f46d3540dc95aca145ddecb0464fdd9a1e8c15/data/tau2/domains/retail/db.json",
52+
"airline/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json",
53+
"mock/db.json": "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/mock/db.json",
54+
}
55+
56+
for rel_path, url in databases.items():
57+
file_path = domains_dir / rel_path
58+
if not file_path.exists():
59+
print(f"📥 Downloading {rel_path} to {file_path}...")
60+
file_path.parent.mkdir(parents=True, exist_ok=True)
61+
try:
62+
urllib.request.urlretrieve(url, file_path)
63+
print(f"✅ Downloaded {rel_path} ({file_path.stat().st_size:,} bytes)")
64+
except Exception as e:
65+
print(f"❌ Failed to download {rel_path}: {e}")
66+
raise
67+
68+
69+
# Ensure databases are available before test runs
70+
_ensure_tau2_databases()
71+
72+
3373
def _get_retail_dataset_path() -> str:
3474
"""Get the retail dataset file path."""
3575
return str(Path(__file__).parent / "data" / "retail_dataset.jsonl")

eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _get_airline_db_path():
3535

3636
if Path(AIRLINE_DB_PATH).exists():
3737
return AIRLINE_DB_PATH
38-
except ImportError:
38+
except (ImportError, FileNotFoundError):
3939
pass
4040

4141
# Use a cache directory in user's temp/cache area

eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_mock_db_path():
3434

3535
if Path(MOCK_DB_PATH).exists():
3636
return MOCK_DB_PATH
37-
except ImportError:
37+
except (ImportError, FileNotFoundError):
3838
pass
3939

4040
# Use a cache directory in user's temp/cache area

eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_retail_db_path():
3434

3535
if Path(RETAIL_DB_PATH).exists():
3636
return RETAIL_DB_PATH
37-
except ImportError:
37+
except (ImportError, FileNotFoundError):
3838
pass
3939

4040
# Use a cache directory in user's temp/cache area

0 commit comments

Comments
 (0)