Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69
Open
uasind wants to merge 1 commit intofoundation-model-stack:mainfrom
Open
Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69uasind wants to merge 1 commit intofoundation-model-stack:mainfrom
uasind wants to merge 1 commit intofoundation-model-stack:mainfrom
Conversation
…ypes
Models like DeepSeek V4-Flash use two additional dtypes that were missing
from fastsafetensors's DType enum:
- F8_E8M0 (torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only
format used for per-tile quantization scales. One byte per element.
- F4 (torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit
values per byte. safetensors stores the shape in FP4-element count while
PyTorch float4_e2m1fn_x2 counts packed pairs, so the byte size per logical
element is 0.5.
Without these variants, add_filenames() raises ValueError on any safetensors
file that contains F4 or F8_E8M0 tensors, making fastsafetensors unusable with
those models.
Changes:
- st_types.py: add F4 and F8_E8M0 to DType enum
- frameworks/_torch.py:
* map DType.F8_E8M0 -> torch.float8_e8m0fnu (guarded by hasattr)
* map DType.F4 -> torch.float4_e2m1fn_x2 (guarded by hasattr)
* add U8 workaround for both (no NCCL support)
* get_dtype_size() returns 0.5 for F4 (two FP4 values per byte)
* get_storage_shape() collapses packed sub-byte shapes to flat byte count
so the DLPack tensor does not overread the buffer
* get_native_shape() restores the PyTorch-native shape after DLPack import
* get_empty_tensor() uses get_native_shape() so shape is correct for alloc
* TorchTensor.reshape() added to support shape adjustment in common.py
- frameworks/__init__.py:
* get_dtype_size() return type changed from int to float
* get_storage_shape() added (default: identity)
* get_native_shape() added (default: identity)
* TensorBase.reshape() added (default: raises NotImplementedError)
- dlpack.py: add DLPack mappings for F4 and F8_E8M0 as opaque uint8 bytes
- common.py:
* validation: nbytes = int(nelements * get_dtype_size()) for sub-byte safety
* get_tensors: use get_storage_shape() and get_native_shape() for packed types
- tests: add test_float4_e2m1fn_x2 and test_float8_e8m0fnu
Validated against DeepSeek V4-Flash MP=2 safetensors shards (~82 GB each).
All six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64) load bit-exactly.
Collaborator
|
@uasind |
takeshi-yoshimura
requested changes
May 3, 2026
Collaborator
takeshi-yoshimura
left a comment
There was a problem hiding this comment.
@uasind
Can you please resolve lint issues, DCO, and my comments. Thanks!
|
|
||
| ratio = int(round(1.0 / size)) # e.g. 2 for F4 (2 FP4 per byte) | ||
| if len(st_shape) > 1: | ||
| return list(st_shape[:-1]) + [st_shape[-1] // ratio] |
Collaborator
There was a problem hiding this comment.
Can you check st_shape[-1] % ratio == 0 and raise Exception on malformed safetensors files?
| import torch | ||
|
|
||
| if not hasattr(torch, "float4_e2m1fn_x2"): | ||
| pytest.skip("torch.float4_e2m1fn_x2 requires PyTorch 2.10+") |
Collaborator
There was a problem hiding this comment.
Our CI does not include PyTorch 2.10 currently. Can you please add a new workflow to run only the new tests by modifying .github/workflows/test-torch.yaml?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds support for two dtypes used by modern mixed-precision models (DeepSeek V4-Flash, DeepSeek V3, etc.) that currently cause
add_filenames()to fail:F8_E8M0 (
torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only format used for per-tile quantization scales. One byte per element. Straightforward addition — same storage size as F8_E4M3.F4 (
torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit values per byte. safetensors stores the shape in FP4-element count while PyTorchfloat4_e2m1fn_x2counts packed pairs (one byte each). This requires shape adjustment on load.Without these, any safetensors file containing F4 or F8_E8M0 tensors raises:
Changes
fastsafetensors/st_types.pyF4andF8_E8M0toDTypeenumfastsafetensors/frameworks/_torch.pyDType.F8_E8M0→torch.float8_e8m0fnu(guarded byhasattr)DType.F4→torch.float4_e2m1fn_x2(guarded byhasattr)U8NCCL workaround for both (NCCL has no float8_e8m0/float4 support)get_dtype_size()returns0.5for F4 (two FP4 values per byte)get_storage_shape(): collapses packed sub-byte shapes to flat byte count for DLPack so the workaround-dtype view doesn't overread the bufferget_native_shape(): converts safetensors FP4-element shape to PyTorch packed-pair shapeget_empty_tensor(): usesget_native_shape()for correct allocation shapeTorchTensor.reshape(): needed bycommon.pyafter DLPack importfastsafetensors/frameworks/__init__.pyget_dtype_size()return type:int→float(to accommodate0.5for F4)get_storage_shape()added (default: identity — no change for existing dtypes)get_native_shape()added (default: identity — no change for existing dtypes)TensorBase.reshape()added (default:NotImplementedError)fastsafetensors/dlpack.py(kDLUInt, 8, 1)— consistent with their U8 workaroundfastsafetensors/common.pynbytes = int(nelements * get_dtype_size())— safe for fractional sizesget_tensors(): usesget_storage_shape()/get_native_shape()for packed dtypestests/test_fastsafetensors.pytest_float8_e8m0fnu: bit-exact round-trip via_test_typetest_float4_e2m1fn_x2: bit-exact round-trip (uses uint8 view, since randn/cast don't support float4)F4 shape handling detail
safetensors stores an F4 weight of logical shape
[2048, 4096](8M FP4 values) as 4 MiB (0.5 bytes per FP4 value).torch.float4_e2m1fn_x2represents paired values:[2048, 4096]in that dtype would be 8 MiB — wrong. The correct PyTorch shape is[2048, 2048](4M packed pairs = 4 MiB).get_storage_shape()handles DLPack by using a flat[4194304]uint8 shape (avoids buffer overread).get_native_shape()then reshapes to[2048, 2048]after the view.get_empty_tensor()also usesget_native_shape()so broadcast/scatter allocations are correctly sized.Validation
Tested against DeepSeek V4-Flash MP=2 safetensors shards (~82 GiB each) containing all six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64). With this PR:
add_filenames()succeeds on V4-Flash shardssafetensors.torch.load_fileNotes
get_storage_shape()andget_native_shape()default to identity onFrameworkOpBase, so non-PyTorch frameworks (e.g. Paddle) are unaffected unless they add their own F4/F8_E8M0 dtype mappings.hasattr()so older PyTorch versions still work for all currently-supported dtypes.