1818from typing import Any , ClassVar , Literal , cast
1919from urllib .parse import urlparse
2020
21+ from openml .enums import APIVersion
22+
2123logger = logging .getLogger (__name__ )
2224openml_logger = logging .getLogger ("openml" )
2325
2426
27+ SERVERS_REGISTRY : dict [str , dict [APIVersion , dict [str , str | None ]]] = {
28+ "production" : {
29+ APIVersion .V1 : {
30+ "server" : "https://www.openml.org/api/v1/xml/" ,
31+ "apikey" : None ,
32+ },
33+ APIVersion .V2 : {
34+ "server" : None ,
35+ "apikey" : None ,
36+ },
37+ },
38+ "test" : {
39+ APIVersion .V1 : {
40+ "server" : "https://test.openml.org/api/v1/xml/" ,
41+ "apikey" : "normaluser" ,
42+ },
43+ APIVersion .V2 : {
44+ "server" : None ,
45+ "apikey" : None ,
46+ },
47+ },
48+ "local" : {
49+ APIVersion .V1 : {
50+ "server" : "http://localhost:8000/api/v1/xml/" ,
51+ "apikey" : "normaluser" ,
52+ },
53+ APIVersion .V2 : {
54+ "server" : "http://localhost:8002/api/v1/xml/" ,
55+ "apikey" : "normaluser" ,
56+ },
57+ },
58+ }
59+
60+
2561def _resolve_default_cache_dir () -> Path :
2662 user_defined_cache_dir = os .environ .get ("OPENML_CACHE_DIR" )
2763 if user_defined_cache_dir is not None :
@@ -57,19 +93,38 @@ def _resolve_default_cache_dir() -> Path:
5793class OpenMLConfig :
5894 """Dataclass storing the OpenML configuration."""
5995
60- apikey : str | None = ""
61- server : str = "https://www.openml.org/api/v1/xml"
96+ servers : dict [APIVersion , dict [str , str | None ]] = field (
97+ default_factory = lambda : SERVERS_REGISTRY ["production" ]
98+ )
99+ api_version : APIVersion = APIVersion .V1
100+ fallback_api_version : APIVersion | None = None
62101 cachedir : Path = field (default_factory = _resolve_default_cache_dir )
63102 avoid_duplicate_runs : bool = False
64103 retry_policy : Literal ["human" , "robot" ] = "human"
65104 connection_n_retries : int = 5
66105 show_progress : bool = False
67106
68- def __setattr__ (self , name : str , value : Any ) -> None :
69- if name == "apikey" and value is not None and not isinstance (value , str ):
70- raise ValueError ("apikey must be a string or None" )
107+ @property
108+ def server (self ) -> str :
109+ server = self .servers [self .api_version ]["server" ]
110+ if server is None :
111+ servers_repr = {k .value : v for k , v in self .servers .items ()}
112+ raise ValueError (
113+ f'server found to be None for api_version="{ self .api_version } " in { servers_repr } '
114+ )
115+ return server
116+
117+ @server .setter
118+ def server (self , value : str | None ) -> None :
119+ self .servers [self .api_version ]["server" ] = value
120+
121+ @property
122+ def apikey (self ) -> str | None :
123+ return self .servers [self .api_version ]["apikey" ]
71124
72- super ().__setattr__ (name , value )
125+ @apikey .setter
126+ def apikey (self , value : str | None ) -> None :
127+ self .servers [self .api_version ]["apikey" ] = value
73128
74129
75130class OpenMLConfigManager :
@@ -79,11 +134,14 @@ def __init__(self) -> None:
79134 self .console_handler : logging .StreamHandler | None = None
80135 self .file_handler : logging .handlers .RotatingFileHandler | None = None
81136
137+ server_test_v1_apikey = SERVERS_REGISTRY ["test" ][APIVersion .V1 ]["apikey" ]
138+ server_test_v1_server = SERVERS_REGISTRY ["test" ][APIVersion .V1 ]["server" ]
139+
82140 self .OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
83141 self .OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET"
84- self ._TEST_SERVER_NORMAL_USER_KEY = "normaluser"
142+ self ._TEST_SERVER_NORMAL_USER_KEY = server_test_v1_apikey
85143 self .OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY"
86- self .TEST_SERVER_URL = "https://test.openml.org"
144+ self .TEST_SERVER_URL = cast ( "str" , server_test_v1_server ). split ( "/api/v1/xml" )[ 0 ]
87145
88146 self ._config : OpenMLConfig = OpenMLConfig ()
89147 # for legacy test `test_non_writable_home`
@@ -127,6 +185,10 @@ def __setattr__(self, name: str, value: Any) -> None:
127185 object .__setattr__ (self , "_config" , replace (self ._config , ** {name : value }))
128186 return None
129187
188+ if name in ["server" , "apikey" ]:
189+ setattr (self ._config , name , value )
190+ return None
191+
130192 object .__setattr__ (self , name , value )
131193 return None
132194
@@ -190,6 +252,21 @@ def get_server_base_url(self) -> str:
190252 domain , _ = self ._config .server .split ("/api" , maxsplit = 1 )
191253 return domain .replace ("api" , "www" )
192254
255+ def set_server_mode (self , mode : str ) -> None :
256+ if mode not in SERVERS_REGISTRY :
257+ raise ValueError (
258+ f'invalid mode="{ mode } " allowed modes: { ", " .join (list (SERVERS_REGISTRY .keys ()))} '
259+ )
260+ self ._config = replace (self ._config , servers = SERVERS_REGISTRY [mode ])
261+
262+ def set_api_version (self , api_version : APIVersion ) -> None :
263+ if api_version not in APIVersion :
264+ raise ValueError (
265+ f'invalid api_version="{ api_version } " '
266+ f"allowed versions: { ', ' .join (list (APIVersion ))} "
267+ )
268+ self ._config = replace (self ._config , api_version = api_version )
269+
193270 def set_retry_policy (
194271 self , value : Literal ["human" , "robot" ], n_retries : int | None = None
195272 ) -> None :
@@ -317,13 +394,18 @@ def _setup(self, config: dict[str, Any] | None = None) -> None:
317394
318395 self ._config = replace (
319396 self ._config ,
320- apikey = config ["apikey" ],
321- server = config ["server" ],
397+ servers = config ["servers" ],
398+ api_version = config ["api_version" ],
399+ fallback_api_version = config ["fallback_api_version" ],
322400 show_progress = config ["show_progress" ],
323401 avoid_duplicate_runs = config ["avoid_duplicate_runs" ],
324402 retry_policy = config ["retry_policy" ],
325403 connection_n_retries = int (config ["connection_n_retries" ]),
326404 )
405+ if "server" in config :
406+ self ._config .server = config ["server" ]
407+ if "apikey" in config :
408+ self ._config .apikey = config ["apikey" ]
327409
328410 user_defined_cache_dir = os .environ .get (self .OPENML_CACHE_DIR_ENV_VAR )
329411 if user_defined_cache_dir is not None :
@@ -393,42 +475,34 @@ def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str,
393475class ConfigurationForExamples :
394476 """Allows easy switching to and from a test configuration, used for examples."""
395477
396- _last_used_server = None
397- _last_used_key = None
478+ _last_used_servers = None
398479 _start_last_called = False
399480
400481 def __init__ (self , manager : OpenMLConfigManager ):
401482 self ._manager = manager
402- self ._test_apikey = manager ._TEST_SERVER_NORMAL_USER_KEY
403- self ._test_server = f"{ manager .TEST_SERVER_URL } /api/v1/xml"
483+ self ._test_servers = SERVERS_REGISTRY ["test" ]
404484
405485 def start_using_configuration_for_example (self ) -> None :
406486 """Sets the configuration to connect to the test server with valid apikey.
407487
408488 To configuration as was before this call is stored, and can be recovered
409489 by using the `stop_use_example_configuration` method.
410490 """
411- if (
412- self ._start_last_called
413- and self ._manager ._config .server == self ._test_server
414- and self ._manager ._config .apikey == self ._test_apikey
415- ):
491+ if self ._start_last_called and self ._manager ._config .servers == self ._test_servers :
416492 # Method is called more than once in a row without modifying the server or apikey.
417493 # We don't want to save the current test configuration as a last used configuration.
418494 return
419495
420- self ._last_used_server = self ._manager ._config .server
421- self ._last_used_key = self ._manager ._config .apikey
496+ self ._last_used_servers = self ._manager ._config .servers
422497 type(self )._start_last_called = True
423498
424499 # Test server key for examples
425500 self ._manager ._config = replace (
426501 self ._manager ._config ,
427- server = self ._test_server ,
428- apikey = self ._test_apikey ,
502+ servers = self ._test_servers ,
429503 )
430504 warnings .warn (
431- f"Switching to the test server { self ._test_server } to not upload results to "
505+ f"Switching to the test servers { self ._test_servers } to not upload results to "
432506 "the live server. Using the test server may result in reduced performance of the "
433507 "API!" ,
434508 stacklevel = 2 ,
@@ -446,8 +520,7 @@ def stop_using_configuration_for_example(self) -> None:
446520
447521 self ._manager ._config = replace (
448522 self ._manager ._config ,
449- server = cast ("str" , self ._last_used_server ),
450- apikey = cast ("str" , self ._last_used_key ),
523+ servers = cast ("dict[APIVersion, dict[str, str | None]]" , self ._last_used_servers ),
451524 )
452525 type(self )._start_last_called = False
453526
0 commit comments