Skip to content

Add SafeTensors checkpoint loading support#950

Open
ryannichols827 wants to merge 2 commits into
PriorLabs:mainfrom
ryannichols827:add-safetensors-support
Open

Add SafeTensors checkpoint loading support#950
ryannichols827 wants to merge 2 commits into
PriorLabs:mainfrom
ryannichols827:add-safetensors-support

Conversation

@ryannichols827
Copy link
Copy Markdown

Summary

This PR adds optional SafeTensors checkpoint loading support for TabPFN model checkpoints.

It introduces:

  • Runtime loading for .safetensors checkpoints with sidecar non-tensor metadata
  • A helper module for reconstructing TabPFN checkpoint dictionaries from SafeTensors + JSON metadata
  • A conversion utility script for converting existing .ckpt checkpoints into:
    • .safetensors tensor weights
    • .non_tensor_metadata.json sidecar metadata
  • safetensors as a runtime dependency

Motivation

Current checkpoint loading relies on torch.load(..., weights_only=None), which requires pickle-based loading for checkpoint metadata. SafeTensors provides a safer tensor serialization format, but TabPFN checkpoints also contain non-tensor fields such as architecture config and inference config.

This implementation keeps tensor weights in SafeTensors and stores required non-tensor checkpoint metadata in a sidecar JSON file.

Notes

This PR does not include model weights, converted .safetensors files, .ckpt files, or generated metadata files.

The conversion utility is intended for developer or maintainer use and preserves the existing checkpoint structure expected by the model-loading code.

@ryannichols827 ryannichols827 requested a review from a team as a code owner May 11, 2026 22:28
@ryannichols827 ryannichols827 requested review from anuragg1209 and removed request for a team May 11, 2026 22:28
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 11, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for SafeTensors checkpoints, including a script to convert PyTorch checkpoints into .safetensors files with sidecar JSON metadata and updates to the model loading logic. The reviewer provided feedback on ensuring tensor contiguity during conversion to avoid serialization errors, explicitly setting loading parameters for legacy compatibility, and addressing potential cache invalidation issues when sidecar metadata is updated independently.

f"Expected all state_dict values to be tensors. "
f"Key {key!r} has type {type(value).__name__}."
)
tensors[key] = value.detach().cpu()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The safetensors.torch.save_file function requires tensors to be contiguous. If the state_dict contains any views or non-contiguous tensors (which can happen depending on how the model was saved or manipulated), the conversion will fail with a RuntimeError. Adding .contiguous() ensures the tensors are in the correct format for serialization.

Suggested change
tensors[key] = value.detach().cpu()
tensors[key] = value.detach().cpu().contiguous()

output_metadata: Path,
) -> None:
"""Convert a TabPFN checkpoint into SafeTensors plus JSON metadata."""
checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using weights_only=None (the current default) triggers a FutureWarning in recent PyTorch versions and will eventually default to True. While weights_only=True is safer, it can fail when loading legacy checkpoints that contain complex metadata types. Since this script is specifically intended to convert existing pickle-based checkpoints, it is better to explicitly set weights_only=False to ensure compatibility and silence the warning.

Suggested change
checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=None)
checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=False)

Comment on lines +929 to +932
if checkpoint_path.suffix == ".safetensors":
from tabpfn.safetensors_checkpoint import load_safetensors_checkpoint

return load_safetensors_checkpoint(checkpoint_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current caching logic in _load_checkpoint_cached (which uses _file_identity) only tracks the primary checkpoint file. For SafeTensors checkpoints, this means that if the sidecar .non_tensor_metadata.json file is updated but the .safetensors file remains unchanged, the cache will not be invalidated, and stale metadata will be returned from the LRU cache.

While _file_identity is not modified in this PR, its implementation should be updated to include the metadata file's stats when a .safetensors path is provided to ensure cache consistency for this new loading mechanism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants