Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion alto/kernels/dispatch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

@dataclass(unsafe_hash=True, kw_only=True, slots=True)
class TrainingOpConfig:
precision: Literal["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4"]
precision: Literal["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4", "amdfp4"]
"""
Quantization recipe family. ``"nvfp4"`` and ``"amdfp4"`` are peers with
the same micro-block layout but different inner-scale grids:

* ``"nvfp4"`` — E4M3 inner scale by default (NVFP4 spec ``s_block``).
* ``"amdfp4"`` — UE5M3 inner scale (GFXIPARCH-2067 §19.10 / OCP E5M3,
max normal 114688, NaN at 0xFF). See ``inner_scale_format`` below
for how this selection is forwarded to the kernels.
"""
use_2dblock_x: bool
use_2dblock_w: bool
use_hadamard: bool
Expand All @@ -34,5 +43,45 @@ class TrainingOpConfig:
* NVFP4: not implemented
"""

inner_scale_format: Literal["e4m3", "ue5m3"] = "e4m3"
"""
Per-block inner-scale dtype used by NVFP4 / AMD-FP4 (orthogonal to ``precision``).

* ``"e4m3"`` — NVFP4 default: signed FP8 E4M3, max 448.
* ``"ue5m3"`` — AMD-FP4: unsigned 8-bit (5 exp + 3 mant, NaN code 0xFF
per GFXIPARCH-2067 §19.10, no sign bit, no Inf encoding), max normal
114688 at code 0xFE; ~256× wider dynamic range than E4M3 at the same
mantissa precision.

Selected when ``precision in {"nvfp4", "amdfp4"}``; ignored for
MXFP4 / MXFP8. When ``precision == "amdfp4"`` this field is forced to
``"ue5m3"`` in :meth:`__post_init__` (AMD-FP4 = NVFP4 spec + UE5M3
inner scale).

Locked design choice (see ``docs/amd-fp4/agent-handoff.md`` D3): keep the
``precision`` literal carrying the recipe and let this orthogonal field
carry the inner-scale dtype, so additional dtypes (E8M3, UE4M4, …) can
be added by extending this Literal alone.
"""

def __post_init__(self) -> None:
# Pin the AMD-FP4 recipe to its UE5M3 inner-scale dtype. Allowed
# caller-side aliases:
# - ``inner_scale_format`` left at its dataclass default ``"e4m3"``
# (i.e. caller didn't touch it; we silently override),
# - ``inner_scale_format="ue5m3"`` explicitly (no-op).
# Any other value is rejected so AMD-FP4 cannot accidentally be
# routed to the E4M3 grid by an out-of-date caller.
if self.precision == "amdfp4":
if self.inner_scale_format == "e4m3":
# Not a frozen dataclass, so plain assignment is fine even with
# ``slots=True`` (the slot descriptor accepts writes).
self.inner_scale_format = "ue5m3"
elif self.inner_scale_format != "ue5m3":
raise ValueError(
f"precision='amdfp4' requires inner_scale_format='ue5m3'; "
f"got inner_scale_format={self.inner_scale_format!r}"
)


torch.serialization.add_safe_globals([TrainingOpConfig])
4 changes: 3 additions & 1 deletion alto/kernels/dispatch/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def _get_tensor_cls_for_config(config: TrainingOpConfig) -> Type[torch.Tensor]:
return MXFP4TrainingWeightWrapperTensor
elif config.precision in ("mxfp8_e4m3", "mxfp8_e5m2"):
return MXFP8TrainingWeightWrapperTensor
elif config.precision == "nvfp4":
elif config.precision in ("nvfp4", "amdfp4"):
# AMD-FP4 reuses the NVFP4 wrapper; it re-dispatches on
# ``config.precision`` internally to pin the UE5M3 inner grid.
return NVFP4TrainingWeightWrapperTensor
else:
raise ValueError(f"Unsupported training op config: {config}")
Expand Down
90 changes: 57 additions & 33 deletions alto/kernels/dispatch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torchao.utils import TorchAOBaseTensor
from torchtitan.tools.logging import logger

from alto.kernels.fp4.amdfp4 import (
_quantize_then_amdfp4_scaled_grouped_mm,
_to_amdfp4_then_scaled_mm,
)
from alto.kernels.fp4.mxfp4.mxfp_linear import _to_mxfp4_then_scaled_mm
from alto.kernels.fp4.mxfp4.mxfp_grouped_gemm.functional import _quantize_then_mxfp_scaled_grouped_mm
from alto.kernels.fp4.nvfp4.nvfp_linear import _to_nvfp4_then_scaled_mm
Expand Down Expand Up @@ -295,31 +299,32 @@ def __torch_function__(cls, func, types, args, kwargs={}):


class NVFP4TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
"""Weight tensor subclass that routes F.linear calls through NVFP4LinearFunction.

The dispatch mechanism follows the same pattern as MXFP4TrainingWeightWrapperTensor:
when PyTorch sees a linear/mm operation whose *weight* (B) is wrapped in this
subclass, __torch_function__ is triggered and the NVFP4 QDQ path is used instead
of the standard BF16 matmul.
"""Weight tensor subclass that routes F.linear calls through the NVFP4-family
autograd functions (covering both NVFP4 and AMD-FP4 recipes).

The dispatch mechanism follows the same pattern as
``MXFP4TrainingWeightWrapperTensor``: when PyTorch sees a linear/mm
operation whose *weight* (B) is wrapped in this subclass,
``__torch_function__`` is triggered and a low-precision QDQ path is used
instead of the standard BF16 matmul.

Recipe selection follows ``config.precision``:

* ``"nvfp4"`` → ``_to_nvfp4_then_scaled_mm`` /
``_quantize_then_nvfp4_scaled_grouped_mm``, with the inner-scale
grid taken from ``config.inner_scale_format``. This keeps the
historic surface (E4M3 default, optional UE5M3 via the orthogonal
``inner_scale_format`` field) intact.
* ``"amdfp4"`` → ``_to_amdfp4_then_scaled_mm`` /
``_quantize_then_amdfp4_scaled_grouped_mm``, the AMD-FP4 thin
wrappers that hard-pin the inner grid to UE5M3. ``__post_init__``
on ``TrainingOpConfig`` already forces
``inner_scale_format='ue5m3'`` here, so the two paths are
numerically equivalent — the split exists so call traces /
profilers can identify the recipe by Python frame.

This means model code never needs to change — the training precision is
controlled purely by wrapping the weight parameters via swap_params().

Config fields used:
use_2dblock_x – 2D block scaling on activations
use_2dblock_w – 2D block scaling on weights (mirrors the axis-invariant view)
use_sr_grad – stochastic rounding on gradient quantization
use_hadamard – wgrad-path Hadamard rotation (mirrors MXFP4 behaviour)
use_dge – differentiable gradient estimator on wgrad (mirrors MXFP4)
two_level_scaling – two-level NVFP4 scaling (global × per-block)

Supported ops:
linear / mm / addmm / matmul – routed through ``_to_nvfp4_then_scaled_mm``
_grouped_mm – routed through
``_quantize_then_nvfp4_scaled_grouped_mm``
(2D × 3D with ``offs`` for MoE routed
experts)

controlled purely by wrapping the weight parameters via ``swap_params()``.
"""

@classmethod
Expand All @@ -338,13 +343,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
assert bias is None, "Bias is not supported for grouped_mm"

config = B.config
assert config.precision == "nvfp4", (
f"expected TrainingOpConfig with precision=nvfp4, got {config.precision}"
assert config.precision in ("nvfp4", "amdfp4"), (
"expected TrainingOpConfig with precision in {'nvfp4','amdfp4'}, "
f"got {config.precision}"
)

return _quantize_then_nvfp4_scaled_grouped_mm(
A,
B,
common_kwargs = dict(
offs=offs,
use_2dblock_x=config.use_2dblock_x,
use_2dblock_w=config.use_2dblock_w,
Expand All @@ -353,6 +357,17 @@ def __torch_function__(cls, func, types, args, kwargs={}):
use_hadamard=config.use_hadamard,
use_dge=config.use_dge,
)
if config.precision == "amdfp4":
# Inner grid is hard-pinned to UE5M3 inside the wrapper.
return _quantize_then_amdfp4_scaled_grouped_mm(A, B, **common_kwargs)
return _quantize_then_nvfp4_scaled_grouped_mm(
A,
B,
# ``inner_scale_format`` is the NVFP4 inner-grid selector
# (E4M3 default, optional UE5M3); orthogonal to ``precision``.
scale_format=config.inner_scale_format,
**common_kwargs,
)

# linear / mm overrides
elif func.__name__ in gemm_ops:
Expand All @@ -367,8 +382,10 @@ def __torch_function__(cls, func, types, args, kwargs={}):
assert isinstance(B, cls), (f"B should be a {cls.__name__} for func {func.__name__}")

config = B.config
assert config.precision == "nvfp4", (
f"expected TrainingOpConfig with precision=nvfp4, got {config.precision}")
assert config.precision in ("nvfp4", "amdfp4"), (
"expected TrainingOpConfig with precision in {'nvfp4','amdfp4'}, "
f"got {config.precision}"
)

# Pass the wrapper tensor itself into the autograd function —
# matching the MXFP4 path — so that any upstream subclass
Expand All @@ -379,16 +396,23 @@ def __torch_function__(cls, func, types, args, kwargs={}):
# entry, so the autograd tape and downstream QDQ ops still see
# plain tensors.
W = B if trans_b else B.T
Y = _to_nvfp4_then_scaled_mm(
A,
W,
common_kwargs = dict(
use_2dblock_x=config.use_2dblock_x,
use_2dblock_w=config.use_2dblock_w,
use_sr_grad=config.use_sr_grad,
use_outer_scale=config.two_level_scaling == "tensorwise",
use_hadamard=config.use_hadamard,
use_dge=config.use_dge,
)
if config.precision == "amdfp4":
Y = _to_amdfp4_then_scaled_mm(A, W, **common_kwargs)
else:
Y = _to_nvfp4_then_scaled_mm(
A,
W,
scale_format=config.inner_scale_format,
**common_kwargs,
)
if bias is not None:
Y = Y + bias
return Y
Expand Down
50 changes: 36 additions & 14 deletions alto/kernels/fp4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,57 @@

"""Unified public entrypoint for FP4 kernels.

This module exposes the most commonly used public APIs for both FP4 families:
This module exposes the most commonly used public APIs for the FP4 family:

- shared E2M1 encode/decode primitives from ``fp4_common``
- shared E2M1 encode/decode primitives from ``fp4_primitives``
- MXFP4 quantize/dequantize entrypoints
- NVFP4 quantize/dequantize entrypoints
- NVFP4 quantize/dequantize entrypoints (E4M3 inner scale)
- AMD-FP4 quantize/dequantize entrypoints (UE5M3 inner scale; same micro-block
layout as NVFP4, GFXIPARCH-2067 §19.10 / OCP E5M3 aligned)

Format-specific internals remain in their own subpackages:
Directory structure (two shared layers + three recipe packages)::

- ``alto.kernels.fp4.fp4_common``
- ``alto.kernels.fp4.mxfp4``
- ``alto.kernels.fp4.nvfp4``
fp4/
├── fp4_primitives/ # lowest layer, shared by ALL recipes:
│ # E2M1 pack/quant/dequant + philox, scale-dtype
│ # casts (e4m3_ops / ue5m3_ops), grouped/MoE
│ # helpers, unwrap_weight_wrapper.
├── outer_scaled_fp4/ # shared by NVFP4 + AMD-FP4 only:
│ # the two-level blockwise layout (per-tensor FP32
│ # outer scale + FP32 float inner-grid). MXFP4's
│ # macro-block scaling is a separate impl and does
│ # NOT depend on this layer.
├── mxfp4/ # MXFP4 recipe — independent (E8M0 inner, 1×32,
│ # native tl.dot_scaled); only uses fp4_primitives.
├── nvfp4/ # NVFP4 recipe (E4M3 inner) over outer_scaled_fp4.
└── amdfp4/ # AMD-FP4 recipe (UE5M3 inner) — thin wrappers over
# nvfp4 / outer_scaled_fp4 with scale_format pinned.
"""

from importlib import import_module


_LAZY_ATTRS = {
"fp4_common": (".fp4_common", None),
"fp4_primitives": (".fp4_primitives", None),
"outer_scaled_fp4": (".outer_scaled_fp4", None),
"mxfp4": (".mxfp4", None),
"nvfp4": (".nvfp4", None),
"dequantize_e2m1": (".fp4_common", "dequantize_e2m1"),
"generate_philox_randval_2x": (".fp4_common", "generate_philox_randval_2x"),
"quantize_e2m1": (".fp4_common", "quantize_e2m1"),
"amdfp4": (".amdfp4", None),
"dequantize_e2m1": (".fp4_primitives", "dequantize_e2m1"),
"generate_philox_randval_2x": (".fp4_primitives", "generate_philox_randval_2x"),
"quantize_e2m1": (".fp4_primitives", "quantize_e2m1"),
"MXFP4_BLOCK_SIZE_DEFAULT": (".mxfp4.mxfp_quantization", "BLOCK_SIZE_DEFAULT"),
"convert_from_mxfp4": (".mxfp4.mxfp_quantization", "convert_from_mxfp4"),
"convert_to_mxfp4": (".mxfp4.mxfp_quantization", "convert_to_mxfp4"),
"NVFP4_BLOCK_SIZE_DEFAULT": (".nvfp4.nvfp_quantization", "BLOCK_SIZE_DEFAULT"),
"NVFP4_BLOCK_SIZE_DEFAULT": (".outer_scaled_fp4", "BLOCK_SIZE_DEFAULT"),
"compute_dynamic_outer_scale": (
".nvfp4.nvfp_quantization",
".outer_scaled_fp4",
"compute_dynamic_outer_scale",
),
"convert_from_nvfp4": (".nvfp4.nvfp_quantization", "convert_from_nvfp4"),
"convert_to_nvfp4": (".nvfp4.nvfp_quantization", "convert_to_nvfp4"),
"convert_from_amdfp4": (".amdfp4.amdfp_quantization", "convert_from_amdfp4"),
"convert_to_amdfp4": (".amdfp4.amdfp_quantization", "convert_to_amdfp4"),
}


Expand All @@ -57,15 +75,19 @@ def __dir__():
__all__ = (
"MXFP4_BLOCK_SIZE_DEFAULT",
"NVFP4_BLOCK_SIZE_DEFAULT",
"amdfp4",
"compute_dynamic_outer_scale",
"convert_from_amdfp4",
"convert_from_mxfp4",
"convert_from_nvfp4",
"convert_to_amdfp4",
"convert_to_mxfp4",
"convert_to_nvfp4",
"dequantize_e2m1",
"fp4_common",
"fp4_primitives",
"generate_philox_randval_2x",
"mxfp4",
"nvfp4",
"outer_scaled_fp4",
"quantize_e2m1",
)
48 changes: 48 additions & 0 deletions alto/kernels/fp4/amdfp4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
#
# SPDX-License-Identifier: MIT

"""AMD-FP4 kernels.

The AMD-FP4 recipe is **NVFP4-style E2M1 storage with a UE5M3 inner-block
scale and an FP32 per-tensor outer scale**, aligned with GFXIPARCH-2067
§19.10 / OCP E5M3. In other words: same micro-block layout as NVFP4,
strictly wider inner-scale dynamic range (max normal 114688 vs 448 for
E4M3) at the same mantissa precision.

This sub-package is a peer of :mod:`alto.kernels.fp4.nvfp4`; both build
on the shared body in :mod:`alto.kernels.fp4.outer_scaled_fp4`. The split
exists so:

* recipe-level call sites (``convert_to_amdfp4``, ``AMDFP4LinearFunction``,
``_quantize_then_amdfp4_scaled_grouped_mm``) pin ``scale_format`` to
``"ue5m3"`` exactly once and never expose the dtype switch to user code;
* dispatch / tracing tooling that keys on the ATen op id sees AMD-FP4 as
a distinct operator (``alto::convert_to_amdfp4``) and can route /
benchmark it independently from NVFP4.
"""

from .amdfp_grouped_gemm import (
ALIGN_SIZE_M,
AMDFP4GroupedGEMM,
_quantize_then_amdfp4_scaled_grouped_mm,
amdfp4_grouped_gemm,
)
from .amdfp_linear import AMDFP4LinearFunction, _to_amdfp4_then_scaled_mm
from .amdfp_quantization import (
BLOCK_SIZE_DEFAULT,
convert_from_amdfp4,
convert_to_amdfp4,
)

__all__ = (
"ALIGN_SIZE_M",
"AMDFP4GroupedGEMM",
"AMDFP4LinearFunction",
"BLOCK_SIZE_DEFAULT",
"_quantize_then_amdfp4_scaled_grouped_mm",
"_to_amdfp4_then_scaled_mm",
"amdfp4_grouped_gemm",
"convert_from_amdfp4",
"convert_to_amdfp4",
)
38 changes: 38 additions & 0 deletions alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
#
# SPDX-License-Identifier: MIT

"""AMD-FP4 grouped GEMM surface (UE5M3 inner scale).

Like :mod:`alto.kernels.fp4.amdfp4.amdfp_linear`, this is a thin layer
on top of the NVFP4 grouped-GEMM machinery: the autograd function and
all loop / native-dispatch helpers are reused as-is and the
``scale_format`` literal is hard-pinned to ``"ue5m3"`` at the AMD-FP4
boundary.
"""

from alto.kernels.fp4.nvfp4.nvfp_grouped_gemm import (
ALIGN_SIZE_M,
NVFP4GroupedGEMM,
)

from .functional import (
_quantize_then_amdfp4_scaled_grouped_mm,
amdfp4_grouped_gemm,
)


# Re-export the autograd Function under an AMD-FP4 name; same alias
# rationale as ``AMDFP4LinearFunction`` -- the underlying
# ``torch.autograd.Function`` already routes ``scale_format``, so we
# avoid creating a parallel marker subclass that would only complicate
# isinstance checks at the dispatch layer.
AMDFP4GroupedGEMM = NVFP4GroupedGEMM


__all__ = (
"ALIGN_SIZE_M",
"AMDFP4GroupedGEMM",
"_quantize_then_amdfp4_scaled_grouped_mm",
"amdfp4_grouped_gemm",
)
Loading