diff --git a/README.md b/README.md index 0fd45bf53..7f45855bd 100644 --- a/README.md +++ b/README.md @@ -496,15 +496,14 @@ pre-commit run --all-files pytest tests/ ``` -## Anonymized Telemetry +## Telemetry -This project collects fully anonymous usage telemetry with an option to opt-out of any telemetry or opt-in to extended telemetry. +This project collects usage telemetry with an option to opt-out. The data is used exclusively to help us provide stability to the relevant products and compute environments and guide future improvements. -- **No personal data is collected** +- **Personal data is collected only if user provided consent and accepted the terms of service** - **No code, model inputs, or outputs are ever sent** -- **Data is strictly anonymous and cannot be linked to individuals** For details on telemetry, please see our [Telemetry Reference](https://github.com/PriorLabs/TabPFN/blob/main/TELEMETRY.md) and our [Privacy Policy](https://priorlabs.ai/privacy-policy/). diff --git a/TELEMETRY.md b/TELEMETRY.md index c37241445..52480713c 100644 --- a/TELEMETRY.md +++ b/TELEMETRY.md @@ -1,71 +1,61 @@ -# 📊 Telemetry +# Telemetry -This project includes lightweight, anonymous telemetry to help us improve TabPFN. -We've designed this with two goals in mind: +TabPFN includes lightweight, optional telemetry that helps us understand how the library is used and where to focus development. This page explains exactly what is collected, how it's handled, and how to opt out. -1. ✅ Be **fully GDPR-compliant** (no personal data, no sensitive data, no surprises) -2. ✅ Be **OSS-friendly and transparent** about what we track and why +## What we collect -If you'd rather not send telemetry, you can always opt out (see **Opting out**). +We gather high-level usage signals - enough to guide development, never enough to expose your data or code. ---- +**Events** -## 🔍 What we collect +- `session` - sent when a TabPFN estimator is initialized +- `ping` - liveness check on model initialization +- `model_load` - sent when a model is loaded from disk or cache +- `fit_called` / `predict_called` - sent when you call `fit` or `predict` -We only gather **very high-level usage signals** — enough to guide development, never enough to identify you or your data. +**Metadata (all events)** -Here's the full list: +- `tabpfn_version`, `python_version`, `numpy_version`, `pandas_version` - software versions +- `gpu_type` - GPU type TabPFN is running on +- `timestamp` - time of the event +- `install_date` - date TabPFN was first used (year-month-day) +- `install_id` - random, locally generated installation identifier (see "Privacy" below) -### Events -- `ping` – sent when models initialize, used to check liveness -- `fit_called` – sent when you call `fit` -- `predict_called` – sent when you call `predict` -- `session` - sent whenever a user initializes a TabPFN estimator. +**Additional metadata (fit / predict only)** -### Metadata (all events) -- `python_version` – version of Python you're running -- `tabpfn_version` – TabPFN package version -- `timestamp` – time of the event -- `numpy_vesion` - local Numpy version -- `pandas_version` - local Pandas version -- `gpu_type` - type of GPU TabPFN is running on. -- `install_date` - `year-month-day` when TabPFN was used for the first time -- `install_id` - unique, random and anonymous installation ID. +- `task` - classification or regression +- `num_rows`, `num_columns` - dataset shape, rounded into ranges (exact values are never recorded) +- `duration_ms` - wall-clock time of the call -### Extra metadata (`fit` / `predict` only) -- `task` – whether the call was for **classification** or **regression** -- `num_rows` – *rounded* number of rows in your dataset -- `num_columns` – *rounded* number of columns in your dataset -- `duration_ms` – time it took to complete the call +## What we never collect ---- +Regardless of account status, we never collect: -## 🛡️ How we protect your privacy +- Training data, features, labels, or model outputs +- File paths, environment variables, or hostnames +- Exact dataset dimensions +- Code of any kind -- **No inputs, no outputs, no code** ever leave your machine. -- **No personal data** is collected. -- Dataset shapes are **rounded into ranges** (e.g. `(953, 17)` → `(1000, 20)`) so exact dimensionalities can't be linked back to you. -- The data is strictly anonymous — it cannot be tied to individuals, projects, or datasets. +No inputs, outputs, or model weights ever leave your machine. -This approach lets us understand dataset *patterns* (e.g. "most users run with ~1k features") while ensuring no one's data is exposed. +## Privacy ---- +TabPFN operates in two modes with different privacy properties: -## 🤔 Why collect telemetry? +**Without an account (anonymous).** Telemetry is tied only to a random `install_id` generated locally on first use. This ID is not linked to any personal information and cannot be traced back to you. -Open-source projects don't get much feedback unless people file issues. Telemetry helps us: -- See which parts of TabPFN are most used (fit vs predict, classification vs regression) -- Detect performance bottlenecks and stability issues -- Prioritize improvements that benefit the most users +**With an account (pseudonymous).** If you create a TabPFN account, your `user_id` is included in telemetry events. -This information goes directly into **making TabPFN better** for the community. +For further details we suggest you check out our [privacy policy](https://priorlabs.ai/privacy-policy). ---- +## Opting out -## 🚫 Opting out - -Don't want to send telemetry? No problem — just set the environment variable: +Set one environment variable to disable all telemetry: ```bash export TABPFN_DISABLE_TELEMETRY=1 -``` \ No newline at end of file +``` + +## Why collect telemetry? + +Open-source projects get limited feedback unless people file issues. Telemetry helps us see which parts of TabPFN are most used, detect performance bottlenecks, and prioritize improvements that benefit the most users. diff --git a/pyproject.toml b/pyproject.toml index cf5fcfc49..fb3d95ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,9 @@ dependencies = [ # Once Python 3.10 is the minimum version, this can be removed. "eval-type-backport>=0.2.2", "joblib>=1.2.0", - "tabpfn-common-utils[telemetry-interactive]>=0.2.13", + "tabpfn-common-utils[telemetry-interactive]>=0.2.19", "filelock>=3.11.0", + "pyjwt>=2.12.1", ] requires-python = ">=3.9" authors = [ diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index b14311941..5a4282dba 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -14,7 +14,6 @@ from sklearn.base import ( check_is_fitted, ) -from tabpfn_common_utils.telemetry.interactive import capture_session, ping # --- TabPFN imports --- from tabpfn.constants import ( @@ -418,16 +417,6 @@ def estimator_to_device( return byte_size -def initialize_telemetry() -> None: - """Initialize telemetry and acknowledge anonymous session. - - If user opted out of telemetry using `TABPFN_DISABLE_TELEMETRY`, - no action is taken. - """ - ping() - capture_session() - - def get_embeddings( model: TabPFNClassifier | TabPFNRegressor, X: XType, diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index a12f61423..616b488da 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -30,7 +30,6 @@ import torch from sklearn import config_context from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted -from tabpfn_common_utils.telemetry import track_model_call from tabpfn.base import ( ClassifierModelSpecs, @@ -39,7 +38,6 @@ estimator_to_device, get_embeddings, initialize_model_variables_helper, - initialize_telemetry, ) from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, @@ -81,6 +79,10 @@ from tabpfn.preprocessing.ensemble import TabPFNEnsemblePreprocessor from tabpfn.preprocessing.label_encoder import TabPFNLabelEncoder from tabpfn.preprocessing.modality_detection import detect_feature_modalities +from tabpfn.telemetry import ( + init as init_telemetry, + track_model_call, +) from tabpfn.utils import ( DevicesSpecification, balance_probas_by_class_counts, @@ -482,7 +484,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this self.n_preprocessing_jobs = n_preprocessing_jobs self.eval_metric = eval_metric self.tuning_config = tuning_config - initialize_telemetry() + init_telemetry() # Only anonymously record `fit_mode` usage log_model_init_params(self, {"fit_mode": self.fit_mode}) diff --git a/src/tabpfn/model_loading.py b/src/tabpfn/model_loading.py index b418f0b57..3dbfbc58d 100644 --- a/src/tabpfn/model_loading.py +++ b/src/tabpfn/model_loading.py @@ -26,7 +26,6 @@ import joblib import torch from filelock import FileLock -from tabpfn_common_utils.telemetry import set_model_config from torch import nn from tabpfn.architectures import ARCHITECTURES @@ -36,6 +35,7 @@ from tabpfn.inference import InferenceEngine from tabpfn.inference_config import InferenceConfig from tabpfn.settings import settings +from tabpfn.telemetry import set_model_config if TYPE_CHECKING: from sklearn.base import BaseEstimator @@ -767,7 +767,7 @@ def log_model_init_params( # We conditionally import here to avoid introducing breaking changes as # this interface was introduced in tabpfn_common_utils 0.2.13 and not all # users have upgraded to this version yet. - from tabpfn_common_utils.telemetry import set_init_params # noqa: PLC0415 + from tabpfn.telemetry import set_init_params # noqa: PLC0415 set_init_params(logged_params) except ImportError: diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 038df7b22..53b3e897d 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -35,7 +35,6 @@ TransformerMixin, check_is_fitted, ) -from tabpfn_common_utils.telemetry import track_model_call from tabpfn.architectures.base.bar_distribution import FullSupportBarDistribution from tabpfn.base import ( @@ -45,7 +44,6 @@ estimator_to_device, get_embeddings, initialize_model_variables_helper, - initialize_telemetry, ) from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion from tabpfn.errors import TabPFNValidationError, handle_oom_errors @@ -70,6 +68,10 @@ from tabpfn.preprocessing.steps import ( get_all_reshape_feature_distribution_preprocessors, ) +from tabpfn.telemetry import ( + init as init_telemetry, + track_model_call, +) from tabpfn.utils import ( DevicesSpecification, convert_batch_of_cat_ix_to_schema, @@ -466,7 +468,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this ) self.n_jobs = n_jobs self.n_preprocessing_jobs = n_preprocessing_jobs - initialize_telemetry() + init_telemetry() # Only anonymously record `fit_mode` usage log_model_init_params(self, {"fit_mode": self.fit_mode}) diff --git a/src/tabpfn/telemetry.py b/src/tabpfn/telemetry.py new file mode 100644 index 000000000..ed1efcdff --- /dev/null +++ b/src/tabpfn/telemetry.py @@ -0,0 +1,174 @@ +"""Telemetry for usage analytics. + +Telemetry is opt-out and only collects aggregate usage statistics — never +any data passed to or returned from the model (inputs, outputs, or features). + +To disable, set the environment variable:: + + TABPFN_DISABLE_TELEMETRY=1 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import jwt +from posthog import Posthog +from tabpfn_common_utils.telemetry.core import ( + decorators as base_decorators, + service as telemetry_service, +) +from tabpfn_common_utils.telemetry.core.config import download_config +from tabpfn_common_utils.telemetry.core.decorators import ( + set_init_params, + set_model_config, + track_model_call, +) +from tabpfn_common_utils.telemetry.core.events import ( + BaseTelemetryEvent, + ModelLoadEvent, + PingEvent, + SessionEvent, +) +from tabpfn_common_utils.telemetry.core.service import ProductTelemetry +from tabpfn_common_utils.telemetry.interactive import ( + capture_session, + flows as base_flows, + ping, +) + +from tabpfn.browser_auth import get_cached_token as get_cached_auth_token + +__all__ = ["init", "set_init_params", "set_model_config", "track_model_call"] + +_cached_user_id: str | None = None + + +@dataclass +class _TelemetryConfig: + project_token: str + api_host: str + + +def init() -> None: + """Initialize telemetry and acknowledge anonymous session.""" + for func in ( + _patch_client, + ping, + capture_session, + ): + func() + + +def _patch_client() -> None: + """Patch the telemetry client with the custom configuration.""" + config = _get_telemetry_config() + instance = telemetry_service.ProductTelemetry() + + if config is None: + # This will block the telemetry service from sending events + instance._posthog_client = None + return + + # Patch the telemetry client with the custom configuration + instance._posthog_client = Posthog( + project_api_key=config.project_token, + host=config.api_host, + disable_geoip=True, + enable_exception_autocapture=False, + max_queue_size=10, + flush_at=10, + ) + + +def _get_telemetry_config() -> _TelemetryConfig | None: + """Load the telemetry configuration. Information we fetch include + the public authentication token and the API host. + + We do not cache the configuration in memory because download_config() + is already a cached function with a TTL of 1 hour. + + Returns: + The telemetry configuration. + """ + config = download_config() + if config is None: + return None + + project_token = config.get("project_token") + api_host = config.get("api_host") + + # Silently ignore if the configuration is not complete + if any(v is None for v in [project_token, api_host]): + return None + + return _TelemetryConfig( + project_token=cast("str", project_token), api_host=cast("str", api_host) + ) + + +def _get_user_id() -> str | None: + global _cached_user_id # noqa: PLW0603 + if _cached_user_id is not None: + return _cached_user_id + + token = get_cached_auth_token() + if token is None: + return None + + try: + payload = jwt.decode(token, options={"verify_signature": False}) + except Exception: # noqa: BLE001 + return None + + user = payload.get("user") + if user is not None: + _cached_user_id = user + + return user + + +def _capture_event_with_user_id( + event: BaseTelemetryEvent, properties: dict[str, Any] | None = None +) -> None: + """Capture an event with the user ID. + + Args: + event: The event to capture. + properties: The properties to capture with the event. + """ + config = _get_telemetry_config() + if config is None: + return + + user_id = _get_user_id() + + # We passthrough the session and ping events anonymously. + # These events still contain anonymous and valuable runtime metadata. + passthrough_events = ( + SessionEvent, + PingEvent, + ModelLoadEvent, + ) + if user_id is None and not isinstance(event, passthrough_events): + return + + kwargs: dict[str, Any] = { + "properties": properties, + } + + # Unless the user is authenticated, we capture the event anonymously + if user_id is not None: + kwargs["distinct_id"] = user_id + + service = ProductTelemetry(api_key=config.project_token) + service.capture(event, **kwargs) + + +# Replace the capture_event reference that _send_model_called_event holds. +# The capture_event_with_user_id function is a wrapper around the +# base_decorators.capture_event function so that we can capture the event with +# the user ID if user is authenticated. Otherwise, the event is captured anonymously. +base_decorators.capture_event = _capture_event_with_user_id +base_flows.capture_event = _capture_event_with_user_id diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py new file mode 100644 index 000000000..dd8932364 --- /dev/null +++ b/tests/test_telemetry.py @@ -0,0 +1,233 @@ +"""Unit tests for `tabpfn.telemetry`.""" + +from __future__ import annotations + +from collections.abc import Callable +from unittest.mock import MagicMock + +import jwt +import pytest +from tabpfn_common_utils.telemetry.core.events import DatasetEvent, PingEvent + +import tabpfn.telemetry as telemetry_module + + +def _product_telemetry_factory(mock_service: MagicMock) -> Callable[..., MagicMock]: + def _factory(*_args: object, **_kwargs: object) -> MagicMock: + return mock_service + + return _factory + + +@pytest.fixture +def mock_config() -> dict[str, str]: + return {"project_token": "test-project-token", "api_host": "https://example.com"} + + +class TestGetTelemetryConfig: + def test_returns_none_when_download_config_is_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(telemetry_module, "download_config", lambda: None) + assert telemetry_module._get_telemetry_config() is None + + def test_returns_none_when_project_token_missing( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr( + "tabpfn.telemetry.download_config", + lambda: {"project_token": None, "api_host": "https://h.example"}, + ) + assert telemetry_module._get_telemetry_config() is None + + def test_returns_none_when_api_host_missing( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr( + "tabpfn.telemetry.download_config", + lambda: {"project_token": "tok", "api_host": None}, + ) + assert telemetry_module._get_telemetry_config() is None + + def test_returns_config_when_complete( + self, monkeypatch: pytest.MonkeyPatch, mock_config: dict[str, str] + ) -> None: + monkeypatch.setattr("tabpfn.telemetry.download_config", lambda: mock_config) + cfg = telemetry_module._get_telemetry_config() + assert cfg is not None + assert cfg.project_token == mock_config["project_token"] + assert cfg.api_host == mock_config["api_host"] + + +class TestGetUserId: + def test_returns_none_when_no_token(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(telemetry_module, "get_cached_auth_token", lambda: None) + assert telemetry_module._get_user_id() is None + + def test_returns_none_when_token_invalid( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr( + telemetry_module, "get_cached_auth_token", lambda: "not-a-jwt" + ) + assert telemetry_module._get_user_id() is None + + def test_returns_user_claim_when_token_valid( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + payload = {"user": "user-123", "sub": "ignored"} + token = jwt.encode( + payload, + "x" * 32, + algorithm="HS256", + ) + monkeypatch.setattr(telemetry_module, "get_cached_auth_token", lambda: token) + assert telemetry_module._get_user_id() == "user-123" + + +class TestPatchClient: + def test_clears_posthog_when_no_config( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(telemetry_module, "_get_telemetry_config", lambda: None) + instance = MagicMock() + monkeypatch.setattr( + telemetry_module.telemetry_service, + "ProductTelemetry", + lambda: instance, + ) + telemetry_module._patch_client() + assert instance._posthog_client is None + + def test_sets_posthog_when_config_present( + self, + monkeypatch: pytest.MonkeyPatch, + mock_config: dict[str, str], + ) -> None: + monkeypatch.setattr( + telemetry_module, + "_get_telemetry_config", + lambda: telemetry_module._TelemetryConfig( + project_token=mock_config["project_token"], + api_host=mock_config["api_host"], + ), + ) + instance = MagicMock() + instance.HOST = "https://eu.i.posthog.com" + monkeypatch.setattr( + telemetry_module.telemetry_service, + "ProductTelemetry", + lambda: instance, + ) + fake_posthog = MagicMock() + monkeypatch.setattr(telemetry_module, "Posthog", fake_posthog) + telemetry_module._patch_client() + fake_posthog.assert_called_once() + call_kw = fake_posthog.call_args.kwargs + assert call_kw["project_api_key"] == mock_config["project_token"] + assert instance._posthog_client == fake_posthog.return_value + + +class TestCaptureEventWithUserId: + def test_returns_early_when_no_config( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(telemetry_module, "_get_telemetry_config", lambda: None) + mock_pt = MagicMock() + monkeypatch.setattr(telemetry_module, "ProductTelemetry", mock_pt) + telemetry_module._capture_event_with_user_id(PingEvent()) + mock_pt.assert_not_called() + + def test_skips_non_passthrough_when_no_user( + self, + monkeypatch: pytest.MonkeyPatch, + mock_config: dict[str, str], + ) -> None: + monkeypatch.setattr( + telemetry_module, + "_get_telemetry_config", + lambda: telemetry_module._TelemetryConfig( + project_token=mock_config["project_token"], + api_host=mock_config["api_host"], + ), + ) + monkeypatch.setattr(telemetry_module, "_get_user_id", lambda: None) + mock_pt = MagicMock() + monkeypatch.setattr(telemetry_module, "ProductTelemetry", mock_pt) + telemetry_module._capture_event_with_user_id( + DatasetEvent(task="classification", role="train") + ) + mock_pt.assert_not_called() + + def test_allows_ping_without_user( + self, + monkeypatch: pytest.MonkeyPatch, + mock_config: dict[str, str], + ) -> None: + monkeypatch.setattr( + telemetry_module, + "_get_telemetry_config", + lambda: telemetry_module._TelemetryConfig( + project_token=mock_config["project_token"], + api_host=mock_config["api_host"], + ), + ) + monkeypatch.setattr(telemetry_module, "_get_user_id", lambda: None) + mock_service = MagicMock() + monkeypatch.setattr( + telemetry_module, + "ProductTelemetry", + _product_telemetry_factory(mock_service), + ) + telemetry_module._capture_event_with_user_id(PingEvent()) + mock_service.capture.assert_called_once() + call_kw = mock_service.capture.call_args.kwargs + assert "distinct_id" not in call_kw + + def test_passes_distinct_id_when_user_present( + self, + monkeypatch: pytest.MonkeyPatch, + mock_config: dict[str, str], + ) -> None: + monkeypatch.setattr( + telemetry_module, + "_get_telemetry_config", + lambda: telemetry_module._TelemetryConfig( + project_token=mock_config["project_token"], + api_host=mock_config["api_host"], + ), + ) + monkeypatch.setattr(telemetry_module, "_get_user_id", lambda: "uid-42") + mock_service = MagicMock() + monkeypatch.setattr( + telemetry_module, + "ProductTelemetry", + _product_telemetry_factory(mock_service), + ) + telemetry_module._capture_event_with_user_id(PingEvent()) + mock_service.capture.assert_called_once() + assert mock_service.capture.call_args.kwargs["distinct_id"] == "uid-42" + + +class TestInit: + def test_calls_patch_ping_and_capture_session( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + calls: list[str] = [] + + def mark(name: str) -> None: + calls.append(name) + + monkeypatch.setattr( + telemetry_module, + "_patch_client", + lambda: mark("_patch_client"), + ) + monkeypatch.setattr(telemetry_module, "ping", lambda: mark("ping")) + monkeypatch.setattr( + telemetry_module, + "capture_session", + lambda: mark("capture_session"), + ) + telemetry_module.init() + assert calls == ["_patch_client", "ping", "capture_session"]