Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ name = "tabpfn"
version = "7.1.1"
dependencies = [
"torch>=2.5",
"safetensors>=0.4.0",
"numpy>=1.21.6",
"scikit-learn>=1.2.0",
"typing_extensions>=4.12.0",
Expand Down
111 changes: 111 additions & 0 deletions scripts/convert_checkpoint_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Convert a TabPFN PyTorch checkpoint to SafeTensors plus sidecar metadata."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import torch
from safetensors.torch import save_file


def _json_safe(value: Any) -> Any:
"""Convert common checkpoint values into JSON-safe values."""
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
if isinstance(value, list):
return [_json_safe(v) for v in value]
if isinstance(value, tuple):
return [_json_safe(v) for v in value]
if isinstance(value, set):
return sorted(_json_safe(v) for v in value)
if isinstance(value, Path):
return str(value)
if isinstance(value, torch.dtype):
return str(value)
if isinstance(value, torch.device):
return str(value)
if value is None or isinstance(value, (str, int, float, bool)):
return value

try:
json.dumps(value)
return value
except TypeError:
return {
"__unsupported_type__": type(value).__name__,
"__repr__": repr(value),
}


def convert_checkpoint(
input_checkpoint: Path,
output_safetensors: Path,
output_metadata: Path,
) -> None:
"""Convert a TabPFN checkpoint into SafeTensors plus JSON metadata."""
checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=False)

if not isinstance(checkpoint, dict):
raise TypeError(
f"Expected checkpoint to be a dict, got {type(checkpoint).__name__}."
)

state_dict = checkpoint.get("state_dict")

if not isinstance(state_dict, dict):
raise ValueError("Checkpoint does not contain a dict-valued 'state_dict'.")

tensors = {}

for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
raise TypeError(
f"Expected all state_dict values to be tensors. "
f"Key {key!r} has type {type(value).__name__}."
)
tensors[key] = value.detach().cpu().contiguous()

metadata = {
key: _json_safe(value)
for key, value in checkpoint.items()
if key != "state_dict"
}

output_safetensors.parent.mkdir(parents=True, exist_ok=True)
output_metadata.parent.mkdir(parents=True, exist_ok=True)

save_file(tensors, str(output_safetensors))

with output_metadata.open("w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, sort_keys=True)

print(f"Saved SafeTensors file: {output_safetensors}")
print(f"Saved metadata file: {output_metadata}")
print(f"Tensor count: {len(tensors)}")
print(f"Metadata keys: {sorted(metadata)}")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a TabPFN .ckpt file to .safetensors plus metadata JSON."
)
parser.add_argument("--input-checkpoint", required=True, type=Path)
parser.add_argument("--output-safetensors", required=True, type=Path)
parser.add_argument("--output-metadata", required=True, type=Path)
return parser.parse_args()


def main() -> None:
args = parse_args()
convert_checkpoint(
input_checkpoint=args.input_checkpoint,
output_safetensors=args.output_safetensors,
output_metadata=args.output_metadata,
)


if __name__ == "__main__":
main()
40 changes: 35 additions & 5 deletions src/tabpfn/model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,14 +905,37 @@ def get_loss_criterion(
return FullSupportBarDistribution(borders, ignore_nan_targets=True)


def _file_identity(path: str) -> tuple[int, int]:
"""Return a cheap identity tuple (mtime_ns, size) for cache-keying."""
st = Path(path).stat()
return (st.st_mtime_ns, st.st_size)
def _file_identity(path: str) -> tuple[int, int, int | None, int | None]:
"""Return a cheap identity tuple for cache-keying.

For SafeTensors checkpoints, include the sidecar metadata file identity so
cached loads are invalidated when either the tensor file or metadata changes.
"""
checkpoint_path = Path(path)
checkpoint_stat = checkpoint_path.stat()

metadata_mtime_ns = None
metadata_size = None

if checkpoint_path.suffix == ".safetensors":
metadata_path = checkpoint_path.with_suffix(".non_tensor_metadata.json")
if metadata_path.exists():
metadata_stat = metadata_path.stat()
metadata_mtime_ns = metadata_stat.st_mtime_ns
metadata_size = metadata_stat.st_size

return (
checkpoint_stat.st_mtime_ns,
checkpoint_stat.st_size,
metadata_mtime_ns,
metadata_size,
)


@functools.lru_cache(maxsize=1)
def _load_checkpoint_cached(path: str, _identity: tuple[int, int]) -> dict:
def _load_checkpoint_cached(
path: str, _identity: tuple[int, int, int | None, int | None]
) -> dict:
"""Load and cache a checkpoint from disk.

The ``_identity`` key ensures the cache is invalidated when the file at
Expand All @@ -924,6 +947,13 @@ def _load_checkpoint_cached(path: str, _identity: tuple[int, int]) -> dict:

def _load_checkpoint(path: str) -> dict:
"""Load a checkpoint from disk."""
checkpoint_path = Path(path)

if checkpoint_path.suffix == ".safetensors":
from tabpfn.safetensors_checkpoint import load_safetensors_checkpoint

return load_safetensors_checkpoint(checkpoint_path)
Comment on lines +952 to +955
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current caching logic in _load_checkpoint_cached (which uses _file_identity) only tracks the primary checkpoint file. For SafeTensors checkpoints, this means that if the sidecar .non_tensor_metadata.json file is updated but the .safetensors file remains unchanged, the cache will not be invalidated, and stale metadata will be returned from the LRU cache.

While _file_identity is not modified in this PR, its implementation should be updated to include the metadata file's stats when a .safetensors path is provided to ensure cache consistency for this new loading mechanism.


# Catch the `FutureWarning` that torch raises. This should be dealt with!
# The warning is raised due to `torch.load`, which advises against ckpt
# files that contain non-tensor data.
Expand Down
47 changes: 47 additions & 0 deletions src/tabpfn/safetensors_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utilities for loading TabPFN checkpoints stored as SafeTensors plus metadata."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from safetensors.torch import load_file


def _metadata_path_for_safetensors(path: Path) -> Path:
"""Return the expected sidecar metadata path for a SafeTensors checkpoint."""
return path.with_suffix(".non_tensor_metadata.json")


def load_safetensors_checkpoint(path: str | Path) -> dict[str, Any]:
"""Load a TabPFN checkpoint from SafeTensors plus sidecar JSON metadata.

The SafeTensors file stores tensor values. The sidecar JSON file stores
non-tensor checkpoint metadata such as architecture name, model config,
and inference config.

Args:
path: Path to the ``.safetensors`` file.

Returns:
A checkpoint-like dictionary compatible with TabPFN model loading.
"""
safetensors_path = Path(path)
metadata_path = _metadata_path_for_safetensors(safetensors_path)

if not metadata_path.exists():
raise FileNotFoundError(
"SafeTensors checkpoint metadata file not found. "
f"Expected sidecar file: {metadata_path}"
)

tensors = load_file(str(safetensors_path), device="cpu")

with metadata_path.open("r", encoding="utf-8") as f:
metadata = json.load(f)

checkpoint: dict[str, Any] = dict(metadata)
checkpoint["state_dict"] = tensors

return checkpoint