Skip to content

Commit aba3d3e

Browse files
committed
update _config.py
1 parent 3d86b18 commit aba3d3e

4 files changed

Lines changed: 110 additions & 41 deletions

File tree

openml/_api/clients/http.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,21 +212,21 @@ def __init__(
212212

213213
@property
214214
def server(self) -> str:
215-
server = openml.config.SERVERS[self.api_version]["server"]
215+
server = openml.config.servers[self.api_version]["server"]
216216
if server is None:
217+
servers_repr = {k.value: v for k, v in openml.config.servers}
217218
raise ValueError(
218-
f"server found to be None for api_version={self.api_version}"
219-
f" in {openml.config.SERVERS}"
219+
f'server found to be None for api_version="{self.api_version}" in {servers_repr}'
220220
)
221-
return server
221+
return cast("str", server)
222222

223223
@property
224224
def api_key(self) -> str | None:
225-
return openml.config.SERVERS[self.api_version]["apikey"]
225+
return cast("str | None", openml.config.SERVERS[self.api_version]["apikey"])
226226

227227
@property
228228
def retries(self) -> int:
229-
return openml.config.connection_n_retries
229+
return cast("int", openml.config.connection_n_retries)
230230

231231
@property
232232
def retry_policy(self) -> RetryPolicy:

openml/_config.py

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,46 @@
1818
from typing import Any, ClassVar, Literal, cast
1919
from urllib.parse import urlparse
2020

21+
from openml.enums import APIVersion
22+
2123
logger = logging.getLogger(__name__)
2224
openml_logger = logging.getLogger("openml")
2325

2426

27+
SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = {
28+
"production": {
29+
APIVersion.V1: {
30+
"server": "https://www.openml.org/api/v1/xml/",
31+
"apikey": None,
32+
},
33+
APIVersion.V2: {
34+
"server": None,
35+
"apikey": None,
36+
},
37+
},
38+
"test": {
39+
APIVersion.V1: {
40+
"server": "https://test.openml.org/api/v1/xml/",
41+
"apikey": "normaluser",
42+
},
43+
APIVersion.V2: {
44+
"server": None,
45+
"apikey": None,
46+
},
47+
},
48+
"local": {
49+
APIVersion.V1: {
50+
"server": "http://localhost:8000/api/v1/xml/",
51+
"apikey": "normaluser",
52+
},
53+
APIVersion.V2: {
54+
"server": "http://localhost:8002/api/v1/xml/",
55+
"apikey": "normaluser",
56+
},
57+
},
58+
}
59+
60+
2561
def _resolve_default_cache_dir() -> Path:
2662
user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR")
2763
if user_defined_cache_dir is not None:
@@ -57,19 +93,38 @@ def _resolve_default_cache_dir() -> Path:
5793
class OpenMLConfig:
5894
"""Dataclass storing the OpenML configuration."""
5995

60-
apikey: str | None = ""
61-
server: str = "https://www.openml.org/api/v1/xml"
96+
servers: dict[APIVersion, dict[str, str | None]] = field(
97+
default_factory=lambda: SERVERS_REGISTRY["production"]
98+
)
99+
api_version: APIVersion = APIVersion.V1
100+
fallback_api_version: APIVersion | None = None
62101
cachedir: Path = field(default_factory=_resolve_default_cache_dir)
63102
avoid_duplicate_runs: bool = False
64103
retry_policy: Literal["human", "robot"] = "human"
65104
connection_n_retries: int = 5
66105
show_progress: bool = False
67106

68-
def __setattr__(self, name: str, value: Any) -> None:
69-
if name == "apikey" and value is not None and not isinstance(value, str):
70-
raise ValueError("apikey must be a string or None")
107+
@property
108+
def server(self) -> str:
109+
server = self.servers[self.api_version]["server"]
110+
if server is None:
111+
servers_repr = {k.value: v for k, v in self.servers.items()}
112+
raise ValueError(
113+
f'server found to be None for api_version="{self.api_version}" in {servers_repr}'
114+
)
115+
return server
116+
117+
@server.setter
118+
def server(self, value: str | None) -> None:
119+
self.servers[self.api_version]["server"] = value
120+
121+
@property
122+
def apikey(self) -> str | None:
123+
return self.servers[self.api_version]["apikey"]
71124

72-
super().__setattr__(name, value)
125+
@apikey.setter
126+
def apikey(self, value: str | None) -> None:
127+
self.servers[self.api_version]["apikey"] = value
73128

74129

75130
class OpenMLConfigManager:
@@ -79,11 +134,14 @@ def __init__(self) -> None:
79134
self.console_handler: logging.StreamHandler | None = None
80135
self.file_handler: logging.handlers.RotatingFileHandler | None = None
81136

137+
server_test_v1_apikey = SERVERS_REGISTRY["test"][APIVersion.V1]["apikey"]
138+
server_test_v1_server = SERVERS_REGISTRY["test"][APIVersion.V1]["server"]
139+
82140
self.OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
83141
self.OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET"
84-
self._TEST_SERVER_NORMAL_USER_KEY = "normaluser"
142+
self._TEST_SERVER_NORMAL_USER_KEY = server_test_v1_apikey
85143
self.OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY"
86-
self.TEST_SERVER_URL = "https://test.openml.org"
144+
self.TEST_SERVER_URL = cast("str", server_test_v1_server).split("/api/v1/xml")[0]
87145

88146
self._config: OpenMLConfig = OpenMLConfig()
89147
# for legacy test `test_non_writable_home`
@@ -127,6 +185,10 @@ def __setattr__(self, name: str, value: Any) -> None:
127185
object.__setattr__(self, "_config", replace(self._config, **{name: value}))
128186
return None
129187

188+
if name in ["server", "apikey"]:
189+
setattr(self._config, name, value)
190+
return None
191+
130192
object.__setattr__(self, name, value)
131193
return None
132194

@@ -190,6 +252,21 @@ def get_server_base_url(self) -> str:
190252
domain, _ = self._config.server.split("/api", maxsplit=1)
191253
return domain.replace("api", "www")
192254

255+
def set_server_mode(self, mode: str) -> None:
256+
if mode not in SERVERS_REGISTRY:
257+
raise ValueError(
258+
f'invalid mode="{mode}" allowed modes: {", ".join(list(SERVERS_REGISTRY.keys()))}'
259+
)
260+
self._config = replace(self._config, servers=SERVERS_REGISTRY[mode])
261+
262+
def set_api_version(self, api_version: APIVersion) -> None:
263+
if api_version not in APIVersion:
264+
raise ValueError(
265+
f'invalid api_version="{api_version}" '
266+
f"allowed versions: {', '.join(list(APIVersion))}"
267+
)
268+
self._config = replace(self._config, api_version=api_version)
269+
193270
def set_retry_policy(
194271
self, value: Literal["human", "robot"], n_retries: int | None = None
195272
) -> None:
@@ -317,13 +394,18 @@ def _setup(self, config: dict[str, Any] | None = None) -> None:
317394

318395
self._config = replace(
319396
self._config,
320-
apikey=config["apikey"],
321-
server=config["server"],
397+
servers=config["servers"],
398+
api_version=config["api_version"],
399+
fallback_api_version=config["fallback_api_version"],
322400
show_progress=config["show_progress"],
323401
avoid_duplicate_runs=config["avoid_duplicate_runs"],
324402
retry_policy=config["retry_policy"],
325403
connection_n_retries=int(config["connection_n_retries"]),
326404
)
405+
if "server" in config:
406+
self._config.server = config["server"]
407+
if "apikey" in config:
408+
self._config.apikey = config["apikey"]
327409

328410
user_defined_cache_dir = os.environ.get(self.OPENML_CACHE_DIR_ENV_VAR)
329411
if user_defined_cache_dir is not None:
@@ -393,42 +475,34 @@ def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str,
393475
class ConfigurationForExamples:
394476
"""Allows easy switching to and from a test configuration, used for examples."""
395477

396-
_last_used_server = None
397-
_last_used_key = None
478+
_last_used_servers = None
398479
_start_last_called = False
399480

400481
def __init__(self, manager: OpenMLConfigManager):
401482
self._manager = manager
402-
self._test_apikey = manager._TEST_SERVER_NORMAL_USER_KEY
403-
self._test_server = f"{manager.TEST_SERVER_URL}/api/v1/xml"
483+
self._test_servers = SERVERS_REGISTRY["test"]
404484

405485
def start_using_configuration_for_example(self) -> None:
406486
"""Sets the configuration to connect to the test server with valid apikey.
407487
408488
To configuration as was before this call is stored, and can be recovered
409489
by using the `stop_use_example_configuration` method.
410490
"""
411-
if (
412-
self._start_last_called
413-
and self._manager._config.server == self._test_server
414-
and self._manager._config.apikey == self._test_apikey
415-
):
491+
if self._start_last_called and self._manager._config.servers == self._test_servers:
416492
# Method is called more than once in a row without modifying the server or apikey.
417493
# We don't want to save the current test configuration as a last used configuration.
418494
return
419495

420-
self._last_used_server = self._manager._config.server
421-
self._last_used_key = self._manager._config.apikey
496+
self._last_used_servers = self._manager._config.servers
422497
type(self)._start_last_called = True
423498

424499
# Test server key for examples
425500
self._manager._config = replace(
426501
self._manager._config,
427-
server=self._test_server,
428-
apikey=self._test_apikey,
502+
servers=self._test_servers,
429503
)
430504
warnings.warn(
431-
f"Switching to the test server {self._test_server} to not upload results to "
505+
f"Switching to the test servers {self._test_servers} to not upload results to "
432506
"the live server. Using the test server may result in reduced performance of the "
433507
"API!",
434508
stacklevel=2,
@@ -446,8 +520,7 @@ def stop_using_configuration_for_example(self) -> None:
446520

447521
self._manager._config = replace(
448522
self._manager._config,
449-
server=cast("str", self._last_used_server),
450-
apikey=cast("str", self._last_used_key),
523+
servers=cast("dict[APIVersion, dict[str, str | None]]", self._last_used_servers),
451524
)
452525
type(self)._start_last_called = False
453526

openml/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TestBase(unittest.TestCase):
4949
"user": [],
5050
}
5151
flow_name_tracker: ClassVar[list[str]] = []
52-
test_server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml"
52+
test_server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml/"
5353
admin_key = os.environ.get(openml.config.OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR)
5454
user_key = openml.config._TEST_SERVER_NORMAL_USER_KEY
5555

tests/test_openml/test_config.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import openml
1616
import openml.testing
1717
from openml.testing import TestBase
18-
from openml.enums import APIVersion, ServerType
18+
from openml.enums import APIVersion
1919

2020

2121
@contextmanager
@@ -80,26 +80,22 @@ def test_get_config_as_dict(self):
8080
_config = {}
8181
_config["api_version"] = APIVersion.V1
8282
_config["fallback_api_version"] = None
83-
_config["server_type"] = ServerType.PRODUCTION
84-
_config["apikey"] = TestBase.user_key
85-
_config["server"] = f"{openml.config.TEST_SERVER_URL}/api/v1/xml"
83+
_config["servers"] = openml._config.SERVERS_REGISTRY['production']
8684
_config["cachedir"] = self.workdir
8785
_config["avoid_duplicate_runs"] = False
8886
_config["connection_n_retries"] = 20
8987
_config["retry_policy"] = "robot"
9088
_config["show_progress"] = False
9189
assert isinstance(config, dict)
92-
assert len(config) == 10
90+
assert len(config) == 8
9391
self.assertDictEqual(config, _config)
9492

9593
def test_setup_with_config(self):
9694
"""Checks if the OpenML configuration can be updated using _setup()."""
9795
_config = {}
9896
_config["api_version"] = APIVersion.V1
9997
_config["fallback_api_version"] = None
100-
_config["server_type"] = ServerType.PRODUCTION
101-
_config["apikey"] = TestBase.user_key
102-
_config["server"] = "https://www.openml.org/api/v1/xml"
98+
_config["servers"] = openml._config.SERVERS_REGISTRY['production']
10399
_config["cachedir"] = self.workdir
104100
_config["avoid_duplicate_runs"] = True
105101
_config["retry_policy"] = "human"

0 commit comments

Comments
 (0)