diff --git a/pyproject.toml b/pyproject.toml index 50da91b20..21efa3d35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ name = "tabpfn" version = "7.1.1" dependencies = [ "torch>=2.5", + "safetensors>=0.4.0", "numpy>=1.21.6", "scikit-learn>=1.2.0", "typing_extensions>=4.12.0", diff --git a/scripts/convert_checkpoint_to_safetensors.py b/scripts/convert_checkpoint_to_safetensors.py new file mode 100644 index 000000000..f91110a5c --- /dev/null +++ b/scripts/convert_checkpoint_to_safetensors.py @@ -0,0 +1,111 @@ +"""Convert a TabPFN PyTorch checkpoint to SafeTensors plus sidecar metadata.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import save_file + + +def _json_safe(value: Any) -> Any: + """Convert common checkpoint values into JSON-safe values.""" + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, list): + return [_json_safe(v) for v in value] + if isinstance(value, tuple): + return [_json_safe(v) for v in value] + if isinstance(value, set): + return sorted(_json_safe(v) for v in value) + if isinstance(value, Path): + return str(value) + if isinstance(value, torch.dtype): + return str(value) + if isinstance(value, torch.device): + return str(value) + if value is None or isinstance(value, (str, int, float, bool)): + return value + + try: + json.dumps(value) + return value + except TypeError: + return { + "__unsupported_type__": type(value).__name__, + "__repr__": repr(value), + } + + +def convert_checkpoint( + input_checkpoint: Path, + output_safetensors: Path, + output_metadata: Path, +) -> None: + """Convert a TabPFN checkpoint into SafeTensors plus JSON metadata.""" + checkpoint = torch.load(input_checkpoint, map_location="cpu", weights_only=False) + + if not isinstance(checkpoint, dict): + raise TypeError( + f"Expected checkpoint to be a dict, got {type(checkpoint).__name__}." + ) + + state_dict = checkpoint.get("state_dict") + + if not isinstance(state_dict, dict): + raise ValueError("Checkpoint does not contain a dict-valued 'state_dict'.") + + tensors = {} + + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + raise TypeError( + f"Expected all state_dict values to be tensors. " + f"Key {key!r} has type {type(value).__name__}." + ) + tensors[key] = value.detach().cpu().contiguous() + + metadata = { + key: _json_safe(value) + for key, value in checkpoint.items() + if key != "state_dict" + } + + output_safetensors.parent.mkdir(parents=True, exist_ok=True) + output_metadata.parent.mkdir(parents=True, exist_ok=True) + + save_file(tensors, str(output_safetensors)) + + with output_metadata.open("w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, sort_keys=True) + + print(f"Saved SafeTensors file: {output_safetensors}") + print(f"Saved metadata file: {output_metadata}") + print(f"Tensor count: {len(tensors)}") + print(f"Metadata keys: {sorted(metadata)}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a TabPFN .ckpt file to .safetensors plus metadata JSON." + ) + parser.add_argument("--input-checkpoint", required=True, type=Path) + parser.add_argument("--output-safetensors", required=True, type=Path) + parser.add_argument("--output-metadata", required=True, type=Path) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert_checkpoint( + input_checkpoint=args.input_checkpoint, + output_safetensors=args.output_safetensors, + output_metadata=args.output_metadata, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/tabpfn/model_loading.py b/src/tabpfn/model_loading.py index 36a7e71c1..dcb9d9921 100644 --- a/src/tabpfn/model_loading.py +++ b/src/tabpfn/model_loading.py @@ -905,14 +905,37 @@ def get_loss_criterion( return FullSupportBarDistribution(borders, ignore_nan_targets=True) -def _file_identity(path: str) -> tuple[int, int]: - """Return a cheap identity tuple (mtime_ns, size) for cache-keying.""" - st = Path(path).stat() - return (st.st_mtime_ns, st.st_size) +def _file_identity(path: str) -> tuple[int, int, int | None, int | None]: + """Return a cheap identity tuple for cache-keying. + + For SafeTensors checkpoints, include the sidecar metadata file identity so + cached loads are invalidated when either the tensor file or metadata changes. + """ + checkpoint_path = Path(path) + checkpoint_stat = checkpoint_path.stat() + + metadata_mtime_ns = None + metadata_size = None + + if checkpoint_path.suffix == ".safetensors": + metadata_path = checkpoint_path.with_suffix(".non_tensor_metadata.json") + if metadata_path.exists(): + metadata_stat = metadata_path.stat() + metadata_mtime_ns = metadata_stat.st_mtime_ns + metadata_size = metadata_stat.st_size + + return ( + checkpoint_stat.st_mtime_ns, + checkpoint_stat.st_size, + metadata_mtime_ns, + metadata_size, + ) @functools.lru_cache(maxsize=1) -def _load_checkpoint_cached(path: str, _identity: tuple[int, int]) -> dict: +def _load_checkpoint_cached( + path: str, _identity: tuple[int, int, int | None, int | None] +) -> dict: """Load and cache a checkpoint from disk. The ``_identity`` key ensures the cache is invalidated when the file at @@ -924,6 +947,13 @@ def _load_checkpoint_cached(path: str, _identity: tuple[int, int]) -> dict: def _load_checkpoint(path: str) -> dict: """Load a checkpoint from disk.""" + checkpoint_path = Path(path) + + if checkpoint_path.suffix == ".safetensors": + from tabpfn.safetensors_checkpoint import load_safetensors_checkpoint + + return load_safetensors_checkpoint(checkpoint_path) + # Catch the `FutureWarning` that torch raises. This should be dealt with! # The warning is raised due to `torch.load`, which advises against ckpt # files that contain non-tensor data. diff --git a/src/tabpfn/safetensors_checkpoint.py b/src/tabpfn/safetensors_checkpoint.py new file mode 100644 index 000000000..fd62da3ef --- /dev/null +++ b/src/tabpfn/safetensors_checkpoint.py @@ -0,0 +1,47 @@ +"""Utilities for loading TabPFN checkpoints stored as SafeTensors plus metadata.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from safetensors.torch import load_file + + +def _metadata_path_for_safetensors(path: Path) -> Path: + """Return the expected sidecar metadata path for a SafeTensors checkpoint.""" + return path.with_suffix(".non_tensor_metadata.json") + + +def load_safetensors_checkpoint(path: str | Path) -> dict[str, Any]: + """Load a TabPFN checkpoint from SafeTensors plus sidecar JSON metadata. + + The SafeTensors file stores tensor values. The sidecar JSON file stores + non-tensor checkpoint metadata such as architecture name, model config, + and inference config. + + Args: + path: Path to the ``.safetensors`` file. + + Returns: + A checkpoint-like dictionary compatible with TabPFN model loading. + """ + safetensors_path = Path(path) + metadata_path = _metadata_path_for_safetensors(safetensors_path) + + if not metadata_path.exists(): + raise FileNotFoundError( + "SafeTensors checkpoint metadata file not found. " + f"Expected sidecar file: {metadata_path}" + ) + + tensors = load_file(str(safetensors_path), device="cpu") + + with metadata_path.open("r", encoding="utf-8") as f: + metadata = json.load(f) + + checkpoint: dict[str, Any] = dict(metadata) + checkpoint["state_dict"] = tensors + + return checkpoint \ No newline at end of file