Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions fastsafetensors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
nelements = 1
for sh in t.shape:
nelements *= sh
nbytes = nelements * framework.get_dtype_size(t.dtype)
nbytes = int(nelements * framework.get_dtype_size(t.dtype))
if (e - s) != nbytes:
raise Exception(
f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}"
Expand Down Expand Up @@ -195,16 +195,25 @@ def get_tensors(
- copy_start_offset
)
disk_dtype = self.framework.as_workaround_dtype(t.dtype)
dl_shape, dl_strides = self.framework.get_storage_shape(
t.dtype, t.shape, t.strides
)
dl_tensor = from_cuda_buffer(
dst_dev_ptr,
t.shape,
t.strides,
dl_shape,
dl_strides,
disk_dtype,
device,
)
t2 = self.framework.from_dlpack(dl_tensor, device, disk_dtype)
if disk_dtype != t.dtype:
t2 = t2.view(t.dtype)
# For packed sub-byte dtypes, reshape to the framework-native shape.
# e.g. F4 (float4_e2m1fn_x2): safetensors shape counts FP4 values,
# but PyTorch shape counts packed pairs (2 FP4 per byte).
native_shape = self.framework.get_native_shape(t.dtype, t.shape)
if native_shape != t.shape:
t2 = t2.reshape(native_shape)

if dtype != DType.AUTO and dtype != t.dtype:
if self.framework.get_dtype_size(dtype) > self.framework.get_dtype_size(
Expand Down
8 changes: 8 additions & 0 deletions fastsafetensors/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def __init__(self, dtype: DType):
DType.F32: (kDLFloat, 32, 1),
DType.F64: (kDLFloat, 64, 1),
DType.BF16: (kDLBfloat, 16, 1),
# F8_E8M0 is an unsigned 8-bit scale format (torch.float8_e8m0fnu).
# DLPack has no float8 type code, so we expose it as opaque uint8
# bytes — consistent with the U8 workaround used for NCCL ops.
DType.F8_E8M0: (kDLUInt, 8, 1),
# F4 is packed FP4 (torch.float4_e2m1fn_x2): two 4-bit values per byte.
# DLPack has no sub-byte type code, so we expose it as opaque uint8
# bytes — consistent with the U8 workaround used for NCCL ops.
DType.F4: (kDLUInt, 8, 1),
}


Expand Down
41 changes: 40 additions & 1 deletion fastsafetensors/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def view(self, dtype: DType) -> "TensorBase":
def __getitem__(self, _val) -> "TensorBase":
pass

def reshape(self, shape: List[int]) -> "TensorBase":
"""Reshape the tensor. Default implementation raises NotImplementedError.
Frameworks that support packed sub-byte dtypes (e.g. F4) must override
this to handle shape adjustment after buffer loading.
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement reshape()"
)


T = TypeVar("T", bound=TensorBase)

Expand Down Expand Up @@ -127,7 +136,14 @@ def copy_tensor(self, dst: T, src: T) -> None:
pass

@abstractmethod
def get_dtype_size(self, dtype: DType) -> int:
def get_dtype_size(self, dtype: DType) -> float:
"""Return the number of bytes per logical element for this dtype.

For packed sub-byte types (e.g. F4 / float4_e2m1fn_x2, where two
FP4 values share a single byte), this returns a fractional value
(0.5 for F4). Callers should use ``int(nelements *
get_dtype_size(dtype))`` when computing byte counts.
"""
pass

@abstractmethod
Expand All @@ -146,6 +162,29 @@ def get_device_ptr_align(self) -> int:
def as_workaround_dtype(self, dtype: DType) -> DType:
pass

def get_storage_shape(
self, dtype: DType, shape: List[int], strides: List[int]
) -> "tuple[List[int], List[int]]":
"""Return the (shape, strides) to use for the workaround-dtype DLPack
tensor when loading from a device buffer.

For most dtypes this is just (shape, strides). Packed sub-byte dtypes
(e.g. F4 / float4_e2m1fn_x2) store multiple logical values per byte;
the safetensors shape counts logical values while DLPack / PyTorch work
in bytes, so the shape must be collapsed to a flat byte count.
"""
return shape, strides

def get_native_shape(self, dtype: DType, st_shape: List[int]) -> List[int]:
"""Return the framework-native tensor shape for a safetensors tensor.

For most dtypes this matches *st_shape* exactly. For packed sub-byte
dtypes the safetensors shape counts logical (sub-byte) elements while
the framework counts packed storage units (bytes), so the shape is
compressed by the packing ratio.
"""
return st_shape

@abstractmethod
def get_process_group(self, pg: Optional[Any]) -> ProcessGroupBase:
pass
Expand Down
49 changes: 45 additions & 4 deletions fastsafetensors/frameworks/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@
need_workaround_dtypes: Dict[DType, DType] = {
DType.F8_E5M2: DType.I8,
DType.F8_E4M3: DType.I8,
DType.F8_E8M0: DType.U8,
DType.F4: DType.U8,
}

if hasattr(torch, "float8_e5m2"):
dtype_convert[DType.F8_E5M2] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_convert[DType.F8_E4M3] = torch.float8_e4m3fn
if hasattr(torch, "float8_e8m0fnu"):
dtype_convert[DType.F8_E8M0] = torch.float8_e8m0fnu
if hasattr(torch, "float4_e2m1fn_x2"):
dtype_convert[DType.F4] = torch.float4_e2m1fn_x2
if hasattr(torch, "uint16"):
dtype_convert[DType.U16] = torch.uint16
if hasattr(torch, "uint32"):
Expand Down Expand Up @@ -87,6 +93,9 @@ def view(self, dtype: DType) -> "TorchTensor":
def __getitem__(self, _val) -> "TorchTensor":
return TorchTensor(self.device, self.dtype, self.real_tensor[_val])

def reshape(self, shape: List[int]) -> "TorchTensor":
return TorchTensor(self.device, self.dtype, self.real_tensor.reshape(shape))


def _needs_fp8_cast() -> bool:
"""Check if FP8 NCCL ops need a bf16 workaround (pre-sm90 GPUs)."""
Expand All @@ -97,7 +106,7 @@ def _needs_fp8_cast() -> bool:


def _is_fp8(dtype: DType) -> bool:
return dtype in (DType.F8_E5M2, DType.F8_E4M3)
return dtype in (DType.F8_E5M2, DType.F8_E4M3, DType.F8_E8M0)


@dataclass
Expand Down Expand Up @@ -211,17 +220,23 @@ def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device):
def get_empty_tensor(
self, shape: List[int], dtype: DType, device: Device
) -> TorchTensor:
native_shape = self.get_native_shape(dtype, shape)
dst = torch.empty(
size=shape, dtype=dtype_convert[dtype], device=device.as_str()
size=native_shape, dtype=dtype_convert[dtype], device=device.as_str()
)
return TorchTensor(device, dtype, dst)

def concat_tensors(self, tensors: List[TorchTensor], dim: int) -> TorchTensor:
ts = [tensor.real_tensor for tensor in tensors]
return TorchTensor(tensors[0].device, tensors[0].dtype, torch.cat(ts, dim=dim))

def get_dtype_size(self, dtype: DType) -> int:
return dtype_convert[dtype].itemsize
def get_dtype_size(self, dtype: DType) -> float:
if dtype == DType.F4:
# float4_e2m1fn_x2 packs two 4-bit values into one byte.
# safetensors stores shape in FP4-element count, so the byte
# size per logical element is 0.5.
return 0.5
return float(dtype_convert[dtype].itemsize)

def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> TorchTensor:
t = torch.from_dlpack(dl_tensor)
Expand Down Expand Up @@ -253,6 +268,32 @@ def as_workaround_dtype(self, dtype: DType) -> DType:
return need_workaround_dtypes[dtype]
return dtype

def get_storage_shape(
self, dtype: DType, shape: List[int], strides: List[int]
) -> "tuple[List[int], List[int]]":
size = self.get_dtype_size(dtype)
if size < 1.0:
# Packed sub-byte dtype: collapse to flat byte count so the
# workaround-dtype DLPack tensor doesn't overread the buffer.
import math

nbytes = int(math.prod(shape) * size)
return [nbytes], [1]
return shape, strides

def get_native_shape(self, dtype: DType, st_shape: List[int]) -> List[int]:
size = self.get_dtype_size(dtype)
if size < 1.0:
# safetensors counts logical sub-byte elements; PyTorch counts
# packed storage units (bytes). Compress the last dimension.
import math

ratio = int(round(1.0 / size)) # e.g. 2 for F4 (2 FP4 per byte)
if len(st_shape) > 1:
return list(st_shape[:-1]) + [st_shape[-1] // ratio]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check st_shape[-1] % ratio == 0 and raise Exception on malformed safetensors files?

return [int(math.prod(st_shape) * size)]
return st_shape

def get_process_group(self, pg: Optional[Any]) -> TorchProcessGroup:
if pg is not None:
if isinstance(pg, SingleGroup):
Expand Down
2 changes: 2 additions & 0 deletions fastsafetensors/st_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ class DType(Enum):
BF16 = "BF16"
F8_E5M2 = "F8_E5M2"
F8_E4M3 = "F8_E4M3"
F8_E8M0 = "F8_E8M0"
F4 = "F4"
AUTO = "AUTO"
59 changes: 59 additions & 0 deletions tests/test_fastsafetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,65 @@ def test_float8_e4m3fn_to_int8(fstcpp_log, tmp_dir, framework) -> None:
_test_type(tmp_dir, DType.F8_E4M3, device, framework, DType.I8)


def test_float4_e2m1fn_x2(fstcpp_log, tmp_dir, framework) -> None:
"""Test bit-exact round-trip for F4 (torch.float4_e2m1fn_x2).

F4 is a packed FP4 format (two 4-bit values per byte, dtype string "F4" in
safetensors) used for expert weight matrices in models like DeepSeek V4-Flash.

The safetensors shape is in FP4-element count while torch.float4_e2m1fn_x2
counts packed byte-pairs, so the PyTorch shape has half as many elements in
the last dimension. fastsafetensors handles the shape conversion internally.
"""
if framework.get_name() != "pytorch":
pytest.skip("F4 is only available in PyTorch")
return
import torch

if not hasattr(torch, "float4_e2m1fn_x2"):
pytest.skip("torch.float4_e2m1fn_x2 requires PyTorch 2.10+")
Comment thread
takeshi-yoshimura marked this conversation as resolved.
return
device, _ = get_and_check_device(framework)
filename = os.path.join(tmp_dir, "f4.safetensors")
# F4 tensors cannot be created via randn/cast; create via uint8 view.
# Shape [8, 16] in FP4-element terms = shape [8, 8] in float4_e2m1fn_x2.
u8 = torch.arange(64, dtype=torch.uint8, device=device.as_str()).reshape(8, 8)
t0 = u8.view(torch.float4_e2m1fn_x2)
save_safetensors_file({"a": t0}, filename, {"fst": "sample"}, framework)
t_ref = load_safetensors_file(filename, device, framework)
with fastsafe_open(
filenames=[filename],
nogds=True,
device=device.as_str(),
framework=framework.get_name(),
debug_log=True,
) as f:
for key in f.keys():
t1 = f.get_tensor_wrapped(key).clone().detach()
assert framework.is_equal(t1, t_ref[key])
assert framework.get_mem_used() == 0
assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0


def test_float8_e8m0fnu(fstcpp_log, tmp_dir, framework) -> None:
"""Test bit-exact round-trip for F8_E8M0 (torch.float8_e8m0fnu).

F8_E8M0 is an unsigned 8-bit exponent-only format used as per-tile
quantization scales in models like DeepSeek V4-Flash. It has no mantissa
bits, so ordinary randn -> cast is safe for creating test tensors.
"""
if framework.get_name() != "pytorch":
pytest.skip("F8_E8M0 is only available in PyTorch")
return
import torch

if not hasattr(torch, "float8_e8m0fnu"):
pytest.skip("torch.float8_e8m0fnu requires PyTorch 2.5+")
return
device, _ = get_and_check_device(framework)
_test_type(tmp_dir, DType.F8_E8M0, device, framework)


def test_cpp_metrics(fstcpp_log, framework) -> None:
device, _ = get_and_check_device(framework)
exp_length = 0
Expand Down
Loading