diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index fb994e8..91c1f6a 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -100,7 +100,7 @@ def __init__( nelements = 1 for sh in t.shape: nelements *= sh - nbytes = nelements * framework.get_dtype_size(t.dtype) + nbytes = int(nelements * framework.get_dtype_size(t.dtype)) if (e - s) != nbytes: raise Exception( f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}" @@ -195,16 +195,25 @@ def get_tensors( - copy_start_offset ) disk_dtype = self.framework.as_workaround_dtype(t.dtype) + dl_shape, dl_strides = self.framework.get_storage_shape( + t.dtype, t.shape, t.strides + ) dl_tensor = from_cuda_buffer( dst_dev_ptr, - t.shape, - t.strides, + dl_shape, + dl_strides, disk_dtype, device, ) t2 = self.framework.from_dlpack(dl_tensor, device, disk_dtype) if disk_dtype != t.dtype: t2 = t2.view(t.dtype) + # For packed sub-byte dtypes, reshape to the framework-native shape. + # e.g. F4 (float4_e2m1fn_x2): safetensors shape counts FP4 values, + # but PyTorch shape counts packed pairs (2 FP4 per byte). + native_shape = self.framework.get_native_shape(t.dtype, t.shape) + if native_shape != t.shape: + t2 = t2.reshape(native_shape) if dtype != DType.AUTO and dtype != t.dtype: if self.framework.get_dtype_size(dtype) > self.framework.get_dtype_size( diff --git a/fastsafetensors/dlpack.py b/fastsafetensors/dlpack.py index ccc616d..59ac014 100644 --- a/fastsafetensors/dlpack.py +++ b/fastsafetensors/dlpack.py @@ -89,6 +89,14 @@ def __init__(self, dtype: DType): DType.F32: (kDLFloat, 32, 1), DType.F64: (kDLFloat, 64, 1), DType.BF16: (kDLBfloat, 16, 1), + # F8_E8M0 is an unsigned 8-bit scale format (torch.float8_e8m0fnu). + # DLPack has no float8 type code, so we expose it as opaque uint8 + # bytes — consistent with the U8 workaround used for NCCL ops. + DType.F8_E8M0: (kDLUInt, 8, 1), + # F4 is packed FP4 (torch.float4_e2m1fn_x2): two 4-bit values per byte. + # DLPack has no sub-byte type code, so we expose it as opaque uint8 + # bytes — consistent with the U8 workaround used for NCCL ops. + DType.F4: (kDLUInt, 8, 1), } diff --git a/fastsafetensors/frameworks/__init__.py b/fastsafetensors/frameworks/__init__.py index 704dc59..68cf97b 100644 --- a/fastsafetensors/frameworks/__init__.py +++ b/fastsafetensors/frameworks/__init__.py @@ -45,6 +45,15 @@ def view(self, dtype: DType) -> "TensorBase": def __getitem__(self, _val) -> "TensorBase": pass + def reshape(self, shape: List[int]) -> "TensorBase": + """Reshape the tensor. Default implementation raises NotImplementedError. + Frameworks that support packed sub-byte dtypes (e.g. F4) must override + this to handle shape adjustment after buffer loading. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement reshape()" + ) + T = TypeVar("T", bound=TensorBase) @@ -127,7 +136,14 @@ def copy_tensor(self, dst: T, src: T) -> None: pass @abstractmethod - def get_dtype_size(self, dtype: DType) -> int: + def get_dtype_size(self, dtype: DType) -> float: + """Return the number of bytes per logical element for this dtype. + + For packed sub-byte types (e.g. F4 / float4_e2m1fn_x2, where two + FP4 values share a single byte), this returns a fractional value + (0.5 for F4). Callers should use ``int(nelements * + get_dtype_size(dtype))`` when computing byte counts. + """ pass @abstractmethod @@ -146,6 +162,29 @@ def get_device_ptr_align(self) -> int: def as_workaround_dtype(self, dtype: DType) -> DType: pass + def get_storage_shape( + self, dtype: DType, shape: List[int], strides: List[int] + ) -> "tuple[List[int], List[int]]": + """Return the (shape, strides) to use for the workaround-dtype DLPack + tensor when loading from a device buffer. + + For most dtypes this is just (shape, strides). Packed sub-byte dtypes + (e.g. F4 / float4_e2m1fn_x2) store multiple logical values per byte; + the safetensors shape counts logical values while DLPack / PyTorch work + in bytes, so the shape must be collapsed to a flat byte count. + """ + return shape, strides + + def get_native_shape(self, dtype: DType, st_shape: List[int]) -> List[int]: + """Return the framework-native tensor shape for a safetensors tensor. + + For most dtypes this matches *st_shape* exactly. For packed sub-byte + dtypes the safetensors shape counts logical (sub-byte) elements while + the framework counts packed storage units (bytes), so the shape is + compressed by the packing ratio. + """ + return st_shape + @abstractmethod def get_process_group(self, pg: Optional[Any]) -> ProcessGroupBase: pass diff --git a/fastsafetensors/frameworks/_torch.py b/fastsafetensors/frameworks/_torch.py index 1e22c70..00a61dc 100644 --- a/fastsafetensors/frameworks/_torch.py +++ b/fastsafetensors/frameworks/_torch.py @@ -29,12 +29,18 @@ need_workaround_dtypes: Dict[DType, DType] = { DType.F8_E5M2: DType.I8, DType.F8_E4M3: DType.I8, + DType.F8_E8M0: DType.U8, + DType.F4: DType.U8, } if hasattr(torch, "float8_e5m2"): dtype_convert[DType.F8_E5M2] = torch.float8_e5m2 if hasattr(torch, "float8_e4m3fn"): dtype_convert[DType.F8_E4M3] = torch.float8_e4m3fn +if hasattr(torch, "float8_e8m0fnu"): + dtype_convert[DType.F8_E8M0] = torch.float8_e8m0fnu +if hasattr(torch, "float4_e2m1fn_x2"): + dtype_convert[DType.F4] = torch.float4_e2m1fn_x2 if hasattr(torch, "uint16"): dtype_convert[DType.U16] = torch.uint16 if hasattr(torch, "uint32"): @@ -87,6 +93,9 @@ def view(self, dtype: DType) -> "TorchTensor": def __getitem__(self, _val) -> "TorchTensor": return TorchTensor(self.device, self.dtype, self.real_tensor[_val]) + def reshape(self, shape: List[int]) -> "TorchTensor": + return TorchTensor(self.device, self.dtype, self.real_tensor.reshape(shape)) + def _needs_fp8_cast() -> bool: """Check if FP8 NCCL ops need a bf16 workaround (pre-sm90 GPUs).""" @@ -97,7 +106,7 @@ def _needs_fp8_cast() -> bool: def _is_fp8(dtype: DType) -> bool: - return dtype in (DType.F8_E5M2, DType.F8_E4M3) + return dtype in (DType.F8_E5M2, DType.F8_E4M3, DType.F8_E8M0) @dataclass @@ -211,8 +220,9 @@ def free_tensor_memory(self, gbuf: gds_device_buffer, dev: Device): def get_empty_tensor( self, shape: List[int], dtype: DType, device: Device ) -> TorchTensor: + native_shape = self.get_native_shape(dtype, shape) dst = torch.empty( - size=shape, dtype=dtype_convert[dtype], device=device.as_str() + size=native_shape, dtype=dtype_convert[dtype], device=device.as_str() ) return TorchTensor(device, dtype, dst) @@ -220,8 +230,13 @@ def concat_tensors(self, tensors: List[TorchTensor], dim: int) -> TorchTensor: ts = [tensor.real_tensor for tensor in tensors] return TorchTensor(tensors[0].device, tensors[0].dtype, torch.cat(ts, dim=dim)) - def get_dtype_size(self, dtype: DType) -> int: - return dtype_convert[dtype].itemsize + def get_dtype_size(self, dtype: DType) -> float: + if dtype == DType.F4: + # float4_e2m1fn_x2 packs two 4-bit values into one byte. + # safetensors stores shape in FP4-element count, so the byte + # size per logical element is 0.5. + return 0.5 + return float(dtype_convert[dtype].itemsize) def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> TorchTensor: t = torch.from_dlpack(dl_tensor) @@ -253,6 +268,32 @@ def as_workaround_dtype(self, dtype: DType) -> DType: return need_workaround_dtypes[dtype] return dtype + def get_storage_shape( + self, dtype: DType, shape: List[int], strides: List[int] + ) -> "tuple[List[int], List[int]]": + size = self.get_dtype_size(dtype) + if size < 1.0: + # Packed sub-byte dtype: collapse to flat byte count so the + # workaround-dtype DLPack tensor doesn't overread the buffer. + import math + + nbytes = int(math.prod(shape) * size) + return [nbytes], [1] + return shape, strides + + def get_native_shape(self, dtype: DType, st_shape: List[int]) -> List[int]: + size = self.get_dtype_size(dtype) + if size < 1.0: + # safetensors counts logical sub-byte elements; PyTorch counts + # packed storage units (bytes). Compress the last dimension. + import math + + ratio = int(round(1.0 / size)) # e.g. 2 for F4 (2 FP4 per byte) + if len(st_shape) > 1: + return list(st_shape[:-1]) + [st_shape[-1] // ratio] + return [int(math.prod(st_shape) * size)] + return st_shape + def get_process_group(self, pg: Optional[Any]) -> TorchProcessGroup: if pg is not None: if isinstance(pg, SingleGroup): diff --git a/fastsafetensors/st_types.py b/fastsafetensors/st_types.py index 82ae74f..3a2dc6c 100644 --- a/fastsafetensors/st_types.py +++ b/fastsafetensors/st_types.py @@ -58,4 +58,6 @@ class DType(Enum): BF16 = "BF16" F8_E5M2 = "F8_E5M2" F8_E4M3 = "F8_E4M3" + F8_E8M0 = "F8_E8M0" + F4 = "F4" AUTO = "AUTO" diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py index d5ff509..bef85a3 100644 --- a/tests/test_fastsafetensors.py +++ b/tests/test_fastsafetensors.py @@ -662,6 +662,65 @@ def test_float8_e4m3fn_to_int8(fstcpp_log, tmp_dir, framework) -> None: _test_type(tmp_dir, DType.F8_E4M3, device, framework, DType.I8) +def test_float4_e2m1fn_x2(fstcpp_log, tmp_dir, framework) -> None: + """Test bit-exact round-trip for F4 (torch.float4_e2m1fn_x2). + + F4 is a packed FP4 format (two 4-bit values per byte, dtype string "F4" in + safetensors) used for expert weight matrices in models like DeepSeek V4-Flash. + + The safetensors shape is in FP4-element count while torch.float4_e2m1fn_x2 + counts packed byte-pairs, so the PyTorch shape has half as many elements in + the last dimension. fastsafetensors handles the shape conversion internally. + """ + if framework.get_name() != "pytorch": + pytest.skip("F4 is only available in PyTorch") + return + import torch + + if not hasattr(torch, "float4_e2m1fn_x2"): + pytest.skip("torch.float4_e2m1fn_x2 requires PyTorch 2.10+") + return + device, _ = get_and_check_device(framework) + filename = os.path.join(tmp_dir, "f4.safetensors") + # F4 tensors cannot be created via randn/cast; create via uint8 view. + # Shape [8, 16] in FP4-element terms = shape [8, 8] in float4_e2m1fn_x2. + u8 = torch.arange(64, dtype=torch.uint8, device=device.as_str()).reshape(8, 8) + t0 = u8.view(torch.float4_e2m1fn_x2) + save_safetensors_file({"a": t0}, filename, {"fst": "sample"}, framework) + t_ref = load_safetensors_file(filename, device, framework) + with fastsafe_open( + filenames=[filename], + nogds=True, + device=device.as_str(), + framework=framework.get_name(), + debug_log=True, + ) as f: + for key in f.keys(): + t1 = f.get_tensor_wrapped(key).clone().detach() + assert framework.is_equal(t1, t_ref[key]) + assert framework.get_mem_used() == 0 + assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0 + + +def test_float8_e8m0fnu(fstcpp_log, tmp_dir, framework) -> None: + """Test bit-exact round-trip for F8_E8M0 (torch.float8_e8m0fnu). + + F8_E8M0 is an unsigned 8-bit exponent-only format used as per-tile + quantization scales in models like DeepSeek V4-Flash. It has no mantissa + bits, so ordinary randn -> cast is safe for creating test tensors. + """ + if framework.get_name() != "pytorch": + pytest.skip("F8_E8M0 is only available in PyTorch") + return + import torch + + if not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch.float8_e8m0fnu requires PyTorch 2.5+") + return + device, _ = get_and_check_device(framework) + _test_type(tmp_dir, DType.F8_E8M0, device, framework) + + def test_cpp_metrics(fstcpp_log, framework) -> None: device, _ = get_and_check_device(framework) exp_length = 0