diff --git a/src/tabpfn_client/api_models.py b/src/tabpfn_client/api_models.py index b68f21b..5f99e16 100644 --- a/src/tabpfn_client/api_models.py +++ b/src/tabpfn_client/api_models.py @@ -94,23 +94,39 @@ class DuplicateTrainSetErrorResponse(ErrorResponse): # --------------------------------------------------------------------------- # /tabpfn/fit/ # --------------------------------------------------------------------------- -class FitRequest(BaseModel): - train_set_upload_id: UUID - task: str - tabpfn_systems: List[str] - force_refit: bool = False - # Estimator-side configuration (model_path, hyperparameters). Some - # `tabpfn_systems` values on the server need this at fit time; the - # server ignores it otherwise. - tabpfn_config: TabPFNConfig = None +class TabPFNSystem(BaseModel): + name: str + + +class TextSystem(TabPFNSystem): + name: str = "text" + + +class PreprocessingSystem(TabPFNSystem): + name: str = "preprocessing" + + +class EnhancedSystem(TabPFNSystem): + name: str = "enhanced" # Drives model selection + ensemble weighting during the enhanced-fit # sweep. Only consulted when `"enhanced"` is in `tabpfn_systems`. None # falls back to the sweep's default per problem type. - enhanced_fit_mode_metric: Optional[str] = None + metric: Optional[str] = None # Ceiling on the enhanced-fit sweep (seconds). Only consulted when # `"enhanced"` is in `tabpfn_systems`. None falls back to the server # default (300s). - enhanced_fit_time_limit_s: Optional[float] = None + time_limit_s: Optional[float] = None + # Estimator-side configuration (model_path, hyperparameters). Some + # `tabpfn_systems` values on the server need this at fit time; the + # server ignores it otherwise. + tabpfn_config: TabPFNConfig = None + + +class FitRequest(BaseModel): + train_set_upload_id: UUID + task: str + tabpfn_systems: List[TabPFNSystem] + force_refit: bool = False class FitResponse(BaseModel): diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index a399aa7..c1c33bc 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -18,7 +18,7 @@ import traceback import warnings from pydantic import BaseModel, ValidationError -from typing import Any, Callable, Dict, Literal, Union, cast +from typing import Any, Callable, Dict, Literal, Union, cast, List import google_crc32c @@ -56,6 +56,10 @@ PredictResponse, TaskConfig, ErrorResponse, + TabPFNSystem, + TextSystem, + PreprocessingSystem, + EnhancedSystem, ) logger = logging.getLogger(__name__) @@ -216,11 +220,12 @@ class ServiceClient(Singleton): or f"{server_config.protocol}://{server_config.host}:{server_config.port}" ) fit_path = SERVER_CONFIG["endpoints"]["fit"]["path"] + predict_path = SERVER_CONFIG["endpoints"]["predict"]["path"] httpx_client = httpx.Client( base_url=base_url, timeout=TABPFN_CLIENT_TIMEOUT, headers={"Prior-Client-Version": get_client_version()}, - transport=SelectiveHTTP2Transport(http2_paths=[fit_path]), + transport=SelectiveHTTP2Transport(http2_paths=[fit_path, predict_path]), follow_redirects=True, ) _access_token = None @@ -273,8 +278,8 @@ def fit( task: Literal["classification", "regression"], tabpfn_config: Union[dict, None] = None, description: str | None = None, + force_refit: bool = False, client_options: ClientOptions | None = None, - is_refitting: bool = False, ) -> UUID: """ Upload a train set to server and return the train set UID if successful. @@ -292,6 +297,8 @@ def fit( `paper_version`. description: str, optional Description of the dataset and task for the server. + force_refit: bool, optional + Whether to force refit the model even if the model has already been fitted. client_options : ClientOptions, optional Per-request options (e.g. timeout, headers) for the fitting API call only. Does not apply to file uploads. Because uploads can run before fitting, @@ -426,15 +433,27 @@ def fit( else None ) + systems: List[TabPFNSystem] = [] + for name in tabpfn_systems: + if name == "text": + systems.append(TextSystem()) + elif name == "preprocessing": + systems.append(PreprocessingSystem()) + elif name == "enhanced": + systems.append( + EnhancedSystem( + metric=enhanced_fit_mode_metric, + time_limit_s=enhanced_fit_time_limit_s, + tabpfn_config=server_tabpfn_config, + ) + ) + res = cls._fit( req=FitRequest( train_set_upload_id=prepare_resp.train_set_upload_id, task=task, - tabpfn_systems=tabpfn_systems, - force_refit=is_refitting or force_refit_enabled(), - tabpfn_config=server_tabpfn_config, - enhanced_fit_mode_metric=enhanced_fit_mode_metric, - enhanced_fit_time_limit_s=enhanced_fit_time_limit_s, + tabpfn_systems=systems, + force_refit=force_refit or force_refit_enabled(), ), timeout=client_options.timeout, headers=client_options.headers, diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index a5afc9e..d836e5a 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -128,6 +128,9 @@ def _get_estimator_params_with_model_path( """ estimator_param = self.get_params() estimator_param["model_path"] = self._model_name_to_path(task, self.model_path) + # Client-side concerns — passed separately to InferenceClient, not part of server config. + estimator_param.pop("client_options", None) + estimator_param.pop("force_refit", None) return estimator_param @@ -169,6 +172,8 @@ def __init__( enhanced_fit_mode: bool = False, enhanced_fit_mode_metric: Optional[str] = None, enhanced_fit_mode_time_limit_s: Optional[float] = None, + force_refit: bool = False, + client_options: ClientOptions | None = None, ): """Construct a TabPFN classifier. @@ -250,19 +255,22 @@ def __init__( self.enhanced_fit_mode = enhanced_fit_mode self.enhanced_fit_mode_metric = enhanced_fit_mode_metric self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.force_refit = force_refit + self.client_options = client_options or ClientOptions() + self.last_trace_id = None self.last_fitted_train_set_id = None self.last_train_X = None self.last_train_y = None self.last_meta = {} self.last_train_set_description = None + self.fit_count = 0 def fit( self, X: pd.DataFrame | np.ndarray, y: pd.Series | np.ndarray, description: str | None = None, - client_options: ClientOptions | None = None, ): # assert init() is called init() @@ -274,11 +282,18 @@ def fit( self._validate_targets_and_classes(y) if Config.use_server: - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = uuid4().hex - - self.last_trace_id = client_options.headers["sentry-trace"] + # Create a new sentry trace at every fit, provided that: + # - The user has not explicitly set a sentry-trace header. + # - In any case if we're going to refit. + # - In any case if we have already called .fit() on this instance. + if ( + self.force_refit + or self.fit_count > 0 + or "sentry-trace" not in self.client_options.headers + ): + self.client_options.headers["sentry-trace"] = uuid4().hex + + self.last_trace_id = self.client_options.headers["sentry-trace"] self.last_train_set_description = description def fit_task() -> UUID: @@ -288,24 +303,22 @@ def fit_task() -> UUID: task="classification", tabpfn_config=estimator_param, description=description, - client_options=client_options, + force_refit=self.force_refit, + client_options=self.client_options, ) self.last_fitted_train_set_id = run_task(fit_task, "Fitting") self.last_train_X = X self.last_train_y = y self.fitted_ = True + self.fit_count += 1 else: raise NotImplementedError( "Only server mode is supported at the moment for init(use_server=False)" ) return self - def predict( - self, - X, - client_options: ClientOptions | None = None, - ): + def predict(self, X): """Predict class labels for samples in X. Args: @@ -314,17 +327,9 @@ def predict( Returns: The predicted class labels. """ - return self._predict( - X, - output_type="preds", - client_options=client_options, - ) + return self._predict(X, output_type="preds") - def predict_proba( - self, - X, - client_options: ClientOptions | None = None, - ): + def predict_proba(self, X): """Predict class probabilities for X. Args: @@ -333,26 +338,20 @@ def predict_proba( Returns: The class probabilities of the input samples. """ - return self._predict( - X, - output_type="probas", - client_options=client_options, - ) + return self._predict(X, output_type="probas") def _predict( self, X, output_type, - client_options: ClientOptions | None = None, ) -> dict[str, np.ndarray]: check_is_fitted(self) estimator_param = self._get_estimator_params_with_model_path("classification") validate_test_set(X, output_type, estimator_param["model_path"]) X = _clean_text_features(X) - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = self.last_trace_id + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = self.last_trace_id def predict_task() -> PredictionResult: last_exc = None @@ -369,7 +368,7 @@ def predict_task() -> PredictionResult: task="classification", tabpfn_config=estimator_param, predict_params={"output_type": output_type}, - client_options=client_options, + client_options=self.client_options, ) except NeedsRefittingError as exc: last_exc = exc @@ -380,8 +379,8 @@ def predict_task() -> PredictionResult: task="classification", tabpfn_config=estimator_param, description=self.last_train_set_description, - client_options=client_options, - is_refitting=True, + force_refit=True, + client_options=self.client_options, ) result = run_task(predict_task, "Predicting") @@ -451,6 +450,8 @@ def __init__( enhanced_fit_mode: bool = False, enhanced_fit_mode_metric: Optional[str] = None, enhanced_fit_mode_time_limit_s: Optional[float] = None, + force_refit: bool = False, + client_options: ClientOptions | None = None, ): """Construct a TabPFN regressor. @@ -509,6 +510,12 @@ def __init__( the default ~5-minute sweep leaves performance on the table. None falls back to the server-side default (300s). Capped at 2400 seconds (40 minutes); higher values raise ValueError at fit. + force_refit: bool, default=False + Whether to force refit the model even if the model has already been fitted. + client_options : ClientOptions, default=None + Per-request options (e.g. timeout, headers) for the fitting API call + only. Does not apply to file uploads. Because uploads can run before fitting, + this method may return later than the timeout specified. """ self.model_path = model_path self.n_estimators = n_estimators @@ -522,19 +529,22 @@ def __init__( self.enhanced_fit_mode = enhanced_fit_mode self.enhanced_fit_mode_metric = enhanced_fit_mode_metric self.enhanced_fit_mode_time_limit_s = enhanced_fit_mode_time_limit_s + self.force_refit = force_refit + self.client_options = client_options or ClientOptions() + self.last_trace_id = None self.last_fitted_train_set_id = None self.last_train_X = None self.last_train_y = None self.last_meta = {} self.last_train_set_description = None + self.fit_count = 0 def fit( self, X: pd.DataFrame | np.ndarray, y: pd.Series | np.ndarray, description: str | None = None, - client_options: ClientOptions | None = None, ): # assert init() is called init() @@ -546,11 +556,19 @@ def fit( X = _clean_text_features(X) if Config.use_server: - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = uuid4().hex + # Create a new sentry trace at every fit, provided that: + # - The user has not explicitly set a sentry-trace header. + # - In any case if we're going to refit. + # - In any case if we have already called .fit() on this instance. + if ( + self.force_refit + or self.fit_count > 0 + or "sentry-trace" not in self.client_options.headers + ): + self.client_options.headers["sentry-trace"] = uuid4().hex + + self.last_trace_id = self.client_options.headers["sentry-trace"] - self.last_trace_id = client_options.headers["sentry-trace"] self.last_train_set_description = description def fit_task() -> UUID: @@ -560,13 +578,15 @@ def fit_task() -> UUID: task="regression", tabpfn_config=estimator_param, description=description, - client_options=client_options, + force_refit=self.force_refit, + client_options=self.client_options, ) self.last_fitted_train_set_id = run_task(fit_task, "Fitting") self.last_train_X = X self.last_train_y = y self.fitted_ = True + self.fit_count += 1 else: raise NotImplementedError( "Only server mode is supported at the moment for init(use_server=False)" @@ -581,7 +601,6 @@ def predict( "mean", "median", "mode", "quantiles", "full", "main" ] = "mean", quantiles: Optional[list[float]] = None, - client_options: ClientOptions | None = None, ) -> Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]]: """Predict regression target for X. @@ -617,9 +636,8 @@ def predict( "quantiles": quantiles, } - client_options = client_options or ClientOptions() - if "sentry-trace" not in client_options.headers: - client_options.headers["sentry-trace"] = self.last_trace_id + if "sentry-trace" not in self.client_options.headers: + self.client_options.headers["sentry-trace"] = self.last_trace_id def predict_task() -> PredictionResult: last_exc = None @@ -636,7 +654,7 @@ def predict_task() -> PredictionResult: task="regression", tabpfn_config=estimator_param, predict_params=predict_params, - client_options=client_options, + client_options=self.client_options, ) except NeedsRefittingError as exc: last_exc = exc @@ -647,8 +665,8 @@ def predict_task() -> PredictionResult: task="regression", tabpfn_config=estimator_param, description=self.last_train_set_description, - client_options=client_options, - is_refitting=True, + force_refit=True, + client_options=self.client_options, ) result = run_task(predict_task, "Predicting") diff --git a/src/tabpfn_client/service_wrapper.py b/src/tabpfn_client/service_wrapper.py index 53f9374..0f4afc5 100644 --- a/src/tabpfn_client/service_wrapper.py +++ b/src/tabpfn_client/service_wrapper.py @@ -261,8 +261,8 @@ def fit( task: Literal["classification", "regression"], tabpfn_config=None, description: str | None = None, + force_refit: bool = False, client_options: ClientOptions | None = None, - is_refitting: bool = False, ) -> UUID: return ServiceClient.fit( X, @@ -270,8 +270,8 @@ def fit( task=task, tabpfn_config=tabpfn_config, description=description, + force_refit=force_refit, client_options=client_options, - is_refitting=is_refitting, ) @classmethod