Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions src/tabpfn_client/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,39 @@ class DuplicateTrainSetErrorResponse(ErrorResponse):
# ---------------------------------------------------------------------------
# /tabpfn/fit/
# ---------------------------------------------------------------------------
class FitRequest(BaseModel):
train_set_upload_id: UUID
task: str
tabpfn_systems: List[str]
force_refit: bool = False
# Estimator-side configuration (model_path, hyperparameters). Some
# `tabpfn_systems` values on the server need this at fit time; the
# server ignores it otherwise.
tabpfn_config: TabPFNConfig = None
class TabPFNSystem(BaseModel):
name: str


class TextSystem(TabPFNSystem):
name: str = "text"


class PreprocessingSystem(TabPFNSystem):
name: str = "preprocessing"


class EnhancedSystem(TabPFNSystem):
name: str = "enhanced"
# Drives model selection + ensemble weighting during the enhanced-fit
# sweep. Only consulted when `"enhanced"` is in `tabpfn_systems`. None
# falls back to the sweep's default per problem type.
enhanced_fit_mode_metric: Optional[str] = None
metric: Optional[str] = None
# Ceiling on the enhanced-fit sweep (seconds). Only consulted when
# `"enhanced"` is in `tabpfn_systems`. None falls back to the server
# default (300s).
enhanced_fit_time_limit_s: Optional[float] = None
time_limit_s: Optional[float] = None
# Estimator-side configuration (model_path, hyperparameters). Some
# `tabpfn_systems` values on the server need this at fit time; the
# server ignores it otherwise.
tabpfn_config: TabPFNConfig = None


class FitRequest(BaseModel):
train_set_upload_id: UUID
task: str
tabpfn_systems: List[TabPFNSystem]
force_refit: bool = False


class FitResponse(BaseModel):
Expand Down
35 changes: 27 additions & 8 deletions src/tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import traceback
import warnings
from pydantic import BaseModel, ValidationError
from typing import Any, Callable, Dict, Literal, Union, cast
from typing import Any, Callable, Dict, Literal, Union, cast, List

import google_crc32c

Expand Down Expand Up @@ -56,6 +56,10 @@
PredictResponse,
TaskConfig,
ErrorResponse,
TabPFNSystem,
TextSystem,
PreprocessingSystem,
EnhancedSystem,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -216,11 +220,12 @@ class ServiceClient(Singleton):
or f"{server_config.protocol}://{server_config.host}:{server_config.port}"
)
fit_path = SERVER_CONFIG["endpoints"]["fit"]["path"]
predict_path = SERVER_CONFIG["endpoints"]["predict"]["path"]
httpx_client = httpx.Client(
base_url=base_url,
timeout=TABPFN_CLIENT_TIMEOUT,
headers={"Prior-Client-Version": get_client_version()},
transport=SelectiveHTTP2Transport(http2_paths=[fit_path]),
transport=SelectiveHTTP2Transport(http2_paths=[fit_path, predict_path]),
follow_redirects=True,
)
_access_token = None
Expand Down Expand Up @@ -273,8 +278,8 @@ def fit(
task: Literal["classification", "regression"],
tabpfn_config: Union[dict, None] = None,
description: str | None = None,
force_refit: bool = False,
client_options: ClientOptions | None = None,
is_refitting: bool = False,
) -> UUID:
"""
Upload a train set to server and return the train set UID if successful.
Expand All @@ -292,6 +297,8 @@ def fit(
`paper_version`.
description: str, optional
Description of the dataset and task for the server.
force_refit: bool, optional
Whether to force refit the model even if the model has already been fitted.
client_options : ClientOptions, optional
Per-request options (e.g. timeout, headers) for the fitting API call
only. Does not apply to file uploads. Because uploads can run before fitting,
Expand Down Expand Up @@ -426,15 +433,27 @@ def fit(
else None
)

systems: List[TabPFNSystem] = []
for name in tabpfn_systems:
if name == "text":
systems.append(TextSystem())
elif name == "preprocessing":
systems.append(PreprocessingSystem())
elif name == "enhanced":
systems.append(
EnhancedSystem(
metric=enhanced_fit_mode_metric,
time_limit_s=enhanced_fit_time_limit_s,
tabpfn_config=server_tabpfn_config,
)
)

res = cls._fit(
req=FitRequest(
train_set_upload_id=prepare_resp.train_set_upload_id,
task=task,
tabpfn_systems=tabpfn_systems,
force_refit=is_refitting or force_refit_enabled(),
tabpfn_config=server_tabpfn_config,
enhanced_fit_mode_metric=enhanced_fit_mode_metric,
enhanced_fit_time_limit_s=enhanced_fit_time_limit_s,
tabpfn_systems=systems,
force_refit=force_refit or force_refit_enabled(),
),
timeout=client_options.timeout,
headers=client_options.headers,
Expand Down
Loading