diff --git a/.gitignore b/.gitignore index e5742f754..cc6ab0406 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# WandB +wandb + .DS_Store ./src/.DS_Store diff --git a/changelog/815.added.md b/changelog/815.added.md new file mode 100644 index 000000000..7e2438cdd --- /dev/null +++ b/changelog/815.added.md @@ -0,0 +1 @@ +Add modular experiment logging for finetuning with `experiment_logger` parameter, including `WandbLogger` for W&B tracking and a `FinetuningLogger` protocol for custom integrations. diff --git a/changelog/862.added.md b/changelog/862.added.md new file mode 100644 index 000000000..8aea41f46 --- /dev/null +++ b/changelog/862.added.md @@ -0,0 +1 @@ +Add three-tier authentication flow: browser-based login for graphical environments, headless interactive login with clipboard copy for SSH/cluster sessions, and clear step-by-step instructions for fully non-interactive environments. diff --git a/changelog/864.added.md b/changelog/864.added.md new file mode 100644 index 000000000..951027845 --- /dev/null +++ b/changelog/864.added.md @@ -0,0 +1 @@ +Add telemetry funnel for the license acceptance flow to track user success rates and churn across graphical, headless, and non-interactive environments. diff --git a/examples/finetune_classifier.py b/examples/finetune_classifier.py index a1a338eb4..6cc312385 100644 --- a/examples/finetune_classifier.py +++ b/examples/finetune_classifier.py @@ -115,6 +115,9 @@ def main() -> None: print("--- 2. Initializing and Fitting Model ---\n") # Instantiate the wrapper with your desired hyperparameters + # To enable WandB logging, pass an experiment_logger: + # . from tabpfn.finetuning.logging import WandbLogger + # experiment_logger=WandbLogger(project="my-project", run_name="my-run", entity="my-entity") finetuned_clf = FinetunedTabPFNClassifier( device="cuda", epochs=NUM_EPOCHS, diff --git a/examples/finetune_regressor.py b/examples/finetune_regressor.py index ae203da11..bc6bafbf8 100644 --- a/examples/finetune_regressor.py +++ b/examples/finetune_regressor.py @@ -101,6 +101,9 @@ def main() -> None: print("--- 2. Initializing and Fitting Model ---\n") # Instantiate the wrapper with your desired hyperparameters + # To enable WandB logging, pass an experiment_logger: + # . from tabpfn.finetuning.logging import WandbLogger + # experiment_logger=WandbLogger(project="my-project", run_name="my-run", entity="my-entity") finetuned_reg = FinetunedTabPFNRegressor( device="cuda", epochs=NUM_EPOCHS, diff --git a/pyproject.toml b/pyproject.toml index 30e84a5c8..80a445e68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,9 @@ classifiers = [ ] license = { file = "LICENSE" } +[project.optional-dependencies] +wandb = ["wandb>=0.25.1"] + [project.urls] documentation = "https://priorlabs.ai/docs" source = "https://github.com/priorlabs/tabpfn" diff --git a/src/tabpfn/auth_token.py b/src/tabpfn/auth_token.py new file mode 100644 index 000000000..2a8f269fd --- /dev/null +++ b/src/tabpfn/auth_token.py @@ -0,0 +1,54 @@ +"""Token cache I/O for TabPFN authentication. + +Pure I/O helpers with no dependencies on other TabPFN modules, so they +can be imported from both ``browser_auth`` and ``telemetry`` without +creating a circular import. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +_CACHE_DIR = Path.home() / ".cache" / "tabpfn" +_TOKEN_FILE = _CACHE_DIR / "auth_token" + +# tabpfn-client stores its token here — we read it as a fallback. +_CLIENT_TOKEN_FILE = Path.home() / ".tabpfn" / "token" + + +def get_cached_token() -> str | None: + """Return a cached token. + + Checks (in priority order): + + 1. ``TABPFN_TOKEN`` environment variable + 2. ``~/.cache/tabpfn/auth_token`` + 3. ``~/.tabpfn/token`` (tabpfn-client's cache) + """ + env_token = os.environ.get("TABPFN_TOKEN") + if env_token: + return env_token.strip() or None + + for path in (_TOKEN_FILE, _CLIENT_TOKEN_FILE): + if path.is_file(): + token = path.read_text().strip() + if len(token) > 0: + return token + + return None + + +def save_token(token: str) -> None: + """Persist *token* to ``~/.cache/tabpfn/auth_token``.""" + _CACHE_DIR.mkdir(parents=True, exist_ok=True) + _TOKEN_FILE.write_text(token) + logger.debug("Token saved to %s", _TOKEN_FILE) + + +def delete_cached_token() -> None: + """Remove the cached token file (if it exists).""" + _TOKEN_FILE.unlink(missing_ok=True) diff --git a/src/tabpfn/browser_auth.py b/src/tabpfn/browser_auth.py index b5459e0a2..ac11c468b 100644 --- a/src/tabpfn/browser_auth.py +++ b/src/tabpfn/browser_auth.py @@ -20,11 +20,12 @@ import urllib.parse import urllib.request import webbrowser -from pathlib import Path from typing import TYPE_CHECKING +from tabpfn.auth_token import delete_cached_token, get_cached_token, save_token from tabpfn.errors import TabPFNLicenseError from tabpfn.settings import settings +from tabpfn.telemetry import track_license_event if TYPE_CHECKING: from typing import Literal @@ -35,49 +36,32 @@ # Short-circuits repeated calls within the same Python process. _accepted_repos: set[str] = set() + # --------------------------------------------------------------------------- -# Token cache helpers +# Environment detection # --------------------------------------------------------------------------- -_CACHE_DIR = Path.home() / ".cache" / "tabpfn" -_TOKEN_FILE = _CACHE_DIR / "auth_token" - -# tabpfn-client stores its token here — we read it as a fallback. -_CLIENT_TOKEN_FILE = Path.home() / ".tabpfn" / "token" - - -def get_cached_token() -> str | None: - """Return a cached token. - Checks (in priority order): +def _has_display() -> bool: + """Heuristic: is a graphical display likely available for opening a browser? - 1. ``TABPFN_TOKEN`` environment variable - 2. ``~/.cache/tabpfn/auth_token`` - 3. ``~/.tabpfn/token`` (tabpfn-client's cache) + Returns ``True`` when it is reasonable to call :func:`webbrowser.open`. """ - env_token = os.environ.get("TABPFN_TOKEN") - if env_token: - return env_token.strip() or None - - for path in (_TOKEN_FILE, _CLIENT_TOKEN_FILE): - if path.is_file(): - token = path.read_text().strip() - if len(token) > 0: - return token - - return None - - -def save_token(token: str) -> None: - """Persist *token* to ``~/.cache/tabpfn/auth_token``.""" - _CACHE_DIR.mkdir(parents=True, exist_ok=True) - _TOKEN_FILE.write_text(token) - logger.debug("Token saved to %s", _TOKEN_FILE) + if sys.platform == "win32": + return True + if sys.platform == "darwin": + # macOS has a display unless we are in a pure SSH session + # without X forwarding. + return not (os.environ.get("SSH_CONNECTION") and not os.environ.get("DISPLAY")) + # Linux / other Unix: require X11 or Wayland. + return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")) -def delete_cached_token() -> None: - """Remove the cached token file (if it exists).""" - _TOKEN_FILE.unlink(missing_ok=True) +def _get_env_type() -> str: + """Classify the current environment for telemetry and flow selection.""" + if not sys.stdin.isatty(): + return "non_interactive" + return "headless_interactive" if not _has_display() else "graphical" # --------------------------------------------------------------------------- @@ -168,6 +152,24 @@ def check_license_accepted(token: str, api_url: str, version: str) -> bool | Non return None +# --------------------------------------------------------------------------- +# Terminal helpers (headless-interactive flow) +# --------------------------------------------------------------------------- + + +def _copy_osc52(text: str) -> None: + """Copy *text* to the system clipboard via the OSC 52 terminal escape. + + Works over SSH when the terminal emulator supports it (iTerm2, kitty, + Windows Terminal, most modern terminals). + """ + import base64 # noqa: PLC0415 + + encoded = base64.b64encode(text.encode()).decode() + sys.stdout.write(f"\033]52;c;{encoded}\a") + sys.stdout.flush() + + # --------------------------------------------------------------------------- # Browser login flow # --------------------------------------------------------------------------- @@ -242,10 +244,10 @@ def do_GET(self) -> None: f"{page_style}
Please paste your token in the terminal, or visit " + "
Please paste your API key in the terminal, or visit " f'{gui_url}/account ' - "to copy your Access Token.
" + "to copy your API Key." "