Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/tabpfn_client/api_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from uuid import UUID
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field

# Classification output_type="preds" preserves the original label type, so
Expand Down Expand Up @@ -103,14 +103,12 @@ class FitRequest(BaseModel):
# `tabpfn_systems` values on the server need this at fit time; the
# server ignores it otherwise.
tabpfn_config: TabPFNConfig = None
# Drives model selection + ensemble weighting during the enhanced-fit
# sweep. Only consulted when `"enhanced"` is in `tabpfn_systems`. None
# falls back to the sweep's default per problem type.
enhanced_fit_mode_metric: Optional[str] = None
# Ceiling on the enhanced-fit sweep (seconds). Only consulted when
# `"enhanced"` is in `tabpfn_systems`. None falls back to the server
# default (300s).
enhanced_fit_time_limit_s: Optional[float] = None
# User-facing thinking-effort level. None disables it.
thinking_effort: Optional[Literal["medium", "high"]] = None
# Budget for the fit (seconds). Only consulted when `thinking_effort` is set.
thinking_timeout_s: Optional[float] = None
# Optimization metric for the fit. Only consulted when `thinking_effort` is set.
thinking_metric: Optional[str] = None


class FitResponse(BaseModel):
Expand Down
60 changes: 34 additions & 26 deletions src/tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,28 +388,33 @@ def fit(
raise

tabpfn_systems = ["preprocessing", "text"]
# Thinking is enabled when either flag is set: explicit `thinking_mode=True`,
# or any non-None `thinking_effort`. Setting `thinking_effort` alone is
# enough — the server-side validator on FitRequest also normalises this,
# but doing it here means the request body itself is consistent.
thinking_enabled = bool(tabpfn_config) and (
bool(tabpfn_config.get("thinking_mode"))
or tabpfn_config.get("thinking_effort") is not None
)
if tabpfn_config:
if tabpfn_config.get("paper_version") is True:
tabpfn_systems = []
elif tabpfn_config.get("enhanced_fit_mode") is True:
# Enhanced mode runs on top of the base systems rather than
elif thinking_enabled:
# Thinking runs on top of the base systems rather than
# replacing them — keep preprocessing + text alongside it.
tabpfn_systems = ["preprocessing", "text", "enhanced"]

# `enhanced_fit_mode_metric` and `enhanced_fit_mode_time_limit_s`
# are top-level FitRequest fields on the server (siblings to
# `tabpfn_systems`), not part of `tabpfn_config`. Lift them out
# before stripping the rest of the client-only keys. The server
# field drops the `mode_` infix (`enhanced_fit_time_limit_s`);
# units are seconds on both sides, no conversion.
enhanced_fit_mode_metric = (
tabpfn_config.get("enhanced_fit_mode_metric") if tabpfn_config else None
)
enhanced_fit_time_limit_s = (
tabpfn_config.get("enhanced_fit_mode_time_limit_s")
if tabpfn_config
else None
)
tabpfn_systems = ["preprocessing", "text", "thinking"]

# The client-side `thinking_*` knobs forward 1:1 to the server's
# top-level FitRequest fields. When the user enabled thinking via
# `thinking_mode=True` without picking a level, default to "medium".
if thinking_enabled and tabpfn_config:
thinking_effort = tabpfn_config.get("thinking_effort") or "medium"
thinking_timeout_s = tabpfn_config.get("thinking_timeout_s")
thinking_metric = tabpfn_config.get("thinking_metric")
else:
thinking_effort = None
thinking_timeout_s = None
thinking_metric = None

# Strip client-only keys that the server does not expect (mirrors
# the predict path's filter below).
Expand All @@ -420,9 +425,10 @@ def fit(
if k
not in {
"paper_version",
"enhanced_fit_mode",
"enhanced_fit_mode_metric",
"enhanced_fit_mode_time_limit_s",
"thinking_mode",
"thinking_effort",
"thinking_timeout_s",
"thinking_metric",
}
}
if tabpfn_config is not None
Expand All @@ -436,8 +442,9 @@ def fit(
tabpfn_systems=tabpfn_systems,
force_refit=force_refit or force_refit_enabled(),
tabpfn_config=server_tabpfn_config,
enhanced_fit_mode_metric=enhanced_fit_mode_metric,
enhanced_fit_time_limit_s=enhanced_fit_time_limit_s,
thinking_effort=thinking_effort,
thinking_timeout_s=thinking_timeout_s,
thinking_metric=thinking_metric,
),
timeout=client_options.timeout,
headers=client_options.headers,
Expand Down Expand Up @@ -584,9 +591,10 @@ def predict(
if k
not in {
"paper_version",
"enhanced_fit_mode",
"enhanced_fit_mode_metric",
"enhanced_fit_mode_time_limit_s",
"thinking_mode",
"thinking_effort",
"thinking_timeout_s",
"thinking_metric",
}
}

Expand Down
178 changes: 123 additions & 55 deletions src/tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
# is kept as a backward-compatible alias.
_AUTO_MODEL_PATH_ALIASES = frozenset({"auto", "default"})

ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S = 40 * 60
THINKING_TIMEOUT_MAX_S = 40 * 60

ThinkingEffort = Literal["medium", "high"]
_VALID_THINKING_EFFORT_LEVELS = frozenset({"medium", "high"})


class TabPFNModelSelection:
Expand Down Expand Up @@ -180,9 +183,10 @@ def __init__(
] = 0,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
enhanced_fit_mode: bool = False,
enhanced_fit_mode_metric: Optional[str] = None,
enhanced_fit_mode_time_limit_s: Optional[float] = None,
thinking_mode: bool = False,
thinking_effort: Optional[ThinkingEffort] = None,
thinking_timeout_s: Optional[float] = None,
thinking_metric: Optional[str] = None,
force_refit: bool = False,
client_options: ClientOptions | None = None,
):
Expand Down Expand Up @@ -240,24 +244,39 @@ def __init__(
paper_version: bool, default=False
If True, will use the model described in the paper, instead of the newest
version available on the API, which e.g handles text features better.
enhanced_fit_mode: bool, default=False
If True, trades off fit time for precision by running an
automated feature-engineering pipeline on top of TabPFN during
fit.
enhanced_fit_mode_metric: str or None, default=None
Only consulted when `enhanced_fit_mode=True`. Drives model
selection + ensemble weighting during the enhanced-fit sweep
(e.g. "accuracy"/"log_loss"/"roc_auc"/"balanced_accuracy"/
"f1" for classification). None falls back to the sweep's
default for the problem type. Distinct from the local
`eval_metric`/`tuning_config` knobs used for decision-threshold
tuning on the standalone TabPFN classifier.
enhanced_fit_mode_time_limit_s: float or None, default=None
Only consulted when `enhanced_fit_mode=True`. Ceiling on the
enhanced-fit sweep (seconds). Raise for larger datasets where
the default ~5-minute sweep leaves performance on the table.
None falls back to the server-side default (300s). Capped at
2400 seconds (40 minutes); higher values raise ValueError at fit.
thinking_mode: bool, default=False
If True, spend extra fit-time compute for higher precision.
Equivalent to passing `thinking_effort="medium"` — setting any
`thinking_effort` value also enables thinking, so this flag is
optional when you've set the level explicitly.
thinking_effort: {"medium", "high"} or None, default=None
Effort level for thinking mode. When set, thinking is enabled
(you don't also need `thinking_mode=True`). When None and
`thinking_mode=True`, defaults to "medium".
thinking_timeout_s: float or None, default=None
Budget for the fit, in seconds. Only consulted when thinking is
enabled. Capped at 2400.
thinking_metric: str or None, default=None
Optimization metric for the fit. Only consulted when thinking
is enabled.

Binary classification:
"accuracy", "balanced_accuracy", "mcc", "log_loss",
"pac", "quadratic_kappa", "roc_auc", "average_precision",
"precision", "precision_macro", "precision_micro",
"precision_weighted", "recall", "recall_macro",
"recall_micro", "recall_weighted", "f1", "f1_macro",
"f1_micro", "f1_weighted".
Multiclass classification:
"accuracy", "balanced_accuracy", "mcc", "log_loss",
"pac", "quadratic_kappa", "precision_macro",
"precision_micro", "precision_weighted", "recall_macro",
"recall_micro", "recall_weighted", "f1_macro",
"f1_micro", "f1_weighted", "roc_auc_ovo",
"roc_auc_ovo_macro", "roc_auc_ovr", "roc_auc_ovr_macro",
"roc_auc_ovr_micro", "roc_auc_ovr_weighted".

Aliases "acc", "nll", "pac_score" are also accepted.
"""
self.model_path = model_path
self.n_estimators = n_estimators
Expand All @@ -269,9 +288,10 @@ def __init__(
self.random_state = random_state
self.inference_config = inference_config
self.paper_version = paper_version
self.enhanced_fit_mode = enhanced_fit_mode
self.enhanced_fit_mode_metric = enhanced_fit_mode_metric
self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s
self.thinking_mode = thinking_mode
self.thinking_effort = thinking_effort
self.thinking_timeout_s = thinking_timeout_s
self.thinking_metric = thinking_metric
self.force_refit = force_refit
self.client_options = client_options or ClientOptions()

Expand All @@ -294,7 +314,12 @@ def fit(

estimator_param = self._get_estimator_params_with_model_path("classification")
validate_train_set(X, y)
validate_enhanced_fit_mode_time_limit(self.enhanced_fit_mode_time_limit_s)
validate_thinking_mode(
self.thinking_mode,
self.thinking_effort,
self.thinking_timeout_s,
self.thinking_metric,
)
X = _clean_text_features(X)
self._validate_targets_and_classes(y)

Expand Down Expand Up @@ -467,9 +492,10 @@ def __init__(
] = 0,
inference_config: Optional[Dict] = None,
paper_version: bool = False,
enhanced_fit_mode: bool = False,
enhanced_fit_mode_metric: Optional[str] = None,
enhanced_fit_mode_time_limit_s: Optional[float] = None,
thinking_mode: bool = False,
thinking_effort: Optional[ThinkingEffort] = None,
thinking_timeout_s: Optional[float] = None,
thinking_metric: Optional[str] = None,
force_refit: bool = False,
client_options: ClientOptions | None = None,
):
Expand Down Expand Up @@ -519,21 +545,31 @@ def __init__(
paper_version: bool, default=False
If True, will use the model described in the paper, instead of the newest
version available on the API, which e.g handles text features better.
enhanced_fit_mode: bool, default=False
If True, trades off fit time for precision by running an
automated feature-engineering pipeline on top of TabPFN during
fit.
enhanced_fit_mode_metric: str or None, default=None
Only consulted when `enhanced_fit_mode=True`. Drives model
selection + ensemble weighting during the enhanced-fit sweep
(e.g. "rmse"/"mae"/"r2"/"mape" for regression). None falls
back to the sweep's default for the problem type.
enhanced_fit_mode_time_limit_s: float or None, default=None
Only consulted when `enhanced_fit_mode=True`. Ceiling on the
enhanced-fit sweep (seconds). Raise for larger datasets where
the default ~5-minute sweep leaves performance on the table.
None falls back to the server-side default (300s). Capped at
2400 seconds (40 minutes); higher values raise ValueError at fit.
thinking_mode: bool, default=False
If True, spend extra fit-time compute for higher precision.
Equivalent to passing `thinking_effort="medium"` — setting any
`thinking_effort` value also enables thinking, so this flag is
optional when you've set the level explicitly.
thinking_effort: {"medium", "high"} or None, default=None
Effort level for thinking mode. When set, thinking is enabled
(you don't also need `thinking_mode=True`). When None and
`thinking_mode=True`, defaults to "medium".
thinking_timeout_s: float or None, default=None
Budget for the fit, in seconds. Only consulted when thinking is
enabled. Capped at 2400.
thinking_metric: str or None, default=None
Optimization metric for the fit. Only consulted when thinking
is enabled.

Regression:
"r2", "mean_squared_error", "root_mean_squared_error",
"mean_absolute_error", "median_absolute_error",
"mean_absolute_percentage_error",
"symmetric_mean_absolute_percentage_error", "spearmanr",
"pearsonr".

Aliases "mse", "rmse", "mae", "mape", "smape" are also
accepted.
force_refit: bool, default=False
Whether to force refit the model even if the model has already been fitted.
client_options : ClientOptions, default=None
Expand All @@ -548,9 +584,10 @@ def __init__(
self.random_state = random_state
self.inference_config = inference_config
self.paper_version = paper_version
self.enhanced_fit_mode = enhanced_fit_mode
self.enhanced_fit_mode_metric = enhanced_fit_mode_metric
self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s
self.thinking_mode = thinking_mode
self.thinking_effort = thinking_effort
self.thinking_timeout_s = thinking_timeout_s
self.thinking_metric = thinking_metric
self.force_refit = force_refit
self.client_options = client_options or ClientOptions()

Expand All @@ -573,7 +610,12 @@ def fit(

estimator_param = self._get_estimator_params_with_model_path("regression")
validate_train_set(X, y)
validate_enhanced_fit_mode_time_limit(self.enhanced_fit_mode_time_limit_s)
validate_thinking_mode(
self.thinking_mode,
self.thinking_effort,
self.thinking_timeout_s,
self.thinking_metric,
)
self._validate_targets(y)
X = _clean_text_features(X)

Expand Down Expand Up @@ -718,14 +760,40 @@ def _validate_targets(self, y) -> np.ndarray:
raise ValueError("Input y contains NaN.")


def validate_enhanced_fit_mode_time_limit(time_limit_s: Optional[float]) -> None:
if time_limit_s is None:
return
if time_limit_s > ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S:
def validate_thinking_mode(
thinking_mode: bool,
thinking_effort: Optional[str],
thinking_timeout_s: Optional[float],
thinking_metric: Optional[str],
) -> None:
if (
thinking_effort is not None
and thinking_effort not in _VALID_THINKING_EFFORT_LEVELS
):
raise ValueError(
f"thinking_effort must be one of "
f"{sorted(_VALID_THINKING_EFFORT_LEVELS)}, got {thinking_effort!r}."
)
# Setting `thinking_effort` is itself a way to enable thinking, so the
# effective state is "either flag set". Knobs that only make sense when
# thinking is on are rejected only when neither is set.
thinking_enabled = thinking_mode or thinking_effort is not None
if not thinking_enabled and (
thinking_timeout_s is not None or thinking_metric is not None
):
raise ValueError(
"thinking_timeout_s and thinking_metric are only "
"consulted when thinking is enabled; pass `thinking_mode=True` "
"or `thinking_effort=...` to use them."
)
if (
thinking_timeout_s is not None
and thinking_timeout_s > THINKING_TIMEOUT_MAX_S
):
raise ValueError(
f"enhanced_fit_mode_time_limit_s ({time_limit_s}) exceeds the "
f"maximum allowed of {ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S} seconds "
f"({ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S // 60} minutes)."
f"thinking_timeout_s ({thinking_timeout_s}) exceeds the "
f"maximum allowed of {THINKING_TIMEOUT_MAX_S} seconds "
f"({THINKING_TIMEOUT_MAX_S // 60} minutes)."
)


Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,10 @@ def test_only_allowed_parameters_passed_to_config(self):
"model_path",
"balance_probabilities",
"paper_version",
"enhanced_fit_mode",
"enhanced_fit_mode_metric",
"enhanced_fit_mode_time_limit_s",
"thinking_mode",
"thinking_effort",
"thinking_timeout_s",
"thinking_metric",
}
OPTIONAL_PARAMS = {
# These may be emitted by newer model versions, but are not required.
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,10 @@ def test_only_allowed_parameters_passed_to_config(self):
"inference_config",
"model_path",
"paper_version",
"enhanced_fit_mode",
"enhanced_fit_mode_metric",
"enhanced_fit_mode_time_limit_s",
"thinking_mode",
"thinking_effort",
"thinking_timeout_s",
"thinking_metric",
}
OPTIONAL_PARAMS = {
"thinking",
Expand Down
Loading
Loading