1212import warnings
1313from collections .abc import Iterator
1414from contextlib import contextmanager
15+ from copy import deepcopy
1516from dataclasses import dataclass , field , fields , replace
1617from io import StringIO
1718from pathlib import Path
1819from typing import Any , ClassVar , Literal , cast
1920from urllib .parse import urlparse
2021
22+ from openml .enums import APIVersion , ServerMode
23+
24+ from .__version__ import __version__
25+
2126logger = logging .getLogger (__name__ )
2227openml_logger = logging .getLogger ("openml" )
2328
2429
30+ _PROD_SERVERS : dict [APIVersion , dict [str , str | None ]] = {
31+ APIVersion .V1 : {
32+ "server" : "https://www.openml.org/api/v1/xml/" ,
33+ "apikey" : None ,
34+ },
35+ APIVersion .V2 : {
36+ "server" : None ,
37+ "apikey" : None ,
38+ },
39+ }
40+
41+ _TEST_SERVERS : dict [APIVersion , dict [str , str | None ]] = {
42+ APIVersion .V1 : {
43+ "server" : "https://test.openml.org/api/v1/xml/" ,
44+ "apikey" : "normaluser" ,
45+ },
46+ APIVersion .V2 : {
47+ "server" : None ,
48+ "apikey" : None ,
49+ },
50+ }
51+
52+ _TEST_SERVERS_LOCAL : dict [APIVersion , dict [str , str | None ]] = {
53+ APIVersion .V1 : {
54+ "server" : "http://localhost:8000/api/v1/xml/" ,
55+ "apikey" : "normaluser" ,
56+ },
57+ APIVersion .V2 : {
58+ "server" : "http://localhost:8082/" ,
59+ "apikey" : "AD000000000000000000000000000000" ,
60+ },
61+ }
62+
63+ _SERVERS_REGISTRY : dict [ServerMode , dict [APIVersion , dict [str , str | None ]]] = {
64+ ServerMode .PRODUCTION : _PROD_SERVERS ,
65+ ServerMode .TEST : (
66+ _TEST_SERVERS_LOCAL if os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" else _TEST_SERVERS
67+ ),
68+ }
69+
70+
71+ def _get_servers (mode : ServerMode ) -> dict [APIVersion , dict [str , str | None ]]:
72+ if mode not in ServerMode :
73+ raise ValueError (f'invalid mode="{ mode } " allowed modes: { ", " .join (list (ServerMode ))} ' )
74+ return deepcopy (_SERVERS_REGISTRY [mode ])
75+
76+
2577def _resolve_default_cache_dir () -> Path :
2678 user_defined_cache_dir = os .environ .get ("OPENML_CACHE_DIR" )
2779 if user_defined_cache_dir is not None :
@@ -57,19 +109,38 @@ def _resolve_default_cache_dir() -> Path:
57109class OpenMLConfig :
58110 """Dataclass storing the OpenML configuration."""
59111
60- apikey : str | None = ""
61- server : str = "https://www.openml.org/api/v1/xml"
112+ servers : dict [APIVersion , dict [str , str | None ]] = field (
113+ default_factory = lambda : _get_servers (ServerMode .PRODUCTION )
114+ )
115+ api_version : APIVersion = APIVersion .V1
116+ fallback_api_version : APIVersion | None = None
62117 cachedir : Path = field (default_factory = _resolve_default_cache_dir )
63118 avoid_duplicate_runs : bool = False
64119 retry_policy : Literal ["human" , "robot" ] = "human"
65120 connection_n_retries : int = 5
66121 show_progress : bool = False
67122
68- def __setattr__ (self , name : str , value : Any ) -> None :
69- if name == "apikey" and not isinstance (value , (type (None ), str )):
70- raise TypeError ("apikey must be a string or None" )
123+ @property
124+ def server (self ) -> str :
125+ server = self .servers [self .api_version ]["server" ]
126+ if server is None :
127+ servers_repr = {k .value : v for k , v in self .servers .items ()}
128+ raise ValueError (
129+ f'server found to be None for api_version="{ self .api_version } " in { servers_repr } '
130+ )
131+ return server
71132
72- super ().__setattr__ (name , value )
133+ @server .setter
134+ def server (self , value : str | None ) -> None :
135+ self .servers [self .api_version ]["server" ] = value
136+
137+ @property
138+ def apikey (self ) -> str | None :
139+ return self .servers [self .api_version ]["apikey" ]
140+
141+ @apikey .setter
142+ def apikey (self , value : str | None ) -> None :
143+ self .servers [self .api_version ]["apikey" ] = value
73144
74145
75146class OpenMLConfigManager :
@@ -81,9 +152,8 @@ def __init__(self) -> None:
81152
82153 self .OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
83154 self .OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET"
84- self ._TEST_SERVER_NORMAL_USER_KEY = "normaluser"
85155 self .OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY"
86- self .TEST_SERVER_URL = "https://test. openml.org"
156+ self ._HEADERS : dict [ str , str ] = { "user-agent" : f" openml-python/ { __version__ } " }
87157
88158 self ._config : OpenMLConfig = OpenMLConfig ()
89159 # for legacy test `test_non_writable_home`
@@ -116,7 +186,7 @@ def __setattr__(self, name: str, value: Any) -> None:
116186 "_examples" ,
117187 "OPENML_CACHE_DIR_ENV_VAR" ,
118188 "OPENML_SKIP_PARQUET_ENV_VAR" ,
119- "_TEST_SERVER_NORMAL_USER_KEY " ,
189+ "_HEADERS " ,
120190 }:
121191 return object .__setattr__ (self , name , value )
122192
@@ -127,6 +197,10 @@ def __setattr__(self, name: str, value: Any) -> None:
127197 object .__setattr__ (self , "_config" , replace (self ._config , ** {name : value }))
128198 return None
129199
200+ if name in ["server" , "apikey" ]:
201+ setattr (self ._config , name , value )
202+ return None
203+
130204 object .__setattr__ (self , name , value )
131205 return None
132206
@@ -190,6 +264,48 @@ def get_server_base_url(self) -> str:
190264 domain , _ = self ._config .server .split ("/api" , maxsplit = 1 )
191265 return domain .replace ("api" , "www" )
192266
267+ def _get_servers (self , mode : ServerMode ) -> dict [APIVersion , dict [str , str | None ]]:
268+ return _get_servers (mode )
269+
270+ def _set_servers (self , mode : ServerMode ) -> None :
271+ servers = self ._get_servers (mode )
272+ self ._config = replace (self ._config , servers = servers )
273+
274+ def get_production_servers (self ) -> dict [APIVersion , dict [str , str | None ]]:
275+ return self ._get_servers (mode = ServerMode .PRODUCTION )
276+
277+ def get_test_servers (self ) -> dict [APIVersion , dict [str , str | None ]]:
278+ return self ._get_servers (mode = ServerMode .TEST )
279+
280+ def use_production_servers (self ) -> None :
281+ self ._set_servers (mode = ServerMode .PRODUCTION )
282+
283+ def use_test_servers (self ) -> None :
284+ self ._set_servers (mode = ServerMode .TEST )
285+
286+ def set_api_version (
287+ self ,
288+ api_version : APIVersion ,
289+ fallback_api_version : APIVersion | None = None ,
290+ ) -> None :
291+ if api_version not in APIVersion :
292+ raise ValueError (
293+ f'invalid api_version="{ api_version } " '
294+ f"allowed versions: { ', ' .join (list (APIVersion ))} "
295+ )
296+
297+ if fallback_api_version is not None and fallback_api_version not in APIVersion :
298+ raise ValueError (
299+ f'invalid fallback_api_version="{ fallback_api_version } " '
300+ f"allowed versions: { ', ' .join (list (APIVersion ))} "
301+ )
302+
303+ self ._config = replace (
304+ self ._config ,
305+ api_version = api_version ,
306+ fallback_api_version = fallback_api_version ,
307+ )
308+
193309 def set_retry_policy (
194310 self , value : Literal ["human" , "robot" ], n_retries : int | None = None
195311 ) -> None :
@@ -317,13 +433,18 @@ def _setup(self, config: dict[str, Any] | None = None) -> None:
317433
318434 self ._config = replace (
319435 self ._config ,
320- apikey = config ["apikey" ],
321- server = config ["server" ],
436+ servers = config ["servers" ],
437+ api_version = config ["api_version" ],
438+ fallback_api_version = config ["fallback_api_version" ],
322439 show_progress = config ["show_progress" ],
323440 avoid_duplicate_runs = config ["avoid_duplicate_runs" ],
324441 retry_policy = config ["retry_policy" ],
325442 connection_n_retries = int (config ["connection_n_retries" ]),
326443 )
444+ if "server" in config :
445+ self ._config .server = config ["server" ]
446+ if "apikey" in config :
447+ self ._config .apikey = config ["apikey" ]
327448
328449 user_defined_cache_dir = os .environ .get (self .OPENML_CACHE_DIR_ENV_VAR )
329450 if user_defined_cache_dir is not None :
@@ -393,42 +514,35 @@ def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str,
393514class ConfigurationForExamples :
394515 """Allows easy switching to and from a test configuration, used for examples."""
395516
396- _last_used_server = None
397- _last_used_key = None
517+ _last_used_servers = None
398518 _start_last_called = False
399519
400520 def __init__ (self , manager : OpenMLConfigManager ):
401521 self ._manager = manager
402- self ._test_apikey = manager ._TEST_SERVER_NORMAL_USER_KEY
403- self ._test_server = f"{ manager .TEST_SERVER_URL } /api/v1/xml"
522+ self ._test_servers = manager .get_test_servers ()
404523
405524 def start_using_configuration_for_example (self ) -> None :
406525 """Sets the configuration to connect to the test server with valid apikey.
407526
408527 To configuration as was before this call is stored, and can be recovered
409528 by using the `stop_use_example_configuration` method.
410529 """
411- if (
412- self ._start_last_called
413- and self ._manager ._config .server == self ._test_server
414- and self ._manager ._config .apikey == self ._test_apikey
415- ):
530+ if self ._start_last_called and self ._manager ._config .servers == self ._test_servers :
416531 # Method is called more than once in a row without modifying the server or apikey.
417532 # We don't want to save the current test configuration as a last used configuration.
418533 return
419534
420- self ._last_used_server = self ._manager ._config .server
421- self ._last_used_key = self ._manager ._config .apikey
535+ self ._last_used_servers = self ._manager ._config .servers
422536 type(self )._start_last_called = True
423537
424538 # Test server key for examples
425539 self ._manager ._config = replace (
426540 self ._manager ._config ,
427- server = self ._test_server ,
428- apikey = self ._test_apikey ,
541+ servers = self ._test_servers ,
429542 )
543+ test_server = self ._test_servers [self ._manager ._config .api_version ]["server" ]
430544 warnings .warn (
431- f"Switching to the test server { self . _test_server } to not upload results to "
545+ f"Switching to the test server { test_server } to not upload results to "
432546 "the live server. Using the test server may result in reduced performance of the "
433547 "API!" ,
434548 stacklevel = 2 ,
@@ -446,8 +560,7 @@ def stop_using_configuration_for_example(self) -> None:
446560
447561 self ._manager ._config = replace (
448562 self ._manager ._config ,
449- server = cast ("str" , self ._last_used_server ),
450- apikey = cast ("str" , self ._last_used_key ),
563+ servers = cast ("dict[APIVersion, dict[str, str | None]]" , self ._last_used_servers ),
451564 )
452565 type(self )._start_last_called = False
453566
0 commit comments