From 0eb249a36828afa4456b28e2e1657d66e0541f9c Mon Sep 17 00:00:00 2001 From: psinger-prior Date: Tue, 7 Apr 2026 15:02:21 +0200 Subject: [PATCH 1/7] W&B Logging support for Finetuning (#815) --- .gitignore | 3 + changelog/815.added.md | 1 + examples/finetune_classifier.py | 3 + examples/finetune_regressor.py | 3 + pyproject.toml | 3 + src/tabpfn/finetuning/__init__.py | 4 + src/tabpfn/finetuning/finetuned_base.py | 51 ++++++ src/tabpfn/finetuning/finetuned_classifier.py | 5 + src/tabpfn/finetuning/finetuned_regressor.py | 3 + src/tabpfn/finetuning/logging.py | 95 +++++++++++ tests/test_finetuning_logging.py | 158 ++++++++++++++++++ 11 files changed, 329 insertions(+) create mode 100644 changelog/815.added.md create mode 100644 src/tabpfn/finetuning/logging.py create mode 100644 tests/test_finetuning_logging.py diff --git a/.gitignore b/.gitignore index e5742f754..cc6ab0406 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# WandB +wandb + .DS_Store ./src/.DS_Store diff --git a/changelog/815.added.md b/changelog/815.added.md new file mode 100644 index 000000000..7e2438cdd --- /dev/null +++ b/changelog/815.added.md @@ -0,0 +1 @@ +Add modular experiment logging for finetuning with `experiment_logger` parameter, including `WandbLogger` for W&B tracking and a `FinetuningLogger` protocol for custom integrations. diff --git a/examples/finetune_classifier.py b/examples/finetune_classifier.py index a1a338eb4..6cc312385 100644 --- a/examples/finetune_classifier.py +++ b/examples/finetune_classifier.py @@ -115,6 +115,9 @@ def main() -> None: print("--- 2. Initializing and Fitting Model ---\n") # Instantiate the wrapper with your desired hyperparameters + # To enable WandB logging, pass an experiment_logger: + # . from tabpfn.finetuning.logging import WandbLogger + # experiment_logger=WandbLogger(project="my-project", run_name="my-run", entity="my-entity") finetuned_clf = FinetunedTabPFNClassifier( device="cuda", epochs=NUM_EPOCHS, diff --git a/examples/finetune_regressor.py b/examples/finetune_regressor.py index ae203da11..bc6bafbf8 100644 --- a/examples/finetune_regressor.py +++ b/examples/finetune_regressor.py @@ -101,6 +101,9 @@ def main() -> None: print("--- 2. Initializing and Fitting Model ---\n") # Instantiate the wrapper with your desired hyperparameters + # To enable WandB logging, pass an experiment_logger: + # . from tabpfn.finetuning.logging import WandbLogger + # experiment_logger=WandbLogger(project="my-project", run_name="my-run", entity="my-entity") finetuned_reg = FinetunedTabPFNRegressor( device="cuda", epochs=NUM_EPOCHS, diff --git a/pyproject.toml b/pyproject.toml index 30e84a5c8..80a445e68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,9 @@ classifiers = [ ] license = { file = "LICENSE" } +[project.optional-dependencies] +wandb = ["wandb>=0.25.1"] + [project.urls] documentation = "https://priorlabs.ai/docs" source = "https://github.com/priorlabs/tabpfn" diff --git a/src/tabpfn/finetuning/__init__.py b/src/tabpfn/finetuning/__init__.py index 4ebcbfed2..e2005b8b0 100644 --- a/src/tabpfn/finetuning/__init__.py +++ b/src/tabpfn/finetuning/__init__.py @@ -4,6 +4,7 @@ from tabpfn.finetuning.finetuned_base import EvalResult, FinetunedTabPFNBase from tabpfn.finetuning.finetuned_classifier import FinetunedTabPFNClassifier from tabpfn.finetuning.finetuned_regressor import FinetunedTabPFNRegressor +from tabpfn.finetuning.logging import FinetuningLogger, NullLogger, WandbLogger __all__ = [ "ClassifierBatch", @@ -11,5 +12,8 @@ "FinetunedTabPFNBase", "FinetunedTabPFNClassifier", "FinetunedTabPFNRegressor", + "FinetuningLogger", + "NullLogger", "RegressorBatch", + "WandbLogger", ] diff --git a/src/tabpfn/finetuning/finetuned_base.py b/src/tabpfn/finetuning/finetuned_base.py index 74359f54f..ab042427c 100644 --- a/src/tabpfn/finetuning/finetuned_base.py +++ b/src/tabpfn/finetuning/finetuned_base.py @@ -38,6 +38,7 @@ get_preprocessed_dataset_chunks, meta_dataset_collator, ) +from tabpfn.finetuning.logging import FinetuningLogger, NullLogger from tabpfn.finetuning.train_util import ( get_and_init_optimizer, get_checkpoint_path_and_epoch_from_output_dir, @@ -238,6 +239,9 @@ class FinetunedTabPFNBase(BaseEstimator, ABC): data batches. This is helpful in most cases because, e.g., the column order will stay the same across batches. If False, the preprocessing will use a different random seed for each batch. + experiment_logger: An optional logger implementing the ``FinetuningLogger`` + protocol (e.g., ``WandbLogger``) for experiment tracking. If None, + a no-op ``NullLogger`` is used. Defaults to None. """ def __init__( # noqa: PLR0913 @@ -265,8 +269,10 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + experiment_logger: FinetuningLogger | None = None, ): super().__init__() + self.experiment_logger = experiment_logger self.device = device self.epochs = epochs self.time_limit = time_limit @@ -549,6 +555,22 @@ def _fit( # noqa: C901,PLR0912 if using_ddp: self.device = device_str + _logger = self.experiment_logger or NullLogger() + global_step = 0 + + if is_main_process: + config = { + k: v for k, v in self.get_params().items() if k != "experiment_logger" + } + try: + _logger.setup(config) + except (OSError, ModuleNotFoundError): + logger.warning( + "Experiment logger setup failed, falling back to NullLogger.", + exc_info=True, + ) + _logger = NullLogger() + # Store the original training size for checkpoint naming train_size = X.shape[0] start_time = time.monotonic() @@ -870,6 +892,23 @@ def _ddp_broadcast_primary_metric(metric: float) -> float: epoch_loss_sum += loss_scalar epoch_batches += 1 + global_step += 1 + + if is_main_process: + current_lr = ( + scheduler.get_last_lr()[0] + if scheduler is not None + else self.learning_rate + ) + _logger.log_step( + { + "train/loss": loss_scalar, + "train/lr": current_lr, + "train/epoch": epoch, + "train/global_step": global_step, + }, + step=global_step, + ) progress_bar.set_postfix( loss=f"{loss_scalar:.4f}", @@ -900,6 +939,17 @@ def _ddp_broadcast_primary_metric(metric: float) -> float: y_val, # pyright: ignore[reportArgumentType] ) self._log_epoch_evaluation(epoch, eval_result, mean_train_loss) + + epoch_log_metrics: dict[str, float] = { + "train/epoch": epoch, + f"val/{self._metric_name}": eval_result.primary, + } + if mean_train_loss is not None: + epoch_log_metrics["train/mean_loss"] = mean_train_loss + for k, v in eval_result.secondary.items(): + epoch_log_metrics[f"val/{k}"] = v + _logger.log_epoch(epoch_log_metrics, step=global_step) + primary_metric = eval_result.primary else: primary_metric = self._get_initial_best_metric() @@ -992,6 +1042,7 @@ def _ddp_broadcast_primary_metric(metric: float) -> float: dist.destroy_process_group() if is_main_process: + _logger.finish() logger.info("--- ✅ Fine-tuning Finished ---") if is_main_process: diff --git a/src/tabpfn/finetuning/finetuned_classifier.py b/src/tabpfn/finetuning/finetuned_classifier.py index 597a06f48..b58ccc27b 100644 --- a/src/tabpfn/finetuning/finetuned_classifier.py +++ b/src/tabpfn/finetuning/finetuned_classifier.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from tabpfn.constants import XType, YType from tabpfn.finetuning.data_util import ClassifierBatch + from tabpfn.finetuning.logging import FinetuningLogger def _compute_classification_loss( @@ -151,6 +152,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + experiment_logger: FinetuningLogger | None = None, extra_classifier_kwargs: dict[str, Any] | None = None, eval_metric: Literal["roc_auc", "log_loss"] | None = None, ): @@ -177,6 +179,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing=use_activation_checkpointing, save_checkpoint_interval=save_checkpoint_interval, use_fixed_preprocessing_seed=use_fixed_preprocessing_seed, + experiment_logger=experiment_logger, ) self.extra_classifier_kwargs = extra_classifier_kwargs self.eval_metric = eval_metric @@ -197,6 +200,8 @@ def _model_type(self) -> Literal["classifier", "regressor"]: @override def _metric_name(self) -> str: """Return the name of the primary metric.""" + if self.eval_metric == "log_loss": + return "log_loss" return "ROC AUC" @override diff --git a/src/tabpfn/finetuning/finetuned_regressor.py b/src/tabpfn/finetuning/finetuned_regressor.py index fa5970bcf..34acc40d8 100644 --- a/src/tabpfn/finetuning/finetuned_regressor.py +++ b/src/tabpfn/finetuning/finetuned_regressor.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from tabpfn.constants import XType, YType from tabpfn.finetuning.data_util import RegressorBatch + from tabpfn.finetuning.logging import FinetuningLogger from tabpfn.regressor import RegressionResultType @@ -333,6 +334,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + experiment_logger: FinetuningLogger | None = None, extra_regressor_kwargs: dict[str, Any] | None = None, ce_loss_weight: float = 0.0, crps_loss_weight: float = 1.0, @@ -366,6 +368,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing=use_activation_checkpointing, save_checkpoint_interval=save_checkpoint_interval, use_fixed_preprocessing_seed=use_fixed_preprocessing_seed, + experiment_logger=experiment_logger, ) self.extra_regressor_kwargs = extra_regressor_kwargs self.eval_metric = eval_metric diff --git a/src/tabpfn/finetuning/logging.py b/src/tabpfn/finetuning/logging.py new file mode 100644 index 000000000..6085fc707 --- /dev/null +++ b/src/tabpfn/finetuning/logging.py @@ -0,0 +1,95 @@ +"""Protocol-based experiment logging for finetuning.""" + +from __future__ import annotations + +from typing import Any, Protocol + + +class FinetuningLogger(Protocol): + """Protocol for finetuning experiment loggers.""" + + def setup(self, config: dict[str, Any]) -> None: + """Initialize the logger with run configuration.""" + ... + + def log_step(self, metrics: dict[str, float], step: int) -> None: + """Log per-step metrics (e.g., batch loss, learning rate).""" + ... + + def log_epoch(self, metrics: dict[str, float], step: int) -> None: + """Log per-epoch metrics (e.g., val metrics, mean loss).""" + ... + + def finish(self) -> None: + """Finalize the logger (e.g., close wandb run).""" + ... + + +class NullLogger: + """No-op logger used when no experiment tracking is configured.""" + + def setup(self, config: dict[str, Any]) -> None: + """No-op.""" + + def log_step(self, metrics: dict[str, float], step: int) -> None: + """No-op.""" + + def log_epoch(self, metrics: dict[str, float], step: int) -> None: + """No-op.""" + + def finish(self) -> None: + """No-op.""" + + +class WandbLogger: + """WandB experiment logger.""" + + def __init__( + self, + project: str | None = None, + run_name: str | None = None, + entity: str | None = None, + **wandb_kwargs: Any, + ): + self.project = project + self.run_name = run_name + self.entity = entity + self.wandb_kwargs = wandb_kwargs + self._run = None + + def setup(self, config: dict[str, Any]) -> None: + """Initialize a new WandB run with the given config.""" + try: + import wandb # noqa: PLC0415 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "WandbLogger requires the 'wandb' package. " + "Install it with: uv sync --extra wandb" + ) from None + + init_kwargs = dict(self.wandb_kwargs) + if self.project: + init_kwargs.setdefault("project", self.project) + if self.run_name: + init_kwargs.setdefault("name", self.run_name) + if self.entity: + init_kwargs.setdefault("entity", self.entity) + init_kwargs.setdefault("config", config) + self._run = wandb.init(**init_kwargs) + wandb.define_metric("val/*", step_metric="train/epoch") + + def log_step(self, metrics: dict[str, float], step: int) -> None: + """Log metrics for a single training step.""" + if self._run: + self._run.log(metrics, step=step) + + def log_epoch(self, metrics: dict[str, float], step: int) -> None: + """Log metrics for a completed epoch.""" + if self._run: + self._run.log(metrics, step=step) + + def finish(self) -> None: + """Finish the WandB run.""" + if self._run: + self._run.finish() + self._run = None diff --git a/tests/test_finetuning_logging.py b/tests/test_finetuning_logging.py new file mode 100644 index 000000000..1426af2fa --- /dev/null +++ b/tests/test_finetuning_logging.py @@ -0,0 +1,158 @@ +"""Tests for the finetuning experiment logging module.""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest + +from tabpfn.finetuning.finetuned_classifier import FinetunedTabPFNClassifier +from tabpfn.finetuning.logging import FinetuningLogger, NullLogger, WandbLogger + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_WANDB_SENTINEL = "_test_wandb_not_installed" + + +class TestNullLogger: + def test_implements_protocol(self): + logger: FinetuningLogger = NullLogger() + assert isinstance(logger, NullLogger) + + def test_all_methods_are_noop(self): + logger = NullLogger() + logger.setup({"lr": 0.01, "epochs": 10}) + logger.log_step({"train/loss": 0.5}, step=1) + logger.log_epoch({"val/accuracy": 0.9}, step=100) + logger.finish() + + +@pytest.fixture +def mock_wandb(monkeypatch): + """Inject a mock wandb module so WandbLogger.setup() picks it up.""" + mock = MagicMock() + mock_run = MagicMock() + mock.init.return_value = mock_run + monkeypatch.setitem(sys.modules, "wandb", mock) + return mock + # monkeypatch automatically restores sys.modules on teardown + + +class TestWandbLogger: + def test_init_stores_params(self): + logger = WandbLogger(project="test-proj", run_name="run-1", entity="team") + assert logger.project == "test-proj" + assert logger.run_name == "run-1" + assert logger.entity == "team" + assert logger.wandb_kwargs == {} + assert logger._run is None + + def test_setup_calls_wandb_init(self, mock_wandb): + mock_run = mock_wandb.init.return_value + + logger = WandbLogger(project="my-proj", run_name="my-run") + config = {"lr": 0.01, "epochs": 5} + logger.setup(config) + + mock_wandb.init.assert_called_once_with( + project="my-proj", name="my-run", config=config + ) + mock_wandb.define_metric.assert_called_once_with( + "val/*", step_metric="train/epoch" + ) + assert logger._run is mock_run + + def test_setup_passes_entity(self, mock_wandb): + logger = WandbLogger(project="p", entity="team") + logger.setup({}) + + call_kwargs = mock_wandb.init.call_args[1] + assert call_kwargs["entity"] == "team" + + def test_setup_does_not_override_explicit_kwargs(self, mock_wandb): + logger = WandbLogger( + project="default-proj", run_name="default-run", config={"custom": True} + ) + logger.setup({"lr": 0.01}) + + call_kwargs = mock_wandb.init.call_args[1] + assert call_kwargs["config"] == {"custom": True} + assert call_kwargs["project"] == "default-proj" + assert call_kwargs["name"] == "default-run" + + def test_log_step_delegates_to_run(self, mock_wandb): + mock_run = mock_wandb.init.return_value + + logger = WandbLogger() + logger.setup({}) + logger.log_step({"train/loss": 0.5}, step=42) + + mock_run.log.assert_called_once_with({"train/loss": 0.5}, step=42) + + def test_log_epoch_delegates_to_run(self, mock_wandb): + mock_run = mock_wandb.init.return_value + + logger = WandbLogger() + logger.setup({}) + logger.log_epoch({"val/accuracy": 0.95}, step=100) + + mock_run.log.assert_called_once_with({"val/accuracy": 0.95}, step=100) + + def test_log_step_noop_before_setup(self): + logger = WandbLogger() + logger.log_step({"train/loss": 0.5}, step=1) + + def test_log_epoch_noop_before_setup(self): + logger = WandbLogger() + logger.log_epoch({"val/acc": 0.9}, step=1) + + def test_finish_closes_run(self, mock_wandb): + mock_run = mock_wandb.init.return_value + + logger = WandbLogger() + logger.setup({}) + logger.finish() + + mock_run.finish.assert_called_once() + assert logger._run is None + + def test_finish_noop_before_setup(self): + logger = WandbLogger() + logger.finish() + + def test_double_finish_is_safe(self, mock_wandb): + mock_run = mock_wandb.init.return_value + + logger = WandbLogger() + logger.setup({}) + logger.finish() + logger.finish() + + mock_run.finish.assert_called_once() + + def test_setup_raises_readable_error_when_wandb_missing(self, monkeypatch): + """WandbLogger.setup() should raise when wandb is missing.""" + monkeypatch.setitem(sys.modules, "wandb", None) + logger = WandbLogger(project="p") + with pytest.raises(ModuleNotFoundError, match="wandb"): + logger.setup({}) + + +class TestClassifierMetricName: + """Verify _metric_name reflects the chosen eval_metric.""" + + def test_default_metric_is_roc_auc(self): + clf = FinetunedTabPFNClassifier() + # eval_metric defaults to None; _metric_name should return "ROC AUC" + assert clf._metric_name == "ROC AUC" + + def test_roc_auc_metric_name(self): + clf = FinetunedTabPFNClassifier(eval_metric="roc_auc") + assert clf._metric_name == "ROC AUC" + + def test_log_loss_metric_name(self): + clf = FinetunedTabPFNClassifier(eval_metric="log_loss") + assert clf._metric_name == "log_loss" From b8d3227fefc24e58f0263f67fd657664da7db7f0 Mon Sep 17 00:00:00 2001 From: ggprior Date: Wed, 8 Apr 2026 09:19:41 +0200 Subject: [PATCH 2/7] Georg/hf ungating flow improvements (#862) --- changelog/862.added.md | 1 + src/tabpfn/browser_auth.py | 217 +++++++++++++++++++++++--- src/tabpfn/errors.py | 22 ++- tests/test_browser_auth.py | 303 ++++++++++++++++++++++++++++++++++++- 4 files changed, 519 insertions(+), 24 deletions(-) create mode 100644 changelog/862.added.md diff --git a/changelog/862.added.md b/changelog/862.added.md new file mode 100644 index 000000000..8aea41f46 --- /dev/null +++ b/changelog/862.added.md @@ -0,0 +1 @@ +Add three-tier authentication flow: browser-based login for graphical environments, headless interactive login with clipboard copy for SSH/cluster sessions, and clear step-by-step instructions for fully non-interactive environments. diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index b5459e0a2..1b649efcc 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -35,6 +35,27 @@ # Short-circuits repeated calls within the same Python process. _accepted_repos: set[str] = set() + +# --------------------------------------------------------------------------- +# Environment detection +# --------------------------------------------------------------------------- + + +def _has_display() -> bool: + """Heuristic: is a graphical display likely available for opening a browser? + + Returns ``True`` when it is reasonable to call :func:`webbrowser.open`. + """ + if sys.platform == "win32": + return True + if sys.platform == "darwin": + # macOS has a display unless we are in a pure SSH session + # without X forwarding. + return not (os.environ.get("SSH_CONNECTION") and not os.environ.get("DISPLAY")) + # Linux / other Unix: require X11 or Wayland. + return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")) + + # --------------------------------------------------------------------------- # Token cache helpers # --------------------------------------------------------------------------- @@ -168,6 +189,24 @@ def check_license_accepted(token: str, api_url: str, version: str) -> bool | Non return None +# --------------------------------------------------------------------------- +# Terminal helpers (headless-interactive flow) +# --------------------------------------------------------------------------- + + +def _copy_osc52(text: str) -> None: + """Copy *text* to the system clipboard via the OSC 52 terminal escape. + + Works over SSH when the terminal emulator supports it (iTerm2, kitty, + Windows Terminal, most modern terminals). + """ + import base64 # noqa: PLC0415 + + encoded = base64.b64encode(text.encode()).decode() + sys.stdout.write(f"\033]52;c;{encoded}\a") + sys.stdout.flush() + + # --------------------------------------------------------------------------- # Browser login flow # --------------------------------------------------------------------------- @@ -242,10 +281,10 @@ def do_GET(self) -> None: f"{page_style}
" "" "
" - "

No token received

" - "

Please paste your token in the terminal, or visit " + "

No API key received

" + "

Please paste your API key in the terminal, or visit " f'{gui_url}/account ' - "to copy your Access Token.

" + "to copy your API Key.

" "
" ) self.wfile.write(html.encode()) @@ -277,7 +316,7 @@ def _poll_for_token( auth_event: threading.Event, received_token: list[str | None] ) -> str | None: """Read token from stdin or wait for browser callback, whichever comes first.""" - sys.stdout.write("Token (or press Enter to keep waiting): ") + sys.stdout.write("API key (or press Enter to keep waiting): ") sys.stdout.flush() while not auth_event.is_set(): ready, _, _ = select.select([sys.stdin], [], [], 0.5) @@ -289,22 +328,156 @@ def _poll_for_token( token = line.strip() if token: return token - sys.stdout.write("Token (or press Enter to keep waiting): ") + sys.stdout.write("API key (or press Enter to keep waiting): ") sys.stdout.flush() return received_token[0] +def _headless_interactive_login( + gui_url: str, hf_repo_id: str | None = None +) -> str | None: + """Token acquisition for headless but interactive environments (e.g. SSH). + + Shows the login URL, offers single-keypress clipboard copy via OSC 52, + and waits for the user to paste a token. + + Returns the JWT on success, or ``None`` on abort / EOF. + """ + login_url = f"{gui_url}/login" + if hf_repo_id: + login_url += f"?hf_repo_id={urllib.parse.quote(hf_repo_id)}" + + print( # noqa: T201 + "\nTabPFN requires a one-time license acceptance to download" + " model weights for local inference.\n" + "\nNo display detected. Open this URL in a browser on another device:\n" + f"\n {login_url}\n" + f"\nAfter logging in, accept the license on the Licenses tab,\n" + f"then copy your API Key from\n" + f" {gui_url}/account\n" + ) + + try: + import termios # noqa: PLC0415 + except ImportError: + termios = None # type: ignore[assignment] + + if termios is not None: + return _headless_cbreak_loop(login_url) + + # Fallback when termios is unavailable (shouldn't happen on Unix, + # but be safe). + return _headless_readline_loop(login_url) + + +def _read_token_cbreak(first_char: str) -> str | None: + """Read token characters in cbreak mode, echoing manually. + + *first_char* is the character that was already read (and not ``c``). + Returns the completed token string, or ``None`` on EOF / Ctrl+C. + """ + chars = [first_char] + sys.stdout.write(first_char) + sys.stdout.flush() + while True: + ch = sys.stdin.read(1) + if not ch or ch == "\x03": + sys.stdout.write("\n") + return None + if ch in ("\r", "\n"): + sys.stdout.write("\n") + sys.stdout.flush() + return "".join(chars).strip() or None + if ch in ("\x7f", "\x08"): # Backspace / Delete + if chars: + chars.pop() + sys.stdout.write("\b \b") + sys.stdout.flush() + continue + chars.append(ch) + sys.stdout.write(ch) + sys.stdout.flush() + + +def _headless_cbreak_loop(login_url: str) -> str | None: + """Headless input loop using cbreak mode (single-keypress, no Enter).""" + import termios # noqa: PLC0415 + import tty # noqa: PLC0415 + + fd = sys.stdin.fileno() + old = termios.tcgetattr(fd) + try: + tty.setcbreak(fd) + while True: + sys.stdout.write( + " [c] Copy URL to clipboard Paste your API key to continue\n\n> " + ) + sys.stdout.flush() + + ch = sys.stdin.read(1) + if not ch or ch == "\x03": # EOF / Ctrl+C + sys.stdout.write("\n") + return None + # Safe to intercept 'c': JWTs always start with 'ey' (base64 of '{') + if ch in ("c", "C"): + _copy_osc52(login_url) + sys.stdout.write("\r> \u2713 Copied to clipboard\n\n") + sys.stdout.flush() + continue + + token = _read_token_cbreak(ch) + if token: + return token + except KeyboardInterrupt: + sys.stdout.write("\n") + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old) + + +def _headless_readline_loop(login_url: str) -> str | None: + """Headless input loop using readline (Enter required, termios unavailable).""" + try: + while True: + sys.stdout.write( + " Type [c]+Enter to copy URL, or paste your API key:\n\n> " + ) + sys.stdout.flush() + line = sys.stdin.readline() + if not line: + return None + text = line.strip() + if text.lower() == "c": + _copy_osc52(login_url) + sys.stdout.write("\u2713 Copied to clipboard\n\n") + sys.stdout.flush() + continue + if text: + return text + except KeyboardInterrupt: + sys.stdout.write("\n") + return None + + def try_browser_login(gui_url: str, hf_repo_id: str | None = None) -> str | None: """Obtain a token via browser callback and/or manual paste concurrently. - Both the local callback server and the paste prompt run at the same time - so that neither blocks the other. + Chooses the right strategy based on the environment: + + * **Non-interactive** (no TTY): returns ``None`` immediately. + * **Headless interactive** (TTY but no display): shows the URL and waits + for the user to paste a token. + * **Graphical** (TTY + display): opens the browser and runs a local + callback server alongside a paste prompt. Returns the JWT on success, or ``None`` on failure / non-TTY environments. """ if not sys.stdin.isatty(): return None + if not _has_display(): + return _headless_interactive_login(gui_url, hf_repo_id=hf_repo_id) + auth_event = threading.Event() received_token: list[str | None] = [None] @@ -330,7 +503,8 @@ def try_browser_login(gui_url: str, hf_repo_id: str | None = None) -> str | None # --- print unified instructions --- print( # noqa: T201 - "\nTabPFN requires a one-time license acceptance." + "\nTabPFN requires a one-time license acceptance to download" + " model weights for local inference." "\nOpening your browser to complete login/registration…\n" f"\n {login_url}\n" "\nWaiting for login to complete…\n" @@ -339,8 +513,8 @@ def try_browser_login(gui_url: str, hf_repo_id: str | None = None) -> str | None " (log in or register if needed)\n" " 2. Accept the license at" f" {gui_url}/account/licenses\n" - " 3. Copy your Access Token\n" - " 4. Paste the token below\n" + " 3. Copy your API Key\n" + " 4. Paste the API key below\n" ) # --- main thread: poll stdin while waiting for callback --- @@ -416,26 +590,33 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 no_browser = os.environ.get("TABPFN_NO_BROWSER", "").strip() if no_browser and no_browser not in ("0", "false", "no", "off"): raise TabPFNLicenseError( - "TabPFN requires license acceptance, but browser login is\n" + "TabPFN requires a one-time license acceptance to download\n" + "model weights for local inference, but browser login is\n" "disabled (TABPFN_NO_BROWSER is set).\n\n" - "Set the TABPFN_TOKEN environment variable with a valid token\n" + "Set the TABPFN_TOKEN environment variable with a valid API key\n" "obtained from https://ux.priorlabs.ai" ) token = try_browser_login(gui_url, hf_repo_id=hf_repo_id) if token is None: raise TabPFNLicenseError( - "Browser login did not complete successfully.\n\n" - "If you are in a headless environment, set the TABPFN_TOKEN\n" - "environment variable with a valid token obtained from\n" - "https://ux.priorlabs.ai" + "TabPFN requires a one-time license acceptance to download\n" + "model weights for local inference, but no interactive terminal\n" + "is available.\n\n" + "To authenticate in a non-interactive environment:\n" + f" 1. Open {gui_url} in a browser and log in (or register)\n" + f" 2. Accept the license on the Licenses tab\n" + f" 3. Copy your API Key from {gui_url}/account\n" + ' 4. Set the environment variable: export TABPFN_TOKEN=""\n' + " or in Python (before calling .fit()):" + ' import os; os.environ["TABPFN_TOKEN"] = ""' ) # Verify the token we just received from the browser. status = verify_token(token, api_url) if status is False: raise TabPFNLicenseError( - "The token received from the browser login was rejected by the\n" + "The API key received from the browser login was rejected by the\n" "server. Please try again or contact support@priorlabs.ai" ) @@ -444,7 +625,7 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 license_status = check_license_accepted(token, api_url, license_version) if license_status is True: - print("License accepted — token cached for future sessions.\n") # noqa: T201 + print("License accepted — API key cached for future sessions.\n") # noqa: T201 _accepted_repos.add(hf_repo_id) return True if license_status is None: diff --git a/src/tabpfn/errors.py b/src/tabpfn/errors.py index 1f8c61d1a..dfcbbb7a2 100644 --- a/src/tabpfn/errors.py +++ b/src/tabpfn/errors.py @@ -10,6 +10,8 @@ import torch +from tabpfn.settings import settings + if TYPE_CHECKING: from tabpfn.constants import XType @@ -31,12 +33,22 @@ class TabPFNLicenseError(TabPFNError): def __init__(self, message: str | None = None): if message is None: + gui_url = settings.tabpfn.auth_gui_url message = ( - "TabPFN requires license acceptance before downloading.\n\n" - "To accept the license, run your script in an interactive terminal\n" - "so a browser window can open for login, or set the TABPFN_TOKEN\n" - "environment variable with a valid token obtained from\n" - "https://ux.priorlabs.ai" + "TabPFN requires a one-time license acceptance" + " to download model weights for local" + " inference.\n\n" + "To authenticate in a non-interactive" + " environment:\n" + f" 1. Open {gui_url} in a browser" + " and log in (or register)\n" + " 2. Accept the license on the Licenses tab\n" + " 3. Copy your API Key from" + f" {gui_url}/account\n" + " 4. Set the environment variable:" + ' export TABPFN_TOKEN=""\n' + " or in Python (before calling .fit()):" + ' import os; os.environ["TABPFN_TOKEN"] = ""' ) super().__init__(message) diff --git a/tests/test_browser_auth.py b/tests/test_browser_auth.py index be6def536..e3d232d25 100644 --- a/tests/test_browser_auth.py +++ b/tests/test_browser_auth.py @@ -2,6 +2,9 @@ from __future__ import annotations +import contextlib +import io +import sys import urllib.error import urllib.request from pathlib import Path @@ -10,6 +13,7 @@ import pytest from tabpfn.browser_auth import ( + _has_display, delete_cached_token, get_cached_token, save_token, @@ -318,7 +322,18 @@ def test_browser_login_returns_none_raises(self): "tabpfn.browser_auth.try_browser_login", return_value=None, ), - pytest.raises(TabPFNLicenseError, match="headless"), + pytest.raises(TabPFNLicenseError, match="no interactive terminal"), + ): + self._import_ensure()("tabpfn_2_6") + + def test_browser_login_returns_none_error_includes_steps(self): + """Non-interactive error should include step-by-step instructions.""" + with ( + patch( + "tabpfn.browser_auth.try_browser_login", + return_value=None, + ), + pytest.raises(TabPFNLicenseError, match="TABPFN_TOKEN"), ): self._import_ensure()("tabpfn_2_6") @@ -336,3 +351,289 @@ def test_browser_token_rejected_raises(self): pytest.raises(TabPFNLicenseError, match="rejected"), ): self._import_ensure()("tabpfn_2_6") + + +# --------------------------------------------------------------------------- +# _has_display +# --------------------------------------------------------------------------- + + +class TestHasDisplay: + def test_windows_always_true(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "win32") + assert _has_display() is True + + def test_macos_local_session(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "darwin") + monkeypatch.delenv("SSH_CONNECTION", raising=False) + monkeypatch.delenv("DISPLAY", raising=False) + assert _has_display() is True + + def test_macos_ssh_without_x_forwarding(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "darwin") + monkeypatch.setenv("SSH_CONNECTION", "1.2.3.4 5678 5.6.7.8 22") + monkeypatch.delenv("DISPLAY", raising=False) + assert _has_display() is False + + def test_macos_ssh_with_x_forwarding(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "darwin") + monkeypatch.setenv("SSH_CONNECTION", "1.2.3.4 5678 5.6.7.8 22") + monkeypatch.setenv("DISPLAY", "localhost:10.0") + assert _has_display() is True + + def test_linux_with_x11(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "linux") + monkeypatch.setenv("DISPLAY", ":0") + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) + assert _has_display() is True + + def test_linux_with_wayland(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "linux") + monkeypatch.delenv("DISPLAY", raising=False) + monkeypatch.setenv("WAYLAND_DISPLAY", "wayland-0") + assert _has_display() is True + + def test_linux_headless(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("tabpfn.browser_auth.sys.platform", "linux") + monkeypatch.delenv("DISPLAY", raising=False) + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) + assert _has_display() is False + + +# --------------------------------------------------------------------------- +# _headless_interactive_login +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + sys.platform == "win32", reason="termios/cbreak not available on Windows" +) +class TestHeadlessCbreakLoop: + """Tests for _headless_cbreak_loop (cbreak-mode input).""" + + def _import_cbreak_loop(self): # noqa: ANN202 + from tabpfn.browser_auth import _headless_cbreak_loop # noqa: PLC0415 + + return _headless_cbreak_loop + + def _fake_stdin(self, chars: str) -> io.StringIO: + """Create a fake stdin backed by a StringIO with a no-op fileno.""" + fake = io.StringIO(chars) + fake.fileno = lambda: 0 # type: ignore[assignment] + return fake + + @contextlib.contextmanager + def _patch_termios(self): # noqa: ANN202 + """Patch termios/tty so cbreak mode doesn't touch the real terminal.""" + import termios as _termios # noqa: PLC0415 + import tty as _tty # noqa: PLC0415 + + with ( + patch.object(_termios, "tcgetattr", return_value=[]), + patch.object(_termios, "tcsetattr"), + patch.object(_tty, "setcbreak"), + ): + yield + + def test_returns_pasted_token(self, monkeypatch: pytest.MonkeyPatch): + """Simulates pasting a full token followed by Enter.""" + cbreak_loop = self._import_cbreak_loop() + fake = self._fake_stdin("eyJhbGciOiJIUzI1NiJ9\r") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with self._patch_termios(): + result = cbreak_loop("https://ux.priorlabs.ai/login?hf_repo_id=tabpfn_2_6") + + assert result == "eyJhbGciOiJIUzI1NiJ9" + + def test_copy_then_paste(self, monkeypatch: pytest.MonkeyPatch): + """Press c to copy, then paste a token.""" + cbreak_loop = self._import_cbreak_loop() + # 'c' for copy, then token chars, then Enter + fake = self._fake_stdin("cmytok\r") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with ( + self._patch_termios(), + patch("tabpfn.browser_auth._copy_osc52") as mock_osc52, + ): + result = cbreak_loop("https://ux.priorlabs.ai/login") + + assert result == "mytok" + mock_osc52.assert_called_once_with("https://ux.priorlabs.ai/login") + + def test_eof_returns_none(self, monkeypatch: pytest.MonkeyPatch): + """EOF on first read returns None.""" + cbreak_loop = self._import_cbreak_loop() + fake = self._fake_stdin("") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with self._patch_termios(): + assert cbreak_loop("https://ux.priorlabs.ai/login") is None + + def test_ctrl_c_returns_none(self, monkeypatch: pytest.MonkeyPatch): + """Ctrl+C character returns None.""" + cbreak_loop = self._import_cbreak_loop() + fake = self._fake_stdin("\x03") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with self._patch_termios(): + assert cbreak_loop("https://ux.priorlabs.ai/login") is None + + def test_keyboard_interrupt_returns_none(self, monkeypatch: pytest.MonkeyPatch): + """KeyboardInterrupt during read returns None.""" + cbreak_loop = self._import_cbreak_loop() + fake = self._fake_stdin("") + fake.read = lambda _n: (_ for _ in ()).throw(KeyboardInterrupt) # type: ignore[assignment,method-assign] + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with self._patch_termios(): + assert cbreak_loop("https://ux.priorlabs.ai/login") is None + + def test_backspace_erases_char(self, monkeypatch: pytest.MonkeyPatch): + """Backspace removes the previous character.""" + cbreak_loop = self._import_cbreak_loop() + # Type 'ab', backspace, 'c', Enter → token is 'ac' + fake = self._fake_stdin("ab\x7fc\r") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + + with self._patch_termios(): + assert cbreak_loop("https://ux.priorlabs.ai/login") == "ac" + + +class TestHeadlessReadlineLoop: + """Tests for _headless_readline_loop (fallback when termios unavailable).""" + + def _import_readline_loop(self): # noqa: ANN202 + from tabpfn.browser_auth import _headless_readline_loop # noqa: PLC0415 + + return _headless_readline_loop + + def test_returns_token(self, monkeypatch: pytest.MonkeyPatch): + readline_loop = self._import_readline_loop() + fake = io.StringIO("my-tok-val\n") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + assert readline_loop("https://ux.priorlabs.ai/login") == "my-tok-val" + + def test_copy_then_token(self, monkeypatch: pytest.MonkeyPatch): + readline_loop = self._import_readline_loop() + fake = io.StringIO("c\nmy-tok-val\n") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + with patch("tabpfn.browser_auth._copy_osc52") as mock_osc52: + result = readline_loop("https://ux.priorlabs.ai/login") + assert result == "my-tok-val" + mock_osc52.assert_called_once() + + def test_eof_returns_none(self, monkeypatch: pytest.MonkeyPatch): + readline_loop = self._import_readline_loop() + fake = io.StringIO("") + monkeypatch.setattr("tabpfn.browser_auth.sys.stdin", fake) + assert readline_loop("https://ux.priorlabs.ai/login") is None + + +class TestHeadlessInteractiveLogin: + """Integration tests for _headless_interactive_login routing.""" + + def _import_headless(self): # noqa: ANN202 + from tabpfn.browser_auth import _headless_interactive_login # noqa: PLC0415 + + return _headless_interactive_login + + @pytest.mark.skipif( + sys.platform == "win32", reason="termios not available on Windows" + ) + def test_routes_to_cbreak_when_termios_available(self): + headless_login = self._import_headless() + with patch( + "tabpfn.browser_auth._headless_cbreak_loop", + return_value="jwt-val", + ) as mock_cbreak: + result = headless_login("https://ux.priorlabs.ai", hf_repo_id="tabpfn_2_6") + assert result == "jwt-val" + assert "tabpfn_2_6" in mock_cbreak.call_args[0][0] + + def test_routes_to_readline_without_termios(self): + headless_login = self._import_headless() + + import builtins # noqa: PLC0415 + + _real_import = builtins.__import__ + + def block_termios(name, *args, **kwargs): # noqa: ANN202 + if name == "termios": + raise ImportError("no termios") + return _real_import(name, *args, **kwargs) + + with ( + patch("builtins.__import__", side_effect=block_termios), + patch( + "tabpfn.browser_auth._headless_readline_loop", + return_value="jwt-val", + ) as mock_readline, + ): + result = headless_login("https://ux.priorlabs.ai") + assert result == "jwt-val" + mock_readline.assert_called_once() + + def test_login_url_includes_hf_repo_id(self, capsys: pytest.CaptureFixture[str]): + headless_login = self._import_headless() + with ( + patch("tabpfn.browser_auth._headless_cbreak_loop", return_value=None), + patch("tabpfn.browser_auth._headless_readline_loop", return_value=None), + ): + headless_login("https://ux.priorlabs.ai", hf_repo_id="tabpfn_2_6") + captured = capsys.readouterr() + assert "hf_repo_id=tabpfn_2_6" in captured.out + + +# --------------------------------------------------------------------------- +# try_browser_login routing +# --------------------------------------------------------------------------- + + +class TestTryBrowserLoginRouting: + def _import_try_login(self): # noqa: ANN202 + from tabpfn.browser_auth import try_browser_login # noqa: PLC0415 + + return try_browser_login + + def test_non_interactive_returns_none(self): + """Non-TTY stdin → returns None without attempting any login.""" + try_login = self._import_try_login() + with patch("tabpfn.browser_auth.sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = False + assert try_login("https://ux.priorlabs.ai") is None + + def test_headless_routes_to_headless_login(self): + """TTY + no display → delegates to _headless_interactive_login.""" + try_login = self._import_try_login() + with ( + patch("tabpfn.browser_auth.sys.stdin") as mock_stdin, + patch("tabpfn.browser_auth._has_display", return_value=False), + patch( + "tabpfn.browser_auth._headless_interactive_login", + return_value="headless-jwt", + ) as mock_headless, + ): + mock_stdin.isatty.return_value = True + result = try_login("https://ux.priorlabs.ai", hf_repo_id="tabpfn_2_6") + + assert result == "headless-jwt" + mock_headless.assert_called_once_with( + "https://ux.priorlabs.ai", hf_repo_id="tabpfn_2_6" + ) + + def test_graphical_opens_browser(self): + """TTY + display → opens browser (existing flow).""" + try_login = self._import_try_login() + with ( + patch("tabpfn.browser_auth.sys.stdin") as mock_stdin, + patch("tabpfn.browser_auth._has_display", return_value=True), + patch("tabpfn.browser_auth.webbrowser.open") as mock_browser, + patch("tabpfn.browser_auth._poll_for_token", return_value="browser-jwt"), + ): + mock_stdin.isatty.return_value = True + result = try_login("https://ux.priorlabs.ai") + + assert result == "browser-jwt" + mock_browser.assert_called_once() From f459046db162ee188b16d0d0f9c805c3bc753c54 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Wed, 8 Apr 2026 09:38:56 +0200 Subject: [PATCH 3/7] feat: telemetry funnel to understand if users have trouble with new flow --- src/tabpfn/browser_auth.py | 22 ++++++++++++- src/tabpfn/license_telemetry.py | 58 +++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 src/tabpfn/license_telemetry.py diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index 1b649efcc..f2e125b80 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING from tabpfn.errors import TabPFNLicenseError +from tabpfn.license_telemetry import track_license_event from tabpfn.settings import settings if TYPE_CHECKING: @@ -532,7 +533,7 @@ def try_browser_login(gui_url: str, hf_repo_id: str | None = None) -> str | None # --------------------------------------------------------------------------- -def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 +def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PLR0912 """Ensure the user has accepted the TabPFN license. Checks for a cached token, verifies it, and falls back to browser login @@ -565,8 +566,10 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 if license_status is True: save_token(token) _accepted_repos.add(hf_repo_id) + track_license_event("cached_token_valid") return True if license_status is None: + track_license_event("error", reason="server_unreachable") raise TabPFNLicenseError( "Could not reach the license server to verify acceptance.\n\n" "Please check your internet connection and try again." @@ -577,6 +580,7 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 "Token valid but license not accepted; opening browser for acceptance.", ) elif status is None: + track_license_event("error", reason="server_unreachable") raise TabPFNLicenseError( "Could not reach the license server to verify your token.\n\n" "Please check your internet connection and try again." @@ -587,8 +591,17 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 delete_cached_token() # No valid cached token — need browser login. + # Determine environment for telemetry. + if not sys.stdin.isatty(): + env = "non_interactive" + elif not _has_display(): + env = "headless_interactive" + else: + env = "graphical" + no_browser = os.environ.get("TABPFN_NO_BROWSER", "").strip() if no_browser and no_browser not in ("0", "false", "no", "off"): + track_license_event("error", environment=env, reason="no_browser_env") raise TabPFNLicenseError( "TabPFN requires a one-time license acceptance to download\n" "model weights for local inference, but browser login is\n" @@ -597,8 +610,11 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 "obtained from https://ux.priorlabs.ai" ) + track_license_event("started", environment=env) + token = try_browser_login(gui_url, hf_repo_id=hf_repo_id) if token is None: + track_license_event("error", environment=env, reason="aborted") raise TabPFNLicenseError( "TabPFN requires a one-time license acceptance to download\n" "model weights for local inference, but no interactive terminal\n" @@ -615,6 +631,7 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 # Verify the token we just received from the browser. status = verify_token(token, api_url) if status is False: + track_license_event("error", environment=env, reason="token_rejected") raise TabPFNLicenseError( "The API key received from the browser login was rejected by the\n" "server. Please try again or contact support@priorlabs.ai" @@ -627,13 +644,16 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 if license_status is True: print("License accepted — API key cached for future sessions.\n") # noqa: T201 _accepted_repos.add(hf_repo_id) + track_license_event("success", environment=env) return True if license_status is None: + track_license_event("error", environment=env, reason="server_unreachable") raise TabPFNLicenseError( "Could not reach the license server to verify acceptance.\n\n" "Please check your internet connection and try again." ) # license_status is False + track_license_event("error", environment=env, reason="license_not_accepted") encoded = urllib.parse.quote(hf_repo_id) raise TabPFNLicenseError( "License not yet accepted. Please complete the acceptance form at\n" diff --git a/src/tabpfn/license_telemetry.py b/src/tabpfn/license_telemetry.py new file mode 100644 index 000000000..c15cd4986 --- /dev/null +++ b/src/tabpfn/license_telemetry.py @@ -0,0 +1,58 @@ +"""Telemetry events for the license acceptance flow.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field + +from tabpfn_common_utils.telemetry import capture_event +from tabpfn_common_utils.telemetry.core.events import ( + BaseTelemetryEvent, + _get_install_id, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class LicenseFlowEvent(BaseTelemetryEvent): + """Event emitted at key points of the license acceptance flow. + + Used to build a funnel: started -> success vs error, broken down by + environment and failure reason. + """ + + outcome: str = "" + + environment: str | None = None + + method: str | None = None + + reason: str | None = None + + install_id: str = field(default_factory=_get_install_id, init=False) + + @property + def name(self) -> str: # noqa: D102 + return "license_flow" + + +def track_license_event( + outcome: str, + *, + environment: str | None = None, + method: str | None = None, + reason: str | None = None, +) -> None: + """Fire a license flow telemetry event, silently ignoring errors.""" + try: + capture_event( + LicenseFlowEvent( + outcome=outcome, + environment=environment, + method=method, + reason=reason, + ) + ) + except Exception: # noqa: BLE001 + logger.debug("Failed to capture license flow event", exc_info=True) From 874958090284b8213b6ea5342dfe8ba1e216a52a Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Wed, 8 Apr 2026 09:58:13 +0200 Subject: [PATCH 4/7] Add changelog entry for license flow telemetry Co-Authored-By: Claude Opus 4.6 (1M context) --- changelog/864.added.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/864.added.md diff --git a/changelog/864.added.md b/changelog/864.added.md new file mode 100644 index 000000000..951027845 --- /dev/null +++ b/changelog/864.added.md @@ -0,0 +1 @@ +Add telemetry funnel for the license acceptance flow to track user success rates and churn across graphical, headless, and non-interactive environments. From bf8298436a6341bb26e4a76ad8bb6708f8910b44 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Wed, 8 Apr 2026 10:02:17 +0200 Subject: [PATCH 5/7] chore: gemini comments --- src/tabpfn/browser_auth.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index f2e125b80..7af80ac92 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -57,6 +57,13 @@ def _has_display() -> bool: return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")) +def _get_env_type() -> str: + """Classify the current environment for telemetry and flow selection.""" + if not sys.stdin.isatty(): + return "non_interactive" + return "headless_interactive" if not _has_display() else "graphical" + + # --------------------------------------------------------------------------- # Token cache helpers # --------------------------------------------------------------------------- @@ -533,7 +540,7 @@ def try_browser_login(gui_url: str, hf_repo_id: str | None = None) -> str | None # --------------------------------------------------------------------------- -def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PLR0912 +def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901 """Ensure the user has accepted the TabPFN license. Checks for a cached token, verifies it, and falls back to browser login @@ -556,6 +563,7 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PL # Resolve the canonical license version string from HF; fall back to repo ID. license_version = _get_license_name(hf_repo_id) or hf_repo_id + env = _get_env_type() token = get_cached_token() if token is not None: @@ -566,10 +574,12 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PL if license_status is True: save_token(token) _accepted_repos.add(hf_repo_id) - track_license_event("cached_token_valid") + track_license_event("cached_token_valid", environment=env) return True if license_status is None: - track_license_event("error", reason="server_unreachable") + track_license_event( + "error", environment=env, reason="server_unreachable" + ) raise TabPFNLicenseError( "Could not reach the license server to verify acceptance.\n\n" "Please check your internet connection and try again." @@ -580,7 +590,7 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PL "Token valid but license not accepted; opening browser for acceptance.", ) elif status is None: - track_license_event("error", reason="server_unreachable") + track_license_event("error", environment=env, reason="server_unreachable") raise TabPFNLicenseError( "Could not reach the license server to verify your token.\n\n" "Please check your internet connection and try again." @@ -591,13 +601,6 @@ def ensure_license_accepted(hf_repo_id: str) -> Literal[True]: # noqa: C901, PL delete_cached_token() # No valid cached token — need browser login. - # Determine environment for telemetry. - if not sys.stdin.isatty(): - env = "non_interactive" - elif not _has_display(): - env = "headless_interactive" - else: - env = "graphical" no_browser = os.environ.get("TABPFN_NO_BROWSER", "").strip() if no_browser and no_browser not in ("0", "false", "no", "off"): From da928c2b103100496b9357dfa31211220460eec7 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Wed, 8 Apr 2026 10:07:40 +0200 Subject: [PATCH 6/7] Integrate license telemetry into tabpfn.telemetry module - Move LicenseFlowEvent + track_license_event into telemetry.py - Add LicenseFlowEvent to passthrough_events (fires before auth) - Delete license_telemetry.py - Make browser_auth import lazy in telemetry.py to avoid circular import - Fix test_telemetry.py to patch at source module Co-Authored-By: Claude Opus 4.6 (1M context) --- src/tabpfn/browser_auth.py | 2 +- src/tabpfn/license_telemetry.py | 58 ------------------------------- src/tabpfn/telemetry.py | 60 +++++++++++++++++++++++++++++++-- tests/test_telemetry.py | 8 ++--- 4 files changed, 61 insertions(+), 67 deletions(-) delete mode 100644 src/tabpfn/license_telemetry.py diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index 7af80ac92..e21b358be 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING from tabpfn.errors import TabPFNLicenseError -from tabpfn.license_telemetry import track_license_event from tabpfn.settings import settings +from tabpfn.telemetry import track_license_event if TYPE_CHECKING: from typing import Literal diff --git a/src/tabpfn/license_telemetry.py b/src/tabpfn/license_telemetry.py deleted file mode 100644 index c15cd4986..000000000 --- a/src/tabpfn/license_telemetry.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Telemetry events for the license acceptance flow.""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field - -from tabpfn_common_utils.telemetry import capture_event -from tabpfn_common_utils.telemetry.core.events import ( - BaseTelemetryEvent, - _get_install_id, -) - -logger = logging.getLogger(__name__) - - -@dataclass -class LicenseFlowEvent(BaseTelemetryEvent): - """Event emitted at key points of the license acceptance flow. - - Used to build a funnel: started -> success vs error, broken down by - environment and failure reason. - """ - - outcome: str = "" - - environment: str | None = None - - method: str | None = None - - reason: str | None = None - - install_id: str = field(default_factory=_get_install_id, init=False) - - @property - def name(self) -> str: # noqa: D102 - return "license_flow" - - -def track_license_event( - outcome: str, - *, - environment: str | None = None, - method: str | None = None, - reason: str | None = None, -) -> None: - """Fire a license flow telemetry event, silently ignoring errors.""" - try: - capture_event( - LicenseFlowEvent( - outcome=outcome, - environment=environment, - method=method, - reason=reason, - ) - ) - except Exception: # noqa: BLE001 - logger.debug("Failed to capture license flow event", exc_info=True) diff --git a/src/tabpfn/telemetry.py b/src/tabpfn/telemetry.py index 97df85926..1a4747ee7 100644 --- a/src/tabpfn/telemetry.py +++ b/src/tabpfn/telemetry.py @@ -10,7 +10,8 @@ from __future__ import annotations -from dataclasses import dataclass +import logging +from dataclasses import dataclass, field from typing import Any, cast import jwt @@ -30,6 +31,7 @@ ModelLoadEvent, PingEvent, SessionEvent, + _get_install_id, ) from tabpfn_common_utils.telemetry.core.service import ProductTelemetry from tabpfn_common_utils.telemetry.interactive import ( @@ -38,9 +40,58 @@ ping, ) -from tabpfn.browser_auth import get_cached_token as get_cached_auth_token +from tabpfn.auth_token import get_cached_token as get_cached_auth_token + +logger = logging.getLogger(__name__) + +__all__ = [ + "init", + "set_init_params", + "set_model_config", + "track_license_event", + "track_model_call", +] + + +@dataclass +class LicenseFlowEvent(BaseTelemetryEvent): + """Event emitted at key points of the license acceptance flow. + + Used to build a funnel: started -> success vs error, broken down by + environment and failure reason. + """ + + outcome: str = "" + environment: str | None = None + method: str | None = None + reason: str | None = None + install_id: str = field(default_factory=_get_install_id, init=False) + + @property + def name(self) -> str: + return "license_flow" + + +def track_license_event( + outcome: str, + *, + environment: str | None = None, + method: str | None = None, + reason: str | None = None, +) -> None: + """Fire a license flow telemetry event, silently ignoring errors.""" + try: + _capture_event_with_user_id( + LicenseFlowEvent( + outcome=outcome, + environment=environment, + method=method, + reason=reason, + ) + ) + except Exception: # noqa: BLE001 + logger.debug("Failed to capture license flow event", exc_info=True) -__all__ = ["init", "set_init_params", "set_model_config", "track_model_call"] _cached_user_id: str | None = None @@ -146,10 +197,13 @@ def _capture_event_with_user_id( # We passthrough the session and ping events anonymously. # These events still contain anonymous and valuable runtime metadata. + # LicenseFlowEvent must also pass through because it fires before/during + # authentication, when no user ID is available yet. passthrough_events = ( SessionEvent, PingEvent, ModelLoadEvent, + LicenseFlowEvent, ) if user_id is None and not isinstance(event, passthrough_events): return diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index dd8932364..de97d3874 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -61,15 +61,13 @@ def test_returns_config_when_complete( class TestGetUserId: def test_returns_none_when_no_token(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(telemetry_module, "get_cached_auth_token", lambda: None) + monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: None) assert telemetry_module._get_user_id() is None def test_returns_none_when_token_invalid( self, monkeypatch: pytest.MonkeyPatch ) -> None: - monkeypatch.setattr( - telemetry_module, "get_cached_auth_token", lambda: "not-a-jwt" - ) + monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: "not-a-jwt") assert telemetry_module._get_user_id() is None def test_returns_user_claim_when_token_valid( @@ -81,7 +79,7 @@ def test_returns_user_claim_when_token_valid( "x" * 32, algorithm="HS256", ) - monkeypatch.setattr(telemetry_module, "get_cached_auth_token", lambda: token) + monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: token) assert telemetry_module._get_user_id() == "user-123" From b7eb3560e01f80799fe97194a8b41a80c8b31162 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Wed, 8 Apr 2026 10:11:33 +0200 Subject: [PATCH 7/7] chore: resolve circular dep -> extract auth_token utilities into separate file --- src/tabpfn/auth_token.py | 54 ++++++++++++++++++++++++++++++++++++++ src/tabpfn/browser_auth.py | 47 +-------------------------------- tests/test_browser_auth.py | 10 +++---- tests/test_telemetry.py | 8 +++--- 4 files changed, 64 insertions(+), 55 deletions(-) create mode 100644 src/tabpfn/auth_token.py diff --git a/src/tabpfn/auth_token.py b/src/tabpfn/auth_token.py new file mode 100644 index 000000000..2a8f269fd --- /dev/null +++ b/src/tabpfn/auth_token.py @@ -0,0 +1,54 @@ +"""Token cache I/O for TabPFN authentication. + +Pure I/O helpers with no dependencies on other TabPFN modules, so they +can be imported from both ``browser_auth`` and ``telemetry`` without +creating a circular import. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +_CACHE_DIR = Path.home() / ".cache" / "tabpfn" +_TOKEN_FILE = _CACHE_DIR / "auth_token" + +# tabpfn-client stores its token here — we read it as a fallback. +_CLIENT_TOKEN_FILE = Path.home() / ".tabpfn" / "token" + + +def get_cached_token() -> str | None: + """Return a cached token. + + Checks (in priority order): + + 1. ``TABPFN_TOKEN`` environment variable + 2. ``~/.cache/tabpfn/auth_token`` + 3. ``~/.tabpfn/token`` (tabpfn-client's cache) + """ + env_token = os.environ.get("TABPFN_TOKEN") + if env_token: + return env_token.strip() or None + + for path in (_TOKEN_FILE, _CLIENT_TOKEN_FILE): + if path.is_file(): + token = path.read_text().strip() + if len(token) > 0: + return token + + return None + + +def save_token(token: str) -> None: + """Persist *token* to ``~/.cache/tabpfn/auth_token``.""" + _CACHE_DIR.mkdir(parents=True, exist_ok=True) + _TOKEN_FILE.write_text(token) + logger.debug("Token saved to %s", _TOKEN_FILE) + + +def delete_cached_token() -> None: + """Remove the cached token file (if it exists).""" + _TOKEN_FILE.unlink(missing_ok=True) diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index e21b358be..ac11c468b 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -20,9 +20,9 @@ import urllib.parse import urllib.request import webbrowser -from pathlib import Path from typing import TYPE_CHECKING +from tabpfn.auth_token import delete_cached_token, get_cached_token, save_token from tabpfn.errors import TabPFNLicenseError from tabpfn.settings import settings from tabpfn.telemetry import track_license_event @@ -64,51 +64,6 @@ def _get_env_type() -> str: return "headless_interactive" if not _has_display() else "graphical" -# --------------------------------------------------------------------------- -# Token cache helpers -# --------------------------------------------------------------------------- - -_CACHE_DIR = Path.home() / ".cache" / "tabpfn" -_TOKEN_FILE = _CACHE_DIR / "auth_token" - -# tabpfn-client stores its token here — we read it as a fallback. -_CLIENT_TOKEN_FILE = Path.home() / ".tabpfn" / "token" - - -def get_cached_token() -> str | None: - """Return a cached token. - - Checks (in priority order): - - 1. ``TABPFN_TOKEN`` environment variable - 2. ``~/.cache/tabpfn/auth_token`` - 3. ``~/.tabpfn/token`` (tabpfn-client's cache) - """ - env_token = os.environ.get("TABPFN_TOKEN") - if env_token: - return env_token.strip() or None - - for path in (_TOKEN_FILE, _CLIENT_TOKEN_FILE): - if path.is_file(): - token = path.read_text().strip() - if len(token) > 0: - return token - - return None - - -def save_token(token: str) -> None: - """Persist *token* to ``~/.cache/tabpfn/auth_token``.""" - _CACHE_DIR.mkdir(parents=True, exist_ok=True) - _TOKEN_FILE.write_text(token) - logger.debug("Token saved to %s", _TOKEN_FILE) - - -def delete_cached_token() -> None: - """Remove the cached token file (if it exists).""" - _TOKEN_FILE.unlink(missing_ok=True) - - # --------------------------------------------------------------------------- # Token verification # --------------------------------------------------------------------------- diff --git a/tests/test_browser_auth.py b/tests/test_browser_auth.py index e3d232d25..6b35865dc 100644 --- a/tests/test_browser_auth.py +++ b/tests/test_browser_auth.py @@ -12,11 +12,9 @@ import pytest +from tabpfn.auth_token import delete_cached_token, get_cached_token, save_token from tabpfn.browser_auth import ( _has_display, - delete_cached_token, - get_cached_token, - save_token, verify_token, ) from tabpfn.errors import TabPFNLicenseError @@ -33,9 +31,9 @@ def _isolate_token_paths(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Non token_file = cache_dir / "auth_token" client_file = tmp_path / ".tabpfn" / "token" - monkeypatch.setattr("tabpfn.browser_auth._CACHE_DIR", cache_dir) - monkeypatch.setattr("tabpfn.browser_auth._TOKEN_FILE", token_file) - monkeypatch.setattr("tabpfn.browser_auth._CLIENT_TOKEN_FILE", client_file) + monkeypatch.setattr("tabpfn.auth_token._CACHE_DIR", cache_dir) + monkeypatch.setattr("tabpfn.auth_token._TOKEN_FILE", token_file) + monkeypatch.setattr("tabpfn.auth_token._CLIENT_TOKEN_FILE", client_file) # Reset in-process cache so tests don't leak state. monkeypatch.setattr("tabpfn.browser_auth._accepted_repos", set()) diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index de97d3874..4e2293b22 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -61,13 +61,15 @@ def test_returns_config_when_complete( class TestGetUserId: def test_returns_none_when_no_token(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: None) + monkeypatch.setattr("tabpfn.telemetry.get_cached_auth_token", lambda: None) assert telemetry_module._get_user_id() is None def test_returns_none_when_token_invalid( self, monkeypatch: pytest.MonkeyPatch ) -> None: - monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: "not-a-jwt") + monkeypatch.setattr( + "tabpfn.telemetry.get_cached_auth_token", lambda: "not-a-jwt" + ) assert telemetry_module._get_user_id() is None def test_returns_user_claim_when_token_valid( @@ -79,7 +81,7 @@ def test_returns_user_claim_when_token_valid( "x" * 32, algorithm="HS256", ) - monkeypatch.setattr("tabpfn.browser_auth.get_cached_token", lambda: token) + monkeypatch.setattr("tabpfn.telemetry.get_cached_auth_token", lambda: token) assert telemetry_module._get_user_id() == "user-123"