From 5bcfd92307828858e072ebf2c776f813fb23934b Mon Sep 17 00:00:00 2001 From: petruki <31597636+petruki@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:46:39 -0800 Subject: [PATCH] feat: added cert_path option for custom secure API communication --- README.md | 2 +- switcher_client/lib/globals/global_context.py | 23 ++++++----- switcher_client/lib/remote.py | 39 ++++++++++++------- tests/fixtures/dummy_cert.pem | 4 ++ tests/test_switcher_remote.py | 23 +++++++++++ 5 files changed, 67 insertions(+), 24 deletions(-) create mode 100644 tests/fixtures/dummy_cert.pem diff --git a/README.md b/README.md index 2ded98e..53416cd 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ switcher = Client.get_switcher() | `throttle_max_workers` | `int` | Max workers for throttling feature checks | `None` | | `regex_max_black_list` | `int` | Max cached entries for failed regex | `100` | | `regex_max_time_limit` | `int` | Regex execution time limit (ms) | `3000` | -| `cert_path` | `str` | 🚧 TODO - Path to custom certificate for API connections | `None` | +| `cert_path` | `str` | Path to custom certificate for secure API connections | `None` | #### Security Features diff --git a/switcher_client/lib/globals/global_context.py b/switcher_client/lib/globals/global_context.py index c789ae2..e383724 100644 --- a/switcher_client/lib/globals/global_context.py +++ b/switcher_client/lib/globals/global_context.py @@ -20,19 +20,21 @@ class ContextOptions: :param regex_max_black_list: The maximum number of blacklisted regex inputs. If not set, it will use the default value of 100 :param regex_max_time_limit: The maximum time limit in milliseconds for regex matching. If not set, it will use the default value of 3000 ms :param restrict_relay: When enabled it will restrict the use of relay when local is enabled. Default is True + :param cert_path: The path to the SSL certificate file for secure connections. If not set, it will use the default system certificates """ def __init__(self, - local: bool = DEFAULT_LOCAL, - logger: bool = DEFAULT_LOGGER, - freeze: bool = DEFAULT_FREEZE, - regex_max_black_list: int = DEFAULT_REGEX_MAX_BLACKLISTED, - regex_max_time_limit: int = DEFAULT_REGEX_MAX_TIME_LIMIT, - restrict_relay: bool = DEFAULT_RESTRICT_RELAY, - snapshot_location: Optional[str] = None, - snapshot_auto_update_interval: Optional[int] = None, - silent_mode: Optional[str] = None, - throttle_max_workers: Optional[int] = None): + local: bool = DEFAULT_LOCAL, + logger: bool = DEFAULT_LOGGER, + freeze: bool = DEFAULT_FREEZE, + regex_max_black_list: int = DEFAULT_REGEX_MAX_BLACKLISTED, + regex_max_time_limit: int = DEFAULT_REGEX_MAX_TIME_LIMIT, + restrict_relay: bool = DEFAULT_RESTRICT_RELAY, + snapshot_location: Optional[str] = None, + snapshot_auto_update_interval: Optional[int] = None, + silent_mode: Optional[str] = None, + throttle_max_workers: Optional[int] = None, + cert_path: Optional[str] = None): self.local = local self.logger = logger self.freeze = freeze @@ -43,6 +45,7 @@ def __init__(self, self.throttle_max_workers = throttle_max_workers self.regex_max_black_list = regex_max_black_list self.regex_max_time_limit = regex_max_time_limit + self.cert_path = cert_path class Context: def __init__(self, diff --git a/switcher_client/lib/remote.py b/switcher_client/lib/remote.py index 765c1c3..ed4266a 100644 --- a/switcher_client/lib/remote.py +++ b/switcher_client/lib/remote.py @@ -1,4 +1,5 @@ import json +import ssl import httpx from typing import Optional @@ -15,7 +16,7 @@ class Remote: @staticmethod def auth(context: Context): url = f'{context.url}/criteria/auth' - response = Remote._do_post(url, { + response = Remote._do_post(context, url, { 'domain': context.domain, 'component': context.component, 'environment': context.environment, @@ -32,7 +33,7 @@ def auth(context: Context): @staticmethod def check_api_health(context: Context) -> bool: url = f'{context.url}/check' - response = Remote._do_get(url) + response = Remote._do_get(context, url) return response.status_code == 200 @@ -40,7 +41,7 @@ def check_api_health(context: Context) -> bool: def check_criteria(token: Optional[str], context: Context, switcher: SwitcherData) -> ResultDetail: url = f'{context.url}/criteria?showReason={str(switcher._show_details).lower()}&key={switcher._key}' entry = get_entry(switcher._input) - response = Remote._do_post(url, { 'entry': [e.to_dict() for e in entry] }, Remote._get_header(token)) + response = Remote._do_post(context, url, { 'entry': [e.to_dict() for e in entry] }, Remote._get_header(token)) if response.status_code == 200: json_response = response.json() @@ -55,7 +56,7 @@ def check_criteria(token: Optional[str], context: Context, switcher: SwitcherDat @staticmethod def check_snapshot_version(token: Optional[str], context: Context, snapshot_version: int) -> bool: url = f'{context.url}/criteria/snapshot_check/{snapshot_version}' - response = Remote._do_get(url, Remote._get_header(token)) + response = Remote._do_get(context, url, Remote._get_header(token)) if response.status_code == 200: return response.json().get('status', False) @@ -85,7 +86,7 @@ def resolve_snapshot(token: Optional[str], context: Context) -> str | None: """ } - response = Remote._do_post(f'{context.url}/graphql', data, Remote._get_header(token)) + response = Remote._do_post(context, f'{context.url}/graphql', data, Remote._get_header(token)) if response.status_code == 200: return json.dumps(response.json().get('data', {}), indent=4) @@ -93,7 +94,7 @@ def resolve_snapshot(token: Optional[str], context: Context) -> str | None: raise RemoteError(f'[resolve_snapshot] failed with status: {response.status_code}') @classmethod - def _get_client(cls) -> httpx.Client: + def _get_client(cls, context: Context) -> httpx.Client: if cls._client is None or cls._client.is_closed: cls._client = httpx.Client( timeout=30.0, @@ -102,23 +103,35 @@ def _get_client(cls) -> httpx.Client: max_connections=100, keepalive_expiry=30.0 ), - http2=True + http2=True, + verify=cls._get_context(context) ) return cls._client @staticmethod - def _do_post(url, data, headers) -> httpx.Response: - client = Remote._get_client() + def _do_post(context: Context, url: str, data: dict, headers: Optional[dict] = None) -> httpx.Response: + client = Remote._get_client(context) return client.post(url, json=data, headers=headers) @staticmethod - def _do_get(url, headers=None) -> httpx.Response: - client = Remote._get_client() + def _do_get(context: Context, url: str, headers: Optional[dict] = None) -> httpx.Response: + client = Remote._get_client(context) return client.get(url, headers=headers) @staticmethod - def _get_header(token: Optional[str]): + def _get_header(token: Optional[str]) -> dict: return { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json', - } \ No newline at end of file + } + + @staticmethod + def _get_context(context: Context) -> bool | ssl.SSLContext: + cert_path = context.options.cert_path + if cert_path is None: + return True + + ctx = ssl.create_default_context() + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + ctx.load_cert_chain(certfile=cert_path) + return ctx \ No newline at end of file diff --git a/tests/fixtures/dummy_cert.pem b/tests/fixtures/dummy_cert.pem new file mode 100644 index 0000000..2d392dd --- /dev/null +++ b/tests/fixtures/dummy_cert.pem @@ -0,0 +1,4 @@ +-----BEGIN CERTIFICATE----- +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/tests/test_switcher_remote.py b/tests/test_switcher_remote.py index eaed494..13c6af5 100644 --- a/tests/test_switcher_remote.py +++ b/tests/test_switcher_remote.py @@ -3,6 +3,7 @@ from typing import Optional from pytest_httpx import HTTPXMock +from unittest.mock import Mock, patch from switcher_client.errors import RemoteAuthError from switcher_client import Client @@ -120,6 +121,28 @@ def test_remote_with_remote_required_request(httpx_mock): # test assert switcher.remote().is_on(key) +def test_remote_with_custom_cert(httpx_mock): + """ Should call the remote API with success using a custom certificate """ + + # Reset Remote client to ensure fresh SSL context creation + from switcher_client.lib.remote import Remote + Remote._client = None + + # given + given_auth(httpx_mock) + given_check_criteria(httpx_mock, response={'result': True}) + given_context(options=ContextOptions(cert_path='tests/fixtures/dummy_cert.pem')) + + switcher = Client.get_switcher() + + # test + mock_ssl_context = Mock() + with patch('ssl.create_default_context', return_value=mock_ssl_context): + assert switcher.is_on('MY_SWITCHER') + mock_ssl_context.load_cert_chain.assert_called_once_with( + certfile='tests/fixtures/dummy_cert.pem' + ) + def test_remote_err_with_remote_reqquired_request_no_local(): """ Should raise an exception when local mode is not enabled and remote is forced """