File tree Expand file tree Collapse file tree
eval_protocol/mcp_servers/tau2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2121
2222logger = logging .getLogger (__name__ )
2323
24- from vendor .tau2 .domains .airline .utils import AIRLINE_DB_PATH
24+
25+ def _get_airline_db_path ():
26+ """Get airline database path, downloading if necessary."""
27+ import os
28+ import tempfile
29+ import urllib .request
30+ from pathlib import Path
31+
32+ # Try local development path first
33+ try :
34+ from vendor .tau2 .domains .airline .utils import AIRLINE_DB_PATH
35+
36+ if Path (AIRLINE_DB_PATH ).exists ():
37+ return AIRLINE_DB_PATH
38+ except ImportError :
39+ pass
40+
41+ # Use a cache directory in user's temp/cache area
42+ cache_dir = Path (tempfile .gettempdir ()) / "tau2_bench_cache"
43+ cache_dir .mkdir (exist_ok = True )
44+ airline_db_path = cache_dir / "airline_db.json"
45+
46+ if not airline_db_path .exists ():
47+ print (f"📥 Downloading airline database to { airline_db_path } ..." )
48+ url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/airline/db.json"
49+ try :
50+ urllib .request .urlretrieve (url , airline_db_path )
51+ print ("✅ Download complete!" )
52+ except Exception as e :
53+ raise RuntimeError (f"Failed to download airline database: { e } " )
54+
55+ return airline_db_path
56+
57+
58+ AIRLINE_DB_PATH = _get_airline_db_path ()
2559
2660
2761class AirlineEnvironment :
Original file line number Diff line number Diff line change 2020
2121logger = logging .getLogger (__name__ )
2222
23- from vendor .tau2 .domains .mock .utils import MOCK_DB_PATH
23+
24+ def _get_mock_db_path ():
25+ """Get mock database path, downloading if necessary."""
26+ import os
27+ import tempfile
28+ import urllib .request
29+ from pathlib import Path
30+
31+ # Try local development path first
32+ try :
33+ from vendor .tau2 .domains .mock .utils import MOCK_DB_PATH
34+
35+ if Path (MOCK_DB_PATH ).exists ():
36+ return MOCK_DB_PATH
37+ except ImportError :
38+ pass
39+
40+ # Use a cache directory in user's temp/cache area
41+ cache_dir = Path (tempfile .gettempdir ()) / "tau2_bench_cache"
42+ cache_dir .mkdir (exist_ok = True )
43+ mock_db_path = cache_dir / "mock_db.json"
44+
45+ if not mock_db_path .exists ():
46+ print (f"📥 Downloading mock database to { mock_db_path } ..." )
47+ url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/main/data/tau2/domains/mock/db.json"
48+ try :
49+ urllib .request .urlretrieve (url , mock_db_path )
50+ print ("✅ Download complete!" )
51+ except Exception as e :
52+ raise RuntimeError (f"Failed to download mock database: { e } " )
53+
54+ return mock_db_path
55+
56+
57+ MOCK_DB_PATH = _get_mock_db_path ()
2458
2559
2660class MockEnvironment :
Original file line number Diff line number Diff line change 2020
2121logger = logging .getLogger (__name__ )
2222
23- from vendor .tau2 .domains .retail .utils import RETAIL_DB_PATH
23+
24+ def _get_retail_db_path ():
25+ """Get retail database path, downloading if necessary."""
26+ import os
27+ import tempfile
28+ import urllib .request
29+ from pathlib import Path
30+
31+ # Try local development path first
32+ try :
33+ from vendor .tau2 .domains .retail .utils import RETAIL_DB_PATH
34+
35+ if Path (RETAIL_DB_PATH ).exists ():
36+ return RETAIL_DB_PATH
37+ except ImportError :
38+ pass
39+
40+ # Use a cache directory in user's temp/cache area
41+ cache_dir = Path (tempfile .gettempdir ()) / "tau2_bench_cache"
42+ cache_dir .mkdir (exist_ok = True )
43+ retail_db_path = cache_dir / "retail_db.json"
44+
45+ if not retail_db_path .exists ():
46+ print (f"📥 Downloading retail database to { retail_db_path } ..." )
47+ url = "https://raw.githubusercontent.com/sierra-research/tau2-bench/40f46d3540dc95aca145ddecb0464fdd9a1e8c15/data/tau2/domains/retail/db.json"
48+ try :
49+ urllib .request .urlretrieve (url , retail_db_path )
50+ print ("✅ Download complete!" )
51+ except Exception as e :
52+ raise RuntimeError (f"Failed to download retail database: { e } " )
53+
54+ return retail_db_path
55+
56+
57+ RETAIL_DB_PATH = _get_retail_db_path ()
2458
2559
2660class RetailEnvironment :
You can’t perform that action at this time.
0 commit comments