Skip to content

Commit 1a51eec

Browse files
committed
download db from internet
1 parent f972484 commit 1a51eec

3 files changed

Lines changed: 105 additions & 3 deletions

File tree

eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,41 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24-
from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH
24+
25+
def _get_airline_db_path():
26+
"""Get airline database path, downloading if necessary."""
27+
import os
28+
import tempfile
29+
import urllib.request
30+
from pathlib import Path
31+
32+
# Try local development path first
33+
try:
34+
from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH
35+
36+
if Path(AIRLINE_DB_PATH).exists():
37+
return AIRLINE_DB_PATH
38+
except ImportError:
39+
pass
40+
41+
# Use a cache directory in user's temp/cache area
42+
cache_dir = Path(tempfile.gettempdir()) / "tau2_bench_cache"
43+
cache_dir.mkdir(exist_ok=True)
44+
airline_db_path = cache_dir / "airline_db.json"
45+
46+
if not airline_db_path.exists():
47+
print(f"📥 Downloading airline database to {airline_db_path}...")
48+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json"
49+
try:
50+
urllib.request.urlretrieve(url, airline_db_path)
51+
print("✅ Download complete!")
52+
except Exception as e:
53+
raise RuntimeError(f"Failed to download airline database: {e}")
54+
55+
return airline_db_path
56+
57+
58+
AIRLINE_DB_PATH = _get_airline_db_path()
2559

2660

2761
class AirlineEnvironment:

eval_protocol/mcp_servers/tau2/mock_environment/mock_environment.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,41 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23-
from vendor.tau2.domains.mock.utils import MOCK_DB_PATH
23+
24+
def _get_mock_db_path():
25+
"""Get mock database path, downloading if necessary."""
26+
import os
27+
import tempfile
28+
import urllib.request
29+
from pathlib import Path
30+
31+
# Try local development path first
32+
try:
33+
from vendor.tau2.domains.mock.utils import MOCK_DB_PATH
34+
35+
if Path(MOCK_DB_PATH).exists():
36+
return MOCK_DB_PATH
37+
except ImportError:
38+
pass
39+
40+
# Use a cache directory in user's temp/cache area
41+
cache_dir = Path(tempfile.gettempdir()) / "tau2_bench_cache"
42+
cache_dir.mkdir(exist_ok=True)
43+
mock_db_path = cache_dir / "mock_db.json"
44+
45+
if not mock_db_path.exists():
46+
print(f"📥 Downloading mock database to {mock_db_path}...")
47+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/mock/db.json"
48+
try:
49+
urllib.request.urlretrieve(url, mock_db_path)
50+
print("✅ Download complete!")
51+
except Exception as e:
52+
raise RuntimeError(f"Failed to download mock database: {e}")
53+
54+
return mock_db_path
55+
56+
57+
MOCK_DB_PATH = _get_mock_db_path()
2458

2559

2660
class MockEnvironment:

eval_protocol/mcp_servers/tau2/retail_environment/retail_environment.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,41 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23-
from vendor.tau2.domains.retail.utils import RETAIL_DB_PATH
23+
24+
def _get_retail_db_path():
25+
"""Get retail database path, downloading if necessary."""
26+
import os
27+
import tempfile
28+
import urllib.request
29+
from pathlib import Path
30+
31+
# Try local development path first
32+
try:
33+
from vendor.tau2.domains.retail.utils import RETAIL_DB_PATH
34+
35+
if Path(RETAIL_DB_PATH).exists():
36+
return RETAIL_DB_PATH
37+
except ImportError:
38+
pass
39+
40+
# Use a cache directory in user's temp/cache area
41+
cache_dir = Path(tempfile.gettempdir()) / "tau2_bench_cache"
42+
cache_dir.mkdir(exist_ok=True)
43+
retail_db_path = cache_dir / "retail_db.json"
44+
45+
if not retail_db_path.exists():
46+
print(f"📥 Downloading retail database to {retail_db_path}...")
47+
url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/40f46d3540dc95aca145ddecb0464fdd9a1e8c15/data/tau2/domains/retail/db.json"
48+
try:
49+
urllib.request.urlretrieve(url, retail_db_path)
50+
print("✅ Download complete!")
51+
except Exception as e:
52+
raise RuntimeError(f"Failed to download retail database: {e}")
53+
54+
return retail_db_path
55+
56+
57+
RETAIL_DB_PATH = _get_retail_db_path()
2458

2559

2660
class RetailEnvironment:

0 commit comments

Comments
 (0)