Add SafeTensors checkpoint loading support#950
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for SafeTensors checkpoints, including a script to convert PyTorch checkpoints into .safetensors files with sidecar JSON metadata and updates to the model loading logic. The reviewer provided feedback on ensuring tensor contiguity during conversion to avoid serialization errors, explicitly setting loading parameters for legacy compatibility, and addressing potential cache invalidation issues when sidecar metadata is updated independently.
| f"Expected all state_dict values to be tensors. " | ||
| f"Key {key!r} has type {type(value).__name__}." | ||
| ) | ||
| tensors[key] = value.detach().cpu() |
There was a problem hiding this comment.
The safetensors.torch.save_file function requires tensors to be contiguous. If the state_dict contains any views or non-contiguous tensors (which can happen depending on how the model was saved or manipulated), the conversion will fail with a RuntimeError. Adding .contiguous() ensures the tensors are in the correct format for serialization.
| tensors[key] = value.detach().cpu() | |
| tensors[key] = value.detach().cpu().contiguous() |
| output_metadata: Path, | ||
| ) -> None: | ||
| """Convert a TabPFN checkpoint into SafeTensors plus JSON metadata.""" | ||
| checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=None) |
There was a problem hiding this comment.
Using weights_only=None (the current default) triggers a FutureWarning in recent PyTorch versions and will eventually default to True. While weights_only=True is safer, it can fail when loading legacy checkpoints that contain complex metadata types. Since this script is specifically intended to convert existing pickle-based checkpoints, it is better to explicitly set weights_only=False to ensure compatibility and silence the warning.
| checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=None) | |
| checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=False) |
| if checkpoint_path.suffix == ".safetensors": | ||
| from tabpfn.safetensors_checkpoint import load_safetensors_checkpoint | ||
|
|
||
| return load_safetensors_checkpoint(checkpoint_path) |
There was a problem hiding this comment.
The current caching logic in _load_checkpoint_cached (which uses _file_identity) only tracks the primary checkpoint file. For SafeTensors checkpoints, this means that if the sidecar .non_tensor_metadata.json file is updated but the .safetensors file remains unchanged, the cache will not be invalidated, and stale metadata will be returned from the LRU cache.
While _file_identity is not modified in this PR, its implementation should be updated to include the metadata file's stats when a .safetensors path is provided to ensure cache consistency for this new loading mechanism.
Summary
This PR adds optional SafeTensors checkpoint loading support for TabPFN model checkpoints.
It introduces:
.safetensorscheckpoints with sidecar non-tensor metadata.ckptcheckpoints into:.safetensorstensor weights.non_tensor_metadata.jsonsidecar metadatasafetensorsas a runtime dependencyMotivation
Current checkpoint loading relies on
torch.load(..., weights_only=None), which requires pickle-based loading for checkpoint metadata. SafeTensors provides a safer tensor serialization format, but TabPFN checkpoints also contain non-tensor fields such as architecture config and inference config.This implementation keeps tensor weights in SafeTensors and stores required non-tensor checkpoint metadata in a sidecar JSON file.
Notes
This PR does not include model weights, converted
.safetensorsfiles,.ckptfiles, or generated metadata files.The conversion utility is intended for developer or maintainer use and preserves the existing checkpoint structure expected by the model-loading code.