Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/tabpfn/architectures/tabpfn_v2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,6 @@ def __init__(
self.input_size // 4, self.input_size
)
self._do_encoder_nan_check = True
# TODO(Phil): This is here to not fail the memory computation. We should make
# this a proper API.
self.emsize = config.emsize

@property
Expand Down
2 changes: 0 additions & 2 deletions src/tabpfn/architectures/tabpfn_v2_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,6 @@ def __init__(
self.input_size // 4, self.input_size
)
self._do_encoder_nan_check = True
# TODO(Phil): This is here to not fail the memory computation. We should make
# this a proper API.
self.emsize = config.emsize

@property
Expand Down
3 changes: 2 additions & 1 deletion src/tabpfn/architectures/tabpfn_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ def cache_size_mb(self) -> int:
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.MATH, # fallback for older GPUs or unsupported configurations
]
_SDPA_BACKENDS_CPU = [*_SDPA_BACKENDS, SDPBackend.MATH]
_SDPA_BACKENDS_CPU = [*_SDPA_BACKENDS]


# ---------------------------------------------------------------------------
Expand Down
41 changes: 23 additions & 18 deletions src/tabpfn/model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,22 @@ def get_regressor_v2_6(cls) -> ModelSource: # noqa: D102
@classmethod
def get_classifier_v3(cls) -> ModelSource: # noqa: D102
filenames = [
"tabpfn-v3-classifier-20260506.ckpt",
"tabpfn-v3-classifier-v3_default.ckpt",
]
return cls(
repo_id="Prior-Labs/tabpfn_3",
default_filename="tabpfn-v3-classifier-20260506.ckpt",
default_filename="tabpfn-v3-classifier-v3_default.ckpt",
filenames=filenames,
)

@classmethod
def get_regressor_v3(cls) -> ModelSource: # noqa: D102
filenames = [
"tabpfn-v3-regressor-20260506.ckpt",
"tabpfn-v3-regressor-v3_default.ckpt",
]
return cls(
repo_id="Prior-Labs/tabpfn_3",
default_filename="tabpfn-v3-regressor-20260506.ckpt",
default_filename="tabpfn-v3-regressor-v3_default.ckpt",
filenames=filenames,
)

Expand Down Expand Up @@ -499,13 +499,14 @@ def _download_model(
ModelVersion.V2_6: "tabpfn_2_6",
ModelVersion.V3: "tabpfn_3",
}
if version in _HF_REPOS:
try:
from tabpfn.browser_auth import ensure_license_accepted # noqa: PLC0415
# Skip license check for now until tabpfn_3 is public
# if version in _HF_REPOS:
# try:
# from tabpfn.browser_auth import ensure_license_accepted

ensure_license_accepted(hf_repo_id=_HF_REPOS[version])
except Exception as e: # noqa: BLE001
return [e]
# ensure_license_accepted(hf_repo_id=_HF_REPOS[version])
# except Exception as e:
# return [e]

try:
model_source = _get_model_source(version, ModelType(which))
Expand Down Expand Up @@ -975,10 +976,10 @@ def load_model(
)

if "test_targets_MB" in inspect.signature(model.forward).parameters:
# The model computes the loss internally. Support for this was only added after
# v2.5, so we can safely assume that the inference config is stored in the
# checkpoint.
model.load_state_dict(full_state)
# The model computes the loss internally. Strip criterion keys that
# save_tabpfn_model may have written so load_state_dict doesn't reject them.
model_state = {k: v for k, v in full_state.items() if "criterion." not in k}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The filtering condition "criterion." not in k is too broad and might accidentally strip valid model parameters if their names contain the string "criterion." (e.g., layer.criterion_weight). Since save_tabpfn_model specifically prefixes criterion keys with criterion., it is safer and more precise to use not k.startswith("criterion.").

Suggested change
model_state = {k: v for k, v in full_state.items() if "criterion." not in k}
model_state = {k: v for k, v in full_state.items() if not k.startswith("criterion.")}

model.load_state_dict(model_state)
model.eval()
inference_config = InferenceConfig(
**_rename_old_inference_config_keys(checkpoint["inference_config"])
Expand Down Expand Up @@ -1102,6 +1103,7 @@ def save_tabpfn_model(
"state_dict": state_dict,
"config": asdict(config),
"architecture_name": architecture_name,
"inference_config": asdict(model.inference_config_),
Comment on lines 1105 to +1106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This added line is redundant and potentially problematic. The inference_config is already retrieved using getattr(model, "inference_config_", None) at line 1083 and added to the checkpoint dict conditionally at lines 1108-1109. Including it here unconditionally will cause a TypeError if model.inference_config_ is None (as asdict(None) is invalid). It is better to rely on the existing conditional block below.

Suggested change
"architecture_name": architecture_name,
"inference_config": asdict(model.inference_config_),
"architecture_name": architecture_name,

}
if inference_config is not None:
checkpoint["inference_config"] = asdict(inference_config)
Expand Down Expand Up @@ -1227,11 +1229,14 @@ def load_fitted_tabpfn_model(

def _resolve_architecture_name(config: ArchitectureConfig) -> str:
"""Resolve the architecture name from the config."""
name = getattr(config, "name", "")
if "v3" in name:
from tabpfn.architectures.tabpfn_v2_5 import TabPFNV2p5Config # noqa: PLC0415
from tabpfn.architectures.tabpfn_v2_6 import TabPFNV2p6Config # noqa: PLC0415
from tabpfn.architectures.tabpfn_v3 import TabPFNV3Config # noqa: PLC0415

if isinstance(config, TabPFNV3Config):
return "tabpfn_v3"
if "2.6" in name:
if isinstance(config, TabPFNV2p6Config):
return "tabpfn_v2_6"
if "2.5" in name:
if isinstance(config, TabPFNV2p5Config):
return "tabpfn_v2_5"
return "base"
2 changes: 1 addition & 1 deletion src/tabpfn/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TabPFNSettings(BaseSettings):
"If not set, uses platform-specific user cache directory.",
)
model_version: ModelVersion = Field(
default=ModelVersion.V2_6,
default=ModelVersion.V3,
description="The version of the TabPFN model to use by default.",
)

Expand Down
12 changes: 0 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import pytest
import torch

from tabpfn.model_loading import ModelSource, get_cache_dir


def pytest_configure(config: pytest.Config) -> None: # noqa: ARG001
"""Configure pytest with global settings."""
Expand All @@ -31,13 +29,3 @@ def set_global_seed() -> None:
torch.manual_seed(seed)
np.random.seed(seed) # noqa: NPY002
random.seed(seed)


def is_v3_classifier_in_cache() -> bool:
cache_dir = get_cache_dir()
return (cache_dir / ModelSource.get_classifier_v3().default_filename).exists()


def is_v3_regressor_in_cache() -> bool:
cache_dir = get_cache_dir()
return (cache_dir / ModelSource.get_regressor_v3().default_filename).exists()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[
[
0.9999734163284302,
1.3569095244747587e-05,
1.2983076885575429e-05
],
[
0.00036331909359432757,
0.9991976618766785,
0.0004389923997223377
],
[
2.672282062121667e-05,
0.0005141739966347814,
0.9994590878486633
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[
[
0.9999734163284302,
1.3569095244747587e-05,
1.2983076885575429e-05
],
[
0.00036331909359432757,
0.9991976618766785,
0.0004389923997223377
],
[
2.672282062121667e-05,
0.0005141739966347814,
0.9994590878486633
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
[
0.745369017124176,
0.2546309530735016
],
[
0.42508718371391296,
0.5749127864837646
],
[
0.6312955617904663,
0.3687044084072113
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
[
0.6735507845878601,
0.3264492154121399
],
[
0.44950684905052185,
0.5504931807518005
],
[
0.6495934128761292,
0.35040655732154846
]
]

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
[
0.7227111458778381,
0.27728888392448425
],
[
0.47503453493118286,
0.5249655246734619
],
[
0.4784924387931824,
0.5215075612068176
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
[
0.678566575050354,
0.3214334547519684
],
[
0.46058961749076843,
0.5394103527069092
],
[
0.6697312593460083,
0.3302687108516693
]
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
5.081058502197266,
4.23805046081543,
4.6899261474609375
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
5.081058502197266,
4.23805046081543,
4.6899261474609375
]
17 changes: 8 additions & 9 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from tabpfn.preprocessing import PreprocessorConfig
from tabpfn.utils import infer_devices

from .conftest import is_v3_classifier_in_cache
from .utils import (
get_pytest_devices,
is_cpu_float16_supported,
Expand Down Expand Up @@ -64,7 +63,11 @@ def X_y() -> tuple[np.ndarray, np.ndarray]:
)


model_sources = [ModelSource.get_classifier_v2(), ModelSource.get_classifier_v2_5()]
model_sources = [
ModelSource.get_classifier_v2(),
ModelSource.get_classifier_v2_5(),
ModelSource.get_classifier_v3(),
]
fit_modes = ["low_memory", "fit_preprocessors"]


Expand Down Expand Up @@ -525,8 +528,9 @@ def test_balance_probabilities_alters_proba_output() -> None:
)


# Only v2 and 2.5 support the KV cache at the moment.
@pytest.mark.parametrize("model_version", [ModelVersion.V2, ModelVersion.V2_5])
@pytest.mark.parametrize(
"model_version", [ModelVersion.V2, ModelVersion.V2_5, ModelVersion.V3]
)
# Disable MPS as it doesn't support float64.
@pytest.mark.parametrize("device", [d for d in get_pytest_devices() if d != "mps"])
def test__fit_preprocessors_and_with_cache_produce_equal_results(
Expand Down Expand Up @@ -564,8 +568,6 @@ def test__fit_preprocessors_and_with_cache_produce_equal_results(
def test__fit_preprocessors_and_low_memory_produce_equal_results(
X_y: tuple[np.ndarray, np.ndarray], model_version: ModelVersion, device: str
) -> None:
if model_version == ModelVersion.V3 and not is_v3_classifier_in_cache():
pytest.skip("V3 classifier model not in cache; skipping V3-specific test.")
kwargs = {
"version": model_version,
"n_estimators": 2,
Expand Down Expand Up @@ -596,9 +598,6 @@ def test__fit_preprocessors_and_low_memory_produce_equal_results(
def test__fit_and_predict__on_demo_dataset__accuracy_reasonable(
model_version: ModelVersion,
) -> None:
if model_version == ModelVersion.V3 and not is_v3_classifier_in_cache():
pytest.skip("V3 classifier model not in cache.")

X, y = sklearn.datasets.load_iris(return_X_y=True)
model = TabPFNClassifier.create_default_for_version(
version=model_version, random_state=0
Expand Down
Loading
Loading