Skip to content

Commit f11ca75

Browse files
committed
latest
2 parents b5eba2b + 65472ed commit f11ca75

8 files changed

Lines changed: 224 additions & 81 deletions

File tree

openml/_config.py

Lines changed: 140 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,68 @@
1212
import warnings
1313
from collections.abc import Iterator
1414
from contextlib import contextmanager
15+
from copy import deepcopy
1516
from dataclasses import dataclass, field, fields, replace
1617
from io import StringIO
1718
from pathlib import Path
1819
from typing import Any, ClassVar, Literal, cast
1920
from urllib.parse import urlparse
2021

22+
from openml.enums import APIVersion, ServerMode
23+
24+
from .__version__ import __version__
25+
2126
logger = logging.getLogger(__name__)
2227
openml_logger = logging.getLogger("openml")
2328

2429

30+
_PROD_SERVERS: dict[APIVersion, dict[str, str | None]] = {
31+
APIVersion.V1: {
32+
"server": "https://www.openml.org/api/v1/xml/",
33+
"apikey": None,
34+
},
35+
APIVersion.V2: {
36+
"server": None,
37+
"apikey": None,
38+
},
39+
}
40+
41+
_TEST_SERVERS: dict[APIVersion, dict[str, str | None]] = {
42+
APIVersion.V1: {
43+
"server": "https://test.openml.org/api/v1/xml/",
44+
"apikey": "normaluser",
45+
},
46+
APIVersion.V2: {
47+
"server": None,
48+
"apikey": None,
49+
},
50+
}
51+
52+
_TEST_SERVERS_LOCAL: dict[APIVersion, dict[str, str | None]] = {
53+
APIVersion.V1: {
54+
"server": "http://localhost:8000/api/v1/xml/",
55+
"apikey": "normaluser",
56+
},
57+
APIVersion.V2: {
58+
"server": "http://localhost:8082/",
59+
"apikey": "AD000000000000000000000000000000",
60+
},
61+
}
62+
63+
_SERVERS_REGISTRY: dict[ServerMode, dict[APIVersion, dict[str, str | None]]] = {
64+
ServerMode.PRODUCTION: _PROD_SERVERS,
65+
ServerMode.TEST: (
66+
_TEST_SERVERS_LOCAL if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true" else _TEST_SERVERS
67+
),
68+
}
69+
70+
71+
def _get_servers(mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
72+
if mode not in ServerMode:
73+
raise ValueError(f'invalid mode="{mode}" allowed modes: {", ".join(list(ServerMode))}')
74+
return deepcopy(_SERVERS_REGISTRY[mode])
75+
76+
2577
def _resolve_default_cache_dir() -> Path:
2678
user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR")
2779
if user_defined_cache_dir is not None:
@@ -57,19 +109,38 @@ def _resolve_default_cache_dir() -> Path:
57109
class OpenMLConfig:
58110
"""Dataclass storing the OpenML configuration."""
59111

60-
apikey: str | None = ""
61-
server: str = "https://www.openml.org/api/v1/xml"
112+
servers: dict[APIVersion, dict[str, str | None]] = field(
113+
default_factory=lambda: _get_servers(ServerMode.PRODUCTION)
114+
)
115+
api_version: APIVersion = APIVersion.V1
116+
fallback_api_version: APIVersion | None = None
62117
cachedir: Path = field(default_factory=_resolve_default_cache_dir)
63118
avoid_duplicate_runs: bool = False
64119
retry_policy: Literal["human", "robot"] = "human"
65120
connection_n_retries: int = 5
66121
show_progress: bool = False
67122

68-
def __setattr__(self, name: str, value: Any) -> None:
69-
if name == "apikey" and not isinstance(value, (type(None), str)):
70-
raise TypeError("apikey must be a string or None")
123+
@property
124+
def server(self) -> str:
125+
server = self.servers[self.api_version]["server"]
126+
if server is None:
127+
servers_repr = {k.value: v for k, v in self.servers.items()}
128+
raise ValueError(
129+
f'server found to be None for api_version="{self.api_version}" in {servers_repr}'
130+
)
131+
return server
71132

72-
super().__setattr__(name, value)
133+
@server.setter
134+
def server(self, value: str | None) -> None:
135+
self.servers[self.api_version]["server"] = value
136+
137+
@property
138+
def apikey(self) -> str | None:
139+
return self.servers[self.api_version]["apikey"]
140+
141+
@apikey.setter
142+
def apikey(self, value: str | None) -> None:
143+
self.servers[self.api_version]["apikey"] = value
73144

74145

75146
class OpenMLConfigManager:
@@ -81,9 +152,8 @@ def __init__(self) -> None:
81152

82153
self.OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
83154
self.OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET"
84-
self._TEST_SERVER_NORMAL_USER_KEY = "normaluser"
85155
self.OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY"
86-
self.TEST_SERVER_URL = "https://test.openml.org"
156+
self._HEADERS: dict[str, str] = {"user-agent": f"openml-python/{__version__}"}
87157

88158
self._config: OpenMLConfig = OpenMLConfig()
89159
# for legacy test `test_non_writable_home`
@@ -116,7 +186,7 @@ def __setattr__(self, name: str, value: Any) -> None:
116186
"_examples",
117187
"OPENML_CACHE_DIR_ENV_VAR",
118188
"OPENML_SKIP_PARQUET_ENV_VAR",
119-
"_TEST_SERVER_NORMAL_USER_KEY",
189+
"_HEADERS",
120190
}:
121191
return object.__setattr__(self, name, value)
122192

@@ -127,6 +197,10 @@ def __setattr__(self, name: str, value: Any) -> None:
127197
object.__setattr__(self, "_config", replace(self._config, **{name: value}))
128198
return None
129199

200+
if name in ["server", "apikey"]:
201+
setattr(self._config, name, value)
202+
return None
203+
130204
object.__setattr__(self, name, value)
131205
return None
132206

@@ -190,6 +264,48 @@ def get_server_base_url(self) -> str:
190264
domain, _ = self._config.server.split("/api", maxsplit=1)
191265
return domain.replace("api", "www")
192266

267+
def _get_servers(self, mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
268+
return _get_servers(mode)
269+
270+
def _set_servers(self, mode: ServerMode) -> None:
271+
servers = self._get_servers(mode)
272+
self._config = replace(self._config, servers=servers)
273+
274+
def get_production_servers(self) -> dict[APIVersion, dict[str, str | None]]:
275+
return self._get_servers(mode=ServerMode.PRODUCTION)
276+
277+
def get_test_servers(self) -> dict[APIVersion, dict[str, str | None]]:
278+
return self._get_servers(mode=ServerMode.TEST)
279+
280+
def use_production_servers(self) -> None:
281+
self._set_servers(mode=ServerMode.PRODUCTION)
282+
283+
def use_test_servers(self) -> None:
284+
self._set_servers(mode=ServerMode.TEST)
285+
286+
def set_api_version(
287+
self,
288+
api_version: APIVersion,
289+
fallback_api_version: APIVersion | None = None,
290+
) -> None:
291+
if api_version not in APIVersion:
292+
raise ValueError(
293+
f'invalid api_version="{api_version}" '
294+
f"allowed versions: {', '.join(list(APIVersion))}"
295+
)
296+
297+
if fallback_api_version is not None and fallback_api_version not in APIVersion:
298+
raise ValueError(
299+
f'invalid fallback_api_version="{fallback_api_version}" '
300+
f"allowed versions: {', '.join(list(APIVersion))}"
301+
)
302+
303+
self._config = replace(
304+
self._config,
305+
api_version=api_version,
306+
fallback_api_version=fallback_api_version,
307+
)
308+
193309
def set_retry_policy(
194310
self, value: Literal["human", "robot"], n_retries: int | None = None
195311
) -> None:
@@ -317,13 +433,18 @@ def _setup(self, config: dict[str, Any] | None = None) -> None:
317433

318434
self._config = replace(
319435
self._config,
320-
apikey=config["apikey"],
321-
server=config["server"],
436+
servers=config["servers"],
437+
api_version=config["api_version"],
438+
fallback_api_version=config["fallback_api_version"],
322439
show_progress=config["show_progress"],
323440
avoid_duplicate_runs=config["avoid_duplicate_runs"],
324441
retry_policy=config["retry_policy"],
325442
connection_n_retries=int(config["connection_n_retries"]),
326443
)
444+
if "server" in config:
445+
self._config.server = config["server"]
446+
if "apikey" in config:
447+
self._config.apikey = config["apikey"]
327448

328449
user_defined_cache_dir = os.environ.get(self.OPENML_CACHE_DIR_ENV_VAR)
329450
if user_defined_cache_dir is not None:
@@ -393,42 +514,35 @@ def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str,
393514
class ConfigurationForExamples:
394515
"""Allows easy switching to and from a test configuration, used for examples."""
395516

396-
_last_used_server = None
397-
_last_used_key = None
517+
_last_used_servers = None
398518
_start_last_called = False
399519

400520
def __init__(self, manager: OpenMLConfigManager):
401521
self._manager = manager
402-
self._test_apikey = manager._TEST_SERVER_NORMAL_USER_KEY
403-
self._test_server = f"{manager.TEST_SERVER_URL}/api/v1/xml"
522+
self._test_servers = manager.get_test_servers()
404523

405524
def start_using_configuration_for_example(self) -> None:
406525
"""Sets the configuration to connect to the test server with valid apikey.
407526
408527
To configuration as was before this call is stored, and can be recovered
409528
by using the `stop_use_example_configuration` method.
410529
"""
411-
if (
412-
self._start_last_called
413-
and self._manager._config.server == self._test_server
414-
and self._manager._config.apikey == self._test_apikey
415-
):
530+
if self._start_last_called and self._manager._config.servers == self._test_servers:
416531
# Method is called more than once in a row without modifying the server or apikey.
417532
# We don't want to save the current test configuration as a last used configuration.
418533
return
419534

420-
self._last_used_server = self._manager._config.server
421-
self._last_used_key = self._manager._config.apikey
535+
self._last_used_servers = self._manager._config.servers
422536
type(self)._start_last_called = True
423537

424538
# Test server key for examples
425539
self._manager._config = replace(
426540
self._manager._config,
427-
server=self._test_server,
428-
apikey=self._test_apikey,
541+
servers=self._test_servers,
429542
)
543+
test_server = self._test_servers[self._manager._config.api_version]["server"]
430544
warnings.warn(
431-
f"Switching to the test server {self._test_server} to not upload results to "
545+
f"Switching to the test server {test_server} to not upload results to "
432546
"the live server. Using the test server may result in reduced performance of the "
433547
"API!",
434548
stacklevel=2,
@@ -446,8 +560,7 @@ def stop_using_configuration_for_example(self) -> None:
446560

447561
self._manager._config = replace(
448562
self._manager._config,
449-
server=cast("str", self._last_used_server),
450-
apikey=cast("str", self._last_used_key),
563+
servers=cast("dict[APIVersion, dict[str, str | None]]", self._last_used_servers),
451564
)
452565
type(self)._start_last_called = False
453566

openml/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import xmltodict
1010

11-
import openml
1211
import openml._api_calls
1312

1413
from .utils import _get_rest_api_type_alias, _tag_openml_base

openml/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def check_server(server: str) -> str:
112112

113113
def replace_shorthand(server: str) -> str:
114114
if server == "test":
115-
return f"{openml.config.TEST_SERVER_URL}/api/v1/xml"
115+
return cast("str", openml.config.get_test_servers()[APIVersion.V1]["server"])
116116
if server == "production_server":
117-
return cast("str", openml.config.get_servers("production")[APIVersion.V1]["server"])
117+
return cast("str", openml.config.get_production_servers()[APIVersion.V1]["server"])
118118
return server
119119

120120
configure_field(

openml/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
from enum import Enum
44

55

6+
class ServerMode(str, Enum):
7+
"""Supported modes in server."""
8+
9+
PRODUCTION = "production"
10+
TEST = "test"
11+
12+
613
class APIVersion(str, Enum):
714
"""Supported OpenML API versions."""
815

openml/study/functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pandas as pd
99
import xmltodict
1010

11-
import openml
1211
import openml._api_calls
1312
import openml.utils
1413
from openml.study.study import OpenMLBenchmarkSuite, OpenMLStudy

openml/testing.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import requests
1616

1717
import openml
18-
from openml._api import API_REGISTRY, HTTPCache, HTTPClient, MinIOClient, ResourceAPI
19-
from openml.enums import APIVersion, ResourceType
2018
from openml.exceptions import OpenMLServerException
2119
from openml.tasks import TaskType
2220

@@ -55,11 +53,6 @@ class TestBase(unittest.TestCase):
5553
logger = logging.getLogger("unit_tests_published_entities")
5654
logger.setLevel(logging.DEBUG)
5755

58-
# migration-specific attributes
59-
cache: HTTPCache
60-
http_clients: dict[APIVersion, HTTPClient]
61-
minio_client: MinIOClient
62-
6356
def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None:
6457
"""Setup variables and temporary directories.
6558
@@ -111,20 +104,13 @@ def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None:
111104
self.connection_n_retries = openml.config.connection_n_retries
112105
openml.config.set_retry_policy("robot", n_retries=20)
113106

114-
self.cache = HTTPCache()
115-
self.http_clients = {
116-
APIVersion.V1: HTTPClient(api_version=APIVersion.V1),
117-
APIVersion.V2: HTTPClient(api_version=APIVersion.V2),
118-
}
119-
self.minio_client = MinIOClient()
120-
121107
def use_production_server(self) -> None:
122108
"""
123109
Use the production server for the OpenML API calls.
124110
125111
Please use this sparingly - it is better to use the test server.
126112
"""
127-
openml.config.set_servers("production")
113+
openml.config.use_production_servers()
128114

129115
def tearDown(self) -> None:
130116
"""Tear down the test"""
@@ -284,11 +270,6 @@ def _check_fold_timing_evaluations( # noqa: PLR0913
284270
assert evaluation >= min_val
285271
assert evaluation <= max_val
286272

287-
def _create_resource(self, api_version: APIVersion, resource_type: ResourceType) -> ResourceAPI:
288-
http_client = self.http_clients[api_version]
289-
resource_cls = API_REGISTRY[api_version][resource_type]
290-
return resource_cls(http=http_client, minio=self.minio_client)
291-
292273

293274
def check_task_existence(
294275
task_type: TaskType,

0 commit comments

Comments
 (0)