diff --git a/src/tabpfn_client/api_models.py b/src/tabpfn_client/api_models.py index b68f21b..cf1a5a8 100644 --- a/src/tabpfn_client/api_models.py +++ b/src/tabpfn_client/api_models.py @@ -1,5 +1,5 @@ from uuid import UUID -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field # Classification output_type="preds" preserves the original label type, so @@ -103,14 +103,12 @@ class FitRequest(BaseModel): # `tabpfn_systems` values on the server need this at fit time; the # server ignores it otherwise. tabpfn_config: TabPFNConfig = None - # Drives model selection + ensemble weighting during the enhanced-fit - # sweep. Only consulted when `"enhanced"` is in `tabpfn_systems`. None - # falls back to the sweep's default per problem type. - enhanced_fit_mode_metric: Optional[str] = None - # Ceiling on the enhanced-fit sweep (seconds). Only consulted when - # `"enhanced"` is in `tabpfn_systems`. None falls back to the server - # default (300s). - enhanced_fit_time_limit_s: Optional[float] = None + # User-facing thinking-effort level. None disables it. + thinking_effort: Optional[Literal["medium", "high"]] = None + # Budget for the fit (seconds). Only consulted when `thinking_effort` is set. + thinking_timeout_s: Optional[float] = None + # Optimization metric for the fit. Only consulted when `thinking_effort` is set. + thinking_metric: Optional[str] = None class FitResponse(BaseModel): diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index c5448cc..8ab3a4c 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -388,28 +388,33 @@ def fit( raise tabpfn_systems = ["preprocessing", "text"] + # Thinking is enabled when either flag is set: explicit `thinking_mode=True`, + # or any non-None `thinking_effort`. Setting `thinking_effort` alone is + # enough — the server-side validator on FitRequest also normalises this, + # but doing it here means the request body itself is consistent. + thinking_enabled = bool(tabpfn_config) and ( + bool(tabpfn_config.get("thinking_mode")) + or tabpfn_config.get("thinking_effort") is not None + ) if tabpfn_config: if tabpfn_config.get("paper_version") is True: tabpfn_systems = [] - elif tabpfn_config.get("enhanced_fit_mode") is True: - # Enhanced mode runs on top of the base systems rather than + elif thinking_enabled: + # Thinking runs on top of the base systems rather than # replacing them — keep preprocessing + text alongside it. - tabpfn_systems = ["preprocessing", "text", "enhanced"] - - # `enhanced_fit_mode_metric` and `enhanced_fit_mode_time_limit_s` - # are top-level FitRequest fields on the server (siblings to - # `tabpfn_systems`), not part of `tabpfn_config`. Lift them out - # before stripping the rest of the client-only keys. The server - # field drops the `mode_` infix (`enhanced_fit_time_limit_s`); - # units are seconds on both sides, no conversion. - enhanced_fit_mode_metric = ( - tabpfn_config.get("enhanced_fit_mode_metric") if tabpfn_config else None - ) - enhanced_fit_time_limit_s = ( - tabpfn_config.get("enhanced_fit_mode_time_limit_s") - if tabpfn_config - else None - ) + tabpfn_systems = ["preprocessing", "text", "thinking"] + + # The client-side `thinking_*` knobs forward 1:1 to the server's + # top-level FitRequest fields. When the user enabled thinking via + # `thinking_mode=True` without picking a level, default to "medium". + if thinking_enabled and tabpfn_config: + thinking_effort = tabpfn_config.get("thinking_effort") or "medium" + thinking_timeout_s = tabpfn_config.get("thinking_timeout_s") + thinking_metric = tabpfn_config.get("thinking_metric") + else: + thinking_effort = None + thinking_timeout_s = None + thinking_metric = None # Strip client-only keys that the server does not expect (mirrors # the predict path's filter below). @@ -420,9 +425,10 @@ def fit( if k not in { "paper_version", - "enhanced_fit_mode", - "enhanced_fit_mode_metric", - "enhanced_fit_mode_time_limit_s", + "thinking_mode", + "thinking_effort", + "thinking_timeout_s", + "thinking_metric", } } if tabpfn_config is not None @@ -436,8 +442,9 @@ def fit( tabpfn_systems=tabpfn_systems, force_refit=force_refit or force_refit_enabled(), tabpfn_config=server_tabpfn_config, - enhanced_fit_mode_metric=enhanced_fit_mode_metric, - enhanced_fit_time_limit_s=enhanced_fit_time_limit_s, + thinking_effort=thinking_effort, + thinking_timeout_s=thinking_timeout_s, + thinking_metric=thinking_metric, ), timeout=client_options.timeout, headers=client_options.headers, @@ -584,9 +591,10 @@ def predict( if k not in { "paper_version", - "enhanced_fit_mode", - "enhanced_fit_mode_metric", - "enhanced_fit_mode_time_limit_s", + "thinking_mode", + "thinking_effort", + "thinking_timeout_s", + "thinking_metric", } } diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index 27d5b2f..27e2015 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -54,7 +54,10 @@ # is kept as a backward-compatible alias. _AUTO_MODEL_PATH_ALIASES = frozenset({"auto", "default"}) -ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S = 40 * 60 +THINKING_TIMEOUT_MAX_S = 40 * 60 + +ThinkingEffort = Literal["medium", "high"] +_VALID_THINKING_EFFORT_LEVELS = frozenset({"medium", "high"}) class TabPFNModelSelection: @@ -180,9 +183,10 @@ def __init__( ] = 0, inference_config: Optional[Dict] = None, paper_version: bool = False, - enhanced_fit_mode: bool = False, - enhanced_fit_mode_metric: Optional[str] = None, - enhanced_fit_mode_time_limit_s: Optional[float] = None, + thinking_mode: bool = False, + thinking_effort: Optional[ThinkingEffort] = None, + thinking_timeout_s: Optional[float] = None, + thinking_metric: Optional[str] = None, force_refit: bool = False, client_options: ClientOptions | None = None, ): @@ -240,24 +244,39 @@ def __init__( paper_version: bool, default=False If True, will use the model described in the paper, instead of the newest version available on the API, which e.g handles text features better. - enhanced_fit_mode: bool, default=False - If True, trades off fit time for precision by running an - automated feature-engineering pipeline on top of TabPFN during - fit. - enhanced_fit_mode_metric: str or None, default=None - Only consulted when `enhanced_fit_mode=True`. Drives model - selection + ensemble weighting during the enhanced-fit sweep - (e.g. "accuracy"/"log_loss"/"roc_auc"/"balanced_accuracy"/ - "f1" for classification). None falls back to the sweep's - default for the problem type. Distinct from the local - `eval_metric`/`tuning_config` knobs used for decision-threshold - tuning on the standalone TabPFN classifier. - enhanced_fit_mode_time_limit_s: float or None, default=None - Only consulted when `enhanced_fit_mode=True`. Ceiling on the - enhanced-fit sweep (seconds). Raise for larger datasets where - the default ~5-minute sweep leaves performance on the table. - None falls back to the server-side default (300s). Capped at - 2400 seconds (40 minutes); higher values raise ValueError at fit. + thinking_mode: bool, default=False + If True, spend extra fit-time compute for higher precision. + Equivalent to passing `thinking_effort="medium"` — setting any + `thinking_effort` value also enables thinking, so this flag is + optional when you've set the level explicitly. + thinking_effort: {"medium", "high"} or None, default=None + Effort level for thinking mode. When set, thinking is enabled + (you don't also need `thinking_mode=True`). When None and + `thinking_mode=True`, defaults to "medium". + thinking_timeout_s: float or None, default=None + Budget for the fit, in seconds. Only consulted when thinking is + enabled. Capped at 2400. + thinking_metric: str or None, default=None + Optimization metric for the fit. Only consulted when thinking + is enabled. + + Binary classification: + "accuracy", "balanced_accuracy", "mcc", "log_loss", + "pac", "quadratic_kappa", "roc_auc", "average_precision", + "precision", "precision_macro", "precision_micro", + "precision_weighted", "recall", "recall_macro", + "recall_micro", "recall_weighted", "f1", "f1_macro", + "f1_micro", "f1_weighted". + Multiclass classification: + "accuracy", "balanced_accuracy", "mcc", "log_loss", + "pac", "quadratic_kappa", "precision_macro", + "precision_micro", "precision_weighted", "recall_macro", + "recall_micro", "recall_weighted", "f1_macro", + "f1_micro", "f1_weighted", "roc_auc_ovo", + "roc_auc_ovo_macro", "roc_auc_ovr", "roc_auc_ovr_macro", + "roc_auc_ovr_micro", "roc_auc_ovr_weighted". + + Aliases "acc", "nll", "pac_score" are also accepted. """ self.model_path = model_path self.n_estimators = n_estimators @@ -269,9 +288,10 @@ def __init__( self.random_state = random_state self.inference_config = inference_config self.paper_version = paper_version - self.enhanced_fit_mode = enhanced_fit_mode - self.enhanced_fit_mode_metric = enhanced_fit_mode_metric - self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.thinking_mode = thinking_mode + self.thinking_effort = thinking_effort + self.thinking_timeout_s = thinking_timeout_s + self.thinking_metric = thinking_metric self.force_refit = force_refit self.client_options = client_options or ClientOptions() @@ -294,7 +314,12 @@ def fit( estimator_param = self._get_estimator_params_with_model_path("classification") validate_train_set(X, y) - validate_enhanced_fit_mode_time_limit(self.enhanced_fit_mode_time_limit_s) + validate_thinking_mode( + self.thinking_mode, + self.thinking_effort, + self.thinking_timeout_s, + self.thinking_metric, + ) X = _clean_text_features(X) self._validate_targets_and_classes(y) @@ -467,9 +492,10 @@ def __init__( ] = 0, inference_config: Optional[Dict] = None, paper_version: bool = False, - enhanced_fit_mode: bool = False, - enhanced_fit_mode_metric: Optional[str] = None, - enhanced_fit_mode_time_limit_s: Optional[float] = None, + thinking_mode: bool = False, + thinking_effort: Optional[ThinkingEffort] = None, + thinking_timeout_s: Optional[float] = None, + thinking_metric: Optional[str] = None, force_refit: bool = False, client_options: ClientOptions | None = None, ): @@ -519,21 +545,31 @@ def __init__( paper_version: bool, default=False If True, will use the model described in the paper, instead of the newest version available on the API, which e.g handles text features better. - enhanced_fit_mode: bool, default=False - If True, trades off fit time for precision by running an - automated feature-engineering pipeline on top of TabPFN during - fit. - enhanced_fit_mode_metric: str or None, default=None - Only consulted when `enhanced_fit_mode=True`. Drives model - selection + ensemble weighting during the enhanced-fit sweep - (e.g. "rmse"/"mae"/"r2"/"mape" for regression). None falls - back to the sweep's default for the problem type. - enhanced_fit_mode_time_limit_s: float or None, default=None - Only consulted when `enhanced_fit_mode=True`. Ceiling on the - enhanced-fit sweep (seconds). Raise for larger datasets where - the default ~5-minute sweep leaves performance on the table. - None falls back to the server-side default (300s). Capped at - 2400 seconds (40 minutes); higher values raise ValueError at fit. + thinking_mode: bool, default=False + If True, spend extra fit-time compute for higher precision. + Equivalent to passing `thinking_effort="medium"` — setting any + `thinking_effort` value also enables thinking, so this flag is + optional when you've set the level explicitly. + thinking_effort: {"medium", "high"} or None, default=None + Effort level for thinking mode. When set, thinking is enabled + (you don't also need `thinking_mode=True`). When None and + `thinking_mode=True`, defaults to "medium". + thinking_timeout_s: float or None, default=None + Budget for the fit, in seconds. Only consulted when thinking is + enabled. Capped at 2400. + thinking_metric: str or None, default=None + Optimization metric for the fit. Only consulted when thinking + is enabled. + + Regression: + "r2", "mean_squared_error", "root_mean_squared_error", + "mean_absolute_error", "median_absolute_error", + "mean_absolute_percentage_error", + "symmetric_mean_absolute_percentage_error", "spearmanr", + "pearsonr". + + Aliases "mse", "rmse", "mae", "mape", "smape" are also + accepted. force_refit: bool, default=False Whether to force refit the model even if the model has already been fitted. client_options : ClientOptions, default=None @@ -548,9 +584,10 @@ def __init__( self.random_state = random_state self.inference_config = inference_config self.paper_version = paper_version - self.enhanced_fit_mode = enhanced_fit_mode - self.enhanced_fit_mode_metric = enhanced_fit_mode_metric - self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.thinking_mode = thinking_mode + self.thinking_effort = thinking_effort + self.thinking_timeout_s = thinking_timeout_s + self.thinking_metric = thinking_metric self.force_refit = force_refit self.client_options = client_options or ClientOptions() @@ -573,7 +610,12 @@ def fit( estimator_param = self._get_estimator_params_with_model_path("regression") validate_train_set(X, y) - validate_enhanced_fit_mode_time_limit(self.enhanced_fit_mode_time_limit_s) + validate_thinking_mode( + self.thinking_mode, + self.thinking_effort, + self.thinking_timeout_s, + self.thinking_metric, + ) self._validate_targets(y) X = _clean_text_features(X) @@ -718,14 +760,40 @@ def _validate_targets(self, y) -> np.ndarray: raise ValueError("Input y contains NaN.") -def validate_enhanced_fit_mode_time_limit(time_limit_s: Optional[float]) -> None: - if time_limit_s is None: - return - if time_limit_s > ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S: +def validate_thinking_mode( + thinking_mode: bool, + thinking_effort: Optional[str], + thinking_timeout_s: Optional[float], + thinking_metric: Optional[str], +) -> None: + if ( + thinking_effort is not None + and thinking_effort not in _VALID_THINKING_EFFORT_LEVELS + ): + raise ValueError( + f"thinking_effort must be one of " + f"{sorted(_VALID_THINKING_EFFORT_LEVELS)}, got {thinking_effort!r}." + ) + # Setting `thinking_effort` is itself a way to enable thinking, so the + # effective state is "either flag set". Knobs that only make sense when + # thinking is on are rejected only when neither is set. + thinking_enabled = thinking_mode or thinking_effort is not None + if not thinking_enabled and ( + thinking_timeout_s is not None or thinking_metric is not None + ): + raise ValueError( + "thinking_timeout_s and thinking_metric are only " + "consulted when thinking is enabled; pass `thinking_mode=True` " + "or `thinking_effort=...` to use them." + ) + if ( + thinking_timeout_s is not None + and thinking_timeout_s > THINKING_TIMEOUT_MAX_S + ): raise ValueError( - f"enhanced_fit_mode_time_limit_s ({time_limit_s}) exceeds the " - f"maximum allowed of {ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S} seconds " - f"({ENHANCED_FIT_MODE_MAX_TIME_LIMIT_S // 60} minutes)." + f"thinking_timeout_s ({thinking_timeout_s}) exceeds the " + f"maximum allowed of {THINKING_TIMEOUT_MAX_S} seconds " + f"({THINKING_TIMEOUT_MAX_S // 60} minutes)." ) diff --git a/tests/unit/test_tabpfn_classifier.py b/tests/unit/test_tabpfn_classifier.py index 1bc9b10..fd9f361 100644 --- a/tests/unit/test_tabpfn_classifier.py +++ b/tests/unit/test_tabpfn_classifier.py @@ -475,9 +475,10 @@ def test_only_allowed_parameters_passed_to_config(self): "model_path", "balance_probabilities", "paper_version", - "enhanced_fit_mode", - "enhanced_fit_mode_metric", - "enhanced_fit_mode_time_limit_s", + "thinking_mode", + "thinking_effort", + "thinking_timeout_s", + "thinking_metric", } OPTIONAL_PARAMS = { # These may be emitted by newer model versions, but are not required. diff --git a/tests/unit/test_tabpfn_regressor.py b/tests/unit/test_tabpfn_regressor.py index 912f9f2..da2bb0c 100644 --- a/tests/unit/test_tabpfn_regressor.py +++ b/tests/unit/test_tabpfn_regressor.py @@ -469,9 +469,10 @@ def test_only_allowed_parameters_passed_to_config(self): "inference_config", "model_path", "paper_version", - "enhanced_fit_mode", - "enhanced_fit_mode_metric", - "enhanced_fit_mode_time_limit_s", + "thinking_mode", + "thinking_effort", + "thinking_timeout_s", + "thinking_metric", } OPTIONAL_PARAMS = { "thinking", diff --git a/tests/unit/test_thinking_validation.py b/tests/unit/test_thinking_validation.py new file mode 100644 index 0000000..27e7604 --- /dev/null +++ b/tests/unit/test_thinking_validation.py @@ -0,0 +1,62 @@ +"""Validator contract for the thinking_* knobs on TabPFNClassifier/Regressor. + +Pins the rule that thinking is enabled when either `thinking_mode=True` OR +`thinking_effort` is set, so callers can pass either or both without surprise. +""" + +import pytest + +from tabpfn_client.estimator import ( + THINKING_TIMEOUT_MAX_S, + validate_thinking_mode, +) + + +def _v(**overrides): + args = dict( + thinking_mode=False, + thinking_effort=None, + thinking_timeout_s=None, + thinking_metric=None, + ) + args.update(overrides) + return validate_thinking_mode(**args) + + +class TestThinkingValidator: + def test_neither_flag_is_off(self): + # No flags: thinking off, no errors. + _v() + + def test_thinking_mode_alone_is_on(self): + # Just `thinking_mode=True` is enough; downstream defaults effort to "medium". + _v(thinking_mode=True) + + def test_thinking_effort_alone_implies_on(self): + # The whole point of this contract: setting thinking_effort enables + # thinking even without thinking_mode=True. + _v(thinking_effort="medium") + _v(thinking_effort="high") + + def test_extra_knobs_with_thinking_effort_set_are_allowed(self): + # If thinking is on (via either flag), the budget/metric knobs apply. + _v(thinking_effort="high", thinking_timeout_s=60.0, thinking_metric="rmse") + _v(thinking_mode=True, thinking_timeout_s=60.0, thinking_metric="rmse") + + def test_extra_knobs_without_thinking_are_rejected(self): + # Knobs that only matter when thinking is on must error if neither flag is set. + with pytest.raises(ValueError, match="thinking is enabled"): + _v(thinking_timeout_s=60.0) + with pytest.raises(ValueError, match="thinking is enabled"): + _v(thinking_metric="rmse") + + def test_invalid_effort_level_rejected(self): + with pytest.raises(ValueError, match="thinking_effort must be one of"): + _v(thinking_effort="extreme") + + def test_timeout_above_cap_rejected(self): + with pytest.raises(ValueError, match="exceeds the"): + _v(thinking_effort="high", thinking_timeout_s=THINKING_TIMEOUT_MAX_S + 1) + + def test_timeout_at_cap_allowed(self): + _v(thinking_effort="high", thinking_timeout_s=THINKING_TIMEOUT_MAX_S)