Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ switcher = Client.get_switcher()
| `throttle_max_workers` | `int` | Max workers for throttling feature checks | `None` |
| `regex_max_black_list` | `int` | Max cached entries for failed regex | `100` |
| `regex_max_time_limit` | `int` | Regex execution time limit (ms) | `3000` |
| `cert_path` | `str` | 🚧 TODO - Path to custom certificate for API connections | `None` |
| `cert_path` | `str` | Path to custom certificate for secure API connections | `None` |

#### Security Features

Expand Down
23 changes: 13 additions & 10 deletions switcher_client/lib/globals/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ class ContextOptions:
:param regex_max_black_list: The maximum number of blacklisted regex inputs. If not set, it will use the default value of 100
:param regex_max_time_limit: The maximum time limit in milliseconds for regex matching. If not set, it will use the default value of 3000 ms
:param restrict_relay: When enabled it will restrict the use of relay when local is enabled. Default is True
:param cert_path: The path to the SSL certificate file for secure connections. If not set, it will use the default system certificates
"""

def __init__(self,
local: bool = DEFAULT_LOCAL,
logger: bool = DEFAULT_LOGGER,
freeze: bool = DEFAULT_FREEZE,
regex_max_black_list: int = DEFAULT_REGEX_MAX_BLACKLISTED,
regex_max_time_limit: int = DEFAULT_REGEX_MAX_TIME_LIMIT,
restrict_relay: bool = DEFAULT_RESTRICT_RELAY,
snapshot_location: Optional[str] = None,
snapshot_auto_update_interval: Optional[int] = None,
silent_mode: Optional[str] = None,
throttle_max_workers: Optional[int] = None):
local: bool = DEFAULT_LOCAL,
logger: bool = DEFAULT_LOGGER,
freeze: bool = DEFAULT_FREEZE,
regex_max_black_list: int = DEFAULT_REGEX_MAX_BLACKLISTED,
regex_max_time_limit: int = DEFAULT_REGEX_MAX_TIME_LIMIT,
restrict_relay: bool = DEFAULT_RESTRICT_RELAY,
snapshot_location: Optional[str] = None,
snapshot_auto_update_interval: Optional[int] = None,
silent_mode: Optional[str] = None,
throttle_max_workers: Optional[int] = None,
cert_path: Optional[str] = None):
self.local = local
self.logger = logger
self.freeze = freeze
Expand All @@ -43,6 +45,7 @@ def __init__(self,
self.throttle_max_workers = throttle_max_workers
self.regex_max_black_list = regex_max_black_list
self.regex_max_time_limit = regex_max_time_limit
self.cert_path = cert_path

class Context:
def __init__(self,
Expand Down
39 changes: 26 additions & 13 deletions switcher_client/lib/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import ssl
import httpx

from typing import Optional
Expand All @@ -15,7 +16,7 @@ class Remote:
@staticmethod
def auth(context: Context):
url = f'{context.url}/criteria/auth'
response = Remote._do_post(url, {
response = Remote._do_post(context, url, {
'domain': context.domain,
'component': context.component,
'environment': context.environment,
Expand All @@ -32,15 +33,15 @@ def auth(context: Context):
@staticmethod
def check_api_health(context: Context) -> bool:
url = f'{context.url}/check'
response = Remote._do_get(url)
response = Remote._do_get(context, url)

return response.status_code == 200

@staticmethod
def check_criteria(token: Optional[str], context: Context, switcher: SwitcherData) -> ResultDetail:
url = f'{context.url}/criteria?showReason={str(switcher._show_details).lower()}&key={switcher._key}'
entry = get_entry(switcher._input)
response = Remote._do_post(url, { 'entry': [e.to_dict() for e in entry] }, Remote._get_header(token))
response = Remote._do_post(context, url, { 'entry': [e.to_dict() for e in entry] }, Remote._get_header(token))

if response.status_code == 200:
json_response = response.json()
Expand All @@ -55,7 +56,7 @@ def check_criteria(token: Optional[str], context: Context, switcher: SwitcherDat
@staticmethod
def check_snapshot_version(token: Optional[str], context: Context, snapshot_version: int) -> bool:
url = f'{context.url}/criteria/snapshot_check/{snapshot_version}'
response = Remote._do_get(url, Remote._get_header(token))
response = Remote._do_get(context, url, Remote._get_header(token))

if response.status_code == 200:
return response.json().get('status', False)
Expand Down Expand Up @@ -85,15 +86,15 @@ def resolve_snapshot(token: Optional[str], context: Context) -> str | None:
"""
}

response = Remote._do_post(f'{context.url}/graphql', data, Remote._get_header(token))
response = Remote._do_post(context, f'{context.url}/graphql', data, Remote._get_header(token))

if response.status_code == 200:
return json.dumps(response.json().get('data', {}), indent=4)

raise RemoteError(f'[resolve_snapshot] failed with status: {response.status_code}')

@classmethod
def _get_client(cls) -> httpx.Client:
def _get_client(cls, context: Context) -> httpx.Client:
if cls._client is None or cls._client.is_closed:
cls._client = httpx.Client(
timeout=30.0,
Expand All @@ -102,23 +103,35 @@ def _get_client(cls) -> httpx.Client:
max_connections=100,
keepalive_expiry=30.0
),
http2=True
http2=True,
verify=cls._get_context(context)
)
return cls._client

@staticmethod
def _do_post(url, data, headers) -> httpx.Response:
client = Remote._get_client()
def _do_post(context: Context, url: str, data: dict, headers: Optional[dict] = None) -> httpx.Response:
client = Remote._get_client(context)
return client.post(url, json=data, headers=headers)

@staticmethod
def _do_get(url, headers=None) -> httpx.Response:
client = Remote._get_client()
def _do_get(context: Context, url: str, headers: Optional[dict] = None) -> httpx.Response:
client = Remote._get_client(context)
return client.get(url, headers=headers)

@staticmethod
def _get_header(token: Optional[str]):
def _get_header(token: Optional[str]) -> dict:
return {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json',
}
}

@staticmethod
def _get_context(context: Context) -> bool | ssl.SSLContext:
cert_path = context.options.cert_path
if cert_path is None:
return True

ctx = ssl.create_default_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.load_cert_chain(certfile=cert_path)
return ctx
4 changes: 4 additions & 0 deletions tests/fixtures/dummy_cert.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-----BEGIN CERTIFICATE-----
-----END CERTIFICATE-----
-----BEGIN PRIVATE KEY-----
-----END PRIVATE KEY-----
23 changes: 23 additions & 0 deletions tests/test_switcher_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Optional
from pytest_httpx import HTTPXMock
from unittest.mock import Mock, patch

from switcher_client.errors import RemoteAuthError
from switcher_client import Client
Expand Down Expand Up @@ -120,6 +121,28 @@ def test_remote_with_remote_required_request(httpx_mock):
# test
assert switcher.remote().is_on(key)

def test_remote_with_custom_cert(httpx_mock):
""" Should call the remote API with success using a custom certificate """

# Reset Remote client to ensure fresh SSL context creation
from switcher_client.lib.remote import Remote
Remote._client = None

# given
given_auth(httpx_mock)
given_check_criteria(httpx_mock, response={'result': True})
given_context(options=ContextOptions(cert_path='tests/fixtures/dummy_cert.pem'))

switcher = Client.get_switcher()

# test
mock_ssl_context = Mock()
with patch('ssl.create_default_context', return_value=mock_ssl_context):
assert switcher.is_on('MY_SWITCHER')
mock_ssl_context.load_cert_chain.assert_called_once_with(
certfile='tests/fixtures/dummy_cert.pem'
)

def test_remote_err_with_remote_reqquired_request_no_local():
""" Should raise an exception when local mode is not enabled and remote is forced """

Expand Down