Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/tabpfn_common_utils/telemetry/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@
def download_config() -> Dict[str, Any]:
"""Download the configuration from server.

If the configuration cannot be fetched or parsed (network error, non-200
response, malformed JSON) the function returns a fail-closed default of
``{"enabled": False}`` so that telemetry is not emitted from clients that
cannot reach the public configuration endpoint.

Returns:
Dict[str, Any]: The configuration.
"""
# The default configuration
default = {"enabled": True}
# Fail-closed default: when we can't reach or parse the remote
# configuration, treat telemetry as disabled. Respects users on restrictive
# networks (firewalls, air-gapped environments) and avoids emitting events
# under any state other than an explicit, server-confirmed enable.
default = {"enabled": False}
Comment thread
noahho marked this conversation as resolved.

# This is a public URL anyone can and should read from
url = os.environ.get(
Expand All @@ -41,9 +49,23 @@ def download_config() -> Dict[str, Any]:
logger.debug(f"Failed to download telemetry config: {url}")
return default

# Disable telemetry by default
if resp.status_code != 200:
logger.debug(f"Failed to download telemetry config: {resp.status_code}")
return default

return resp.json()
try:
config = resp.json()
except ValueError:
logger.debug(f"Failed to parse telemetry config JSON from: {url}")
return default
Comment thread
noahho marked this conversation as resolved.

# Validate the shape: anything other than a dict containing "enabled" would
# cause a TypeError/KeyError in downstream `config["enabled"]` access. Fail
# closed so a malformed remote response cannot crash the host process.
if not isinstance(config, dict) or "enabled" not in config:
logger.debug(
f"Telemetry config from {url} is malformed or missing 'enabled' key"
)
return default

return config
94 changes: 94 additions & 0 deletions tests/telemetry/core/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for the telemetry remote-config download behaviour."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest
import requests

from tabpfn_common_utils.telemetry.core import config as config_module


@pytest.fixture(autouse=True)
def _clear_ttl_cache():
"""Clear the lru_cache on download_config between tests.

download_config is wrapped by @ttl_cache (which itself uses functools.lru_cache).
functools.wraps preserves the underlying cached callable as __wrapped__, so we
can clear it directly to ensure each test sees a fresh fetch.
"""
config_module.download_config.__wrapped__.cache_clear() # type: ignore[attr-defined]
yield
config_module.download_config.__wrapped__.cache_clear() # type: ignore[attr-defined]


def _mock_response(status_code: int, json_payload=None, raise_on_json: bool = False):
resp = MagicMock(spec=requests.Response)
resp.status_code = status_code
if raise_on_json:
resp.json.side_effect = ValueError("not json")
else:
resp.json.return_value = json_payload
return resp


def test_download_config_returns_payload_on_200():
payload = {"enabled": True, "sampling_rate": 0.5}
with patch.object(config_module.requests, "get", return_value=_mock_response(200, payload)):
assert config_module.download_config() == payload


def test_download_config_fails_closed_on_network_exception():
"""If requests.get raises (timeout, DNS, firewall, etc.), default to disabled."""
with patch.object(config_module.requests, "get", side_effect=requests.exceptions.ConnectionError("boom")):
assert config_module.download_config() == {"enabled": False}


def test_download_config_fails_closed_on_timeout():
with patch.object(config_module.requests, "get", side_effect=requests.exceptions.Timeout("slow")):
assert config_module.download_config() == {"enabled": False}


@pytest.mark.parametrize("status", [403, 404, 500, 503])
def test_download_config_fails_closed_on_non_200(status: int):
with patch.object(config_module.requests, "get", return_value=_mock_response(status, {"enabled": True})):
assert config_module.download_config() == {"enabled": False}


def test_download_config_fails_closed_on_invalid_json():
"""If the body isn't parseable JSON, default to disabled."""
with patch.object(
config_module.requests, "get", return_value=_mock_response(200, raise_on_json=True)
):
assert config_module.download_config() == {"enabled": False}


@pytest.mark.parametrize(
"payload",
[
[1, 2, 3], # list
"enabled", # string
42, # int
None, # null
{"foo": "bar"}, # dict missing "enabled"
{}, # empty dict
],
)
def test_download_config_fails_closed_on_malformed_shape(payload):
"""If the JSON parses but isn't a dict with an 'enabled' key, default to disabled."""
with patch.object(
config_module.requests, "get", return_value=_mock_response(200, payload)
):
assert config_module.download_config() == {"enabled": False}


def test_download_config_uses_env_override(monkeypatch):
"""Respects TABPFN_TELEMETRY_CONFIG_URL when set."""
monkeypatch.setenv("TABPFN_TELEMETRY_CONFIG_URL", "https://example.test/cfg.json")
with patch.object(
config_module.requests, "get", return_value=_mock_response(200, {"enabled": True})
) as mock_get:
config_module.download_config()
called_url = mock_get.call_args.args[0] if mock_get.call_args.args else mock_get.call_args.kwargs.get("url")
assert called_url == "https://example.test/cfg.json"
Loading