|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import Any, Mapping |
| 3 | +from pathlib import Path |
| 4 | +from typing import TYPE_CHECKING, Any |
| 5 | +from urllib.parse import urlencode, urljoin, urlparse |
4 | 6 |
|
5 | 7 | import requests |
6 | 8 | from requests import Response |
7 | 9 |
|
8 | 10 | from openml.__version__ import __version__ |
| 11 | +from openml._api.config import settings |
9 | 12 |
|
| 13 | +if TYPE_CHECKING: |
| 14 | + from openml._api.config import APIConfig |
10 | 15 |
|
11 | | -class HTTPClient: |
12 | | - def __init__(self, base_url: str) -> None: |
13 | | - self.base_url = base_url |
| 16 | + |
| 17 | +class CacheMixin: |
| 18 | + @property |
| 19 | + def dir(self) -> str: |
| 20 | + return settings.cache.dir |
| 21 | + |
| 22 | + @property |
| 23 | + def ttl(self) -> int: |
| 24 | + return settings.cache.ttl |
| 25 | + |
| 26 | + def _get_cache_dir(self, url: str, params: dict[str, Any]) -> Path: |
| 27 | + parsed_url = urlparse(url) |
| 28 | + netloc_parts = parsed_url.netloc.split(".")[::-1] # reverse domain |
| 29 | + path_parts = parsed_url.path.strip("/").split("/") |
| 30 | + |
| 31 | + # remove api_key and serialize params if any |
| 32 | + filtered_params = {k: v for k, v in params.items() if k != "api_key"} |
| 33 | + params_part = [urlencode(filtered_params)] if filtered_params else [] |
| 34 | + |
| 35 | + return Path(self.dir).joinpath(*netloc_parts, *path_parts, *params_part) |
| 36 | + |
| 37 | + def _get_cache_response(self, cache_dir: Path) -> Response: # noqa: ARG002 |
| 38 | + return Response() |
| 39 | + |
| 40 | + def _set_cache_response(self, cache_dir: Path, response: Response) -> None: # noqa: ARG002 |
| 41 | + return None |
| 42 | + |
| 43 | + |
| 44 | +class HTTPClient(CacheMixin): |
| 45 | + def __init__(self, config: APIConfig) -> None: |
| 46 | + self.config = config |
14 | 47 | self.headers: dict[str, str] = {"user-agent": f"openml-python/{__version__}"} |
15 | 48 |
|
| 49 | + @property |
| 50 | + def server(self) -> str: |
| 51 | + return self.config.server |
| 52 | + |
| 53 | + @property |
| 54 | + def base_url(self) -> str: |
| 55 | + return self.config.base_url |
| 56 | + |
| 57 | + @property |
| 58 | + def key(self) -> str: |
| 59 | + return self.config.key |
| 60 | + |
| 61 | + @property |
| 62 | + def timeout(self) -> int: |
| 63 | + return self.config.timeout |
| 64 | + |
| 65 | + def request( |
| 66 | + self, |
| 67 | + method: str, |
| 68 | + path: str, |
| 69 | + *, |
| 70 | + use_cache: bool = False, |
| 71 | + use_api_key: bool = False, |
| 72 | + **request_kwargs: Any, |
| 73 | + ) -> Response: |
| 74 | + url = urljoin(self.server, urljoin(self.base_url, path)) |
| 75 | + |
| 76 | + params = request_kwargs.pop("params", {}) |
| 77 | + params = params.copy() |
| 78 | + if use_api_key: |
| 79 | + params["api_key"] = self.key |
| 80 | + |
| 81 | + headers = request_kwargs.pop("headers", {}) |
| 82 | + headers = headers.copy() |
| 83 | + headers.update(self.headers) |
| 84 | + |
| 85 | + timeout = request_kwargs.pop("timeout", self.timeout) |
| 86 | + cache_dir = self._get_cache_dir(url, params) |
| 87 | + |
| 88 | + if use_cache: |
| 89 | + try: |
| 90 | + return self._get_cache_response(cache_dir) |
| 91 | + # TODO: handle ttl expired error |
| 92 | + except Exception: |
| 93 | + raise |
| 94 | + |
| 95 | + response = requests.request( |
| 96 | + method=method, |
| 97 | + url=url, |
| 98 | + params=params, |
| 99 | + headers=headers, |
| 100 | + timeout=timeout, |
| 101 | + **request_kwargs, |
| 102 | + ) |
| 103 | + |
| 104 | + if use_cache: |
| 105 | + self._set_cache_response(cache_dir, response) |
| 106 | + |
| 107 | + return response |
| 108 | + |
16 | 109 | def get( |
17 | 110 | self, |
18 | 111 | path: str, |
19 | | - params: Mapping[str, Any] | None = None, |
| 112 | + *, |
| 113 | + use_cache: bool = False, |
| 114 | + use_api_key: bool = False, |
| 115 | + **request_kwargs: Any, |
20 | 116 | ) -> Response: |
21 | | - url = f"{self.base_url}/{path}" |
22 | | - return requests.get(url, params=params, headers=self.headers, timeout=10) |
| 117 | + # TODO: remove override when cache is implemented |
| 118 | + use_cache = False |
| 119 | + return self.request( |
| 120 | + method="GET", |
| 121 | + path=path, |
| 122 | + use_cache=use_cache, |
| 123 | + use_api_key=use_api_key, |
| 124 | + **request_kwargs, |
| 125 | + ) |
23 | 126 |
|
24 | 127 | def post( |
25 | 128 | self, |
26 | 129 | path: str, |
27 | | - data: Mapping[str, Any] | None = None, |
28 | | - json: dict | None = None, |
29 | | - files: Any = None, |
| 130 | + **request_kwargs: Any, |
30 | 131 | ) -> Response: |
31 | | - url = f"{self.base_url}/{path}" |
32 | | - return requests.post( |
33 | | - url, data=data, json=json, files=files, headers=self.headers, timeout=10 |
| 132 | + return self.request( |
| 133 | + method="POST", |
| 134 | + path=path, |
| 135 | + use_cache=False, |
| 136 | + use_api_key=True, |
| 137 | + **request_kwargs, |
34 | 138 | ) |
35 | 139 |
|
36 | 140 | def delete( |
37 | 141 | self, |
38 | 142 | path: str, |
39 | | - params: Mapping[str, Any] | None = None, |
| 143 | + **request_kwargs: Any, |
40 | 144 | ) -> Response: |
41 | | - url = f"{self.base_url}/{path}" |
42 | | - return requests.delete(url, params=params, headers=self.headers, timeout=10) |
| 145 | + return self.request( |
| 146 | + method="DELETE", |
| 147 | + path=path, |
| 148 | + use_cache=False, |
| 149 | + use_api_key=True, |
| 150 | + **request_kwargs, |
| 151 | + ) |
0 commit comments