From fa87ad699c1139824bc8e055b1d35ceca11c3246 Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 11:02:59 +0200 Subject: [PATCH 1/7] http2 for predict --- src/tabpfn_client/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index a399aa7..26a9b94 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -216,11 +216,12 @@ class ServiceClient(Singleton): or f"{server_config.protocol}://{server_config.host}:{server_config.port}" ) fit_path = SERVER_CONFIG["endpoints"]["fit"]["path"] + predict_path = SERVER_CONFIG["endpoints"]["predict"]["path"] httpx_client = httpx.Client( base_url=base_url, timeout=TABPFN_CLIENT_TIMEOUT, headers={"Prior-Client-Version": get_client_version()}, - transport=SelectiveHTTP2Transport(http2_paths=[fit_path]), + transport=SelectiveHTTP2Transport(http2_paths=[fit_path, predict_path]), follow_redirects=True, ) _access_token = None From 24c7c89373ba2da7d9bdaa4b485d565831d7d736 Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 11:15:39 +0200 Subject: [PATCH 2/7] ClientOptions --- src/tabpfn_client/estimator.py | 74 ++++++++++++++-------------------- 1 file changed, 30 insertions(+), 44 deletions(-) diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index a5afc9e..d294dd6 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -169,6 +169,8 @@ def __init__( enhanced_fit_mode: bool = False, enhanced_fit_mode_metric: Optional[str] = None, enhanced_fit_mode_time_limit_s: Optional[float] = None, + force_refit: bool = False, + client_options: ClientOptions | None = None, ): """Construct a TabPFN classifier. @@ -250,6 +252,9 @@ def __init__( self.enhanced_fit_mode = enhanced_fit_mode self.enhanced_fit_mode_metric = enhanced_fit_mode_metric self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.force_refit = force_refit + self.client_options = client_options or ClientOptions() + self.last_trace_id = None self.last_fitted_train_set_id = None self.last_train_X = None @@ -262,7 +267,6 @@ def fit( X: pd.DataFrame | np.ndarray, y: pd.Series | np.ndarray, description: str | None = None, - client_options: ClientOptions | None = None, ): # assert init() is called init() @@ -274,11 +278,10 @@ def fit( self._validate_targets_and_classes(y) if Config.use_server: - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = uuid4().hex + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = uuid4().hex - self.last_trace_id = client_options.headers["sentry-trace"] + self.last_trace_id = self.client_options.headers["sentry-trace"] self.last_train_set_description = description def fit_task() -> UUID: @@ -288,7 +291,7 @@ def fit_task() -> UUID: task="classification", tabpfn_config=estimator_param, description=description, - client_options=client_options, + client_options=self.client_options, ) self.last_fitted_train_set_id = run_task(fit_task, "Fitting") @@ -301,11 +304,7 @@ def fit_task() -> UUID: ) return self - def predict( - self, - X, - client_options: ClientOptions | None = None, - ): + def predict(self, X): """Predict class labels for samples in X. Args: @@ -314,17 +313,9 @@ def predict( Returns: The predicted class labels. """ - return self._predict( - X, - output_type="preds", - client_options=client_options, - ) + return self._predict(X, output_type="preds") - def predict_proba( - self, - X, - client_options: ClientOptions | None = None, - ): + def predict_proba(self, X): """Predict class probabilities for X. Args: @@ -333,26 +324,20 @@ def predict_proba( Returns: The class probabilities of the input samples. """ - return self._predict( - X, - output_type="probas", - client_options=client_options, - ) + return self._predict(X, output_type="probas") def _predict( self, X, output_type, - client_options: ClientOptions | None = None, ) -> dict[str, np.ndarray]: check_is_fitted(self) estimator_param = self._get_estimator_params_with_model_path("classification") validate_test_set(X, output_type, estimator_param["model_path"]) X = _clean_text_features(X) - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = self.last_trace_id + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = self.last_trace_id def predict_task() -> PredictionResult: last_exc = None @@ -369,7 +354,7 @@ def predict_task() -> PredictionResult: task="classification", tabpfn_config=estimator_param, predict_params={"output_type": output_type}, - client_options=client_options, + client_options=self.client_options, ) except NeedsRefittingError as exc: last_exc = exc @@ -380,7 +365,7 @@ def predict_task() -> PredictionResult: task="classification", tabpfn_config=estimator_param, description=self.last_train_set_description, - client_options=client_options, + client_options=self.client_options, is_refitting=True, ) @@ -451,6 +436,8 @@ def __init__( enhanced_fit_mode: bool = False, enhanced_fit_mode_metric: Optional[str] = None, enhanced_fit_mode_time_limit_s: Optional[float] = None, + force_refit: bool = False, + client_options: ClientOptions | None = None, ): """Construct a TabPFN regressor. @@ -522,6 +509,9 @@ def __init__( self.enhanced_fit_mode = enhanced_fit_mode self.enhanced_fit_mode_metric = enhanced_fit_mode_metric self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.force_refit = force_refit + self.client_options = client_options or ClientOptions() + self.last_trace_id = None self.last_fitted_train_set_id = None self.last_train_X = None @@ -534,7 +524,6 @@ def fit( X: pd.DataFrame | np.ndarray, y: pd.Series | np.ndarray, description: str | None = None, - client_options: ClientOptions | None = None, ): # assert init() is called init() @@ -546,11 +535,10 @@ def fit( X = _clean_text_features(X) if Config.use_server: - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = uuid4().hex + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = uuid4().hex - self.last_trace_id = client_options.headers["sentry-trace"] + self.last_trace_id = self.client_options.headers["sentry-trace"] self.last_train_set_description = description def fit_task() -> UUID: @@ -560,7 +548,7 @@ def fit_task() -> UUID: task="regression", tabpfn_config=estimator_param, description=description, - client_options=client_options, + client_options=self.client_options, ) self.last_fitted_train_set_id = run_task(fit_task, "Fitting") @@ -581,7 +569,6 @@ def predict( "mean", "median", "mode", "quantiles", "full", "main" ] = "mean", quantiles: Optional[list[float]] = None, - client_options: ClientOptions | None = None, ) -> Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]]: """Predict regression target for X. @@ -617,9 +604,8 @@ def predict( "quantiles": quantiles, } - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = self.last_trace_id + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = self.last_trace_id def predict_task() -> PredictionResult: last_exc = None @@ -636,7 +622,7 @@ def predict_task() -> PredictionResult: task="regression", tabpfn_config=estimator_param, predict_params=predict_params, - client_options=client_options, + client_options=self.client_options, ) except NeedsRefittingError as exc: last_exc = exc @@ -647,7 +633,7 @@ def predict_task() -> PredictionResult: task="regression", tabpfn_config=estimator_param, description=self.last_train_set_description, - client_options=client_options, + client_options=self.client_options, is_refitting=True, ) From e9f5f7f0d0e3529042e7f8a9e441ef2fb47df96e Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 11:22:18 +0200 Subject: [PATCH 3/7] force_refit --- src/tabpfn_client/client.py | 6 ++++-- src/tabpfn_client/estimator.py | 12 ++++++++++-- src/tabpfn_client/service_wrapper.py | 4 ++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index 26a9b94..fbedacb 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -274,8 +274,8 @@ def fit( task: Literal["classification", "regression"], tabpfn_config: Union[dict, None] = None, description: str | None = None, + force_refit: bool = False, client_options: ClientOptions | None = None, - is_refitting: bool = False, ) -> UUID: """ Upload a train set to server and return the train set UID if successful. @@ -293,6 +293,8 @@ def fit( `paper_version`. description: str, optional Description of the dataset and task for the server. + force_refit: bool, optional + Whether to force refit the model even if the model has already been fitted. client_options : ClientOptions, optional Per-request options (e.g. timeout, headers) for the fitting API call only. Does not apply to file uploads. Because uploads can run before fitting, @@ -432,7 +434,7 @@ def fit( train_set_upload_id=prepare_resp.train_set_upload_id, task=task, tabpfn_systems=tabpfn_systems, - force_refit=is_refitting or force_refit_enabled(), + force_refit=force_refit or force_refit_enabled(), tabpfn_config=server_tabpfn_config, enhanced_fit_mode_metric=enhanced_fit_mode_metric, enhanced_fit_time_limit_s=enhanced_fit_time_limit_s, diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index d294dd6..4764ba6 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -291,6 +291,7 @@ def fit_task() -> UUID: task="classification", tabpfn_config=estimator_param, description=description, + force_refit=self.force_refit, client_options=self.client_options, ) @@ -365,8 +366,8 @@ def predict_task() -> PredictionResult: task="classification", tabpfn_config=estimator_param, description=self.last_train_set_description, + force_refit=True, client_options=self.client_options, - is_refitting=True, ) result = run_task(predict_task, "Predicting") @@ -496,6 +497,12 @@ def __init__( the default ~5-minute sweep leaves performance on the table. None falls back to the server-side default (300s). Capped at 2400 seconds (40 minutes); higher values raise ValueError at fit. + force_refit: bool, default=False + Whether to force refit the model even if the model has already been fitted. + client_options : ClientOptions, default=None + Per-request options (e.g. timeout, headers) for the fitting API call + only. Does not apply to file uploads. Because uploads can run before fitting, + this method may return later than the timeout specified. """ self.model_path = model_path self.n_estimators = n_estimators @@ -548,6 +555,7 @@ def fit_task() -> UUID: task="regression", tabpfn_config=estimator_param, description=description, + force_refit=self.force_refit, client_options=self.client_options, ) @@ -633,8 +641,8 @@ def predict_task() -> PredictionResult: task="regression", tabpfn_config=estimator_param, description=self.last_train_set_description, + force_refit=True, client_options=self.client_options, - is_refitting=True, ) result = run_task(predict_task, "Predicting") diff --git a/src/tabpfn_client/service_wrapper.py b/src/tabpfn_client/service_wrapper.py index 53f9374..0f4afc5 100644 --- a/src/tabpfn_client/service_wrapper.py +++ b/src/tabpfn_client/service_wrapper.py @@ -261,8 +261,8 @@ def fit( task: Literal["classification", "regression"], tabpfn_config=None, description: str | None = None, + force_refit: bool = False, client_options: ClientOptions | None = None, - is_refitting: bool = False, ) -> UUID: return ServiceClient.fit( X, @@ -270,8 +270,8 @@ def fit( task=task, tabpfn_config=tabpfn_config, description=description, + force_refit=force_refit, client_options=client_options, - is_refitting=is_refitting, ) @classmethod From a5a10ffe94df7f907a8bc92feda3ed76c7d62b95 Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 11:36:50 +0200 Subject: [PATCH 4/7] sentry-trace --- src/tabpfn_client/estimator.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index 4764ba6..13da4f2 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -261,6 +261,7 @@ def __init__( self.last_train_y = None self.last_meta = {} self.last_train_set_description = None + self.fit_count = 0 def fit( self, @@ -278,7 +279,15 @@ def fit( self._validate_targets_and_classes(y) if Config.use_server: - if "sentry-trace" not in self.client_options.headers: + # Create a new sentry trace at every fit, provided that: + # - The user has not explicitly set a sentry-trace header. + # - In any case if we're going to refit. + # - In any case if we have already called .fit() on this instance. + if ( + self.force_refit + or self.fit_count > 0 + or "sentry-trace" not in self.client_options.headers + ): self.client_options.headers["sentry-trace"] = uuid4().hex self.last_trace_id = self.client_options.headers["sentry-trace"] @@ -299,6 +308,7 @@ def fit_task() -> UUID: self.last_train_X = X self.last_train_y = y self.fitted_ = True + self.fit_count += 1 else: raise NotImplementedError( "Only server mode is supported at the moment for init(use_server=False)" @@ -525,6 +535,7 @@ def __init__( self.last_train_y = None self.last_meta = {} self.last_train_set_description = None + self.fit_count = 0 def fit( self, @@ -542,10 +553,19 @@ def fit( X = _clean_text_features(X) if Config.use_server: - if "sentry-trace" not in self.client_options.headers: + # Create a new sentry trace at every fit, provided that: + # - The user has not explicitly set a sentry-trace header. + # - In any case if we're going to refit. + # - In any case if we have already called .fit() on this instance. + if ( + self.force_refit + or self.fit_count > 0 + or "sentry-trace" not in self.client_options.headers + ): self.client_options.headers["sentry-trace"] = uuid4().hex self.last_trace_id = self.client_options.headers["sentry-trace"] + self.last_train_set_description = description def fit_task() -> UUID: @@ -563,6 +583,7 @@ def fit_task() -> UUID: self.last_train_X = X self.last_train_y = y self.fitted_ = True + self.fit_count += 1 else: raise NotImplementedError( "Only server mode is supported at the moment for init(use_server=False)" From b8bed84e70d7c60f3017c5b1c598ec90e3199318 Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 11:45:19 +0200 Subject: [PATCH 5/7] fix tests --- src/tabpfn_client/estimator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index 13da4f2..d836e5a 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -128,6 +128,9 @@ def _get_estimator_params_with_model_path( """ estimator_param = self.get_params() estimator_param["model_path"] = self._model_name_to_path(task, self.model_path) + # Client-side concerns — passed separately to InferenceClient, not part of server config. + estimator_param.pop("client_options", None) + estimator_param.pop("force_refit", None) return estimator_param From 055b40796682915412f5fb1c8026e55125509d5c Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 12:04:13 +0200 Subject: [PATCH 6/7] adjust fit interface --- src/tabpfn_client/api_models.py | 38 +++++++++++++++++++++++---------- src/tabpfn_client/client.py | 26 +++++++++++++++++----- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/tabpfn_client/api_models.py b/src/tabpfn_client/api_models.py index b68f21b..5f99e16 100644 --- a/src/tabpfn_client/api_models.py +++ b/src/tabpfn_client/api_models.py @@ -94,23 +94,39 @@ class DuplicateTrainSetErrorResponse(ErrorResponse): # --------------------------------------------------------------------------- # /tabpfn/fit/ # --------------------------------------------------------------------------- -class FitRequest(BaseModel): - train_set_upload_id: UUID - task: str - tabpfn_systems: List[str] - force_refit: bool = False - # Estimator-side configuration (model_path, hyperparameters). Some - # `tabpfn_systems` values on the server need this at fit time; the - # server ignores it otherwise. - tabpfn_config: TabPFNConfig = None +class TabPFNSystem(BaseModel): + name: str + + +class TextSystem(TabPFNSystem): + name: str = "text" + + +class PreprocessingSystem(TabPFNSystem): + name: str = "preprocessing" + + +class EnhancedSystem(TabPFNSystem): + name: str = "enhanced" # Drives model selection + ensemble weighting during the enhanced-fit # sweep. Only consulted when `"enhanced"` is in `tabpfn_systems`. None # falls back to the sweep's default per problem type. - enhanced_fit_mode_metric: Optional[str] = None + metric: Optional[str] = None # Ceiling on the enhanced-fit sweep (seconds). Only consulted when # `"enhanced"` is in `tabpfn_systems`. None falls back to the server # default (300s). - enhanced_fit_time_limit_s: Optional[float] = None + time_limit_s: Optional[float] = None + # Estimator-side configuration (model_path, hyperparameters). Some + # `tabpfn_systems` values on the server need this at fit time; the + # server ignores it otherwise. + tabpfn_config: TabPFNConfig = None + + +class FitRequest(BaseModel): + train_set_upload_id: UUID + task: str + tabpfn_systems: List[TabPFNSystem] + force_refit: bool = False class FitResponse(BaseModel): diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index fbedacb..e98f780 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -18,7 +18,7 @@ import traceback import warnings from pydantic import BaseModel, ValidationError -from typing import Any, Callable, Dict, Literal, Union, cast +from typing import Any, Callable, Dict, Literal, Union, cast, List import google_crc32c @@ -56,6 +56,10 @@ PredictResponse, TaskConfig, ErrorResponse, + TabPFNSystem, + TextSystem, + PreprocessingSystem, + EnhancedSystem, ) logger = logging.getLogger(__name__) @@ -429,15 +433,27 @@ def fit( else None ) + systems = List[TabPFNSystem] = [] + for name in tabpfn_systems: + if name == "text": + systems.append(TextSystem()) + elif name == "preprocessing": + systems.append(PreprocessingSystem()) + elif name == "enhanced": + systems.append( + EnhancedSystem( + metric=enhanced_fit_mode_metric, + time_limit_s=enhanced_fit_time_limit_s, + tabpfn_config=server_tabpfn_config, + ) + ) + res = cls._fit( req=FitRequest( train_set_upload_id=prepare_resp.train_set_upload_id, task=task, - tabpfn_systems=tabpfn_systems, + tabpfn_systems=systems, force_refit=force_refit or force_refit_enabled(), - tabpfn_config=server_tabpfn_config, - enhanced_fit_mode_metric=enhanced_fit_mode_metric, - enhanced_fit_time_limit_s=enhanced_fit_time_limit_s, ), timeout=client_options.timeout, headers=client_options.headers, From e0b59669153b1e21bb0d5d5bff4688e0ad413a52 Mon Sep 17 00:00:00 2001 From: simo-prior Date: Thu, 30 Apr 2026 12:19:26 +0200 Subject: [PATCH 7/7] Update src/tabpfn_client/client.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/tabpfn_client/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index e98f780..c1c33bc 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -433,7 +433,7 @@ def fit( else None ) - systems = List[TabPFNSystem] = [] + systems: List[TabPFNSystem] = [] for name in tabpfn_systems: if name == "text": systems.append(TextSystem())