Skip to content

Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69

Open
uasind wants to merge 1 commit intofoundation-model-stack:mainfrom
uasind:add-f4-and-f8-e8m0-dtypes
Open

Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69
uasind wants to merge 1 commit intofoundation-model-stack:mainfrom
uasind:add-f4-and-f8-e8m0-dtypes

Conversation

@uasind
Copy link
Copy Markdown

@uasind uasind commented Apr 26, 2026

Summary

Adds support for two dtypes used by modern mixed-precision models (DeepSeek V4-Flash, DeepSeek V3, etc.) that currently cause add_filenames() to fail:

  • F8_E8M0 (torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only format used for per-tile quantization scales. One byte per element. Straightforward addition — same storage size as F8_E4M3.

  • F4 (torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit values per byte. safetensors stores the shape in FP4-element count while PyTorch float4_e2m1fn_x2 counts packed pairs (one byte each). This requires shape adjustment on load.

Without these, any safetensors file containing F4 or F8_E8M0 tensors raises:

ValueError: 'F8_E8M0' is not a valid DType
ValueError: 'F4' is not a valid DType

Changes

fastsafetensors/st_types.py

  • Add F4 and F8_E8M0 to DType enum

fastsafetensors/frameworks/_torch.py

  • Map DType.F8_E8M0torch.float8_e8m0fnu (guarded by hasattr)
  • Map DType.F4torch.float4_e2m1fn_x2 (guarded by hasattr)
  • Add U8 NCCL workaround for both (NCCL has no float8_e8m0/float4 support)
  • get_dtype_size() returns 0.5 for F4 (two FP4 values per byte)
  • get_storage_shape(): collapses packed sub-byte shapes to flat byte count for DLPack so the workaround-dtype view doesn't overread the buffer
  • get_native_shape(): converts safetensors FP4-element shape to PyTorch packed-pair shape
  • get_empty_tensor(): uses get_native_shape() for correct allocation shape
  • TorchTensor.reshape(): needed by common.py after DLPack import

fastsafetensors/frameworks/__init__.py

  • get_dtype_size() return type: intfloat (to accommodate 0.5 for F4)
  • get_storage_shape() added (default: identity — no change for existing dtypes)
  • get_native_shape() added (default: identity — no change for existing dtypes)
  • TensorBase.reshape() added (default: NotImplementedError)

fastsafetensors/dlpack.py

  • F4 and F8_E8M0: mapped as opaque (kDLUInt, 8, 1) — consistent with their U8 workaround

fastsafetensors/common.py

  • Validation: nbytes = int(nelements * get_dtype_size()) — safe for fractional sizes
  • get_tensors(): uses get_storage_shape() / get_native_shape() for packed dtypes

tests/test_fastsafetensors.py

  • test_float8_e8m0fnu: bit-exact round-trip via _test_type
  • test_float4_e2m1fn_x2: bit-exact round-trip (uses uint8 view, since randn/cast don't support float4)

F4 shape handling detail

safetensors stores an F4 weight of logical shape [2048, 4096] (8M FP4 values) as 4 MiB (0.5 bytes per FP4 value). torch.float4_e2m1fn_x2 represents paired values: [2048, 4096] in that dtype would be 8 MiB — wrong. The correct PyTorch shape is [2048, 2048] (4M packed pairs = 4 MiB).

get_storage_shape() handles DLPack by using a flat [4194304] uint8 shape (avoids buffer overread). get_native_shape() then reshapes to [2048, 2048] after the view. get_empty_tensor() also uses get_native_shape() so broadcast/scatter allocations are correctly sized.

Validation

Tested against DeepSeek V4-Flash MP=2 safetensors shards (~82 GiB each) containing all six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64). With this PR:

  • add_filenames() succeeds on V4-Flash shards
  • All tensors load bit-exactly vs safetensors.torch.load_file
  • 41 tests pass (39 existing + 2 new), zero regressions

Notes

  • get_storage_shape() and get_native_shape() default to identity on FrameworkOpBase, so non-PyTorch frameworks (e.g. Paddle) are unaffected unless they add their own F4/F8_E8M0 dtype mappings.
  • F8_E8M0 requires PyTorch ≥ 2.5; F4 requires PyTorch ≥ 2.10. Both additions are guarded by hasattr() so older PyTorch versions still work for all currently-supported dtypes.

…ypes

Models like DeepSeek V4-Flash use two additional dtypes that were missing
from fastsafetensors's DType enum:

  - F8_E8M0 (torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only
    format used for per-tile quantization scales.  One byte per element.

  - F4 (torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit
    values per byte.  safetensors stores the shape in FP4-element count while
    PyTorch float4_e2m1fn_x2 counts packed pairs, so the byte size per logical
    element is 0.5.

Without these variants, add_filenames() raises ValueError on any safetensors
file that contains F4 or F8_E8M0 tensors, making fastsafetensors unusable with
those models.

Changes:
- st_types.py: add F4 and F8_E8M0 to DType enum
- frameworks/_torch.py:
    * map DType.F8_E8M0 -> torch.float8_e8m0fnu (guarded by hasattr)
    * map DType.F4     -> torch.float4_e2m1fn_x2 (guarded by hasattr)
    * add U8 workaround for both (no NCCL support)
    * get_dtype_size() returns 0.5 for F4 (two FP4 values per byte)
    * get_storage_shape() collapses packed sub-byte shapes to flat byte count
      so the DLPack tensor does not overread the buffer
    * get_native_shape()  restores the PyTorch-native shape after DLPack import
    * get_empty_tensor() uses get_native_shape() so shape is correct for alloc
    * TorchTensor.reshape() added to support shape adjustment in common.py
- frameworks/__init__.py:
    * get_dtype_size() return type changed from int to float
    * get_storage_shape() added (default: identity)
    * get_native_shape()  added (default: identity)
    * TensorBase.reshape() added (default: raises NotImplementedError)
- dlpack.py: add DLPack mappings for F4 and F8_E8M0 as opaque uint8 bytes
- common.py:
    * validation: nbytes = int(nelements * get_dtype_size()) for sub-byte safety
    * get_tensors: use get_storage_shape() and get_native_shape() for packed types
- tests: add test_float4_e2m1fn_x2 and test_float8_e8m0fnu

Validated against DeepSeek V4-Flash MP=2 safetensors shards (~82 GB each).
All six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64) load bit-exactly.
@takeshi-yoshimura takeshi-yoshimura self-requested a review April 27, 2026 10:37
@takeshi-yoshimura
Copy link
Copy Markdown
Collaborator

@uasind
Thanks for your contribution! While I review and test your code, please fix the lint issue and add signed-off-by like other commits.

Copy link
Copy Markdown
Collaborator

@takeshi-yoshimura takeshi-yoshimura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uasind
Can you please resolve lint issues, DCO, and my comments. Thanks!


ratio = int(round(1.0 / size)) # e.g. 2 for F4 (2 FP4 per byte)
if len(st_shape) > 1:
return list(st_shape[:-1]) + [st_shape[-1] // ratio]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check st_shape[-1] % ratio == 0 and raise Exception on malformed safetensors files?

import torch

if not hasattr(torch, "float4_e2m1fn_x2"):
pytest.skip("torch.float4_e2m1fn_x2 requires PyTorch 2.10+")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our CI does not include PyTorch 2.10 currently. Can you please add a new workflow to run only the new tests by modifying .github/workflows/test-torch.yaml?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants