diff --git a/alto/kernels/dispatch/config.py b/alto/kernels/dispatch/config.py index 314f46d..15f249b 100644 --- a/alto/kernels/dispatch/config.py +++ b/alto/kernels/dispatch/config.py @@ -9,7 +9,16 @@ @dataclass(unsafe_hash=True, kw_only=True, slots=True) class TrainingOpConfig: - precision: Literal["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4"] + precision: Literal["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4", "amdfp4"] + """ + Quantization recipe family. ``"nvfp4"`` and ``"amdfp4"`` are peers with + the same micro-block layout but different inner-scale grids: + + * ``"nvfp4"`` — E4M3 inner scale by default (NVFP4 spec ``s_block``). + * ``"amdfp4"`` — UE5M3 inner scale (GFXIPARCH-2067 §19.10 / OCP E5M3, + max normal 114688, NaN at 0xFF). See ``inner_scale_format`` below + for how this selection is forwarded to the kernels. + """ use_2dblock_x: bool use_2dblock_w: bool use_hadamard: bool @@ -34,5 +43,45 @@ class TrainingOpConfig: * NVFP4: not implemented """ + inner_scale_format: Literal["e4m3", "ue5m3"] = "e4m3" + """ + Per-block inner-scale dtype used by NVFP4 / AMD-FP4 (orthogonal to ``precision``). + + * ``"e4m3"`` — NVFP4 default: signed FP8 E4M3, max 448. + * ``"ue5m3"`` — AMD-FP4: unsigned 8-bit (5 exp + 3 mant, NaN code 0xFF + per GFXIPARCH-2067 §19.10, no sign bit, no Inf encoding), max normal + 114688 at code 0xFE; ~256× wider dynamic range than E4M3 at the same + mantissa precision. + + Selected when ``precision in {"nvfp4", "amdfp4"}``; ignored for + MXFP4 / MXFP8. When ``precision == "amdfp4"`` this field is forced to + ``"ue5m3"`` in :meth:`__post_init__` (AMD-FP4 = NVFP4 spec + UE5M3 + inner scale). + + Locked design choice (see ``docs/amd-fp4/agent-handoff.md`` D3): keep the + ``precision`` literal carrying the recipe and let this orthogonal field + carry the inner-scale dtype, so additional dtypes (E8M3, UE4M4, …) can + be added by extending this Literal alone. + """ + + def __post_init__(self) -> None: + # Pin the AMD-FP4 recipe to its UE5M3 inner-scale dtype. Allowed + # caller-side aliases: + # - ``inner_scale_format`` left at its dataclass default ``"e4m3"`` + # (i.e. caller didn't touch it; we silently override), + # - ``inner_scale_format="ue5m3"`` explicitly (no-op). + # Any other value is rejected so AMD-FP4 cannot accidentally be + # routed to the E4M3 grid by an out-of-date caller. + if self.precision == "amdfp4": + if self.inner_scale_format == "e4m3": + # Not a frozen dataclass, so plain assignment is fine even with + # ``slots=True`` (the slot descriptor accepts writes). + self.inner_scale_format = "ue5m3" + elif self.inner_scale_format != "ue5m3": + raise ValueError( + f"precision='amdfp4' requires inner_scale_format='ue5m3'; " + f"got inner_scale_format={self.inner_scale_format!r}" + ) + torch.serialization.add_safe_globals([TrainingOpConfig]) diff --git a/alto/kernels/dispatch/conversion.py b/alto/kernels/dispatch/conversion.py index 6127ce8..1ff1ff0 100644 --- a/alto/kernels/dispatch/conversion.py +++ b/alto/kernels/dispatch/conversion.py @@ -31,7 +31,9 @@ def _get_tensor_cls_for_config(config: TrainingOpConfig) -> Type[torch.Tensor]: return MXFP4TrainingWeightWrapperTensor elif config.precision in ("mxfp8_e4m3", "mxfp8_e5m2"): return MXFP8TrainingWeightWrapperTensor - elif config.precision == "nvfp4": + elif config.precision in ("nvfp4", "amdfp4"): + # AMD-FP4 reuses the NVFP4 wrapper; it re-dispatches on + # ``config.precision`` internally to pin the UE5M3 inner grid. return NVFP4TrainingWeightWrapperTensor else: raise ValueError(f"Unsupported training op config: {config}") diff --git a/alto/kernels/dispatch/tensor.py b/alto/kernels/dispatch/tensor.py index a13d8a5..8b5871e 100644 --- a/alto/kernels/dispatch/tensor.py +++ b/alto/kernels/dispatch/tensor.py @@ -15,6 +15,10 @@ from torchao.utils import TorchAOBaseTensor from torchtitan.tools.logging import logger +from alto.kernels.fp4.amdfp4 import ( + _quantize_then_amdfp4_scaled_grouped_mm, + _to_amdfp4_then_scaled_mm, +) from alto.kernels.fp4.mxfp4.mxfp_linear import _to_mxfp4_then_scaled_mm from alto.kernels.fp4.mxfp4.mxfp_grouped_gemm.functional import _quantize_then_mxfp_scaled_grouped_mm from alto.kernels.fp4.nvfp4.nvfp_linear import _to_nvfp4_then_scaled_mm @@ -295,31 +299,32 @@ def __torch_function__(cls, func, types, args, kwargs={}): class NVFP4TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor): - """Weight tensor subclass that routes F.linear calls through NVFP4LinearFunction. - - The dispatch mechanism follows the same pattern as MXFP4TrainingWeightWrapperTensor: - when PyTorch sees a linear/mm operation whose *weight* (B) is wrapped in this - subclass, __torch_function__ is triggered and the NVFP4 QDQ path is used instead - of the standard BF16 matmul. + """Weight tensor subclass that routes F.linear calls through the NVFP4-family + autograd functions (covering both NVFP4 and AMD-FP4 recipes). + + The dispatch mechanism follows the same pattern as + ``MXFP4TrainingWeightWrapperTensor``: when PyTorch sees a linear/mm + operation whose *weight* (B) is wrapped in this subclass, + ``__torch_function__`` is triggered and a low-precision QDQ path is used + instead of the standard BF16 matmul. + + Recipe selection follows ``config.precision``: + + * ``"nvfp4"`` → ``_to_nvfp4_then_scaled_mm`` / + ``_quantize_then_nvfp4_scaled_grouped_mm``, with the inner-scale + grid taken from ``config.inner_scale_format``. This keeps the + historic surface (E4M3 default, optional UE5M3 via the orthogonal + ``inner_scale_format`` field) intact. + * ``"amdfp4"`` → ``_to_amdfp4_then_scaled_mm`` / + ``_quantize_then_amdfp4_scaled_grouped_mm``, the AMD-FP4 thin + wrappers that hard-pin the inner grid to UE5M3. ``__post_init__`` + on ``TrainingOpConfig`` already forces + ``inner_scale_format='ue5m3'`` here, so the two paths are + numerically equivalent — the split exists so call traces / + profilers can identify the recipe by Python frame. This means model code never needs to change — the training precision is - controlled purely by wrapping the weight parameters via swap_params(). - - Config fields used: - use_2dblock_x – 2D block scaling on activations - use_2dblock_w – 2D block scaling on weights (mirrors the axis-invariant view) - use_sr_grad – stochastic rounding on gradient quantization - use_hadamard – wgrad-path Hadamard rotation (mirrors MXFP4 behaviour) - use_dge – differentiable gradient estimator on wgrad (mirrors MXFP4) - two_level_scaling – two-level NVFP4 scaling (global × per-block) - - Supported ops: - linear / mm / addmm / matmul – routed through ``_to_nvfp4_then_scaled_mm`` - _grouped_mm – routed through - ``_quantize_then_nvfp4_scaled_grouped_mm`` - (2D × 3D with ``offs`` for MoE routed - experts) - + controlled purely by wrapping the weight parameters via ``swap_params()``. """ @classmethod @@ -338,13 +343,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): assert bias is None, "Bias is not supported for grouped_mm" config = B.config - assert config.precision == "nvfp4", ( - f"expected TrainingOpConfig with precision=nvfp4, got {config.precision}" + assert config.precision in ("nvfp4", "amdfp4"), ( + "expected TrainingOpConfig with precision in {'nvfp4','amdfp4'}, " + f"got {config.precision}" ) - return _quantize_then_nvfp4_scaled_grouped_mm( - A, - B, + common_kwargs = dict( offs=offs, use_2dblock_x=config.use_2dblock_x, use_2dblock_w=config.use_2dblock_w, @@ -353,6 +357,17 @@ def __torch_function__(cls, func, types, args, kwargs={}): use_hadamard=config.use_hadamard, use_dge=config.use_dge, ) + if config.precision == "amdfp4": + # Inner grid is hard-pinned to UE5M3 inside the wrapper. + return _quantize_then_amdfp4_scaled_grouped_mm(A, B, **common_kwargs) + return _quantize_then_nvfp4_scaled_grouped_mm( + A, + B, + # ``inner_scale_format`` is the NVFP4 inner-grid selector + # (E4M3 default, optional UE5M3); orthogonal to ``precision``. + scale_format=config.inner_scale_format, + **common_kwargs, + ) # linear / mm overrides elif func.__name__ in gemm_ops: @@ -367,8 +382,10 @@ def __torch_function__(cls, func, types, args, kwargs={}): assert isinstance(B, cls), (f"B should be a {cls.__name__} for func {func.__name__}") config = B.config - assert config.precision == "nvfp4", ( - f"expected TrainingOpConfig with precision=nvfp4, got {config.precision}") + assert config.precision in ("nvfp4", "amdfp4"), ( + "expected TrainingOpConfig with precision in {'nvfp4','amdfp4'}, " + f"got {config.precision}" + ) # Pass the wrapper tensor itself into the autograd function — # matching the MXFP4 path — so that any upstream subclass @@ -379,9 +396,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): # entry, so the autograd tape and downstream QDQ ops still see # plain tensors. W = B if trans_b else B.T - Y = _to_nvfp4_then_scaled_mm( - A, - W, + common_kwargs = dict( use_2dblock_x=config.use_2dblock_x, use_2dblock_w=config.use_2dblock_w, use_sr_grad=config.use_sr_grad, @@ -389,6 +404,15 @@ def __torch_function__(cls, func, types, args, kwargs={}): use_hadamard=config.use_hadamard, use_dge=config.use_dge, ) + if config.precision == "amdfp4": + Y = _to_amdfp4_then_scaled_mm(A, W, **common_kwargs) + else: + Y = _to_nvfp4_then_scaled_mm( + A, + W, + scale_format=config.inner_scale_format, + **common_kwargs, + ) if bias is not None: Y = Y + bias return Y diff --git a/alto/kernels/fp4/__init__.py b/alto/kernels/fp4/__init__.py index 1c4720f..04e4e5c 100644 --- a/alto/kernels/fp4/__init__.py +++ b/alto/kernels/fp4/__init__.py @@ -4,39 +4,57 @@ """Unified public entrypoint for FP4 kernels. -This module exposes the most commonly used public APIs for both FP4 families: +This module exposes the most commonly used public APIs for the FP4 family: -- shared E2M1 encode/decode primitives from ``fp4_common`` +- shared E2M1 encode/decode primitives from ``fp4_primitives`` - MXFP4 quantize/dequantize entrypoints -- NVFP4 quantize/dequantize entrypoints +- NVFP4 quantize/dequantize entrypoints (E4M3 inner scale) +- AMD-FP4 quantize/dequantize entrypoints (UE5M3 inner scale; same micro-block + layout as NVFP4, GFXIPARCH-2067 §19.10 / OCP E5M3 aligned) -Format-specific internals remain in their own subpackages: +Directory structure (two shared layers + three recipe packages):: -- ``alto.kernels.fp4.fp4_common`` -- ``alto.kernels.fp4.mxfp4`` -- ``alto.kernels.fp4.nvfp4`` + fp4/ + ├── fp4_primitives/ # lowest layer, shared by ALL recipes: + │ # E2M1 pack/quant/dequant + philox, scale-dtype + │ # casts (e4m3_ops / ue5m3_ops), grouped/MoE + │ # helpers, unwrap_weight_wrapper. + ├── outer_scaled_fp4/ # shared by NVFP4 + AMD-FP4 only: + │ # the two-level blockwise layout (per-tensor FP32 + │ # outer scale + FP32 float inner-grid). MXFP4's + │ # macro-block scaling is a separate impl and does + │ # NOT depend on this layer. + ├── mxfp4/ # MXFP4 recipe — independent (E8M0 inner, 1×32, + │ # native tl.dot_scaled); only uses fp4_primitives. + ├── nvfp4/ # NVFP4 recipe (E4M3 inner) over outer_scaled_fp4. + └── amdfp4/ # AMD-FP4 recipe (UE5M3 inner) — thin wrappers over + # nvfp4 / outer_scaled_fp4 with scale_format pinned. """ from importlib import import_module _LAZY_ATTRS = { - "fp4_common": (".fp4_common", None), + "fp4_primitives": (".fp4_primitives", None), + "outer_scaled_fp4": (".outer_scaled_fp4", None), "mxfp4": (".mxfp4", None), "nvfp4": (".nvfp4", None), - "dequantize_e2m1": (".fp4_common", "dequantize_e2m1"), - "generate_philox_randval_2x": (".fp4_common", "generate_philox_randval_2x"), - "quantize_e2m1": (".fp4_common", "quantize_e2m1"), + "amdfp4": (".amdfp4", None), + "dequantize_e2m1": (".fp4_primitives", "dequantize_e2m1"), + "generate_philox_randval_2x": (".fp4_primitives", "generate_philox_randval_2x"), + "quantize_e2m1": (".fp4_primitives", "quantize_e2m1"), "MXFP4_BLOCK_SIZE_DEFAULT": (".mxfp4.mxfp_quantization", "BLOCK_SIZE_DEFAULT"), "convert_from_mxfp4": (".mxfp4.mxfp_quantization", "convert_from_mxfp4"), "convert_to_mxfp4": (".mxfp4.mxfp_quantization", "convert_to_mxfp4"), - "NVFP4_BLOCK_SIZE_DEFAULT": (".nvfp4.nvfp_quantization", "BLOCK_SIZE_DEFAULT"), + "NVFP4_BLOCK_SIZE_DEFAULT": (".outer_scaled_fp4", "BLOCK_SIZE_DEFAULT"), "compute_dynamic_outer_scale": ( - ".nvfp4.nvfp_quantization", + ".outer_scaled_fp4", "compute_dynamic_outer_scale", ), "convert_from_nvfp4": (".nvfp4.nvfp_quantization", "convert_from_nvfp4"), "convert_to_nvfp4": (".nvfp4.nvfp_quantization", "convert_to_nvfp4"), + "convert_from_amdfp4": (".amdfp4.amdfp_quantization", "convert_from_amdfp4"), + "convert_to_amdfp4": (".amdfp4.amdfp_quantization", "convert_to_amdfp4"), } @@ -57,15 +75,19 @@ def __dir__(): __all__ = ( "MXFP4_BLOCK_SIZE_DEFAULT", "NVFP4_BLOCK_SIZE_DEFAULT", + "amdfp4", "compute_dynamic_outer_scale", + "convert_from_amdfp4", "convert_from_mxfp4", "convert_from_nvfp4", + "convert_to_amdfp4", "convert_to_mxfp4", "convert_to_nvfp4", "dequantize_e2m1", - "fp4_common", + "fp4_primitives", "generate_philox_randval_2x", "mxfp4", "nvfp4", + "outer_scaled_fp4", "quantize_e2m1", ) diff --git a/alto/kernels/fp4/amdfp4/__init__.py b/alto/kernels/fp4/amdfp4/__init__.py new file mode 100644 index 0000000..097dcbb --- /dev/null +++ b/alto/kernels/fp4/amdfp4/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""AMD-FP4 kernels. + +The AMD-FP4 recipe is **NVFP4-style E2M1 storage with a UE5M3 inner-block +scale and an FP32 per-tensor outer scale**, aligned with GFXIPARCH-2067 +§19.10 / OCP E5M3. In other words: same micro-block layout as NVFP4, +strictly wider inner-scale dynamic range (max normal 114688 vs 448 for +E4M3) at the same mantissa precision. + +This sub-package is a peer of :mod:`alto.kernels.fp4.nvfp4`; both build +on the shared body in :mod:`alto.kernels.fp4.outer_scaled_fp4`. The split +exists so: + +* recipe-level call sites (``convert_to_amdfp4``, ``AMDFP4LinearFunction``, + ``_quantize_then_amdfp4_scaled_grouped_mm``) pin ``scale_format`` to + ``"ue5m3"`` exactly once and never expose the dtype switch to user code; +* dispatch / tracing tooling that keys on the ATen op id sees AMD-FP4 as + a distinct operator (``alto::convert_to_amdfp4``) and can route / + benchmark it independently from NVFP4. +""" + +from .amdfp_grouped_gemm import ( + ALIGN_SIZE_M, + AMDFP4GroupedGEMM, + _quantize_then_amdfp4_scaled_grouped_mm, + amdfp4_grouped_gemm, +) +from .amdfp_linear import AMDFP4LinearFunction, _to_amdfp4_then_scaled_mm +from .amdfp_quantization import ( + BLOCK_SIZE_DEFAULT, + convert_from_amdfp4, + convert_to_amdfp4, +) + +__all__ = ( + "ALIGN_SIZE_M", + "AMDFP4GroupedGEMM", + "AMDFP4LinearFunction", + "BLOCK_SIZE_DEFAULT", + "_quantize_then_amdfp4_scaled_grouped_mm", + "_to_amdfp4_then_scaled_mm", + "amdfp4_grouped_gemm", + "convert_from_amdfp4", + "convert_to_amdfp4", +) diff --git a/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/__init__.py b/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/__init__.py new file mode 100644 index 0000000..616559a --- /dev/null +++ b/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""AMD-FP4 grouped GEMM surface (UE5M3 inner scale). + +Like :mod:`alto.kernels.fp4.amdfp4.amdfp_linear`, this is a thin layer +on top of the NVFP4 grouped-GEMM machinery: the autograd function and +all loop / native-dispatch helpers are reused as-is and the +``scale_format`` literal is hard-pinned to ``"ue5m3"`` at the AMD-FP4 +boundary. +""" + +from alto.kernels.fp4.nvfp4.nvfp_grouped_gemm import ( + ALIGN_SIZE_M, + NVFP4GroupedGEMM, +) + +from .functional import ( + _quantize_then_amdfp4_scaled_grouped_mm, + amdfp4_grouped_gemm, +) + + +# Re-export the autograd Function under an AMD-FP4 name; same alias +# rationale as ``AMDFP4LinearFunction`` -- the underlying +# ``torch.autograd.Function`` already routes ``scale_format``, so we +# avoid creating a parallel marker subclass that would only complicate +# isinstance checks at the dispatch layer. +AMDFP4GroupedGEMM = NVFP4GroupedGEMM + + +__all__ = ( + "ALIGN_SIZE_M", + "AMDFP4GroupedGEMM", + "_quantize_then_amdfp4_scaled_grouped_mm", + "amdfp4_grouped_gemm", +) diff --git a/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/functional.py b/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/functional.py new file mode 100644 index 0000000..883d576 --- /dev/null +++ b/alto/kernels/fp4/amdfp4/amdfp_grouped_gemm/functional.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""AMD-FP4 grouped-GEMM Python entrypoints. + +Mirrors :mod:`alto.kernels.fp4.nvfp4.nvfp_grouped_gemm.functional` but +pins ``scale_format='ue5m3'`` so neither the AMD-FP4 dispatch layer nor +training recipes need to thread the dtype literal through call sites. +""" + +from __future__ import annotations + +import torch + +from alto.kernels.fp4.nvfp4.nvfp_grouped_gemm import ( + _quantize_then_nvfp4_scaled_grouped_mm, + nvfp4_grouped_gemm, +) + + +def amdfp4_grouped_gemm( + inputs: torch.Tensor, # [M_total, K] + expert_weights: torch.Tensor, # [E, N, K] if trans_weights else [E, K, N] + expert_indices: torch.Tensor, # [M_total] int32 + *, + trans_weights: bool = True, + use_2dblock_x: bool = False, + use_2dblock_w: bool = False, + use_sr_grad: bool = True, + use_outer_scale: bool = False, + use_hadamard: bool = False, + use_dge: bool = False, +) -> torch.Tensor: + """AMD-FP4 (UE5M3 inner scale) Grouped GEMM with full autograd support. + + Same parameter contract as + :func:`alto.kernels.fp4.nvfp4.nvfp_grouped_gemm.nvfp4_grouped_gemm` + minus ``scale_format``, which is hard-pinned to ``"ue5m3"``. + """ + return nvfp4_grouped_gemm( + inputs, + expert_weights, + expert_indices, + trans_weights=trans_weights, + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=use_sr_grad, + use_outer_scale=use_outer_scale, + use_hadamard=use_hadamard, + use_dge=use_dge, + scale_format="ue5m3", + ) + + +def _quantize_then_amdfp4_scaled_grouped_mm( + A: torch.Tensor, + B: torch.Tensor, # [E, K, N] per dispatch convention + offs: torch.Tensor, + use_2dblock_x: bool, + use_2dblock_w: bool, + use_sr_grad: bool, + use_outer_scale: bool = False, + use_hadamard: bool = False, + use_dge: bool = False, +) -> torch.Tensor: + """AMD-FP4 dispatch-side variant of + :func:`alto.kernels.fp4.nvfp4.nvfp_grouped_gemm._quantize_then_nvfp4_scaled_grouped_mm`. + + ``scale_format`` is hard-pinned to ``"ue5m3"`` so the MoE dispatch + code path can route AMD-FP4 traffic by op surface alone. + """ + return _quantize_then_nvfp4_scaled_grouped_mm( + A, + B, + offs=offs, + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=use_sr_grad, + use_outer_scale=use_outer_scale, + use_hadamard=use_hadamard, + use_dge=use_dge, + scale_format="ue5m3", + ) diff --git a/alto/kernels/fp4/amdfp4/amdfp_linear.py b/alto/kernels/fp4/amdfp4/amdfp_linear.py new file mode 100644 index 0000000..5e4721b --- /dev/null +++ b/alto/kernels/fp4/amdfp4/amdfp_linear.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""AMD-FP4 linear autograd surface. + +The autograd function itself is identical to the NVFP4 path — +``NVFP4LinearFunction`` already accepts a ``scale_format`` parameter that +selects E4M3 vs UE5M3 inner scales. This module re-exposes it under an +AMD-FP4 name and provides a thin ``_to_amdfp4_then_scaled_mm`` helper +that pins ``scale_format='ue5m3'`` so user code never spells the dtype +literal. +""" + +from __future__ import annotations + +import torch + +from alto.kernels.fp4.nvfp4.nvfp_linear import ( + NVFP4LinearFunction, + _to_nvfp4_then_scaled_mm, +) + + +# Re-export the autograd function under an AMD-FP4 name. We deliberately +# do NOT subclass: ``torch.autograd.Function`` works through a custom +# metaclass + class methods, and a marker subclass would create a +# *separate* autograd record key without changing behaviour. Keep this +# as a simple alias so isinstance-checks against the NVFP4 type still +# match (the dispatch layer benefits from this single source of truth). +AMDFP4LinearFunction = NVFP4LinearFunction + + +def _to_amdfp4_then_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + use_2dblock_x: bool, + use_2dblock_w: bool, + use_sr_grad: bool, + use_outer_scale: bool = False, + use_hadamard: bool = False, + use_dge: bool = False, +) -> torch.Tensor: + """AMD-FP4 (UE5M3 inner scale) variant of ``_to_nvfp4_then_scaled_mm``. + + Mirrors the NVFP4 helper exactly, except ``scale_format`` is hard-pinned + to ``"ue5m3"`` so the AMD-FP4 dispatch layer / training recipes can + address the AMD-FP4 path by op surface alone. + """ + return _to_nvfp4_then_scaled_mm( + a, + b, + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=use_sr_grad, + use_outer_scale=use_outer_scale, + use_hadamard=use_hadamard, + use_dge=use_dge, + scale_format="ue5m3", + ) + + +__all__ = ( + "AMDFP4LinearFunction", + "_to_amdfp4_then_scaled_mm", +) diff --git a/alto/kernels/fp4/amdfp4/amdfp_quantization.py b/alto/kernels/fp4/amdfp4/amdfp_quantization.py new file mode 100644 index 0000000..1118fb5 --- /dev/null +++ b/alto/kernels/fp4/amdfp4/amdfp_quantization.py @@ -0,0 +1,135 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""AMD-FP4 (UE5M3 inner scale) public quantize / dequantize ops. + +AMD-FP4 = NVFP4 spec + UE5M3 inner-block scale + FP32 per-tensor outer +scale (GFXIPARCH-2067 §19.10 / OCP E5M3 aligned). The block-quant +implementation is identical to NVFP4 modulo the inner-grid choice and +is shared via :mod:`alto.kernels.fp4.outer_scaled_fp4`. This module only +adds the AMD-FP4 ATen op surface: + +* ``alto::convert_to_amdfp4`` -- pinned ``scale_format='ue5m3'`` +* ``alto::convert_from_amdfp4`` -- pinned ``scale_format='ue5m3'`` + +These ops are exposed alongside (not in place of) the NVFP4 ops. They +exist so dispatch / tracing / benchmarking tools can identify AMD-FP4 +calls by op id without inspecting ``scale_format`` kwargs. +""" + +from typing import Optional, Tuple + +import torch +from torch.library import triton_op + +from alto.kernels.fp4.outer_scaled_fp4 import BLOCK_SIZE_DEFAULT +from alto.kernels.fp4.outer_scaled_fp4.api import ( + _convert_from_fp4_blockwise, + _convert_to_fp4_blockwise, + _fake_convert_from_fp4_blockwise, + _fake_convert_to_fp4_blockwise, +) + + +# ``outer_scale`` is mutated in place (``copy_``) when ``update_outer_scale`` +# is True; declare it so Dynamo/functionalization invalidate stale caches. +@triton_op("alto::convert_to_amdfp4", mutates_args={"outer_scale"}) +def convert_to_amdfp4( + data_hp: torch.Tensor, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + update_outer_scale: bool = True, + use_sr: bool = False, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_asm: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a high-precision tensor to AMD-FP4 (E2M1 + UE5M3 inner scale). + + The signature mirrors :func:`alto.kernels.fp4.nvfp4.convert_to_nvfp4` + minus the ``scale_format`` parameter, which is hard-pinned to + ``"ue5m3"`` for this op. See + :func:`alto.kernels.fp4.outer_scaled_fp4.api._convert_to_fp4_blockwise` + for the full parameter contract. + """ + return _convert_to_fp4_blockwise( + data_hp, + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + update_outer_scale=update_outer_scale, + scale_format="ue5m3", + use_sr=use_sr, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_asm=use_asm, + ) + + +@triton_op("alto::convert_from_amdfp4", mutates_args={}) +def convert_from_amdfp4( + data_lp: torch.Tensor, + scales: torch.Tensor, + output_dtype: torch.dtype = torch.float32, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + use_asm: Optional[bool] = None, +) -> torch.Tensor: + """Dequantize AMD-FP4 (E2M1 + UE5M3 inner scale) back to high precision. + + Like :func:`convert_to_amdfp4`, this op pins ``scale_format='ue5m3'``; + the dequantization path itself only multiplies by the stored FP32 + scale, so the format pin is a contract / op-id signal rather than a + numerical switch. + """ + return _convert_from_fp4_blockwise( + data_lp, + scales, + output_dtype=output_dtype, + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + scale_format="ue5m3", + use_asm=use_asm, + ) + + +@convert_to_amdfp4.register_fake +def _fake_convert_to_amdfp4( + data_hp: torch.Tensor, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + update_outer_scale: bool = True, + use_sr: bool = False, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_asm: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _fake_convert_to_fp4_blockwise( + data_hp, block_size=block_size, axis=axis, is_2d_block=is_2d_block, + ) + + +@convert_from_amdfp4.register_fake +def _fake_convert_from_amdfp4( + data_lp: torch.Tensor, + scales: torch.Tensor, + output_dtype: torch.dtype = torch.float32, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + use_asm: Optional[bool] = None, +) -> torch.Tensor: + return _fake_convert_from_fp4_blockwise( + data_lp, output_dtype=output_dtype, axis=axis, + ) diff --git a/alto/kernels/fp4/fp4_common/__init__.py b/alto/kernels/fp4/fp4_common/__init__.py deleted file mode 100644 index 7e089a6..0000000 --- a/alto/kernels/fp4/fp4_common/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT - -from .grouped_utils import ( - DEFAULT_ALIGN_SIZE_M, - build_hadamard_transform_if_needed, - check_grouped_loop_contract, - group_ids_from_expert_indices, - reset_cdna4_grouped_backend_cache, - resolve_expert_indices, - use_cdna4_grouped_backend, -) -from .triton_fp4_ops import ( - dequantize_e2m1, - generate_philox_randval_2x, - make_dequantize_e2m1, - make_generate_philox_randval_2x, - make_quantize_e2m1, - quantize_e2m1, -) -from .tensor_wrappers import unwrap_weight_wrapper - -__all__ = ( - "DEFAULT_ALIGN_SIZE_M", - "build_hadamard_transform_if_needed", - "check_grouped_loop_contract", - "dequantize_e2m1", - "generate_philox_randval_2x", - "group_ids_from_expert_indices", - "make_dequantize_e2m1", - "make_generate_philox_randval_2x", - "make_quantize_e2m1", - "quantize_e2m1", - "reset_cdna4_grouped_backend_cache", - "resolve_expert_indices", - "unwrap_weight_wrapper", - "use_cdna4_grouped_backend", -) diff --git a/alto/kernels/fp4/fp4_primitives/__init__.py b/alto/kernels/fp4/fp4_primitives/__init__.py new file mode 100644 index 0000000..09f26bc --- /dev/null +++ b/alto/kernels/fp4/fp4_primitives/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""``fp4_primitives`` — lowest shared layer for ALL FP4 recipes. + +This package holds the format-agnostic building blocks every FP4 recipe +(MXFP4, NVFP4, AMD-FP4) reuses, with no assumption about block size, +inner-scale grid, or outer-scale policy: + +* E2M1 payload primitives (``triton_fp4_ops``): ``make_quantize_e2m1`` / + ``make_dequantize_e2m1`` + Philox RNG for stochastic rounding. +* Scale-dtype casts (``e4m3_ops`` / ``ue5m3_ops``): the FP8 inner-scale + encodings used by the NVFP4 family (NVFP4 → E4M3, AMD-FP4 → UE5M3). +* MoE / grouped-GEMM helpers (``grouped_utils``): routing, CDNA4 backend + probe, alignment contracts, Hadamard construction. +* ``unwrap_weight_wrapper`` (``tensor_wrappers``): strips dispatch weight + wrapper subclasses to a plain tensor at autograd boundaries. + +The NVFP4/AMD-FP4 two-level blockwise layout lives one layer up in +``alto.kernels.fp4.outer_scaled_fp4``; MXFP4 builds directly on this layer. +""" + +from .grouped_utils import ( + DEFAULT_ALIGN_SIZE_M, + build_hadamard_transform_if_needed, + check_grouped_loop_contract, + group_ids_from_expert_indices, + reset_cdna4_grouped_backend_cache, + resolve_expert_indices, + use_cdna4_grouped_backend, +) +from .e4m3_ops import ( + E4M3_EPS, + E4M3_MAX, + F8E4M3_MAX, + make_quantize_e4m3, + quantize_e4m3, + quantize_to_e4m3, + triton_quantize_to_e4m3, +) +from .triton_fp4_ops import ( + dequantize_e2m1, + generate_philox_randval_2x, + make_dequantize_e2m1, + make_generate_philox_randval_2x, + make_quantize_e2m1, + quantize_e2m1, +) +from .ue5m3_ops import ( + UE5M3_EPS, + UE5M3_EXP_BIAS, + UE5M3_MAX, + UE5M3_MIN_NORMAL, + UE5M3_NAN_CODE, + UE5M3_NUM_EXP_BITS, + UE5M3_NUM_MAN_BITS, + f32_to_ue5m3_uint8, + make_quantize_ue5m3, + quantize_to_ue5m3, + quantize_ue5m3, + triton_quantize_to_ue5m3, + ue5m3_uint8_to_f32, +) +from .tensor_wrappers import unwrap_weight_wrapper + +__all__ = ( + "DEFAULT_ALIGN_SIZE_M", + "E4M3_EPS", + "E4M3_MAX", + "F8E4M3_MAX", + "UE5M3_EPS", + "UE5M3_EXP_BIAS", + "UE5M3_MAX", + "UE5M3_MIN_NORMAL", + "UE5M3_NAN_CODE", + "UE5M3_NUM_EXP_BITS", + "UE5M3_NUM_MAN_BITS", + "build_hadamard_transform_if_needed", + "check_grouped_loop_contract", + "dequantize_e2m1", + "f32_to_ue5m3_uint8", + "generate_philox_randval_2x", + "group_ids_from_expert_indices", + "make_dequantize_e2m1", + "make_generate_philox_randval_2x", + "make_quantize_e4m3", + "make_quantize_e2m1", + "make_quantize_ue5m3", + "quantize_e2m1", + "quantize_e4m3", + "quantize_to_e4m3", + "quantize_to_ue5m3", + "quantize_ue5m3", + "triton_quantize_to_e4m3", + "triton_quantize_to_ue5m3", + "reset_cdna4_grouped_backend_cache", + "resolve_expert_indices", + "ue5m3_uint8_to_f32", + "unwrap_weight_wrapper", + "use_cdna4_grouped_backend", +) diff --git a/alto/kernels/fp4/fp4_primitives/e4m3_ops.py b/alto/kernels/fp4/fp4_primitives/e4m3_ops.py new file mode 100644 index 0000000..8d5aedb --- /dev/null +++ b/alto/kernels/fp4/fp4_primitives/e4m3_ops.py @@ -0,0 +1,112 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""E4M3 (float8_e4m3fn) inner-block scale primitives for NVFP4. + +NVFP4 stores per-block inner scales on the IEEE-style E4M3 grid (finite-only +``float8_e4m3fn``). Unlike UE5M3, PyTorch and Triton expose this format +natively, so the implementation is a thin clamp + cast wrapper rather than +custom bit-twiddling. + +Surface mirrors :mod:`ue5m3_ops` so ``scale_format`` dispatch can treat both +inner dtypes uniformly in Phase 1+. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +# Naming mirrors ``ue5m3_ops``; keep the historical ``F8E4M3_MAX`` alias used +# throughout ``nvfp_quantization.py``. +E4M3_NUM_EXP_BITS = 4 +E4M3_NUM_MAN_BITS = 3 +E4M3_EXP_BIAS = 7 # float8_e4m3fn exponent bias + +F8E4M3_MAX = 448.0 +E4M3_MAX = F8E4M3_MAX +E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny +E4M3_MIN_NORMAL = float(torch.finfo(torch.float8_e4m3fn).tiny) + + +def quantize_to_e4m3(x: torch.Tensor) -> torch.Tensor: + """Round FP32 to the E4M3 grid and return FP32 (idempotent on-grid values). + + NVFP4 inner scales are non-negative; values are clamped to + ``[E4M3_EPS, E4M3_MAX]`` before the ``float8_e4m3fn`` round-trip. + """ + assert x.dtype == torch.float32, f"quantize_to_e4m3 expects float32, got {x.dtype}" + return x.clamp(min=E4M3_EPS, max=E4M3_MAX).to(torch.float8_e4m3fn).to(torch.float32) + + +def f32_to_e4m3_uint8(x: torch.Tensor) -> torch.Tensor: + """Encode FP32 to raw ``float8_e4m3fn`` storage (uint8 view).""" + return quantize_to_e4m3(x).to(torch.float8_e4m3fn).view(torch.uint8) + + +def e4m3_uint8_to_f32(qx: torch.Tensor) -> torch.Tensor: + """Decode ``float8_e4m3fn`` bytes to FP32.""" + if qx.dtype == torch.uint8: + qx = qx.view(torch.float8_e4m3fn) + return qx.to(torch.float32) + + +def make_quantize_e4m3(): + """Return a ``@triton.jit`` function that snaps FP32 to the E4M3 grid.""" + + _E4M3_EPS_PY = E4M3_EPS + _E4M3_MAX_PY = E4M3_MAX + + @triton.jit + def _quantize_e4m3(x): + E4M3_EPS: tl.constexpr = _E4M3_EPS_PY + E4M3_MAX: tl.constexpr = _E4M3_MAX_PY + x = tl.minimum(tl.maximum(x, E4M3_EPS), E4M3_MAX) + return x.to(tl.float8e4nv).to(tl.float32) + + return _quantize_e4m3 + + +quantize_e4m3 = make_quantize_e4m3() + + +@triton.jit +def _e4m3_snap_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) + y = quantize_e4m3(x) + tl.store(out_ptr + offsets, y, mask=mask) + + +def triton_quantize_to_e4m3(x: torch.Tensor) -> torch.Tensor: + """Snap FP32 to the E4M3 grid on GPU (bit-identical to :func:`quantize_to_e4m3`).""" + assert x.is_cuda and x.dtype == torch.float32 + flat = x.contiguous().view(-1) + out = torch.empty_like(flat) + n = flat.numel() + block = 256 + grid = (triton.cdiv(n, block),) + _e4m3_snap_kernel[grid](flat, out, n, BLOCK_SIZE=block) + return out.view_as(x) + + +__all__ = ( + "E4M3_EPS", + "E4M3_EXP_BIAS", + "E4M3_MAX", + "E4M3_MIN_NORMAL", + "E4M3_NUM_EXP_BITS", + "E4M3_NUM_MAN_BITS", + "F8E4M3_MAX", + "e4m3_uint8_to_f32", + "f32_to_e4m3_uint8", + "make_quantize_e4m3", + "quantize_e4m3", + "quantize_to_e4m3", + "triton_quantize_to_e4m3", +) diff --git a/alto/kernels/fp4/fp4_common/grouped_utils.py b/alto/kernels/fp4/fp4_primitives/grouped_utils.py similarity index 98% rename from alto/kernels/fp4/fp4_common/grouped_utils.py rename to alto/kernels/fp4/fp4_primitives/grouped_utils.py index 26a53cb..870f2a4 100644 --- a/alto/kernels/fp4/fp4_common/grouped_utils.py +++ b/alto/kernels/fp4/fp4_primitives/grouped_utils.py @@ -150,7 +150,7 @@ def build_hadamard_transform_if_needed( assert not use_2dblock_x, ( "Hadamard transform can only be applied when use_2dblock_x=False." ) - # Local import to avoid a module-level cycle between fp4_common and the + # Local import to avoid a module-level cycle between fp4_primitives and the # hadamard package (which itself may pull in fp4-specific utilities). from alto.kernels.hadamard_transform import HadamardFactory with torch.no_grad(): diff --git a/alto/kernels/fp4/fp4_common/tensor_wrappers.py b/alto/kernels/fp4/fp4_primitives/tensor_wrappers.py similarity index 100% rename from alto/kernels/fp4/fp4_common/tensor_wrappers.py rename to alto/kernels/fp4/fp4_primitives/tensor_wrappers.py diff --git a/alto/kernels/fp4/fp4_common/triton_fp4_ops.py b/alto/kernels/fp4/fp4_primitives/triton_fp4_ops.py similarity index 100% rename from alto/kernels/fp4/fp4_common/triton_fp4_ops.py rename to alto/kernels/fp4/fp4_primitives/triton_fp4_ops.py diff --git a/alto/kernels/fp4/fp4_primitives/ue5m3_ops.py b/alto/kernels/fp4/fp4_primitives/ue5m3_ops.py new file mode 100644 index 0000000..ad8f9bf --- /dev/null +++ b/alto/kernels/fp4/fp4_primitives/ue5m3_ops.py @@ -0,0 +1,482 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""UE5M3 (unsigned 5-exp / 3-mant, NaN-aware) dtype primitives. + +UE5M3 is the inner-block scale format used by the AMD-FP4 recipe (NVFP4 +1x16 micro-block layout + UE5M3 inner scale + FP32 per-tensor outer +scale). It trades the sign bit of E4M3 (which is unused for scales, +which are always >= 0) for one extra exponent bit, **doubling the +exponent range at zero precision cost** (mantissa width stays at 3). + +Specification (committed, GFXIPARCH-2067 §19.10 OCP E5M3 aligned): + +* bit layout: ``[7:3] = exponent (5b)``, ``[2:0] = mantissa (3b)``; + no sign bit. +* exponent bias: ``2^(5-1) - 1 = 15``. +* code 0x00 -> +0.0 +* codes 0x01..0x07 -> subnormals: ``value = (mant/8) * 2^-14`` +* codes 0x08..0xFE -> normals: ``value = (1 + mant/8) * 2^(exp-15)``; + ``0xFE -> (1 + 6/8) * 2^16 = 114688.0`` (UE5M3_MAX, max normal). +* code 0xFF -> NaN (reserved). +* Inf encoding: **none**. ±Inf, NaN, and any value strictly greater + than ``UE5M3_MAX`` encode to ``0xFF`` (NaN code). Finite negative + inputs clamp to 0 (UE5M3 is unsigned). + +Decision history: + +* D2 (original, finite-only): ``0xFF -> 122880`` saturate-on-overflow. + Superseded; broke bit-level alignment with GFXIPARCH-2067 §19.10. +* D2' (current): ``0xFE`` is max normal, ``0xFF`` is NaN; NaN + propagation into downstream GEMM scales is prevented by a defense + layer in :mod:`alto.kernels.fp4.nvfp4.nvfp_quantization` and + :mod:`tests.unittest.nvfp4.utils._quantize_inner_scale` (mirrors + TransformerEngine / vLLM / TRT-LLM "caller-side sanitize" pattern). + +This module exposes two surfaces: + +* PyTorch reference (this is the bit-level oracle): + :func:`f32_to_ue5m3_uint8`, :func:`ue5m3_uint8_to_f32`, + :func:`quantize_to_ue5m3` (round-trip convenience). +* Triton ``@triton.jit`` factories used by the NVFP4 quantization + kernels: :func:`make_quantize_ue5m3` (returns an in-kernel + function that snaps an fp32 to the UE5M3 grid). + +The PyTorch and Triton paths are required to be bit-identical on the +fp32 snapped value across the full 256-code grid; see +``tests/unittest/ue5m3/`` for the oracle. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +UE5M3_NUM_EXP_BITS = 5 +UE5M3_NUM_MAN_BITS = 3 +UE5M3_EXP_BIAS = (1 << (UE5M3_NUM_EXP_BITS - 1)) - 1 # 15 + +# Max normal (code 0xFE per GFXIPARCH-2067 §19.10): +# (1 + 6/8) * 2^(31-15) = 1.75 * 65536 = 114688.0 +UE5M3_MAX = 1.75 * (2.0 ** (((1 << UE5M3_NUM_EXP_BITS) - 1) - UE5M3_EXP_BIAS)) + +# Smallest positive normal (code 0x08): 2^(1-15) = 2^-14 +UE5M3_MIN_NORMAL = 2.0 ** (1 - UE5M3_EXP_BIAS) + +# Smallest positive subnormal (code 0x01): (1/8) * 2^-14 = 2^-17 +UE5M3_EPS = (1.0 / (1 << UE5M3_NUM_MAN_BITS)) * UE5M3_MIN_NORMAL + +# Reserved NaN code (per GFXIPARCH-2067 §19.10). +UE5M3_NAN_CODE = 0xFF +# Quiet-NaN bit pattern used by the decoder when emitting fp32 NaN. +_FP32_QUIET_NAN_BITS = 0x7FC00000 + +# Sanity: the chosen FP32 div-by-zero floor used by NVFP4 outer-scale code +# must keep ``outer_scale * UE5M3_EPS`` inside FP32 normal range so that the +# effective per-block divisor never flushes under FTZ. We re-verify here +# at import time so any future floor change breaks loudly. +_FP32_MIN_NORMAL = torch.finfo(torch.float32).tiny +_OUTER_SCALE_DIVZERO_FLOOR = 1.0e-30 # mirrors nvfp_quantization._OUTER_SCALE_DIVZERO_FLOOR +assert _OUTER_SCALE_DIVZERO_FLOOR * UE5M3_EPS > _FP32_MIN_NORMAL, ( + "UE5M3_EPS * _OUTER_SCALE_DIVZERO_FLOOR falls below FP32 min normal; " + "the NVFP4 outer-scale floor in nvfp_quantization.py must be raised." +) +del _FP32_MIN_NORMAL + + +# Internal float32 IEEE-754 layout constants +_F32_EXP_BITS = 8 +_F32_MAN_BITS = 23 +_F32_EXP_BIAS = (1 << (_F32_EXP_BITS - 1)) - 1 # 127 + + +# --------------------------------------------------------------------------- +# PyTorch reference -- bit-level oracle. +# Inputs/outputs match the Triton path bit-for-bit on the fp32 snapped value. +# --------------------------------------------------------------------------- + +def f32_to_ue5m3_uint8(x: torch.Tensor) -> torch.Tensor: + """Encode a non-negative FP32 tensor to packed UE5M3 (uint8). + + Behavior (D2', GFXIPARCH-2067 §19.10): + * round-to-nearest-even (ties to even via magic-adder trick) + * ``UE5M3_MAX`` (114688) encodes to ``0xFE`` (max normal) + * NaN, ±Inf, and finite values strictly greater than ``UE5M3_MAX`` + encode to ``0xFF`` (NaN code, per spec) + * finite negative inputs clamp to 0 (unsigned format); -Inf + counts as a special and encodes to ``0xFF`` + + Args: + x: float32 tensor (CPU or CUDA). + + Returns: + uint8 tensor of identical shape; each element holds a UE5M3 code. + """ + assert x.dtype == torch.float32, f"f32_to_ue5m3_uint8 expects float32, got {x.dtype}" + + # ---- Identify "special" inputs that must encode to the reserved NaN code + # 0xFF: NaN, ±Inf, and finite values that overflow the normal range. + # We short-circuit them to 0 so the RTNE bit-magic only sees well-formed + # non-negative finite values in [0, UE5M3_MAX], then re-tag them at the end. + is_nan_or_inf = torch.isnan(x) | torch.isinf(x) + is_overflow = x > UE5M3_MAX + is_special = is_nan_or_inf | is_overflow + + x = torch.where(is_special, torch.zeros_like(x), x) + x = torch.clamp(x, min=0.0, max=UE5M3_MAX) + + # Saturate mask AFTER clamp: items pinned exactly at UE5M3_MAX (= 0xFE) + # must encode to the max-normal code, NOT to the NaN code. + saturate_mask = x >= UE5M3_MAX + denormal_mask = (~saturate_mask) & (x < UE5M3_MIN_NORMAL) + normal_mask = ~(saturate_mask | denormal_mask) + + # ---- Normal path: magic-adder RTNE on the lower (23-3) = 20 bits. + # + # val_to_add = ((EXP_BIAS_LP - EXP_BIAS_F32) << 23) <-- exponent rebias + # + magic_adder <-- round-to-nearest + # + mant_odd <-- tie-to-even + # out = ((x_int + val_to_add) >> (23 - 3)) <-- truncate + # + # This is the standard textbook trick (see Computer Organization and + # Design, RISC-V edition, Chapter 3.5) and mirrors what + # ``alto.kernels.fp4.testing_utils._f32_to_floatx_unpacked`` does for + # the signed FloatX path, minus the sign-bit handling. + magic_adder = (1 << (_F32_MAN_BITS - UE5M3_NUM_MAN_BITS - 1)) - 1 # 524287 + + x_int = x.view(torch.int32) + mant_odd = (x_int >> (_F32_MAN_BITS - UE5M3_NUM_MAN_BITS)) & 1 + val_to_add = ((UE5M3_EXP_BIAS - _F32_EXP_BIAS) << _F32_MAN_BITS) + magic_adder + normal_x = x_int + val_to_add + mant_odd + normal_x = (normal_x >> (_F32_MAN_BITS - UE5M3_NUM_MAN_BITS)).to(torch.uint8) + + # ---- Denormal path: align-by-add trick. + # + # Adding ``denorm_mask_float`` shifts the input so that the bits + # representing the subnormal mantissa land exactly where UE5M3 wants + # them; subtracting the integer of the same value isolates those bits. + denorm_exp = (_F32_EXP_BIAS - UE5M3_EXP_BIAS) + (_F32_MAN_BITS - UE5M3_NUM_MAN_BITS) + 1 + denorm_mask_int = denorm_exp << _F32_MAN_BITS + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32).item() + + denormal_x = (x + denorm_mask_float).view(torch.int32) - denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # ---- Compose: max-normal 0xFE is the default for the saturate path; + # overwrite for normal/denormal paths; finally re-tag special inputs to + # the NaN code 0xFF (D2', spec-aligned). + out = torch.full_like(x, 0xFE, dtype=torch.uint8) + out = torch.where(normal_mask, normal_x, out) + out = torch.where(denormal_mask, denormal_x, out) + nan_code = torch.full_like(out, UE5M3_NAN_CODE) + out = torch.where(is_special, nan_code, out) + return out + + +def ue5m3_uint8_to_f32(qx: torch.Tensor) -> torch.Tensor: + """Decode a UE5M3 uint8 tensor to FP32. + + This is the canonical decoder; the value table is (D2', spec-aligned): + + code 0x00 -> 0.0 + codes 0x01..0x07 -> (mant/8) * 2^-14 + codes 0x08..0xFE -> (1 + mant/8) * 2^(exp-15) + code 0xFF -> NaN (reserved per GFXIPARCH-2067 §19.10) + + Args: + qx: uint8 tensor. + + Returns: + float32 tensor of identical shape. + """ + assert qx.dtype == torch.uint8, f"ue5m3_uint8_to_f32 expects uint8, got {qx.dtype}" + + exp = (qx >> UE5M3_NUM_MAN_BITS).to(torch.int32) # [0, 31] + mant = (qx & ((1 << UE5M3_NUM_MAN_BITS) - 1)).to(torch.int32) + + nan_code_mask = qx == UE5M3_NAN_CODE + zero_mask = qx == 0 + denormal_mask = (exp == 0) & (mant != 0) + normal_mask = (exp > 0) & ~nan_code_mask + + # Normal path: rebias exponent into FP32, shift mantissa. + exp_f32 = (exp - UE5M3_EXP_BIAS + _F32_EXP_BIAS) << _F32_MAN_BITS + mant_f32 = mant << (_F32_MAN_BITS - UE5M3_NUM_MAN_BITS) + normal_int = exp_f32 | mant_f32 + + # Denormal path: with 3-bit mantissa we enumerate the seven non-zero codes + # directly; this avoids the loop-based approach of the signed FloatX helper + # for clarity at this small size. + # value(mant) = (mant / 8) * 2^-14 + # = mant * 2^-17 + # In FP32 bits, that's: + # mant == 1: 2^-17 = (127 - 17) << 23 = 0x37000000 + # mant == 2: 2^-16 = (127 - 16) << 23 = 0x37800000 + # mant == 3: 1.5 * 2^-16 = ((127-16) << 23) | (1 << 22) = 0x37C00000 + # mant == 4: 2^-15 = (127 - 15) << 23 = 0x38000000 + # mant == 5: 1.25 * 2^-15 = ((127-15) << 23) | (1 << 21) = 0x38200000 + # mant == 6: 1.5 * 2^-15 = ((127-15) << 23) | (2 << 21) = 0x38400000 + # mant == 7: 1.75 * 2^-15 = ((127-15) << 23) | (3 << 21) = 0x38600000 + denorm_table = torch.tensor( + [0x00000000, # mant=0 (unreachable on denormal path; placeholder) + 0x37000000, 0x37800000, 0x37C00000, + 0x38000000, 0x38200000, 0x38400000, 0x38600000], + dtype=torch.int32, + device=qx.device, + ) + denormal_int = denorm_table[mant] + + result = torch.where(normal_mask, normal_int, torch.zeros_like(normal_int)) + result = torch.where(denormal_mask, denormal_int, result) + result = torch.where(zero_mask, torch.zeros_like(result), result) + # Spec-aligned NaN exit: code 0xFF decodes to a quiet FP32 NaN. + result = torch.where( + nan_code_mask, + torch.full_like(result, _FP32_QUIET_NAN_BITS), + result, + ) + + return result.view(torch.float32) + + +def quantize_to_ue5m3(x: torch.Tensor) -> torch.Tensor: + """Round an FP32 tensor to the UE5M3 grid, returning FP32. + + Equivalent to ``ue5m3_uint8_to_f32(f32_to_ue5m3_uint8(x))`` and + therefore idempotent: ``quantize_to_ue5m3(quantize_to_ue5m3(x)) == + quantize_to_ue5m3(x)`` bit-for-bit. + + This is the surface used by the NVFP4 quant kernel's inner-scale + rounding step (the AMD-FP4 analogue of ``inner.to(torch.float8_e4m3fn) + .to(torch.float32)``). + """ + return ue5m3_uint8_to_f32(f32_to_ue5m3_uint8(x)) + + +# --------------------------------------------------------------------------- +# Triton ``@triton.jit`` factories. +# +# The factory pattern mirrors :mod:`alto.kernels.fp4.fp4_primitives.triton_fp4_ops` +# (``make_quantize_e2m1``) so call-sites can swap one for another without +# changing the surrounding kernel skeleton. +# +# The returned function takes a positive FP32 tile and returns the +# UE5M3-snapped FP32 tile. No uint8 surface is exposed inside Triton +# because the only consumer (``_calculate_nvfp4_scales``) needs the +# snapped fp32 value as a divisor. +# --------------------------------------------------------------------------- + +def _ue5m3_triton_constants() -> dict: + """Pre-compute all Triton-side magic numbers in pure Python. + + Triton's ``tl.constexpr`` slots accept Python ints / floats but not the + result of in-kernel ops (``tl.cast`` etc.), so we resolve the FP32 + align-by-add value here and pass it in as a float constant. + """ + ebits = UE5M3_NUM_EXP_BITS + mbits = UE5M3_NUM_MAN_BITS + exp_bias_lp = UE5M3_EXP_BIAS + f32_man_bits = _F32_MAN_BITS + f32_exp_bias = _F32_EXP_BIAS + + denorm_exp = (f32_exp_bias - exp_bias_lp) + (f32_man_bits - mbits) + 1 # 133 + denorm_mask_int = denorm_exp << f32_man_bits # 0x42800000 + # Bitcast 0x42800000 -> fp32: sign=0, exp=133, mant=0 -> 1.0 * 2^(133-127) = 64.0 + denorm_mask_float = float( + torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32).item() + ) + + # Signed val_to_add for the normal-path rebias + RTNE adder; this is an + # int32 quantity that wraps correctly when added to a uint32 bit pattern. + magic_adder = (1 << (f32_man_bits - mbits - 1)) - 1 # 524287 + val_to_add = ((exp_bias_lp - f32_exp_bias) << f32_man_bits) + magic_adder + + return { + "EBITS": ebits, + "MBITS": mbits, + "EXP_BIAS_LP": exp_bias_lp, + "MBITS_F32": f32_man_bits, + "EXP_BIAS_F32": f32_exp_bias, + "MAX_VAL": UE5M3_MAX, + "MIN_NORMAL": UE5M3_MIN_NORMAL, + "MAGIC_ADDER": magic_adder, + "VAL_TO_ADD": val_to_add, + "DENORM_MASK_INT": denorm_mask_int, + "DENORM_MASK_FLOAT": denorm_mask_float, + "NAN_CODE": UE5M3_NAN_CODE, + "NAN_BITS": _FP32_QUIET_NAN_BITS, + } + + +# Resolved once at import time -- these values never change. +_UE5M3_CONSTS = _ue5m3_triton_constants() +# Sanity: 64.0 is the only correct value for the align-by-add float. +assert _UE5M3_CONSTS["DENORM_MASK_FLOAT"] == 64.0, _UE5M3_CONSTS + + +def make_quantize_ue5m3(): + """Return a ``@triton.jit`` function that snaps FP32 to the UE5M3 grid. + + Returned signature:: + + _quantize_ue5m3(x: tl.tensor[fp32]) -> tl.tensor[fp32] + + Semantics are bit-identical to :func:`quantize_to_ue5m3`. + """ + + # Capture as Python locals so each tl.constexpr below can read a plain int / float. + _MBITS_PY = _UE5M3_CONSTS["MBITS"] + _EXP_BIAS_LP_PY = _UE5M3_CONSTS["EXP_BIAS_LP"] + _MBITS_F32_PY = _UE5M3_CONSTS["MBITS_F32"] + _EXP_BIAS_F32_PY = _UE5M3_CONSTS["EXP_BIAS_F32"] + _MAX_VAL_PY = _UE5M3_CONSTS["MAX_VAL"] + _MIN_NORMAL_PY = _UE5M3_CONSTS["MIN_NORMAL"] + _VAL_TO_ADD_PY = _UE5M3_CONSTS["VAL_TO_ADD"] + _DENORM_MASK_INT_PY = _UE5M3_CONSTS["DENORM_MASK_INT"] + _DENORM_MASK_FLOAT_PY = _UE5M3_CONSTS["DENORM_MASK_FLOAT"] + _NAN_CODE_PY = _UE5M3_CONSTS["NAN_CODE"] + _NAN_BITS_PY = _UE5M3_CONSTS["NAN_BITS"] + + @triton.jit + def _quantize_ue5m3(x): + MBITS: tl.constexpr = _MBITS_PY + EXP_BIAS_LP: tl.constexpr = _EXP_BIAS_LP_PY + MBITS_F32: tl.constexpr = _MBITS_F32_PY + EXP_BIAS_F32: tl.constexpr = _EXP_BIAS_F32_PY + MAX_VAL: tl.constexpr = _MAX_VAL_PY + MIN_NORMAL: tl.constexpr = _MIN_NORMAL_PY + VAL_TO_ADD: tl.constexpr = _VAL_TO_ADD_PY + DENORM_MASK_INT: tl.constexpr = _DENORM_MASK_INT_PY + DENORM_MASK_FLOAT: tl.constexpr = _DENORM_MASK_FLOAT_PY + NAN_CODE: tl.constexpr = _NAN_CODE_PY + NAN_BITS: tl.constexpr = _NAN_BITS_PY + + # ---- Identify "special" inputs that must encode to the reserved + # NaN code 0xFF: NaN, ±Inf, and finite values strictly above MAX. + # * ``x != x`` catches NaN + # * ``x * 0 != 0`` catches NaN OR ±Inf (Inf*0 == NaN, finite*0 == 0) + # * ``x > MAX_VAL`` catches finite overflow + is_special = (x != x) | (x * 0.0 != 0.0) | (x > MAX_VAL) + + # Short-circuit specials to 0 so the RTNE bit-magic only sees + # well-formed non-negative finite values in [0, MAX_VAL]. Negative + # finite values clamp to 0 (unsigned). + x = tl.where(is_special, 0.0, x) + x = tl.minimum(x, MAX_VAL) + x = tl.maximum(x, 0.0) + + x_int = x.to(tl.uint32, bitcast=True) + saturate_mask = x >= MAX_VAL + denormal_mask = (~saturate_mask) & (x < MIN_NORMAL) + normal_mask = ~(saturate_mask | denormal_mask) + + # ---- Normal path: magic-adder RTNE + # Use signed int32 bit patterns for the add (HIP rejects uint32 + negative). + mant_odd = (x_int >> (MBITS_F32 - MBITS)) & 1 + normal_x = x_int.to(tl.int32, bitcast=True) + VAL_TO_ADD + mant_odd + normal_x = (normal_x >> (MBITS_F32 - MBITS)).to(tl.uint8) + + # ---- Denormal path: align-by-add + denormal_x = (x + DENORM_MASK_FLOAT).to(tl.uint32, bitcast=True) - DENORM_MASK_INT + denormal_x = denormal_x.to(tl.uint8) + + # ---- Compose uint8 code (default = max-normal 0xFE for the saturate + # path; specials get re-tagged to 0xFF below). + # We use ``x_int * 0 + 0xFE`` instead of ``tl.full`` because we want + # a tile-shaped result without knowing the BLOCK_SIZE here. + encoded = (x_int * 0 + 0xFE).to(tl.uint8) + encoded = tl.where(normal_mask, normal_x, encoded) + encoded = tl.where(denormal_mask, denormal_x, encoded) + # Re-tag specials to the NaN code (D2', spec-aligned). + encoded = tl.where(is_special, NAN_CODE, encoded) + + # ---- Decode back to FP32 (mirror of ue5m3_uint8_to_f32, kept inline + # so the call-site only ever sees fp32). + exp = (encoded >> MBITS).to(tl.int32) + mant = (encoded & 0x7).to(tl.int32) + + # Normal: ((exp - 15 + 127) << 23) | (mant << 20) + normal_int = ((exp - EXP_BIAS_LP + EXP_BIAS_F32) << MBITS_F32) | (mant << (MBITS_F32 - MBITS)) + + # Denormal: arithmetic FP32 reconstruction. + # Let k = floor(log2(mant)), 2^k <= mant < 2^(k+1). Then: + # value = (1 + (mant - 2^k) / 2^k) * 2^(k - 17) + # exp_f32_bits = (k + 110) << 23 + # mant_f32_bits = (mant - 2^k) << (23 - k) + is_ge_4 = mant >= 4 + is_ge_2 = mant >= 2 + k = tl.where(is_ge_4, 2, tl.where(is_ge_2, 1, 0)) + two_k = tl.where(is_ge_4, 4, tl.where(is_ge_2, 2, 1)) + shift = tl.where(is_ge_4, 21, tl.where(is_ge_2, 22, 23)) + denorm_exp_f32 = (k + 110) << MBITS_F32 + denorm_mant_f32 = (mant - two_k) << shift + denormal_int_out = denorm_exp_f32 | denorm_mant_f32 + + nan_code_mask = encoded == NAN_CODE + normal_mask_out = (exp > 0) & ~nan_code_mask + denormal_mask_out = (exp == 0) & (mant != 0) + # ``encoded == 0`` falls through to the final ``0`` default. + result_int = tl.where( + normal_mask_out, + normal_int, + tl.where(denormal_mask_out, denormal_int_out, 0), + ) + # Spec-aligned NaN exit: bit-cast quiet NaN for code 0xFF. + result_int = tl.where(nan_code_mask, NAN_BITS, result_int) + + return result_int.to(tl.float32, bitcast=True) + + return _quantize_ue5m3 + + +# Module-level pre-built instance for callers that just want the JIT. +quantize_ue5m3 = make_quantize_ue5m3() + + +# Driver kernel must live in this module so Triton can inline ``quantize_ue5m3`` +# into the parent @triton.jit (cross-file nested JIT calls fail at compile time). +@triton.jit +def _ue5m3_snap_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) + y = quantize_ue5m3(x) + tl.store(out_ptr + offsets, y, mask=mask) + + +def triton_quantize_to_ue5m3(x: torch.Tensor) -> torch.Tensor: + """Snap FP32 to the UE5M3 grid on GPU (bit-identical to :func:`quantize_to_ue5m3`).""" + assert x.is_cuda and x.dtype == torch.float32 + flat = x.contiguous().view(-1) + out = torch.empty_like(flat) + n = flat.numel() + block = 256 + grid = (triton.cdiv(n, block),) + _ue5m3_snap_kernel[grid](flat, out, n, BLOCK_SIZE=block) + return out.view_as(x) + + +__all__ = ( + "UE5M3_EPS", + "UE5M3_EXP_BIAS", + "UE5M3_MAX", + "UE5M3_MIN_NORMAL", + "UE5M3_NAN_CODE", + "UE5M3_NUM_EXP_BITS", + "UE5M3_NUM_MAN_BITS", + "f32_to_ue5m3_uint8", + "make_quantize_ue5m3", + "quantize_to_ue5m3", + "quantize_ue5m3", + "triton_quantize_to_ue5m3", + "ue5m3_uint8_to_f32", +) diff --git a/alto/kernels/fp4/mxfp4/mxfp_grouped_gemm/cg_backward.py b/alto/kernels/fp4/mxfp4/mxfp_grouped_gemm/cg_backward.py index b215a12..393b07c 100644 --- a/alto/kernels/fp4/mxfp4/mxfp_grouped_gemm/cg_backward.py +++ b/alto/kernels/fp4/mxfp4/mxfp_grouped_gemm/cg_backward.py @@ -34,7 +34,7 @@ from alto.kernels.dge import dge_bwd from alto.kernels.hadamard_transform.transform import (HadamardFactory, HadamardTransform) from ..macro_block_scaling import macro_block_scaling, macro_block_descaling -from ...fp4_common.tensor_wrappers import unwrap_weight_wrapper +from ...fp4_primitives.tensor_wrappers import unwrap_weight_wrapper # ============ Triton kernel for contiguous grouped GEMM backward inputs ============ diff --git a/alto/kernels/fp4/mxfp4/mxfp_linear.py b/alto/kernels/fp4/mxfp4/mxfp_linear.py index 4abe7e1..626b838 100644 --- a/alto/kernels/fp4/mxfp4/mxfp_linear.py +++ b/alto/kernels/fp4/mxfp4/mxfp_linear.py @@ -18,7 +18,7 @@ import triton import triton.language as tl -from alto.kernels.fp4.fp4_common import unwrap_weight_wrapper +from alto.kernels.fp4.fp4_primitives import unwrap_weight_wrapper from alto.kernels.hadamard_transform import (HadamardTransform, HadamardFactory) from alto.kernels.dge import dge_bwd from .mxfp_quantization import ( diff --git a/alto/kernels/fp4/mxfp4/mxfp_quantization.py b/alto/kernels/fp4/mxfp4/mxfp_quantization.py index dde4e6a..d427077 100644 --- a/alto/kernels/fp4/mxfp4/mxfp_quantization.py +++ b/alto/kernels/fp4/mxfp4/mxfp_quantization.py @@ -8,7 +8,7 @@ import triton import triton.language as tl -from alto.kernels.fp4.fp4_common import ( +from alto.kernels.fp4.fp4_primitives import ( make_dequantize_e2m1, make_generate_philox_randval_2x, make_quantize_e2m1, diff --git a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/autograd.py b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/autograd.py index 7e73099..797aac7 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/autograd.py +++ b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/autograd.py @@ -29,13 +29,13 @@ import torch from alto.kernels.dge import dge_bwd -from alto.kernels.fp4.fp4_common import unwrap_weight_wrapper +from alto.kernels.fp4.fp4_primitives import unwrap_weight_wrapper from alto.kernels.fp4.nvfp4.nvfp_quantization import _qdq, convert_from_nvfp4 from alto.kernels.hadamard_transform import HadamardTransform from .autotune import ALIGN_SIZE_M from .cg_backward import _nvfp4_grouped_dgrad, _nvfp4_grouped_wgrad from .cg_forward import _nvfp4_grouped_fprop -from alto.kernels.fp4.fp4_common import ( +from alto.kernels.fp4.fp4_primitives import ( check_grouped_loop_contract, resolve_expert_indices, use_cdna4_grouped_backend, @@ -75,6 +75,7 @@ def forward( use_outer_scale: bool, hadamard_transform: Optional[HadamardTransform] = None, use_dge: bool = False, + scale_format: str = "e4m3", ) -> torch.Tensor: M_bufferlen = inputs.shape[0] original_dtype = inputs.dtype @@ -102,10 +103,12 @@ def forward( x_dq = _qdq( inputs, axis=-1, is_2d_block=use_2dblock_x, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) w_dq = _qdq( expert_weights, axis=_FPROP_AXIS_W, is_2d_block=use_2dblock_w, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) y = _nvfp4_grouped_fprop( x_dq, @@ -123,6 +126,7 @@ def forward( axis=_DGRAD_AXIS_W, is_2d_block=False, use_outer_scale=use_outer_scale, + scale_format=scale_format, return_raw=use_dge, ) if use_dge: @@ -137,6 +141,7 @@ def forward( axis=_FPROP_AXIS_W, is_2d_block=True, use_outer_scale=use_outer_scale, + scale_format=scale_format, return_raw=True, ) @@ -148,6 +153,7 @@ def forward( x_bwd = _qdq( x_for_wgrad, axis=0, is_2d_block=False, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) else: x_bwd = x_dq @@ -165,6 +171,7 @@ def forward( ctx.original_dtype = original_dtype ctx.hadamard_transform = hadamard_transform ctx.use_dge = use_dge + ctx.scale_format = scale_format return y @staticmethod @@ -191,6 +198,7 @@ def backward(ctx, grad_output: torch.Tensor): g_dq = _qdq( grad_output, axis=-1, is_2d_block=ctx.use_2dblock_x, use_outer_scale=ctx.use_outer_scale, use_sr=ctx.use_sr_grad, + scale_format=ctx.scale_format, ) grad_inputs = _nvfp4_grouped_dgrad( g_dq, @@ -215,6 +223,7 @@ def backward(ctx, grad_output: torch.Tensor): use_outer_scale=ctx.use_outer_scale, use_2dblock_x=ctx.use_2dblock_x, output_dtype=ctx.original_dtype, + scale_format=ctx.scale_format, ) if ctx.use_dge: @@ -234,5 +243,8 @@ def backward(ctx, grad_output: torch.Tensor): # Return grads with arity matching forward's positional inputs: # (inputs, expert_weights, expert_indices, offs, # use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, - # hadamard_transform, use_dge) - return grad_inputs, grad_weights, None, None, None, None, None, None, None, None + # hadamard_transform, use_dge, scale_format) + return ( + grad_inputs, grad_weights, + None, None, None, None, None, None, None, None, None, + ) diff --git a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_backward.py b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_backward.py index fa14352..941fb78 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_backward.py +++ b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_backward.py @@ -16,7 +16,7 @@ ) from alto.kernels.fp4.nvfp4.nvfp_quantization import _qdq -from alto.kernels.fp4.fp4_common import ( +from alto.kernels.fp4.fp4_primitives import ( check_grouped_loop_contract, group_ids_from_expert_indices, use_cdna4_grouped_backend, @@ -92,6 +92,7 @@ def _nvfp4_grouped_wgrad( use_outer_scale: bool, use_2dblock_x: bool, output_dtype: torch.dtype, + scale_format: str = "e4m3", ) -> torch.Tensor: """Quantize ``grad_output`` and compute the grouped weight gradient. @@ -103,11 +104,13 @@ def _nvfp4_grouped_wgrad( g_m_dq = _qdq( grad_output, axis=-1, is_2d_block=True, use_outer_scale=use_outer_scale, use_sr=use_sr_grad, + scale_format=scale_format, ) else: g_m_dq = _qdq( grad_output, axis=0, is_2d_block=False, use_outer_scale=use_outer_scale, use_sr=use_sr_grad, + scale_format=scale_format, ) if use_cdna4_grouped_backend(): diff --git a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_forward.py b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_forward.py index cacd184..ca1da03 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_forward.py +++ b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/cg_forward.py @@ -12,7 +12,7 @@ from alto.kernels.blockwise_fp8.grouped_gemm.cg_forward import cg_grouped_gemm_forward -from alto.kernels.fp4.fp4_common import ( +from alto.kernels.fp4.fp4_primitives import ( check_grouped_loop_contract, group_ids_from_expert_indices, use_cdna4_grouped_backend, diff --git a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py index 96ea40f..b172d0e 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py +++ b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py @@ -27,7 +27,7 @@ import torch -from alto.kernels.fp4.fp4_common import build_hadamard_transform_if_needed +from alto.kernels.fp4.fp4_primitives import build_hadamard_transform_if_needed from .autograd import NVFP4GroupedGEMM @@ -43,6 +43,7 @@ def _nvfp4_grouped_gemm_impl( use_outer_scale: bool = False, use_hadamard: bool = False, use_dge: bool = False, + scale_format: str = "e4m3", ) -> torch.Tensor: """Normalized entrypoint shared by both public APIs. @@ -70,6 +71,7 @@ def _nvfp4_grouped_gemm_impl( use_outer_scale, hadamard_transform, use_dge, + scale_format, ) @@ -85,6 +87,7 @@ def nvfp4_grouped_gemm( use_outer_scale: bool = False, use_hadamard: bool = False, use_dge: bool = False, + scale_format: str = "e4m3", ) -> torch.Tensor: """NVFP4 QDQ-emulated Grouped GEMM with full autograd support. @@ -124,6 +127,7 @@ def nvfp4_grouped_gemm( use_outer_scale=use_outer_scale, use_hadamard=use_hadamard, use_dge=use_dge, + scale_format=scale_format, ) @@ -137,6 +141,7 @@ def _quantize_then_nvfp4_scaled_grouped_mm( use_outer_scale: bool = False, use_hadamard: bool = False, use_dge: bool = False, + scale_format: str = "e4m3", ) -> torch.Tensor: """Drop-in for the dispatch layer, mirroring mxfp4's ``_quantize_then_mxfp4_scaled_grouped_mm``. @@ -158,4 +163,5 @@ def _quantize_then_nvfp4_scaled_grouped_mm( use_outer_scale=use_outer_scale, use_hadamard=use_hadamard, use_dge=use_dge, + scale_format=scale_format, ) diff --git a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/utils.py b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/utils.py index f36a6cb..43ea858 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/utils.py +++ b/alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/utils.py @@ -8,7 +8,7 @@ Shared pieces (backend probe, ``ALIGN_SIZE_M`` loop contract, expert-index normalization, group-id extraction, Hadamard construction) live in -``alto.kernels.fp4.fp4_common.grouped_utils`` and are re-exported here for +``alto.kernels.fp4.fp4_primitives.grouped_utils`` and are re-exported here for backward compatibility with legacy imports. Only truly NVFP4-specific contracts (the axis-0 QDQ block-size alignment) stay @@ -19,7 +19,7 @@ import torch -from alto.kernels.fp4.fp4_common import ( # noqa: F401 — re-exported for BC +from alto.kernels.fp4.fp4_primitives import ( # noqa: F401 — re-exported for BC build_hadamard_transform_if_needed, check_grouped_loop_contract, group_ids_from_expert_indices, diff --git a/alto/kernels/fp4/nvfp4/nvfp_linear.py b/alto/kernels/fp4/nvfp4/nvfp_linear.py index 3a91a0d..da523b4 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_linear.py +++ b/alto/kernels/fp4/nvfp4/nvfp_linear.py @@ -6,7 +6,7 @@ import torch -from alto.kernels.fp4.fp4_common import unwrap_weight_wrapper +from alto.kernels.fp4.fp4_primitives import unwrap_weight_wrapper from alto.kernels.hadamard_transform import HadamardFactory, HadamardTransform from alto.kernels.dge import dge_bwd from .nvfp_quantization import ( @@ -60,6 +60,7 @@ def forward( use_outer_scale: bool, hadamard_transform: Optional[HadamardTransform] = None, use_dge: bool = False, + scale_format: str = "e4m3", ): weight = unwrap_weight_wrapper(weight) # Align weight dtype with activation so saved QDQ tensors share a @@ -77,11 +78,13 @@ def forward( x_2d, axis=-1, is_2d_block=use_2dblock_x, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) w_dq = _qdq( weight, axis=-1, is_2d_block=use_2dblock_w, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) y = x_dq @ w_dq.T @@ -117,6 +120,7 @@ def forward( weight, axis=0, is_2d_block=False, use_outer_scale=use_outer_scale, + scale_format=scale_format, return_raw=True, ) else: @@ -124,6 +128,7 @@ def forward( weight, axis=0, is_2d_block=False, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) else: # 2D block scaling is axis-invariant, so the fprop quantized view is @@ -136,6 +141,7 @@ def forward( weight, axis=-1, is_2d_block=use_2dblock_w, use_outer_scale=use_outer_scale, + scale_format=scale_format, return_raw=True, ) w_dq_axis0 = w_dq @@ -167,6 +173,7 @@ def forward( x_for_axis0, axis=0, is_2d_block=False, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) else: x_dq_axis0 = x_dq @@ -181,6 +188,7 @@ def forward( ctx.use_outer_scale = use_outer_scale ctx.hadamard_transform = hadamard_transform ctx.use_dge = use_dge + ctx.scale_format = scale_format return y.view(*original_shape[:-1], -1) @@ -210,6 +218,7 @@ def backward(ctx, grad_output): is_2d_block=True, use_outer_scale=ctx.use_outer_scale, use_sr=ctx.use_sr_grad, + scale_format=ctx.scale_format, ) grad_output_m_dq = grad_output_dq else: @@ -218,6 +227,7 @@ def backward(ctx, grad_output): is_2d_block=False, use_outer_scale=ctx.use_outer_scale, use_sr=ctx.use_sr_grad, + scale_format=ctx.scale_format, ) # Apply the same Hadamard rotation that was used on x in forward, # to grad_output before axis=0 QDQ. See the forward-side comment @@ -231,6 +241,7 @@ def backward(ctx, grad_output): is_2d_block=False, use_outer_scale=ctx.use_outer_scale, use_sr=ctx.use_sr_grad, + scale_format=ctx.scale_format, ) grad_inputs = grad_output_dq @ w_dq @@ -259,7 +270,11 @@ def backward(ctx, grad_output): return ( grad_inputs.view(*original_shape[:-1], w_dq.shape[-1]), grad_weights, - None, None, None, None, None, None, + # Match forward's positional arity: (x, weight, use_2dblock_x, + # use_2dblock_w, use_sr_grad, use_outer_scale, hadamard_transform, + # use_dge, scale_format). 9 inputs → 9 grad slots, with a None + # for every non-Tensor / non-grad-tracked input. + None, None, None, None, None, None, None, ) @@ -272,6 +287,7 @@ def _to_nvfp4_then_scaled_mm( use_outer_scale: bool = False, use_hadamard: bool = False, use_dge: bool = False, + scale_format: str = "e4m3", ) -> torch.Tensor: """Build the optional Hadamard transform and apply ``NVFP4LinearFunction``. @@ -294,4 +310,5 @@ def _to_nvfp4_then_scaled_mm( use_outer_scale, hadamard_transform, use_dge, + scale_format, ) diff --git a/alto/kernels/fp4/nvfp4/nvfp_quantization.py b/alto/kernels/fp4/nvfp4/nvfp_quantization.py index b0597b7..4f56772 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_quantization.py +++ b/alto/kernels/fp4/nvfp4/nvfp_quantization.py @@ -2,441 +2,79 @@ # # SPDX-License-Identifier: MIT -from typing import Optional, Tuple, Union - -import torch -import triton -import triton.language as tl -from torch.library import triton_op, wrap_triton - -from alto.kernels.fp4.fp4_common import ( - make_dequantize_e2m1, - make_generate_philox_randval_2x, - make_quantize_e2m1, -) - - -BLOCK_SIZE_DEFAULT = 16 -F4_E2M1_MAX = 6.0 -F8E4M3_MAX = 448.0 -E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny - -# Naming convention for the NVFP4 scale hierarchy: -# inner_scale -- per-block scale stored alongside the packed FP4 data -# (NVFP4 spec ``s_block``). Value lives on the E4M3 grid -# in an FP32 container (see outer-side notes below). -# outer_scale -- the outer-level scale factor that sits above -# ``inner_scale`` (NVFP4 spec ``s_global``). Currently a -# per-tensor FP32 scalar; the name is intentionally -# agnostic so the same API can later carry an -# outer-blockwise layout (e.g. one scale per 128x128 -# tile) without renaming the public surface. -# -# Per spec, ``outer_scale`` lives in FP32; this floor is a div-by-zero -# guard for the downstream ``max_abs / outer_scale`` when ``amax == 0``. -# We pick ``1e-30`` (well above FP32 denormal range and ~22 orders below -# any natural training-time outer scale) so that the effective per-block -# divisor ``quant_scale = inner_scale * outer_scale`` stays an FP32 -# *normal* in the worst case (``inner_scale == E4M3_EPS``, -# ``outer_scale == _OUTER_SCALE_DIVZERO_FLOOR``): -# 1e-30 * 2**-6 ≈ 1.56e-32 > FP32 smallest normal (~1.18e-38) -_OUTER_SCALE_DIVZERO_FLOOR = 1.0e-30 - -SUPPORTED_SCALE_FORMATS = ("e4m3",) - - -def _check_scale_format(scale_format: str) -> None: - if scale_format != "e4m3": - raise NotImplementedError( - f"scale_format={scale_format!r} is not yet supported. " - f"Currently supported: {SUPPORTED_SCALE_FORMATS}" - ) -_dequantize_e2m1 = make_dequantize_e2m1() -_generate_philox_randval_2x = make_generate_philox_randval_2x() -_quantize_e2m1 = make_quantize_e2m1() - - -def is_cdna4(): - target = triton.runtime.driver.active.get_current_target() - return target is not None and target.backend == "hip" and target.arch == "gfx950" - - -@triton.jit -def _pack_fp4( - x, - scales_fp32, - philox_seed, - philox_offset, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - QUANT_BLOCK_SIZE: tl.constexpr, - IS_2D_BLOCK: tl.constexpr = False, - USE_SR: tl.constexpr = False, - USE_ASM: tl.constexpr = False, -): - """Quantize and pack a tile into nibble-packed uint8. - - Mirrors the API of ``mxfp4._pack_fp4``. The key difference is that - *scales_fp32* is already a per-block **float32** tensor, whereas MXFP4 - passes uint8 exponents and converts them internally. - - ``USE_ASM`` is accepted for API parity but has no effect — CDNA4 FP4 - ASM instructions only honour the biased exponent of the scale operand, - making them incompatible with NVFP4's general float32 scales. - - Args: - x: input tile ``[BLOCK_M, BLOCK_N]`` (float32 | bfloat16) - scales_fp32: per-block float32 scales. - 1D: ``[BLOCK_M, SCALE_BLOCK_N]``; - 2D: ``[SCALE_BLOCK_M, SCALE_BLOCK_N]``. - BLOCK_M / BLOCK_N / QUANT_BLOCK_SIZE: tile constants - IS_2D_BLOCK: if True, use 2D (square) blocks for scaling - USE_SR: enable stochastic rounding - USE_ASM: (no-op) kept for API consistency with MXFP4 - - Returns: - packed uint8 tile ``[BLOCK_M, HALF_BLOCK_N]`` - """ - HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 - HALF_QUANT_BLOCK_SIZE: tl.constexpr = QUANT_BLOCK_SIZE // 2 - SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE - SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE - - if IS_2D_BLOCK: - scales_bc = scales_fp32.expand_dims(axis=(1, 3)).broadcast_to( - SCALE_BLOCK_M, QUANT_BLOCK_SIZE, SCALE_BLOCK_N, - HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) - else: - scales_bc = scales_fp32.expand_dims(axis=2).broadcast_to( - BLOCK_M, SCALE_BLOCK_N, - HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) - - x0, x1 = tl.split(x.reshape(BLOCK_M, HALF_BLOCK_N, 2)) - - if USE_SR: - randval0, randval1 = _generate_philox_randval_2x( - BLOCK_M, HALF_BLOCK_N, philox_seed, philox_offset) - else: - randval0 = 0 - randval1 = 0 - - y0 = _quantize_e2m1(x0, scales_bc, randval0, USE_SR=USE_SR) - y1 = _quantize_e2m1(x1, scales_bc, randval1, USE_SR=USE_SR) - y = y0 | (y1 << 4) - - return y.to(tl.uint8) - - -@triton.jit -def _unpack_fp4( - x, - scales_fp32, - output_dtype: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - QUANT_BLOCK_SIZE: tl.constexpr, - IS_2D_BLOCK: tl.constexpr = False, - USE_ASM: tl.constexpr = False, -): - """Unpack and dequantize a nibble-packed uint8 tile back to float. - - Mirrors the API of ``mxfp4._unpack_fp4``. *scales_fp32* is a per-block - **float32** tensor; see :func:`_pack_fp4` for details on the - MXFP4/NVFP4 scale difference. - - ``USE_ASM`` is accepted for API parity but has no effect. +"""NVFP4 (E4M3 inner scale) public quantize / dequantize ops. - Args: - x: packed uint8 tile ``[BLOCK_M, HALF_BLOCK_N]`` - scales_fp32: per-block float32 scales. - 1D: ``[BLOCK_M, SCALE_BLOCK_N]``; - 2D: ``[SCALE_BLOCK_M, SCALE_BLOCK_N]``. - output_dtype: target element type (tl.float32 | tl.bfloat16). - Kept for API parity with MXFP4; the software path - always computes in float32. - BLOCK_M / BLOCK_N / QUANT_BLOCK_SIZE: tile constants - IS_2D_BLOCK: if True, use 2D (square) blocks for scaling - USE_ASM: (no-op) kept for API consistency with MXFP4 +The actual block-quant implementation is shared with AMD-FP4 and lives in +:mod:`alto.kernels.fp4.outer_scaled_fp4`. This module is a *thin* recipe-side +surface that: - Returns: - unpacked float32 tile ``[BLOCK_M, BLOCK_N]`` - """ - HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 - HALF_QUANT_BLOCK_SIZE: tl.constexpr = QUANT_BLOCK_SIZE // 2 - SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE - SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE - - if IS_2D_BLOCK: - scales_bc = scales_fp32.expand_dims(axis=(1, 3)).broadcast_to( - SCALE_BLOCK_M, QUANT_BLOCK_SIZE, SCALE_BLOCK_N, - HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) - else: - scales_bc = scales_fp32.expand_dims(axis=2).broadcast_to( - BLOCK_M, SCALE_BLOCK_N, - HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) - - x0 = x & 0xF - x1 = (x & 0xF0) >> 4 - y0 = _dequantize_e2m1(x0, scales_bc) - y1 = _dequantize_e2m1(x1, scales_bc) - - y = tl.join(y0, y1).reshape(BLOCK_M, BLOCK_N) - return y - - -# ---- scale calculation (NVFP4-specific) ----------------------------------- - -@triton.jit -def _calculate_nvfp4_scales( - x, - outer_scale_ptr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - QUANT_BLOCK_SIZE: tl.constexpr, - IS_2D_BLOCK: tl.constexpr = False, - USE_OUTER_SCALE: tl.constexpr = False, -): - """Compute per-block E4M3-quantised scales for NVFP4 quantization. - - The returned ``inner_scale`` is the NVFP4 spec ``s_block``; the - ``outer_scale`` operand is the spec ``s_global``. See the - naming-convention block at the top of this module for the rationale. - - Per spec, the only value stored as FP8 E4M3 is the per-block scale - written next to the packed FP4 data; the outer scale and intermediates - stay in FP32. With ``USE_OUTER_SCALE=True`` the spec order is:: - - inner_scale_raw = block_amax(x) / outer_scale / F4_E2M1_MAX - inner_scale = round_e4m3(clamp(inner_scale_raw, - [E4M3_EPS, F8E4M3_MAX])) - quant_scale = inner_scale * outer_scale - - i.e. clamp + E4M3 round are applied exactly once, on the final stored - block (inner) scale. - - When ``IS_2D_BLOCK`` is True, one scale covers a - ``QUANT_BLOCK_SIZE x QUANT_BLOCK_SIZE`` tile, yielding output shapes - ``[SCALE_BLOCK_M, SCALE_BLOCK_N]`` instead of ``[BLOCK_M, SCALE_BLOCK_N]``. - """ - NEW_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE - - if IS_2D_BLOCK: - NEW_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE - x_grouped = x.reshape(NEW_BLOCK_M, QUANT_BLOCK_SIZE, - NEW_BLOCK_N, QUANT_BLOCK_SIZE) - max_abs = tl.max(tl.abs(x_grouped), axis=-1) - max_abs = tl.max(max_abs, axis=-2).to(tl.float32) - else: - x_grouped = x.reshape(BLOCK_M, NEW_BLOCK_N, QUANT_BLOCK_SIZE) - max_abs = tl.max(tl.abs(x_grouped), axis=-1).to(tl.float32) - - if USE_OUTER_SCALE: - outer_scale = tl.load(outer_scale_ptr) - inner_scale_raw = max_abs / outer_scale / F4_E2M1_MAX - inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX) - inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32) - quant_scale = inner_scale * outer_scale - else: - inner_scale_raw = max_abs / F4_E2M1_MAX - inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX) - inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32) - quant_scale = inner_scale - - return inner_scale, quant_scale - - -# ---- top-level Triton kernels --------------------------------------------- - -@triton.jit -def _convert_to_nvfp4_kernel( - x_ptr, - y_ptr, - s_ptr, - outer_scale_ptr, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - stride_sm, - stride_sn, - philox_seed, - philox_offset, - M_ACTUAL, - N_ACTUAL, - PACKED_N_ACTUAL, - SCALE_M_ACTUAL, - SCALE_N_ACTUAL, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - QUANT_BLOCK_SIZE: tl.constexpr, - IS_2D_BLOCK: tl.constexpr, - USE_OUTER_SCALE: tl.constexpr, - USE_SR: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - - HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 - SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE - SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE - - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_xn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_yn = pid_n * HALF_BLOCK_N + tl.arange(0, HALF_BLOCK_N) - offs_sn = pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N) - if IS_2D_BLOCK: - offs_sm = pid_m * SCALE_BLOCK_M + tl.arange(0, SCALE_BLOCK_M) - else: - offs_sm = offs_m - - offs_x = offs_m[:, None] * stride_xm + offs_xn[None, :] * stride_xn - offs_y = offs_m[:, None] * stride_ym + offs_yn[None, :] * stride_yn - offs_s = offs_sm[:, None] * stride_sm + offs_sn[None, :] * stride_sn - - tl.static_assert( - (x_ptr.type.element_ty == tl.float32) | (x_ptr.type.element_ty == tl.bfloat16) - ) - x = tl.load( - x_ptr + offs_x, - mask=(offs_m[:, None] < M_ACTUAL) & (offs_xn[None, :] < N_ACTUAL), - other=0, - ) - - inner_scale, quant_scale = _calculate_nvfp4_scales( - x, - outer_scale_ptr, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, - IS_2D_BLOCK=IS_2D_BLOCK, - USE_OUTER_SCALE=USE_OUTER_SCALE, - ) - - y = _pack_fp4( - x, - quant_scale, - philox_seed, - philox_offset, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, - IS_2D_BLOCK=IS_2D_BLOCK, - USE_SR=USE_SR, - ) - - tl.store( - y_ptr + offs_y, - y.to(y_ptr.type.element_ty), - mask=(offs_m[:, None] < M_ACTUAL) & (offs_yn[None, :] < PACKED_N_ACTUAL), - ) - tl.store( - s_ptr + offs_s, - inner_scale, - mask=(offs_sm[:, None] < SCALE_M_ACTUAL) & (offs_sn[None, :] < SCALE_N_ACTUAL), - ) +* registers the ATen-level operators ``alto::convert_to_nvfp4`` and + ``alto::convert_from_nvfp4`` on top of the shared body; +* re-exports the legacy private symbols (``_pack_fp4``, ``_unpack_fp4``, + ``_calculate_nvfp4_scales``, kernel handles, scale tables) under their + historical names so external imports keep working. +The ``_qdq`` round-trip helper lives here too because it is consumed by +``nvfp_linear`` / ``nvfp_grouped_gemm`` and is logically part of the +NVFP4 layer. -@triton.jit -def _convert_from_nvfp4_kernel( - x_ptr, - y_ptr, - s_ptr, - outer_scale_ptr, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - stride_sm, - stride_sn, - M_ACTUAL, - N_ACTUAL, - PACKED_N_ACTUAL, - SCALE_M_ACTUAL, - SCALE_N_ACTUAL, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - QUANT_BLOCK_SIZE: tl.constexpr, - IS_2D_BLOCK: tl.constexpr, - USE_OUTER_SCALE: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) +Naming convention used throughout this sub-package: - HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 - SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE - SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE +* ``inner_scale`` -- per-block scale stored next to the packed FP4 data + (NVFP4 spec ``s_block``). Lives on the inner grid (E4M3 here, UE5M3 + in AMD-FP4) inside an FP32 container. +* ``outer_scale`` -- outer-level scale that sits above ``inner_scale`` + (NVFP4 spec ``s_global``). Today a per-tensor FP32 scalar; named + agnostically so a future outer-blockwise layout reuses the same + surface without renames. - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_xn = pid_n * HALF_BLOCK_N + tl.arange(0, HALF_BLOCK_N) - offs_yn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_sn = pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N) - if IS_2D_BLOCK: - offs_sm = pid_m * SCALE_BLOCK_M + tl.arange(0, SCALE_BLOCK_M) - else: - offs_sm = offs_m +See ``alto/kernels/fp4/outer_scaled_fp4/scales.py`` for the +``_OUTER_SCALE_DIVZERO_FLOOR`` div-by-zero reasoning. +""" - offs_x = offs_m[:, None] * stride_xm + offs_xn[None, :] * stride_xn - offs_y = offs_m[:, None] * stride_ym + offs_yn[None, :] * stride_yn - offs_s = offs_sm[:, None] * stride_sm + offs_sn[None, :] * stride_sn - - x = tl.load( - x_ptr + offs_x, - mask=(offs_m[:, None] < M_ACTUAL) & (offs_xn[None, :] < PACKED_N_ACTUAL), - other=0, - ) - s = tl.load( - s_ptr + offs_s, - mask=(offs_sm[:, None] < SCALE_M_ACTUAL) & (offs_sn[None, :] < SCALE_N_ACTUAL), - other=0, - ) - - if USE_OUTER_SCALE: - outer_scale = tl.load(outer_scale_ptr) - s = s * outer_scale - - y = _unpack_fp4( - x, - s, - y_ptr.type.element_ty, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, - IS_2D_BLOCK=IS_2D_BLOCK, - ) - - tl.store( - y_ptr + offs_y, - y.to(y_ptr.type.element_ty), - mask=(offs_m[:, None] < M_ACTUAL) & (offs_yn[None, :] < N_ACTUAL), - ) - - -def compute_dynamic_outer_scale( - data_hp: torch.Tensor, - scale_format: str = "e4m3", -) -> torch.Tensor: - """Compute the FP32 outer-level scale ``amax / (F8E4M3_MAX * F4_E2M1_MAX)``. +import warnings +from typing import Optional, Tuple, Union - The "outer" naming reflects this scalar's position in the NVFP4 - hierarchy: it sits above the per-block ``inner_scale`` and is shared - across the entire tensor today (NVFP4 spec ``s_global``). Future - extensions may produce one outer scale per outer-block tile; this - function and the surrounding API are named to accommodate that - without further renames. +import torch +from torch.library import triton_op + +from alto.kernels.fp4.outer_scaled_fp4 import ( + BLOCK_SIZE_DEFAULT, + F4_E2M1_MAX, + SUPPORTED_SCALE_FORMATS, + _OUTER_SCALE_DIVZERO_FLOOR, + _SCALE_FORMAT_TABLE, + _calculate_inner_scales, + _check_scale_format, + _convert_from_fp4_kernel, + _convert_to_fp4_kernel, + _pack_fp4, + _scale_format_is_ue5m3, + _unpack_fp4, + compute_dynamic_outer_scale, + is_cdna4, +) +from alto.kernels.fp4.outer_scaled_fp4.api import ( + _convert_from_fp4_blockwise, + _convert_to_fp4_blockwise, + _fake_convert_from_fp4_blockwise, + _fake_convert_to_fp4_blockwise, +) - Per spec, the outer scale stays in FP32 with only a - ``_OUTER_SCALE_DIVZERO_FLOOR`` div-by-zero guard (``amax == 0``). +# --------------------------------------------------------------------------- +# Back-compat aliases: external code (and historical NVFP4 docstrings) +# refer to ``_calculate_nvfp4_scales`` and the old kernel handle names. +# Keep them resolvable so import sites do not break. +# --------------------------------------------------------------------------- - The *scale_format* parameter is reserved for future scale representations - (e.g. ``"e5m3"``). Currently only ``"e4m3"`` is supported. - """ - _check_scale_format(scale_format) - amax = data_hp.float().abs().max() - outer_scale = (amax / (F8E4M3_MAX * F4_E2M1_MAX)).clamp(min=_OUTER_SCALE_DIVZERO_FLOOR) - return outer_scale.to(dtype=torch.float32).reshape(1) +_calculate_nvfp4_scales = _calculate_inner_scales +_convert_to_nvfp4_kernel = _convert_to_fp4_kernel +_convert_from_nvfp4_kernel = _convert_from_fp4_kernel -@triton_op("alto::convert_to_nvfp4", mutates_args={}) +# ``outer_scale`` is mutated in place (``copy_``) when ``update_outer_scale`` +# is True; declare it so Dynamo/functionalization invalidate stale caches. +@triton_op("alto::convert_to_nvfp4", mutates_args={"outer_scale"}) def convert_to_nvfp4( data_hp: torch.Tensor, block_size: int = BLOCK_SIZE_DEFAULT, @@ -452,129 +90,34 @@ def convert_to_nvfp4( ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize a high-precision tensor to NVFP4 (E2M1) format. - Outer-level (NVFP4 spec ``s_global``) scaling is controlled by - ``outer_scale`` and ``update_outer_scale``. ``outer_scale`` is today a - 1-element FP32 tensor representing a per-tensor scale; the parameter is - deliberately named ``outer_scale`` (not ``per_tensor_scale``) so a future - outer-blockwise layout can reuse the same surface without renaming. - - * ``outer_scale`` given, ``update_outer_scale=True`` (default): - recompute the scale from ``data_hp``'s amax and write it back into the - caller's tensor **in place** (no clone). The caller reads the updated - value back through the same tensor. This is the recommended path for - training, where the outer scale tracks the current tensor's range. - * ``outer_scale`` given, ``update_outer_scale=False``: - use the caller-provided scale as-is (for calibrated / frozen scales). - * ``outer_scale=None``, ``update_outer_scale=True`` (default): - compute a dynamic scale internally and apply it for this call only. - The scale is not returned; if the caller wants to track it across calls - they should pre-allocate a buffer and pass it in. - * ``outer_scale=None``, ``update_outer_scale=False``: - outer-level scaling is disabled. - - *scale_format* selects the per-block scale representation. Currently - supported: ``"e4m3"`` (default). ``"e5m3"`` is accepted for forward - compatibility but is not yet validated on hardware. - - ``use_asm`` is accepted for API consistency with MXFP4 but has no effect. + See :func:`alto.kernels.fp4.outer_scaled_fp4.api._convert_to_fp4_blockwise` + for the full parameter contract. ``scale_format`` defaults to + ``"e4m3"`` (NVFP4); passing ``"ue5m3"`` is supported for backward + compatibility but new AMD-FP4 call sites should use the dedicated + ``alto::convert_to_amdfp4`` op via + :mod:`alto.kernels.fp4.amdfp4`. """ - torch._check( - data_hp.shape[axis] % block_size == 0, - lambda: f"tensor shape ({data_hp.shape}) at axis={axis} is not divisible by {block_size}", - ) - assert data_hp.dtype in [torch.float32, torch.bfloat16] - assert block_size % 2 == 0 and block_size >= 2, ( - f"block_size must be a positive even number, got {block_size}" - ) - assert not is_2d_block or data_hp.size(-2) % block_size == 0, ( - f"2D block requires dim -2 ({data_hp.size(-2)}) divisible by block_size ({block_size})" - ) - _check_scale_format(scale_format) - - data_hp = data_hp.transpose(axis, -1) - ori_shape = data_hp.shape - data_hp = data_hp.reshape(-1, ori_shape[-1]) - - new_shape = (*ori_shape[:-1], ori_shape[-1] // 2) - if is_2d_block: - scale_shape = (*ori_shape[:-2], ori_shape[-2] // block_size, - ori_shape[-1] // block_size) - else: - scale_shape = (*ori_shape[:-1], ori_shape[-1] // block_size) - data_lp = torch.empty(new_shape, dtype=torch.uint8, device=data_hp.device).reshape( - -1, new_shape[-1] - ) - scales = torch.empty(scale_shape, dtype=torch.float32, device=data_hp.device).reshape( - -1, scale_shape[-1] - ) - - # Resolve outer-scale I/O on the caller's own buffer — no clone. - if outer_scale is not None: - assert outer_scale.numel() == 1, "outer_scale must be a scalar tensor" - assert outer_scale.dtype == torch.float32, ( - "outer_scale must be float32" - ) - assert outer_scale.device == data_hp.device, ( - f"outer_scale device ({outer_scale.device}) must match " - f"data_hp device ({data_hp.device})" + if scale_format != "e4m3": + warnings.warn( + f"convert_to_nvfp4(scale_format={scale_format!r}) is deprecated; " + f"use the dedicated alto::convert_to_amdfp4 op " + f"(alto.kernels.fp4.amdfp4.convert_to_amdfp4) for non-E4M3 " + f"inner-scale grids.", + DeprecationWarning, + stacklevel=2, ) - if update_outer_scale: - outer_scale.copy_( - compute_dynamic_outer_scale(data_hp).reshape_as(outer_scale) - ) - outer_scale_buf = outer_scale.reshape(()) - use_outer_scale = True - elif update_outer_scale: - # No buffer supplied — compute an ephemeral scale for this call. - outer_scale_buf = compute_dynamic_outer_scale(data_hp).reshape(()) - use_outer_scale = True - else: - outer_scale_buf = torch.ones((), dtype=torch.float32, device=data_hp.device) - use_outer_scale = False - - stride_xm, stride_xn = data_hp.stride() - stride_ym, stride_yn = data_lp.stride() - stride_sm, stride_sn = scales.stride() - - M, N = data_hp.shape - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]), triton.cdiv(N, META["BLOCK_N"])) - BLOCK_M = 64 if M >= 64 else M - BLOCK_N = 64 if N >= 64 else N - - if philox_seed is None: - philox_seed = torch.randint(0, 2**31 - 1, (1,)).item() - if philox_offset is None: - philox_offset = torch.randint(0, 2**31 - 1, (1,)).item() - - wrap_triton(_convert_to_nvfp4_kernel)[grid]( + return _convert_to_fp4_blockwise( data_hp, - data_lp, - scales, - outer_scale_buf, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - stride_sm, - stride_sn, - philox_seed, - philox_offset, - M, - N, - data_lp.shape[1], - scales.shape[0], - scales.shape[1], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - QUANT_BLOCK_SIZE=block_size, - IS_2D_BLOCK=is_2d_block, - USE_OUTER_SCALE=use_outer_scale, - USE_SR=use_sr, - ) - - return ( - data_lp.reshape(new_shape).transpose(axis, -1), - scales.reshape(scale_shape).transpose(axis, -1), + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + update_outer_scale=update_outer_scale, + scale_format=scale_format, + use_sr=use_sr, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_asm=use_asm, ) @@ -592,72 +135,21 @@ def convert_from_nvfp4( ) -> torch.Tensor: """Dequantize NVFP4 (E2M1) data back to high-precision format. - ``outer_scale`` is the optional outer-level FP32 scalar (NVFP4 spec - ``s_global``); see :func:`convert_to_nvfp4` for the naming rationale. - - *scale_format* is accepted for API symmetry with :func:`convert_to_nvfp4`. - The dequantization path only multiplies by the stored float32 scale, so - the format parameter does not affect the computation. - - ``use_asm`` is accepted for API consistency with MXFP4 but has no effect. + See :func:`alto.kernels.fp4.outer_scaled_fp4.api._convert_from_fp4_blockwise` + for the full parameter contract. """ - assert output_dtype in [torch.float32, torch.bfloat16] - assert block_size % 2 == 0 and block_size >= 2, ( - f"block_size must be a positive even number, got {block_size}" - ) - - data_lp = data_lp.transpose(axis, -1) - scales = scales.transpose(axis, -1) - orig_shape_lp = data_lp.shape - - data_lp = data_lp.reshape(-1, orig_shape_lp[-1]) - scales = scales.reshape(-1, (orig_shape_lp[-1] * 2) // block_size).to(torch.float32) - orig_shape_hp = (*orig_shape_lp[:-1], orig_shape_lp[-1] * 2) - data_hp = data_lp.new_empty(orig_shape_hp, dtype=output_dtype).reshape(-1, orig_shape_hp[-1]) - - if outer_scale is None: - outer_scale_buf = torch.ones((), dtype=torch.float32, device=data_lp.device) - use_outer_scale = False - else: - outer_scale = outer_scale.to(device=data_lp.device, dtype=torch.float32) - assert outer_scale.numel() == 1, "outer_scale must be a scalar tensor" - outer_scale_buf = outer_scale.reshape(()) - use_outer_scale = True - - stride_xm, stride_xn = data_lp.stride() - stride_ym, stride_yn = data_hp.stride() - stride_sm, stride_sn = scales.stride() - M, N = data_hp.shape - - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]), triton.cdiv(N, META["BLOCK_N"])) - BLOCK_M = 64 if M >= 64 else M - BLOCK_N = 64 if N >= 64 else N - - wrap_triton(_convert_from_nvfp4_kernel)[grid]( + return _convert_from_fp4_blockwise( data_lp, - data_hp, scales, - outer_scale_buf, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - stride_sm, - stride_sn, - M, - N, - data_lp.shape[1], - scales.shape[0], - scales.shape[1], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - QUANT_BLOCK_SIZE=block_size, - IS_2D_BLOCK=is_2d_block, - USE_OUTER_SCALE=use_outer_scale, + output_dtype=output_dtype, + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + scale_format=scale_format, + use_asm=use_asm, ) - return data_hp.reshape(orig_shape_hp).transpose(axis, -1) - @convert_to_nvfp4.register_fake def _fake_convert_to_nvfp4( @@ -673,18 +165,9 @@ def _fake_convert_to_nvfp4( philox_offset: Optional[int] = None, use_asm: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - data_hp = data_hp.transpose(axis, -1) - orig_shape = data_hp.shape - - new_shape = (*orig_shape[:-1], orig_shape[-1] // 2) - if is_2d_block: - scale_shape = (*orig_shape[:-2], orig_shape[-2] // block_size, - orig_shape[-1] // block_size) - else: - scale_shape = (*orig_shape[:-1], orig_shape[-1] // block_size) - data_lp = data_hp.new_empty(new_shape, dtype=torch.uint8) - scales = data_hp.new_empty(scale_shape, dtype=torch.float32) - return data_lp.transpose(axis, -1), scales.transpose(axis, -1) + return _fake_convert_to_fp4_blockwise( + data_hp, block_size=block_size, axis=axis, is_2d_block=is_2d_block, + ) @convert_from_nvfp4.register_fake @@ -699,8 +182,9 @@ def _fake_convert_from_nvfp4( scale_format: str = "e4m3", use_asm: Optional[bool] = None, ) -> torch.Tensor: - data_hp = data_lp.new_empty(data_lp.shape, dtype=output_dtype) - return torch.cat((data_hp, data_hp), dim=axis) + return _fake_convert_from_fp4_blockwise( + data_lp, output_dtype=output_dtype, axis=axis, + ) def _qdq( @@ -711,6 +195,7 @@ def _qdq( use_outer_scale: bool, use_sr: bool = False, block_size: int = 16, + scale_format: str = "e4m3", return_raw: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Quantize to NVFP4 then immediately dequantize back (QDQ round-trip). @@ -741,6 +226,7 @@ def _qdq( # With ``outer_scale`` provided, refresh it in-place from tensor.amax(); # without it, skip outer-level scaling entirely. update_outer_scale=use_outer_scale, + scale_format=scale_format, use_sr=use_sr, ) dq = convert_from_nvfp4( @@ -751,6 +237,7 @@ def _qdq( axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale, + scale_format=scale_format, ) if return_raw: return dq, data_lp, scales diff --git a/alto/kernels/fp4/outer_scaled_fp4/__init__.py b/alto/kernels/fp4/outer_scaled_fp4/__init__.py new file mode 100644 index 0000000..6cea2ef --- /dev/null +++ b/alto/kernels/fp4/outer_scaled_fp4/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""``outer_scaled_fp4`` — shared two-level blockwise body for NVFP4 + AMD-FP4. + +The package name describes the layout this layer owns: FP4 / E2M1 storage +with a **per-block FP32 (float-grid) inner scale** plus an optional +**per-tensor FP32 outer scale**. Only NVFP4 and AMD-FP4 build on it; MXFP4 +is intentionally NOT a consumer — its macro-block (128×128 uint8) scaling +is a separate implementation under ``alto.kernels.fp4.mxfp4`` and shares +only the lower ``fp4_primitives`` E2M1 layer. + +Both consuming recipes share an identical block-quantization layout; only +the inner-scale dtype differs: + +* NVFP4 — inner scale on the **E4M3** grid (signed FP8, max 448). +* AMD-FP4 — inner scale on the **UE5M3** grid (unsigned float8, max + normal 114688 at code 0xFE; NaN reserved at code 0xFF per + GFXIPARCH-2067 §19.10 / OCP E5M3). + +This sub-package owns everything that is independent of which inner +grid the caller chose: + +* :mod:`pack_unpack` — Triton ``_pack_fp4`` / ``_unpack_fp4`` (E2M1 + nibble packing; the inner-scale dtype only feeds in as a pre-snapped + FP32 scalar, so this layer is format-agnostic). +* :mod:`scales` — table-driven ``_SCALE_FORMAT_TABLE``, the shared + Triton ``_calculate_inner_scales`` helper (with the D2' NaN defense + layer), and the FP32 :func:`compute_dynamic_outer_scale` helper. +* :mod:`kernels` — top-level ``_convert_to_fp4_kernel`` / + ``_convert_from_fp4_kernel`` Triton kernels driven by a constexpr + format selector. +* :mod:`api` — ``_convert_to_fp4_blockwise`` / ``_convert_from_fp4_blockwise`` + Python entry points; ``alto.kernels.fp4.nvfp4`` and + ``alto.kernels.fp4.amdfp4`` register their respective ``triton_op`` + surfaces (``alto::convert_to_nvfp4`` and ``alto::convert_to_amdfp4``) + on top of these. + +Public names are deliberately unstable and prefixed with ``_``; user +code should keep going through ``alto.kernels.fp4.{nvfp4,amdfp4}`` / +``torch.ops.alto``. +""" + +from .scales import ( + BLOCK_SIZE_DEFAULT, + F4_E2M1_MAX, + SUPPORTED_SCALE_FORMATS, + _OUTER_SCALE_DIVZERO_FLOOR, + _SCALE_FORMAT_TABLE, + _calculate_inner_scales, + _check_scale_format, + _scale_format_is_ue5m3, + compute_dynamic_outer_scale, + is_cdna4, +) +from .pack_unpack import _pack_fp4, _unpack_fp4 +from .kernels import _convert_from_fp4_kernel, _convert_to_fp4_kernel +from .api import _convert_from_fp4_blockwise, _convert_to_fp4_blockwise + +__all__ = ( + "BLOCK_SIZE_DEFAULT", + "F4_E2M1_MAX", + "SUPPORTED_SCALE_FORMATS", + "_OUTER_SCALE_DIVZERO_FLOOR", + "_SCALE_FORMAT_TABLE", + "_calculate_inner_scales", + "_check_scale_format", + "_convert_from_fp4_blockwise", + "_convert_from_fp4_kernel", + "_convert_to_fp4_blockwise", + "_convert_to_fp4_kernel", + "_pack_fp4", + "_scale_format_is_ue5m3", + "_unpack_fp4", + "compute_dynamic_outer_scale", + "is_cdna4", +) diff --git a/alto/kernels/fp4/outer_scaled_fp4/api.py b/alto/kernels/fp4/outer_scaled_fp4/api.py new file mode 100644 index 0000000..808c9b2 --- /dev/null +++ b/alto/kernels/fp4/outer_scaled_fp4/api.py @@ -0,0 +1,293 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Shared FP4 blockwise quantize / dequantize implementation. + +This module hosts the *single* quant / dequant body that both the NVFP4 +recipe (``alto::convert_to_nvfp4``) and the AMD-FP4 recipe +(``alto::convert_to_amdfp4``) register on top of as ``triton_op``. Each +recipe-side wrapper is a 5-line forwarder that pins ``scale_format`` to +``"e4m3"`` / ``"ue5m3"`` respectively; the heavy lifting (input +validation, memory allocation, outer-scale resolution, kernel launch, +shape book-keeping) lives here exactly once. + +Keeping this Python function *not* decorated with ``@triton_op`` lets us +stack two different ATen-registered ops on top of one shared body +without paying an extra dispatcher hop in the inner call. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import triton +from torch.library import wrap_triton + +from .kernels import _convert_from_fp4_kernel, _convert_to_fp4_kernel +from .scales import ( + BLOCK_SIZE_DEFAULT, + _check_scale_format, + _scale_format_is_ue5m3, + compute_dynamic_outer_scale, +) + + +def _convert_to_fp4_blockwise( + data_hp: torch.Tensor, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + update_outer_scale: bool = True, + scale_format: str = "e4m3", + use_sr: bool = False, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_asm: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """FP4 blockwise quantize body shared by NVFP4 and AMD-FP4. + + Outer-level (NVFP4 spec ``s_global``) scaling is controlled by + ``outer_scale`` and ``update_outer_scale``. ``outer_scale`` is today a + 1-element FP32 tensor representing a per-tensor scale; the parameter is + deliberately named ``outer_scale`` (not ``per_tensor_scale``) so a future + outer-blockwise layout can reuse the same surface without renaming. + + * ``outer_scale`` given, ``update_outer_scale=True`` (default): + recompute the scale from ``data_hp``'s amax and write it back into the + caller's tensor **in place** (no clone). The caller reads the updated + value back through the same tensor. This is the recommended path for + training, where the outer scale tracks the current tensor's range. + * ``outer_scale`` given, ``update_outer_scale=False``: + use the caller-provided scale as-is (for calibrated / frozen scales). + * ``outer_scale=None``, ``update_outer_scale=True`` (default): + compute a dynamic scale internally and apply it for this call only. + The scale is not returned; if the caller wants to track it across calls + they should pre-allocate a buffer and pass it in. + * ``outer_scale=None``, ``update_outer_scale=False``: + outer-level scaling is disabled. + + *scale_format* selects the per-block inner-scale grid: ``"e4m3"`` (NVFP4, + default) or ``"ue5m3"`` (AMD-FP4). + + ``use_asm`` is accepted for API consistency with MXFP4 but has no effect. + """ + torch._check( + data_hp.shape[axis] % block_size == 0, + lambda: f"tensor shape ({data_hp.shape}) at axis={axis} is not divisible by {block_size}", + ) + assert data_hp.dtype in [torch.float32, torch.bfloat16] + assert block_size % 2 == 0 and block_size >= 2, ( + f"block_size must be a positive even number, got {block_size}" + ) + assert not is_2d_block or data_hp.size(-2) % block_size == 0, ( + f"2D block requires dim -2 ({data_hp.size(-2)}) divisible by block_size ({block_size})" + ) + _check_scale_format(scale_format) + + data_hp = data_hp.transpose(axis, -1) + ori_shape = data_hp.shape + data_hp = data_hp.reshape(-1, ori_shape[-1]) + + new_shape = (*ori_shape[:-1], ori_shape[-1] // 2) + if is_2d_block: + scale_shape = (*ori_shape[:-2], ori_shape[-2] // block_size, + ori_shape[-1] // block_size) + else: + scale_shape = (*ori_shape[:-1], ori_shape[-1] // block_size) + data_lp = torch.empty(new_shape, dtype=torch.uint8, device=data_hp.device).reshape( + -1, new_shape[-1] + ) + scales = torch.empty(scale_shape, dtype=torch.float32, device=data_hp.device).reshape( + -1, scale_shape[-1] + ) + + # Resolve outer-scale I/O on the caller's own buffer — no clone. + if outer_scale is not None: + assert outer_scale.numel() == 1, "outer_scale must be a scalar tensor" + assert outer_scale.dtype == torch.float32, ( + "outer_scale must be float32" + ) + assert outer_scale.device == data_hp.device, ( + f"outer_scale device ({outer_scale.device}) must match " + f"data_hp device ({data_hp.device})" + ) + if update_outer_scale: + outer_scale.copy_( + compute_dynamic_outer_scale( + data_hp, scale_format=scale_format, + ).reshape_as(outer_scale) + ) + outer_scale_buf = outer_scale.reshape(()) + use_outer_scale = True + elif update_outer_scale: + # No buffer supplied — compute an ephemeral scale for this call. + outer_scale_buf = compute_dynamic_outer_scale( + data_hp, scale_format=scale_format, + ).reshape(()) + use_outer_scale = True + else: + outer_scale_buf = torch.ones((), dtype=torch.float32, device=data_hp.device) + use_outer_scale = False + + scale_format_is_ue5m3 = _scale_format_is_ue5m3(scale_format) + stride_xm, stride_xn = data_hp.stride() + stride_ym, stride_yn = data_lp.stride() + stride_sm, stride_sn = scales.stride() + + M, N = data_hp.shape + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]), triton.cdiv(N, META["BLOCK_N"])) + BLOCK_M = 64 if M >= 64 else M + BLOCK_N = 64 if N >= 64 else N + + if philox_seed is None: + philox_seed = torch.randint(0, 2**31 - 1, (1,)).item() + if philox_offset is None: + philox_offset = torch.randint(0, 2**31 - 1, (1,)).item() + + wrap_triton(_convert_to_fp4_kernel)[grid]( + data_hp, + data_lp, + scales, + outer_scale_buf, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + philox_seed, + philox_offset, + M, + N, + data_lp.shape[1], + scales.shape[0], + scales.shape[1], + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + QUANT_BLOCK_SIZE=block_size, + IS_2D_BLOCK=is_2d_block, + USE_OUTER_SCALE=use_outer_scale, + USE_SR=use_sr, + SCALE_FORMAT_IS_UE5M3=scale_format_is_ue5m3, + ) + + return ( + data_lp.reshape(new_shape).transpose(axis, -1), + scales.reshape(scale_shape).transpose(axis, -1), + ) + + +def _convert_from_fp4_blockwise( + data_lp: torch.Tensor, + scales: torch.Tensor, + output_dtype: torch.dtype = torch.float32, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[torch.Tensor] = None, + scale_format: str = "e4m3", + use_asm: Optional[bool] = None, +) -> torch.Tensor: + """FP4 blockwise dequantize body shared by NVFP4 and AMD-FP4. + + ``outer_scale`` is the optional outer-level FP32 scalar (NVFP4 spec + ``s_global``); see :func:`_convert_to_fp4_blockwise` for the naming + rationale. + + *scale_format* is accepted for API symmetry; the dequantization path + only multiplies by the stored float32 scale, so the format parameter + does not affect the computation. + + ``use_asm`` is accepted for API consistency with MXFP4 but has no effect. + """ + assert output_dtype in [torch.float32, torch.bfloat16] + assert block_size % 2 == 0 and block_size >= 2, ( + f"block_size must be a positive even number, got {block_size}" + ) + _check_scale_format(scale_format) + + data_lp = data_lp.transpose(axis, -1) + scales = scales.transpose(axis, -1) + orig_shape_lp = data_lp.shape + + data_lp = data_lp.reshape(-1, orig_shape_lp[-1]) + scales = scales.reshape(-1, (orig_shape_lp[-1] * 2) // block_size).to(torch.float32) + orig_shape_hp = (*orig_shape_lp[:-1], orig_shape_lp[-1] * 2) + data_hp = data_lp.new_empty(orig_shape_hp, dtype=output_dtype).reshape(-1, orig_shape_hp[-1]) + + if outer_scale is None: + outer_scale_buf = torch.ones((), dtype=torch.float32, device=data_lp.device) + use_outer_scale = False + else: + outer_scale = outer_scale.to(device=data_lp.device, dtype=torch.float32) + assert outer_scale.numel() == 1, "outer_scale must be a scalar tensor" + outer_scale_buf = outer_scale.reshape(()) + use_outer_scale = True + + stride_xm, stride_xn = data_lp.stride() + stride_ym, stride_yn = data_hp.stride() + stride_sm, stride_sn = scales.stride() + M, N = data_hp.shape + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]), triton.cdiv(N, META["BLOCK_N"])) + BLOCK_M = 64 if M >= 64 else M + BLOCK_N = 64 if N >= 64 else N + + wrap_triton(_convert_from_fp4_kernel)[grid]( + data_lp, + data_hp, + scales, + outer_scale_buf, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + M, + N, + data_lp.shape[1], + scales.shape[0], + scales.shape[1], + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + QUANT_BLOCK_SIZE=block_size, + IS_2D_BLOCK=is_2d_block, + USE_OUTER_SCALE=use_outer_scale, + ) + + return data_hp.reshape(orig_shape_hp).transpose(axis, -1) + + +def _fake_convert_to_fp4_blockwise( + data_hp: torch.Tensor, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Shape-only fake for :func:`_convert_to_fp4_blockwise`.""" + data_hp = data_hp.transpose(axis, -1) + orig_shape = data_hp.shape + + new_shape = (*orig_shape[:-1], orig_shape[-1] // 2) + if is_2d_block: + scale_shape = (*orig_shape[:-2], orig_shape[-2] // block_size, + orig_shape[-1] // block_size) + else: + scale_shape = (*orig_shape[:-1], orig_shape[-1] // block_size) + data_lp = data_hp.new_empty(new_shape, dtype=torch.uint8) + scales = data_hp.new_empty(scale_shape, dtype=torch.float32) + return data_lp.transpose(axis, -1), scales.transpose(axis, -1) + + +def _fake_convert_from_fp4_blockwise( + data_lp: torch.Tensor, + output_dtype: torch.dtype = torch.float32, + axis: int = -1, +) -> torch.Tensor: + """Shape-only fake for :func:`_convert_from_fp4_blockwise`.""" + data_hp = data_lp.new_empty(data_lp.shape, dtype=output_dtype) + return torch.cat((data_hp, data_hp), dim=axis) diff --git a/alto/kernels/fp4/outer_scaled_fp4/kernels.py b/alto/kernels/fp4/outer_scaled_fp4/kernels.py new file mode 100644 index 0000000..47fa42f --- /dev/null +++ b/alto/kernels/fp4/outer_scaled_fp4/kernels.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Top-level Triton kernels for FP4 blockwise quantize / dequantize. + +These kernels are shared by NVFP4 (E4M3 inner) and AMD-FP4 (UE5M3 inner). +The inner-grid choice flows in via the ``SCALE_FORMAT_IS_UE5M3`` +constexpr; everything else (E2M1 nibble packing, tile geometry, outer +scale handling) is identical between the two recipes. +""" + +from __future__ import annotations + +import triton +import triton.language as tl + +from .pack_unpack import _pack_fp4, _unpack_fp4 +from .scales import _calculate_inner_scales + + +@triton.jit +def _convert_to_fp4_kernel( + x_ptr, + y_ptr, + s_ptr, + outer_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + philox_seed, + philox_offset, + M_ACTUAL, + N_ACTUAL, + PACKED_N_ACTUAL, + SCALE_M_ACTUAL, + SCALE_N_ACTUAL, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + IS_2D_BLOCK: tl.constexpr, + USE_OUTER_SCALE: tl.constexpr, + USE_SR: tl.constexpr, + SCALE_FORMAT_IS_UE5M3: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 + SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_xn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_yn = pid_n * HALF_BLOCK_N + tl.arange(0, HALF_BLOCK_N) + offs_sn = pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N) + if IS_2D_BLOCK: + offs_sm = pid_m * SCALE_BLOCK_M + tl.arange(0, SCALE_BLOCK_M) + else: + offs_sm = offs_m + + offs_x = offs_m[:, None] * stride_xm + offs_xn[None, :] * stride_xn + offs_y = offs_m[:, None] * stride_ym + offs_yn[None, :] * stride_yn + offs_s = offs_sm[:, None] * stride_sm + offs_sn[None, :] * stride_sn + + tl.static_assert( + (x_ptr.type.element_ty == tl.float32) | (x_ptr.type.element_ty == tl.bfloat16) + ) + x = tl.load( + x_ptr + offs_x, + mask=(offs_m[:, None] < M_ACTUAL) & (offs_xn[None, :] < N_ACTUAL), + other=0, + ) + + inner_scale, quant_scale = _calculate_inner_scales( + x, + outer_scale_ptr, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, + IS_2D_BLOCK=IS_2D_BLOCK, + USE_OUTER_SCALE=USE_OUTER_SCALE, + SCALE_FORMAT_IS_UE5M3=SCALE_FORMAT_IS_UE5M3, + ) + + y = _pack_fp4( + x, + quant_scale, + philox_seed, + philox_offset, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, + IS_2D_BLOCK=IS_2D_BLOCK, + USE_SR=USE_SR, + ) + + tl.store( + y_ptr + offs_y, + y.to(y_ptr.type.element_ty), + mask=(offs_m[:, None] < M_ACTUAL) & (offs_yn[None, :] < PACKED_N_ACTUAL), + ) + tl.store( + s_ptr + offs_s, + inner_scale, + mask=(offs_sm[:, None] < SCALE_M_ACTUAL) & (offs_sn[None, :] < SCALE_N_ACTUAL), + ) + + +@triton.jit +def _convert_from_fp4_kernel( + x_ptr, + y_ptr, + s_ptr, + outer_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + M_ACTUAL, + N_ACTUAL, + PACKED_N_ACTUAL, + SCALE_M_ACTUAL, + SCALE_N_ACTUAL, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + IS_2D_BLOCK: tl.constexpr, + USE_OUTER_SCALE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 + SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_xn = pid_n * HALF_BLOCK_N + tl.arange(0, HALF_BLOCK_N) + offs_yn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_sn = pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N) + if IS_2D_BLOCK: + offs_sm = pid_m * SCALE_BLOCK_M + tl.arange(0, SCALE_BLOCK_M) + else: + offs_sm = offs_m + + offs_x = offs_m[:, None] * stride_xm + offs_xn[None, :] * stride_xn + offs_y = offs_m[:, None] * stride_ym + offs_yn[None, :] * stride_yn + offs_s = offs_sm[:, None] * stride_sm + offs_sn[None, :] * stride_sn + + x = tl.load( + x_ptr + offs_x, + mask=(offs_m[:, None] < M_ACTUAL) & (offs_xn[None, :] < PACKED_N_ACTUAL), + other=0, + ) + s = tl.load( + s_ptr + offs_s, + mask=(offs_sm[:, None] < SCALE_M_ACTUAL) & (offs_sn[None, :] < SCALE_N_ACTUAL), + other=0, + ) + + if USE_OUTER_SCALE: + outer_scale = tl.load(outer_scale_ptr) + s = s * outer_scale + + y = _unpack_fp4( + x, + s, + y_ptr.type.element_ty, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + QUANT_BLOCK_SIZE=QUANT_BLOCK_SIZE, + IS_2D_BLOCK=IS_2D_BLOCK, + ) + + tl.store( + y_ptr + offs_y, + y.to(y_ptr.type.element_ty), + mask=(offs_m[:, None] < M_ACTUAL) & (offs_yn[None, :] < N_ACTUAL), + ) diff --git a/alto/kernels/fp4/outer_scaled_fp4/pack_unpack.py b/alto/kernels/fp4/outer_scaled_fp4/pack_unpack.py new file mode 100644 index 0000000..c5da411 --- /dev/null +++ b/alto/kernels/fp4/outer_scaled_fp4/pack_unpack.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Triton ``_pack_fp4`` / ``_unpack_fp4`` for the NVFP4 family. + +These kernels handle the **E2M1 nibble-pack layout**, which is identical +between NVFP4 (E4M3 inner scale) and AMD-FP4 (UE5M3 inner scale): the +per-block scale arrives as a pre-snapped FP32 scalar, so this layer is +deliberately format-agnostic. + +The MXFP4 family has its *own* ``_pack_fp4`` / ``_unpack_fp4`` in +``alto.kernels.fp4.mxfp4`` because MXFP4 stores the scale as a uint8 +exponent and converts internally; that path is not shared. +""" + +from __future__ import annotations + +import triton +import triton.language as tl + +from alto.kernels.fp4.fp4_primitives import ( + make_dequantize_e2m1, + make_generate_philox_randval_2x, + make_quantize_e2m1, +) + + +# Module-level Triton helpers consumed by the @jit factories below. Re-bind +# under the same names the callers used inside ``nvfp_quantization.py`` so +# Triton's NON_CONSTEXPR_GLOBALS lookup keeps resolving them. +_dequantize_e2m1 = make_dequantize_e2m1() +_generate_philox_randval_2x = make_generate_philox_randval_2x() +_quantize_e2m1 = make_quantize_e2m1() + + +@triton.jit +def _pack_fp4( + x, + scales_fp32, + philox_seed, + philox_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + IS_2D_BLOCK: tl.constexpr = False, + USE_SR: tl.constexpr = False, + USE_ASM: tl.constexpr = False, +): + """Quantize and pack a tile into nibble-packed uint8. + + Mirrors the API of ``mxfp4._pack_fp4``. The key difference is that + *scales_fp32* is already a per-block **float32** tensor, whereas MXFP4 + passes uint8 exponents and converts them internally. + + ``USE_ASM`` is accepted for API parity but has no effect — CDNA4 FP4 + ASM instructions only honour the biased exponent of the scale operand, + making them incompatible with NVFP4's general float32 scales. + + Args: + x: input tile ``[BLOCK_M, BLOCK_N]`` (float32 | bfloat16) + scales_fp32: per-block float32 scales. + 1D: ``[BLOCK_M, SCALE_BLOCK_N]``; + 2D: ``[SCALE_BLOCK_M, SCALE_BLOCK_N]``. + BLOCK_M / BLOCK_N / QUANT_BLOCK_SIZE: tile constants + IS_2D_BLOCK: if True, use 2D (square) blocks for scaling + USE_SR: enable stochastic rounding + USE_ASM: (no-op) kept for API consistency with MXFP4 + + Returns: + packed uint8 tile ``[BLOCK_M, HALF_BLOCK_N]`` + """ + HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 + HALF_QUANT_BLOCK_SIZE: tl.constexpr = QUANT_BLOCK_SIZE // 2 + SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE + + if IS_2D_BLOCK: + scales_bc = scales_fp32.expand_dims(axis=(1, 3)).broadcast_to( + SCALE_BLOCK_M, QUANT_BLOCK_SIZE, SCALE_BLOCK_N, + HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) + else: + scales_bc = scales_fp32.expand_dims(axis=2).broadcast_to( + BLOCK_M, SCALE_BLOCK_N, + HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) + + x0, x1 = tl.split(x.reshape(BLOCK_M, HALF_BLOCK_N, 2)) + + if USE_SR: + randval0, randval1 = _generate_philox_randval_2x( + BLOCK_M, HALF_BLOCK_N, philox_seed, philox_offset) + else: + randval0 = 0 + randval1 = 0 + + y0 = _quantize_e2m1(x0, scales_bc, randval0, USE_SR=USE_SR) + y1 = _quantize_e2m1(x1, scales_bc, randval1, USE_SR=USE_SR) + y = y0 | (y1 << 4) + + return y.to(tl.uint8) + + +@triton.jit +def _unpack_fp4( + x, + scales_fp32, + output_dtype: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + IS_2D_BLOCK: tl.constexpr = False, + USE_ASM: tl.constexpr = False, +): + """Unpack and dequantize a nibble-packed uint8 tile back to float. + + Mirrors the API of ``mxfp4._unpack_fp4``. *scales_fp32* is a per-block + **float32** tensor; see :func:`_pack_fp4` for details on the + MXFP4/NVFP4 scale difference. + + ``USE_ASM`` is accepted for API parity but has no effect. + + Args: + x: packed uint8 tile ``[BLOCK_M, HALF_BLOCK_N]`` + scales_fp32: per-block float32 scales. + 1D: ``[BLOCK_M, SCALE_BLOCK_N]``; + 2D: ``[SCALE_BLOCK_M, SCALE_BLOCK_N]``. + output_dtype: target element type (tl.float32 | tl.bfloat16). + Kept for API parity with MXFP4; the software path + always computes in float32. + BLOCK_M / BLOCK_N / QUANT_BLOCK_SIZE: tile constants + IS_2D_BLOCK: if True, use 2D (square) blocks for scaling + USE_ASM: (no-op) kept for API consistency with MXFP4 + + Returns: + unpacked float32 tile ``[BLOCK_M, BLOCK_N]`` + """ + HALF_BLOCK_N: tl.constexpr = BLOCK_N // 2 + HALF_QUANT_BLOCK_SIZE: tl.constexpr = QUANT_BLOCK_SIZE // 2 + SCALE_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE + + if IS_2D_BLOCK: + scales_bc = scales_fp32.expand_dims(axis=(1, 3)).broadcast_to( + SCALE_BLOCK_M, QUANT_BLOCK_SIZE, SCALE_BLOCK_N, + HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) + else: + scales_bc = scales_fp32.expand_dims(axis=2).broadcast_to( + BLOCK_M, SCALE_BLOCK_N, + HALF_QUANT_BLOCK_SIZE).reshape(BLOCK_M, HALF_BLOCK_N) + + x0 = x & 0xF + x1 = (x & 0xF0) >> 4 + y0 = _dequantize_e2m1(x0, scales_bc) + y1 = _dequantize_e2m1(x1, scales_bc) + + y = tl.join(y0, y1).reshape(BLOCK_M, BLOCK_N) + return y diff --git a/alto/kernels/fp4/outer_scaled_fp4/scales.py b/alto/kernels/fp4/outer_scaled_fp4/scales.py new file mode 100644 index 0000000..f6daac1 --- /dev/null +++ b/alto/kernels/fp4/outer_scaled_fp4/scales.py @@ -0,0 +1,203 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Inner-scale grid table, format dispatch, and outer-scale FP32 helper. + +This module owns the *single source of truth* for the per-block scale +hierarchy used by NVFP4 (E4M3 inner) and AMD-FP4 (UE5M3 inner): + +* ``_SCALE_FORMAT_TABLE`` maps a string format name to ``(eps, max)``. +* ``_calculate_inner_scales`` is the Triton helper that turns a 2D + tile into per-block inner scales on the chosen grid, with the D2' + NaN-defense layer applied symmetrically across both formats. +* :func:`compute_dynamic_outer_scale` produces the FP32 outer-level + scale (NVFP4 spec ``s_global``) from a tensor's amax. + +Naming convention used throughout this sub-package: + +* ``inner_scale`` — per-block scale stored alongside the packed FP4 + data (NVFP4 spec ``s_block``). Lives in an FP32 container with the + value snapped to the chosen inner grid (E4M3 or UE5M3). +* ``outer_scale`` — the outer-level scale that sits above + ``inner_scale`` (NVFP4 spec ``s_global``). Today a per-tensor FP32 + scalar; named agnostically so a future outer-blockwise layout (e.g. + one scale per 128x128 tile) can reuse the same surface without a + rename. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from alto.kernels.fp4.fp4_primitives import ( + E4M3_EPS, + F8E4M3_MAX, + UE5M3_EPS, + UE5M3_MAX, + quantize_e4m3, + quantize_ue5m3, +) + + +BLOCK_SIZE_DEFAULT = 16 +F4_E2M1_MAX = 6.0 + +# Per spec, ``outer_scale`` lives in FP32; this floor is a div-by-zero +# guard for the downstream ``max_abs / outer_scale`` when ``amax == 0``. +# We pick ``1e-30`` (well above FP32 denormal range and ~22 orders below +# any natural training-time outer scale) so that the effective per-block +# divisor ``quant_scale = inner_scale * outer_scale`` stays an FP32 +# *normal* in the worst case (``inner_scale == fmt_eps``, +# ``outer_scale == _OUTER_SCALE_DIVZERO_FLOOR``). Both inner grids stay +# above FP32 min normal (~1.18e-38): +# E4M3 : 1e-30 * E4M3_EPS(2**-6) ≈ 1.56e-32 > 1.18e-38 +# UE5M3: 1e-30 * UE5M3_EPS(2**-17)≈ 7.63e-36 > 1.18e-38 +# The UE5M3 case is the tighter of the two and is additionally guarded by +# the import-time assert in ``fp4_primitives/ue5m3_ops.py``. +_OUTER_SCALE_DIVZERO_FLOOR = 1.0e-30 + +# Per inner-scale dtype: (eps, max) used for clamp + round on ``s_block``. +_SCALE_FORMAT_TABLE: dict[str, tuple[float, float]] = { + "e4m3": (float(E4M3_EPS), float(F8E4M3_MAX)), + "ue5m3": (float(UE5M3_EPS), float(UE5M3_MAX)), +} +SUPPORTED_SCALE_FORMATS = tuple(_SCALE_FORMAT_TABLE.keys()) + +# Triton-side mirrors so ``@triton.jit`` constexpr slots see plain floats. +_E4M3_EPS_TRITON = float(E4M3_EPS) +_E4M3_MAX_TRITON = float(F8E4M3_MAX) +_UE5M3_EPS_TRITON = float(UE5M3_EPS) +_UE5M3_MAX_TRITON = float(UE5M3_MAX) + + +def _check_scale_format(scale_format: str) -> None: + if scale_format not in _SCALE_FORMAT_TABLE: + raise NotImplementedError( + f"scale_format={scale_format!r} is not supported. " + f"Supported: {SUPPORTED_SCALE_FORMATS}" + ) + + +def _scale_format_is_ue5m3(scale_format: str) -> bool: + _check_scale_format(scale_format) + return scale_format == "ue5m3" + + +def is_cdna4(): + target = triton.runtime.driver.active.get_current_target() + return target is not None and target.backend == "hip" and target.arch == "gfx950" + + +def compute_dynamic_outer_scale( + data_hp: torch.Tensor, + scale_format: str = "e4m3", +) -> torch.Tensor: + """Compute the FP32 outer-level scale ``amax / (max_fmt * F4_E2M1_MAX)``. + + The "outer" naming reflects this scalar's position in the NVFP4 + hierarchy: it sits above the per-block ``inner_scale`` and is shared + across the entire tensor today (NVFP4 spec ``s_global``). Future + extensions may produce one outer scale per outer-block tile; this + function and the surrounding API are named to accommodate that + without further renames. + + Per spec, the outer scale stays in FP32 with only a + ``_OUTER_SCALE_DIVZERO_FLOOR`` div-by-zero guard (``amax == 0``). + + *scale_format* selects the inner-scale dtype whose ``max`` participates + in the outer-scale formula (E4M3 for NVFP4, UE5M3 for AMD-FP4). + """ + _check_scale_format(scale_format) + _, scale_max = _SCALE_FORMAT_TABLE[scale_format] + # D2' NaN defense (symmetric with the per-block ``tl.where`` guard in + # ``_calculate_inner_scales``): ``torch.clamp`` does NOT remove NaN, so a + # single NaN element would otherwise make ``amax`` -> NaN and, because the + # outer scale is per-tensor, poison *every* block's ``inner_scale_raw`` + # (bypassing the per-block guard). Sanitise the reduced amax to 0 here so a + # NaN spike degrades to "that block goes to 0", not "whole tensor is NaN". + amax = torch.nan_to_num(data_hp.float().abs().max(), nan=0.0) + outer_scale = (amax / (scale_max * F4_E2M1_MAX)).clamp(min=_OUTER_SCALE_DIVZERO_FLOOR) + return outer_scale.to(dtype=torch.float32).reshape(1) + + +@triton.jit +def _calculate_inner_scales( + x, + outer_scale_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + IS_2D_BLOCK: tl.constexpr = False, + USE_OUTER_SCALE: tl.constexpr = False, + SCALE_FORMAT_IS_UE5M3: tl.constexpr = False, +): + """Compute per-block inner scales for NVFP4 / AMD-FP4 quantization. + + The returned ``inner_scale`` is the NVFP4 spec ``s_block``; the + ``outer_scale`` operand is the spec ``s_global``. See the + naming-convention block at the top of this module for the rationale. + + Per spec, the per-block scale is snapped to the selected inner grid + (E4M3 or UE5M3) and stored in an FP32 container; the outer scale and + intermediates stay in FP32. With ``USE_OUTER_SCALE=True`` the spec + order is:: + + inner_scale_raw = block_amax(x) / outer_scale / F4_E2M1_MAX + inner_scale = round_fmt(clamp(inner_scale_raw, [eps_fmt, max_fmt])) + quant_scale = inner_scale * outer_scale + + i.e. clamp + inner-grid round are applied exactly once, on the final + stored block (inner) scale. + + When ``IS_2D_BLOCK`` is True, one scale covers a + ``QUANT_BLOCK_SIZE x QUANT_BLOCK_SIZE`` tile, yielding output shapes + ``[SCALE_BLOCK_M, SCALE_BLOCK_N]`` instead of ``[BLOCK_M, SCALE_BLOCK_N]``. + """ + NEW_BLOCK_N: tl.constexpr = BLOCK_N // QUANT_BLOCK_SIZE + + if IS_2D_BLOCK: + NEW_BLOCK_M: tl.constexpr = BLOCK_M // QUANT_BLOCK_SIZE + x_grouped = x.reshape(NEW_BLOCK_M, QUANT_BLOCK_SIZE, + NEW_BLOCK_N, QUANT_BLOCK_SIZE) + max_abs = tl.max(tl.abs(x_grouped), axis=-1) + max_abs = tl.max(max_abs, axis=-2).to(tl.float32) + else: + x_grouped = x.reshape(BLOCK_M, NEW_BLOCK_N, QUANT_BLOCK_SIZE) + max_abs = tl.max(tl.abs(x_grouped), axis=-1).to(tl.float32) + + # NaN defense (D2'): if any element of the input block is NaN, ``tl.max`` + # propagates NaN. Sanitise here so the inner-scale cast (UE5M3 0xFF or + # E4M3 0xFF/0x7F) never produces a NaN code that could contaminate + # downstream GEMM scales. Mirrors the TransformerEngine / vLLM / + # TRT-LLM "kernel-side mask" pattern; symmetric across both formats. + max_abs = tl.where(max_abs != max_abs, 0.0, max_abs) + + if SCALE_FORMAT_IS_UE5M3: + SCALE_EPS: tl.constexpr = _UE5M3_EPS_TRITON + SCALE_MAX: tl.constexpr = _UE5M3_MAX_TRITON + else: + SCALE_EPS: tl.constexpr = _E4M3_EPS_TRITON + SCALE_MAX: tl.constexpr = _E4M3_MAX_TRITON + + if USE_OUTER_SCALE: + outer_scale = tl.load(outer_scale_ptr) + inner_scale_raw = max_abs / outer_scale / F4_E2M1_MAX + inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, SCALE_EPS), SCALE_MAX) + if SCALE_FORMAT_IS_UE5M3: + inner_scale = quantize_ue5m3(inner_scale_raw) + else: + inner_scale = quantize_e4m3(inner_scale_raw) + quant_scale = inner_scale * outer_scale + else: + inner_scale_raw = max_abs / F4_E2M1_MAX + inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, SCALE_EPS), SCALE_MAX) + if SCALE_FORMAT_IS_UE5M3: + inner_scale = quantize_ue5m3(inner_scale_raw) + else: + inner_scale = quantize_e4m3(inner_scale_raw) + quant_scale = inner_scale + + return inner_scale, quant_scale diff --git a/alto/kernels/mxfp8/mxfp8_linear.py b/alto/kernels/mxfp8/mxfp8_linear.py index 27e5aaf..1d67a39 100644 --- a/alto/kernels/mxfp8/mxfp8_linear.py +++ b/alto/kernels/mxfp8/mxfp8_linear.py @@ -12,7 +12,7 @@ import triton import triton.language as tl -from alto.kernels.fp4.fp4_common import unwrap_weight_wrapper +from alto.kernels.fp4.fp4_primitives import unwrap_weight_wrapper from .mxfp8_quantization import ( BLOCK_SIZE_DEFAULT, diff --git a/alto/modifiers/lpt/base.py b/alto/modifiers/lpt/base.py index b46ee87..6c9f748 100644 --- a/alto/modifiers/lpt/base.py +++ b/alto/modifiers/lpt/base.py @@ -58,7 +58,7 @@ def validate_targets(cls, value: str | list[str]) -> list[str]: @field_validator("scheme", mode="before") def validate_scheme(cls, value: str | dict[str, str | list[str]]) -> str | dict[str, list[str]]: - if isinstance(value, str) and value not in ["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4"]: + if isinstance(value, str) and value not in ["mxfp4", "mxfp8_e4m3", "mxfp8_e5m2", "nvfp4", "amdfp4"]: raise ValueError(f"Unsupported training op scheme: {value}") if isinstance(value, dict): @@ -77,10 +77,10 @@ def validate_lora_rank_alignment(self): schemes = self.scheme if isinstance(self.scheme, dict) else {self.scheme: None} for scheme_name in schemes: - if scheme_name == "nvfp4": + if scheme_name in ("nvfp4", "amdfp4"): if self.lora_rank % 16 != 0: raise ValueError( - f"lora_rank must be divisible by 16 for nvfp4, got {self.lora_rank}" + f"lora_rank must be divisible by 16 for {scheme_name}, got {self.lora_rank}" ) elif scheme_name in ("mxfp4", "mxfp8_e4m3", "mxfp8_e5m2"): if self.lora_rank % 32 != 0: diff --git a/tests/unittest/amdfp4/__init__.py b/tests/unittest/amdfp4/__init__.py new file mode 100644 index 0000000..85eac52 --- /dev/null +++ b/tests/unittest/amdfp4/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT diff --git a/tests/unittest/amdfp4/conftest.py b/tests/unittest/amdfp4/conftest.py new file mode 100644 index 0000000..ab04183 --- /dev/null +++ b/tests/unittest/amdfp4/conftest.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Triton env + sibling-import shim for the AMD-FP4 test suite. + +This suite covers both layers of AMD-FP4: + +* Low-level UE5M3 dtype primitives (``test_amdfp_dtype.py``, + ``test_amdfp_triton_pytorch_parity.py``) — formerly under + ``tests/unittest/ue5m3/``. +* Recipe-level quantize / dequantize / linear / grouped-GEMM ops + pinned to the AMD-FP4 recipe (``test_amdfp_quantization.py``, + ``test_amdfp_linear.py``, ``test_amdfp_grouped_gemm.py``, + ``test_ab_matrix.py``). + +The Triton path needs ``TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1`` to read +the module-level ``_quantize_ue5m3`` / ``_quantize_e4m3`` JIT helpers, +and a per-suite cache dir so this suite does not collide with other +Triton suites in the repo. +""" + +import os +import sys + +import pytest + +# Sibling import: ``from nvfp4.utils import ...`` lets the AMD-FP4 oracle +# delegate to the NVFP4 PyTorch reference (recipe-level tests share the +# same kernel under different ``scale_format``). +_UNITTEST_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _UNITTEST_DIR not in sys.path: + sys.path.insert(0, _UNITTEST_DIR) + + +@pytest.fixture(autouse=True) +def _configure_triton_env(monkeypatch): + monkeypatch.setenv("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "1") + monkeypatch.setenv("TRITON_CACHE_DIR", "/tmp/triton-cache-amdfp4-tests") diff --git a/tests/unittest/amdfp4/test_ab_matrix.py b/tests/unittest/amdfp4/test_ab_matrix.py new file mode 100644 index 0000000..e732602 --- /dev/null +++ b/tests/unittest/amdfp4/test_ab_matrix.py @@ -0,0 +1,357 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""A/B matrix: NVFP4 (E4M3 inner) vs AMD-FP4 (UE5M3 inner) on quant patterns. + +Runs the C.1 stress patterns from ``docs/amd-fp4/implementation-validation-plan.md`` +and prints a tabulated summary (scale bit-exactness, SQNR, saturation) for both +inner-scale formats on the same input tensor. + +Hard gates only apply where format choice can actually distinguish E4M3 from +UE5M3 — i.e. without a dynamic outer scale, since the outer scale would +otherwise normalize each tensor's range into ``[0, max_fmt]`` and erase the +range/precision difference between the two inner grids. + +Saturation is defined as the fraction of blocks whose unclamped +``inner_scale_raw = block_amax / outer / F4_E2M1_MAX`` would exceed +``max_fmt`` and therefore trigger the upper-rail inner-scale clamp. This +is the only saturation signal that is sensitive to the inner-grid choice; +counting "fraction of E2M1 elements at the rail" is by construction always +100% (the per-block divisor is built so the largest element in the block +exactly hits ``F4_E2M1_MAX``). +""" + +from __future__ import annotations + +import pytest +import torch +from tabulate import tabulate + +from alto.kernels.fp4.nvfp4.nvfp_quantization import ( + F4_E2M1_MAX, + _SCALE_FORMAT_TABLE, + compute_dynamic_outer_scale, + convert_from_nvfp4, + convert_to_nvfp4, +) + +from nvfp4.utils import ( # noqa: E402 (see amdfp4/conftest sys.path) + calc_cossim, + convert_from_nvfp4_pytorch, + convert_to_nvfp4_pytorch, + prepare_data, +) + +cuda_required = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="AMD-FP4 A/B matrix requires CUDA / ROCm", +) + +AB_PATTERNS = ( + "random", + "zeros", + "large", + "hot_channel", + "lognormal", + "near_overflow", + "near_underflow", + "single_spike", +) + +# Patterns that show up in the printed report (kept short for readability; +# the full matrix is still asserted on). +REPORT_PATTERNS = ( + "random", + "near_overflow", + "near_underflow", + "hot_channel", + "lognormal", +) + + +def _sqnr_db(x: torch.Tensor, x_hat: torch.Tensor) -> float: + num = (x.float() - x_hat.float()).pow(2).sum() + den = x.float().pow(2).sum().clamp(min=1e-30) + if num == 0: + return float("inf") + return (10.0 * torch.log10(den / num)).item() + + +def _inner_clamp_rate( + x: torch.Tensor, + scale_format: str, + outer_scale: torch.Tensor | None, + block_size: int = 16, +) -> float: + """Fraction of blocks whose unclamped ``inner_scale_raw`` exceeds ``max_fmt``. + + This is the inner-scale upper-rail saturation rate: blocks here lose + information because the per-block divisor cannot grow large enough to + fit the block's amplitude into FP4. UE5M3 is meaningfully better than + E4M3 only when this rate diverges between the two formats. + """ + fmt_max = _SCALE_FORMAT_TABLE[scale_format][1] + x2d = x.float().reshape(-1, x.shape[-1]) + M, N = x2d.shape + nb = N // block_size + grouped = x2d.reshape(M, nb, block_size) + max_abs = grouped.abs().amax(dim=-1) + if outer_scale is None: + outer = 1.0 + else: + outer = outer_scale.float().reshape(()).item() + inner_raw = max_abs / outer / F4_E2M1_MAX + return (inner_raw > fmt_max).float().mean().item() + + +def _run_one_cell( + pattern: str, + scale_format: str, + *, + use_outer_scale: bool, +) -> dict: + block_size = 16 + axis = -1 + data_type = torch.bfloat16 + shape = (128, 64) + + x = prepare_data(shape, data_type, pattern=pattern) + if use_outer_scale: + # Production path: outer is computed per-format (it scales by + # 1/(max_fmt * F4_E2M1_MAX), so UE5M3's outer is ~274× smaller than + # E4M3's on the same tensor — exactly the AMD-FP4 design point). + outer_scale = compute_dynamic_outer_scale(x, scale_format=scale_format) + else: + outer_scale = None + + data_lp_ref, scales_ref = convert_to_nvfp4_pytorch( + x, + block_size=block_size, + axis=axis, + outer_scale=outer_scale, + scale_format=scale_format, + ) + data_lp, scales = convert_to_nvfp4( + x, + block_size=block_size, + axis=axis, + outer_scale=outer_scale, + update_outer_scale=False, + scale_format=scale_format, + use_sr=False, + ) + + # The full kernel-vs-oracle bit-exact regression is covered by + # ``tests/unittest/nvfp4/test_nvfp_quantization.py`` (all axes / shapes / + # block layouts / SR / outer / scale_format). Here we only require + # numerical agreement: relative scale error must stay small. This makes + # the AB matrix robust against rail-boundary FP rounding differences (e.g. + # ``inner_raw == max_fmt`` evaluated in two different orders) while still + # catching real bugs that would push errors orders of magnitude above + # quantization noise. + scale_rel_err = ( + (scales_ref.float() - scales.float()).abs() + / scales_ref.float().abs().clamp(min=1e-30) + ).max().item() + scale_ok = scale_rel_err < 5e-3 + fp4_mismatch_rate = (data_lp_ref != data_lp).float().mean().item() + fp4_ok = fp4_mismatch_rate < 1e-3 + + x_dq_ref = convert_from_nvfp4_pytorch( + data_lp_ref, + scales_ref, + output_dtype=data_type, + block_size=block_size, + axis=axis, + outer_scale=outer_scale, + scale_format=scale_format, + ) + x_dq = convert_from_nvfp4( + data_lp, + scales, + output_dtype=data_type, + block_size=block_size, + axis=axis, + outer_scale=outer_scale, + scale_format=scale_format, + ) + + mae = (x_dq_ref - x_dq).abs().mean().item() + sqnr = _sqnr_db(x, x_dq) + cos = calc_cossim(x, x_dq) + sat = _inner_clamp_rate(x, scale_format, outer_scale, block_size=block_size) + finite = torch.isfinite(x_dq).all().item() + + return { + "scale_ok": scale_ok, + "scale_rel_err": scale_rel_err, + "fp4_ok": fp4_ok, + "fp4_mismatch_rate": fp4_mismatch_rate, + "mae": mae, + "sqnr_db": sqnr, + "cosine": cos, + "saturation": sat, + "finite": finite, + } + + +def _fmt_sqnr(v: float) -> str: + if v == float("inf"): + return "inf" + if v == float("-inf"): + return "-inf" + return f"{v:.2f}" + + +@cuda_required +@pytest.mark.parametrize("use_outer_scale", [False, True]) +def test_amdfp4_ab_matrix_report(use_outer_scale): + """Run all AB patterns × {E4M3, UE5M3} and assert format-distinguishing gates.""" + rows = [] + by_pattern_fmt = {} + + for pattern in AB_PATTERNS: + for fmt, label in (("e4m3", "NVFP4 (E4M3)"), ("ue5m3", "AMD-FP4 (UE5M3)")): + m = _run_one_cell(pattern, fmt, use_outer_scale=use_outer_scale) + by_pattern_fmt[(pattern, fmt)] = m + if pattern in REPORT_PATTERNS: + rows.append([ + pattern, + label, + f"{m['scale_rel_err']:.2e}", + _fmt_sqnr(m["sqnr_db"]), + f"{m['cosine']:.5f}", + f"{m['saturation'] * 100:.3f}%", + f"{m['mae']:.3e}", + "yes" if m["finite"] else "NO", + ]) + + hdr = [ + "pattern", "format", "scale rel err", "SQNR (dB)", "cosine", + "inner clamp", "MAE", "finite", + ] + print() + print(tabulate(rows, headers=hdr, tablefmt="github")) + print(f"\n(use_outer_scale={use_outer_scale})") + + # ---- Always-on invariants ------------------------------------------------ + # Bit-exact kernel-vs-oracle scale agreement is regression-tested by + # ``tests/unittest/nvfp4/test_nvfp_quantization.py`` over the full + # axis/shape/SR/outer/scale_format matrix. Here we only assert that the + # kernel and the oracle land within one UE5M3 grid step (≤ 1 mantissa + # ULP, i.e. 12.5%) — this catches catastrophic drift while tolerating + # rail-boundary code choices that hash to neighbouring UE5M3 codepoints + # but reconstruct to indistinguishable BF16 outputs (MAE stays tiny). + UE5M3_GRID_STEP = 0.125 + 1e-6 # one mantissa ULP, with FP slack. + for (pattern, fmt), m in by_pattern_fmt.items(): + assert m["finite"], f"non-finite dequant: pattern={pattern} format={fmt}" + assert m["scale_rel_err"] <= UE5M3_GRID_STEP, ( + f"scale drift > 1 inner-format ULP " + f"(kernel vs PyTorch oracle): pattern={pattern} format={fmt} " + f"rel_err={m['scale_rel_err']:.3e}" + ) + # MAE on dequant output stays small relative to tensor amplitude. + assert m["mae"] < 1.0, ( + f"dequant MAE {m['mae']:.3e} too large: " + f"pattern={pattern} format={fmt}" + ) + + # On the production outer-scale path the outer normalizes both formats' + # inner_scale ranges into ``[0, max_fmt]``, which erases the format + # advantage on every pattern by design. We only assert non-regression + # there; the format-distinguishing gates are checked on the + # ``use_outer_scale=False`` parametrize cell below. + + e4_rand = by_pattern_fmt[("random", "e4m3")] + ue5_rand = by_pattern_fmt[("random", "ue5m3")] + assert ue5_rand["sqnr_db"] >= e4_rand["sqnr_db"] - 0.5, ( + f"random: UE5M3 SQNR {ue5_rand['sqnr_db']:.2f} dB regressed vs " + f"E4M3 {e4_rand['sqnr_db']:.2f} dB by more than 0.5 dB" + ) + assert ue5_rand["cosine"] >= e4_rand["cosine"] - 0.001, ( + f"random: UE5M3 cosine {ue5_rand['cosine']:.5f} regressed vs " + f"E4M3 {e4_rand['cosine']:.5f}" + ) + + if not use_outer_scale: + # ---- near_overflow: x == 1e4 ---------------------------------------- + # E4M3_max=448 < 1e4/6 → every block clamps; reconstruction caps at + # 448*6 = 2688. UE5M3_max=114688 (D2', GFXIPARCH-2067 §19.10) + # ≫ 1e4/6 ≈ 1667 → no clamp. + e4_ov = by_pattern_fmt[("near_overflow", "e4m3")] + ue5_ov = by_pattern_fmt[("near_overflow", "ue5m3")] + assert e4_ov["saturation"] >= 0.99, ( + f"near_overflow: E4M3 should clamp ≥99% of blocks, got " + f"{e4_ov['saturation']:.4%}" + ) + assert ue5_ov["saturation"] <= 0.01, ( + f"near_overflow: UE5M3 should clamp ≤1% of blocks, got " + f"{ue5_ov['saturation']:.4%}" + ) + assert ue5_ov["sqnr_db"] >= e4_ov["sqnr_db"] + 10.0, ( + f"near_overflow: UE5M3 SQNR {ue5_ov['sqnr_db']:.2f} dB must beat " + f"E4M3 {e4_ov['sqnr_db']:.2f} dB by ≥10 dB" + ) + + # ---- near_underflow: x == 1e-4 -------------------------------------- + # 1e-4/6 ≈ 1.67e-5; E4M3_EPS = 2^-6 = 1.56e-2 ≫ 1.67e-5 → every block + # clamps to EPS, x/EPS ≈ 6.4e-3 rounds to E2M1=0 → reconstruction = 0. + # UE5M3_EPS = 2^-17 = 7.63e-6 < 1.67e-5 → UE5M3 keeps signal. + e4_un = by_pattern_fmt[("near_underflow", "e4m3")] + ue5_un = by_pattern_fmt[("near_underflow", "ue5m3")] + assert ue5_un["sqnr_db"] >= e4_un["sqnr_db"] + 5.0, ( + f"near_underflow: UE5M3 SQNR {ue5_un['sqnr_db']:.2f} dB must beat " + f"E4M3 {e4_un['sqnr_db']:.2f} dB by ≥5 dB" + ) + + +# --------------------------------------------------------------------------- +# D2' NaN-input A/B handling +# --------------------------------------------------------------------------- + +@cuda_required +@pytest.mark.parametrize("scale_format", ["e4m3", "ue5m3"]) +def test_amdfp4_nan_input_handling(scale_format): + """A NaN spike in the input MUST NOT corrupt either E4M3 or UE5M3 output. + + Under D2' both formats may emit a NaN code at cast time (E4M3 0xFF/0x7F, + UE5M3 0xFF), so the defense layer in ``_calculate_nvfp4_scales`` / + ``_quantize_inner_scale`` is what keeps downstream GEMM finite. This + test exercises the production kernel + dequant path on a tensor with a + NaN spike and asserts that the dequant output is finite. + + Mirrors industry pattern: TransformerEngine zero-inits padding scale, + vLLM zero-outs MoE padding, TRT-LLM clamps scale ≥ 1e-12. + """ + block_size = 16 + shape = (128, 64) + data_type = torch.bfloat16 + x = prepare_data(shape, data_type, pattern="random") + x[0, 7] = float("nan") + + data_lp, scales = convert_to_nvfp4( + x, + block_size=block_size, + axis=-1, + is_2d_block=False, + update_outer_scale=False, + scale_format=scale_format, + ) + assert torch.isfinite(scales.float()).all(), ( + f"{scale_format}: NaN input contaminated inner scales — D2' " + f"defense layer regression." + ) + + x_dq = convert_from_nvfp4( + data_lp, scales, + output_dtype=data_type, + block_size=block_size, + axis=-1, + is_2d_block=False, + scale_format=scale_format, + ) + assert torch.isfinite(x_dq).all(), ( + f"{scale_format}: NaN input leaked into dequant output — defense " + f"layer must keep downstream GEMM input finite." + ) diff --git a/tests/unittest/amdfp4/test_amdfp_dispatch_guards.py b/tests/unittest/amdfp4/test_amdfp_dispatch_guards.py new file mode 100644 index 0000000..e697e7f --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_dispatch_guards.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Dispatch-layer guard tests for ``precision='amdfp4'``. + +These tests lock down the AMD-FP4 entry of +:class:`alto.kernels.dispatch.tensor.NVFP4TrainingWeightWrapperTensor` (the +NVFP4 / AMD-FP4 family share one wrapper class): + +* ``TrainingOpConfig(precision='amdfp4')`` forces ``inner_scale_format`` + to ``"ue5m3"`` via ``__post_init__`` (``"e4m3"`` default is silently + upgraded; any explicit non-UE5M3 value is rejected). +* Linear / grouped_mm dispatch under ``precision='amdfp4'`` routes to the + AMD-FP4 thin wrappers (``_to_amdfp4_then_scaled_mm`` / + ``_quantize_then_amdfp4_scaled_grouped_mm``), not the NVFP4 ones. +* ``precision='amdfp4'`` end-to-end produces a finite output of the + expected shape on a small smoke shape (no silent fallback). +""" + +from __future__ import annotations + +import pytest +import torch + +import alto.kernels.dispatch.tensor as dispatch_tensor +from alto.kernels.dispatch.config import TrainingOpConfig +from alto.kernels.dispatch.conversion import swap_params +from alto.kernels.dispatch.tensor import NVFP4TrainingWeightWrapperTensor +from alto.kernels.fp4.amdfp4 import ALIGN_SIZE_M + + +def _make_amdfp4_config(**overrides) -> TrainingOpConfig: + defaults = dict( + precision="amdfp4", + use_2dblock_x=False, + use_2dblock_w=False, + use_hadamard=False, + use_sr_grad=False, + use_dge=False, + ) + defaults.update(overrides) + return TrainingOpConfig(**defaults) + + +@pytest.fixture +def device() -> torch.device: + if not torch.cuda.is_available(): + pytest.skip("AMD-FP4 dispatch guards require a CUDA device") + return torch.device("cuda") + + +# --------------------------------------------------------------------------- +# Config-level invariants +# --------------------------------------------------------------------------- + +def test_amdfp4_config_default_inner_scale_format_promoted_to_ue5m3(): + """Caller didn't set ``inner_scale_format`` -> it must be silently + promoted to ``"ue5m3"`` (AMD-FP4 = NVFP4 spec + UE5M3).""" + cfg = _make_amdfp4_config() + assert cfg.inner_scale_format == "ue5m3" + + +def test_amdfp4_config_explicit_ue5m3_is_accepted(): + cfg = _make_amdfp4_config(inner_scale_format="ue5m3") + assert cfg.inner_scale_format == "ue5m3" + + +def test_amdfp4_swap_params_wraps_linear_weight(): + """Regression for BLOCKER-1: ``swap_params`` must support + ``precision='amdfp4'``. Before the fix, ``_get_tensor_cls_for_config`` + had no amdfp4 arm and raised ``ValueError`` at model setup. + + amdfp4 reuses the NVFP4 wrapper (it re-dispatches on ``config.precision`` + internally), so the swapped parameter must become an + ``NVFP4TrainingWeightWrapperTensor`` carrying the amdfp4 config. This is + a CPU-only setup-time check — no kernel launch — so it does not require a + GPU. + """ + cfg = _make_amdfp4_config() + linear = torch.nn.Linear(8, 4, bias=False) + + swap_params(linear, config=cfg) + + assert isinstance(linear.weight.data, NVFP4TrainingWeightWrapperTensor), ( + "swap_params(precision='amdfp4') must wrap the weight in the " + "NVFP4/AMD-FP4 family wrapper tensor" + ) + assert linear.weight.data.config.precision == "amdfp4" + assert linear.weight.data.config.inner_scale_format == "ue5m3" + + +def test_amdfp4_config_rejects_explicit_non_ue5m3_inner_scale_format(): + """Explicit ``e4m3`` request under ``precision='amdfp4'`` would silently + flip the recipe to NVFP4-on-an-AMD-FP4-label. ``__post_init__`` must + reject it. We construct via the dataclass directly to bypass our + helper's silent default.""" + with pytest.raises(ValueError, match="inner_scale_format"): + TrainingOpConfig( + precision="amdfp4", + use_2dblock_x=False, + use_2dblock_w=False, + use_hadamard=False, + use_sr_grad=False, + use_dge=False, + # The dataclass default is ``"e4m3"``, but that gets silently + # promoted; any *other* explicit value fails. Pick a value + # outside the Literal to demonstrate. ``inner_scale_format`` + # uses ``Literal["e4m3","ue5m3"]`` so we use the runtime + # validator rather than mypy's static check here. + inner_scale_format="bogus", # type: ignore[arg-type] + ) + + +# --------------------------------------------------------------------------- +# linear dispatch routing +# --------------------------------------------------------------------------- + +def test_amdfp4_linear_routes_to_amdfp4_thin_wrapper(monkeypatch, device): + """Dispatch must route AMD-FP4 wrapped linear to ``_to_amdfp4_then_scaled_mm``, + not ``_to_nvfp4_then_scaled_mm``.""" + calls = [] + + def _mock_amdfp4(A, W, *, use_2dblock_x, use_2dblock_w, + use_sr_grad, use_outer_scale, use_hadamard, use_dge): + calls.append({ + "use_2dblock_x": use_2dblock_x, + "use_2dblock_w": use_2dblock_w, + "use_sr_grad": use_sr_grad, + "use_outer_scale": use_outer_scale, + "use_hadamard": use_hadamard, + "use_dge": use_dge, + }) + return A @ W + + def _forbidden_nvfp4(*args, **kwargs): + raise AssertionError( + "precision='amdfp4' must NOT call into _to_nvfp4_then_scaled_mm" + ) + + monkeypatch.setattr(dispatch_tensor, "_to_amdfp4_then_scaled_mm", _mock_amdfp4) + monkeypatch.setattr(dispatch_tensor, "_to_nvfp4_then_scaled_mm", _forbidden_nvfp4) + + cfg = _make_amdfp4_config() + K, N, M = 16, 16, 16 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W = torch.randn(N, K, dtype=torch.bfloat16, device=device) + W_wrapped = NVFP4TrainingWeightWrapperTensor(W, cfg) + + y = torch.nn.functional.linear(A, W_wrapped) + assert y.shape == (M, N) + # ``F.linear`` may dispatch through ``linear`` and additionally re-enter + # via the ``mm.default`` / ``matmul`` decomposition path (we route both + # in ``gemm_ops``). We just need every routed call to land on the + # AMD-FP4 thin wrapper. + assert len(calls) >= 1, "F.linear with AMD-FP4 weight should hit the amdfp4 routing path" + for call in calls: + # AMD-FP4 wrapper has no ``scale_format`` parameter; it's pinned UE5M3. + assert "scale_format" not in call + + +def test_amdfp4_grouped_mm_routes_to_amdfp4_thin_wrapper(monkeypatch, device): + """Dispatch must route AMD-FP4 wrapped grouped_mm to + ``_quantize_then_amdfp4_scaled_grouped_mm``.""" + calls = [] + + def _mock_amdfp4_grouped(A, B, *, offs, use_2dblock_x, use_2dblock_w, + use_sr_grad, use_outer_scale, use_hadamard, use_dge): + calls.append({ + "offs_len": offs.numel(), + "use_2dblock_x": use_2dblock_x, + "use_2dblock_w": use_2dblock_w, + "use_sr_grad": use_sr_grad, + "use_outer_scale": use_outer_scale, + "use_hadamard": use_hadamard, + "use_dge": use_dge, + }) + return torch.zeros(A.shape[0], B.shape[-1], dtype=A.dtype, device=A.device) + + def _forbidden_nvfp4_grouped(*args, **kwargs): + raise AssertionError( + "precision='amdfp4' grouped_mm must NOT call into _quantize_then_nvfp4_scaled_grouped_mm" + ) + + monkeypatch.setattr( + dispatch_tensor, "_quantize_then_amdfp4_scaled_grouped_mm", _mock_amdfp4_grouped + ) + monkeypatch.setattr( + dispatch_tensor, "_quantize_then_nvfp4_scaled_grouped_mm", _forbidden_nvfp4_grouped + ) + + cfg = _make_amdfp4_config() + num_experts, K, N, M = 2, 16, 16, ALIGN_SIZE_M + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device) + W_wrapped = NVFP4TrainingWeightWrapperTensor(W, cfg) + offs = torch.tensor([M // num_experts, M], dtype=torch.int32, device=device) + + _ = torch._grouped_mm(A, W_wrapped, offs=offs) + assert len(calls) >= 1, "torch._grouped_mm with AMD-FP4 weight should hit the amdfp4 routing path" + for call in calls: + assert "scale_format" not in call + + +# --------------------------------------------------------------------------- +# end-to-end smoke +# --------------------------------------------------------------------------- + +def test_amdfp4_linear_smoke(device): + """End-to-end: precision='amdfp4' linear must produce a finite BF16 output + of the expected shape (no silent fallback).""" + cfg = _make_amdfp4_config() + K, N, M = 32, 32, 32 + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W = torch.randn(N, K, dtype=torch.bfloat16, device=device) + W_wrapped = NVFP4TrainingWeightWrapperTensor(W, cfg) + y = torch.nn.functional.linear(A, W_wrapped) + assert y.shape == (M, N) + assert y.dtype == torch.bfloat16 + assert torch.isfinite(y).all() + assert y.abs().max().item() > 0, "AMD-FP4 linear smoke must exercise a real matmul" + + +def test_amdfp4_grouped_mm_smoke(device): + """End-to-end: precision='amdfp4' grouped_mm must produce a finite output.""" + cfg = _make_amdfp4_config() + num_experts, K, N, M = 2, 16, 16, ALIGN_SIZE_M + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device) + W_wrapped = NVFP4TrainingWeightWrapperTensor(W, cfg) + offs = torch.tensor([M // num_experts, M], dtype=torch.int32, device=device) + y = torch._grouped_mm(A, W_wrapped, offs=offs) + assert y.shape == (M, N) + assert y.dtype == torch.bfloat16 + assert torch.isfinite(y).all() + assert y.abs().max().item() > 0, "AMD-FP4 grouped_mm smoke must exercise a real matmul" diff --git a/tests/unittest/amdfp4/test_amdfp_dtype.py b/tests/unittest/amdfp4/test_amdfp_dtype.py new file mode 100644 index 0000000..837d70e --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_dtype.py @@ -0,0 +1,382 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Bit-level oracle tests for the UE5M3 PyTorch primitives (AMD-FP4 inner scale). + +UE5M3 is the inner-scale dtype of the AMD-FP4 recipe; this file owns the +dtype-layer regression so ``tests/unittest/amdfp4/`` is the single home +for every AMD-FP4 layer (dtype primitive → quant op → linear/grouped GEMM +→ A/B matrix). + +Coverage (D2', GFXIPARCH-2067 §19.10 aligned): + +* Constants vs. closed-form ground truth (UE5M3_MAX = 114688 at code 0xFE). +* All 256 uint8 codes decode to the expected fp32 value (manual table); + code 0xFF decodes to NaN (verified with isnan, NaN ≠ NaN bypass). +* Round-trip: every code re-encodes to itself (identity-on-grid), + including the 0xFF NaN code. +* Idempotency: ``quantize(quantize(x)) == quantize(x)`` (NaN positions + match between successive quantizations). +* Saturation: NaN / +Inf / -Inf / overflow → 0xFF (NaN); finite negatives + → 0x00; ``UE5M3_MAX`` exact → 0xFE (max normal). +* Round-to-nearest-even spot checks on between-grid values. + +The CUDA path is not exercised here; that's the job of +``test_amdfp_triton_pytorch_parity.py``. +""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from alto.kernels.fp4.fp4_primitives import ( + UE5M3_EPS, + UE5M3_EXP_BIAS, + UE5M3_MAX, + UE5M3_MIN_NORMAL, + UE5M3_NAN_CODE, + UE5M3_NUM_EXP_BITS, + UE5M3_NUM_MAN_BITS, + f32_to_ue5m3_uint8, + quantize_to_ue5m3, + ue5m3_uint8_to_f32, +) + + +# --------------------------------------------------------------------------- +# Ground-truth 256-code value table. +# Independently computed from the documented spec, used as the oracle. +# --------------------------------------------------------------------------- + +def _build_ue5m3_value_table() -> list[float]: + """Return ``table[i]`` = exact fp32 value for UE5M3 code ``i``. + + Per D2' / GFXIPARCH-2067 §19.10, code 0xFF is reserved for NaN. + """ + table: list[float] = [] + for code in range(256): + if code == UE5M3_NAN_CODE: + table.append(float("nan")) + continue + exp = code >> UE5M3_NUM_MAN_BITS + mant = code & ((1 << UE5M3_NUM_MAN_BITS) - 1) + if exp == 0: + if mant == 0: + value = 0.0 + else: + value = (mant / (1 << UE5M3_NUM_MAN_BITS)) * (2.0 ** (1 - UE5M3_EXP_BIAS)) + else: + value = (1.0 + mant / (1 << UE5M3_NUM_MAN_BITS)) * ( + 2.0 ** (exp - UE5M3_EXP_BIAS) + ) + table.append(value) + return table + + +_VALUE_TABLE = _build_ue5m3_value_table() + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +def test_constants_match_closed_form(): + assert UE5M3_NUM_EXP_BITS == 5 + assert UE5M3_NUM_MAN_BITS == 3 + assert UE5M3_EXP_BIAS == 15 + # Max normal at code 0xFE: (1 + 6/8) * 2^16 = 1.75 * 65536 = 114688 + assert UE5M3_MAX == 114688.0 + assert UE5M3_MAX == 1.75 * (2.0 ** 16) + # Min normal: 2^-14 + assert UE5M3_MIN_NORMAL == math.ldexp(1.0, -14) + # Smallest subnormal: 2^-17 + assert UE5M3_EPS == math.ldexp(1.0, -17) + # NaN code reserved per GFXIPARCH-2067 §19.10 + assert UE5M3_NAN_CODE == 0xFF + + +def test_dynamic_range_increase_vs_e4m3(): + """Sanity: ``UE5M3_MAX / E4M3_MAX`` is the AMD-FP4 dynamic-range gain. + + Closed form (D2', spec-aligned): ``(1.75 * 2^16) / (1.75 * 2^8) = 256``. + """ + e4m3_max = 448.0 + ratio = UE5M3_MAX / e4m3_max + assert abs(ratio - (1.75 / 1.75) * 256.0) < 1e-9 + assert ratio == 256.0 + assert ratio > 250.0 + + +def test_value_table_endpoints(): + """Spot-check the most important entries of the manual decode table.""" + assert _VALUE_TABLE[0x00] == 0.0 + assert _VALUE_TABLE[0x01] == math.ldexp(1.0, -17) # smallest subnormal + assert _VALUE_TABLE[0x07] == 7.0 * math.ldexp(1.0, -17) # largest subnormal + assert _VALUE_TABLE[0x08] == math.ldexp(1.0, -14) # smallest normal + # Code 0x78 = exp=15, mant=0 -> (1 + 0) * 2^(15-15) = 1.0 + assert _VALUE_TABLE[0x78] == 1.0 + # Code 0xFE = exp=31, mant=6 -> (1 + 6/8) * 2^16 = 114688 (max normal) + assert _VALUE_TABLE[0xFE] == 114688.0 + assert _VALUE_TABLE[0xFE] == UE5M3_MAX + # Code 0xFF: NaN (D2', spec-aligned). + assert math.isnan(_VALUE_TABLE[0xFF]) + + +# --------------------------------------------------------------------------- +# Decode (uint8 -> fp32) bit-exact on all 256 codes +# --------------------------------------------------------------------------- + +def test_decode_all_256_codes_bit_exact(): + codes = torch.arange(256, dtype=torch.uint8) + decoded = ue5m3_uint8_to_f32(codes) + expected = torch.tensor(_VALUE_TABLE, dtype=torch.float32) + + # Code 0xFF is NaN under D2'; ``torch.equal`` returns False for NaN + # positions, so we mask 0xFF out and assert NaN-ness separately. + finite_mask = torch.arange(256, dtype=torch.uint8) != UE5M3_NAN_CODE + assert torch.equal(decoded[finite_mask], expected[finite_mask]), ( + "UE5M3 decode mismatch on the finite 0x00-0xFE grid: " + f"first diff at code 0x{int((decoded[finite_mask] != expected[finite_mask]).nonzero()[0].item()):02X}" + ) + # Code 0xFF MUST decode to NaN. + assert torch.isnan(decoded[UE5M3_NAN_CODE]).item(), ( + f"Code 0xFF must decode to NaN under D2'; got {float(decoded[UE5M3_NAN_CODE])}" + ) + + +# --------------------------------------------------------------------------- +# Round-trip: every code re-encodes to itself +# --------------------------------------------------------------------------- + +def test_encode_decode_roundtrip_all_codes(): + """``encode(decode(code)) == code`` for every 8-bit pattern, including 0xFF. + + For the NaN code 0xFF the round-trip is: decode -> NaN, then re-encode + classifies NaN as a "special" input (D2') and emits 0xFF again. + """ + codes = torch.arange(256, dtype=torch.uint8) + decoded = ue5m3_uint8_to_f32(codes) + re_encoded = f32_to_ue5m3_uint8(decoded) + assert torch.equal(codes, re_encoded), ( + "Round-trip (decode + encode) is not identity on the 256-code grid; " + f"first diff at code 0x{int((codes != re_encoded).nonzero()[0].item()):02X}, " + f"re-encoded as 0x{int(re_encoded[(codes != re_encoded).nonzero()[0].item()].item()):02X}" + ) + + +# --------------------------------------------------------------------------- +# Idempotency: quantize(quantize(x)) == quantize(x) bit-for-bit +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("size", [(8,), (128,), (1024,), (4, 256), (3, 17, 31)]) +def test_idempotency_random(size): + """``quantize(quantize(x)) == quantize(x)`` bit-for-bit on finite values. + + Under D2' a value > UE5M3_MAX encodes to the 0xFF NaN code, which decodes + to NaN. We therefore check finite-vs-finite equality and NaN-mask + equality separately (since ``torch.equal`` returns False on NaN). + """ + torch.manual_seed(0) + # Mix positive normal range with subnormal and saturation candidates. + x = torch.randn(size, dtype=torch.float32).abs() * 10.0 + x[..., 0] = 0.0 # exact zero + x[..., -1] = UE5M3_MAX * 0.5 # mid-range + if x.numel() >= 3: + x.view(-1)[1] = UE5M3_EPS * 0.5 # below smallest subnormal -> rounds to 0 or eps + x.view(-1)[2] = UE5M3_MAX * 2.0 # overflow -> NaN code (D2') + q1 = quantize_to_ue5m3(x) + q2 = quantize_to_ue5m3(q1) + nan1 = torch.isnan(q1) + nan2 = torch.isnan(q2) + assert torch.equal(nan1, nan2), ( + "quantize_to_ue5m3 NaN mask not preserved across two passes; " + f"shape={tuple(size)}, diff at idx={int((nan1 ^ nan2).nonzero()[0].item())}" + ) + finite_mask = ~nan1 + assert torch.equal(q1[finite_mask], q2[finite_mask]), ( + "quantize_to_ue5m3 is not idempotent on finite values; " + f"shape={tuple(size)}, max diff=" + f"{float((q1[finite_mask] - q2[finite_mask]).abs().max()):.3e}" + ) + + +# --------------------------------------------------------------------------- +# Saturation / clamp policy +# --------------------------------------------------------------------------- + +def test_saturation_nan_inf_negative_large(): + """D2' saturation policy (GFXIPARCH-2067 §19.10): + + * NaN, ±Inf, finite values strictly above ``UE5M3_MAX`` -> 0xFF (NaN) + * Finite negatives -> 0x00 + * 0.0 -> 0x00 + * ``UE5M3_MAX`` exact -> 0xFE (max normal) + """ + x = torch.tensor( + [float("nan"), float("inf"), -float("inf"), + -1.0, -1e30, 0.0, + UE5M3_MAX, UE5M3_MAX * 2.0, 1e30], + dtype=torch.float32, + ) + codes = f32_to_ue5m3_uint8(x) + expected = torch.tensor( + # NaN +Inf -Inf -1.0 -1e30 0.0 MAX 2*MAX 1e30 + [0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0xFE, 0xFF, 0xFF], + dtype=torch.uint8, + ) + assert torch.equal(codes, expected), ( + f"Saturation policy violated; codes={[hex(c) for c in codes.tolist()]}, " + f"expected={[hex(c) for c in expected.tolist()]}" + ) + + +def test_zero_clean(): + """All-zero input must encode to 0x00 and decode back to exact 0.0.""" + x = torch.zeros(64, dtype=torch.float32) + codes = f32_to_ue5m3_uint8(x) + assert torch.equal(codes, torch.zeros(64, dtype=torch.uint8)) + assert torch.equal(ue5m3_uint8_to_f32(codes), x) + + +# --------------------------------------------------------------------------- +# D2' spec-aligned NaN / max-normal contract +# --------------------------------------------------------------------------- + +def test_decode_0xff_is_nan(): + """Code 0xFF MUST decode to a quiet FP32 NaN (GFXIPARCH-2067 §19.10).""" + codes = torch.tensor([UE5M3_NAN_CODE], dtype=torch.uint8) + decoded = ue5m3_uint8_to_f32(codes) + assert torch.isnan(decoded).all().item(), ( + f"0xFF must decode to NaN, got {decoded.tolist()}" + ) + + +def test_max_normal_at_0xfe_decodes_114688(): + """Code 0xFE MUST decode to UE5M3_MAX = 114688.0 (max normal).""" + codes = torch.tensor([0xFE], dtype=torch.uint8) + decoded = ue5m3_uint8_to_f32(codes) + assert decoded.item() == 114688.0 + assert decoded.item() == UE5M3_MAX + + +def test_encode_max_value_lands_on_0xfe(): + """Encoding the exact max-normal value MUST land on 0xFE, NOT on 0xFF.""" + x = torch.tensor([UE5M3_MAX], dtype=torch.float32) + code = f32_to_ue5m3_uint8(x).item() + assert code == 0xFE, ( + f"Expected encode(UE5M3_MAX) == 0xFE under D2', got 0x{code:02X}" + ) + + +def test_overflow_above_max_encodes_to_nan(): + """Any finite value strictly above UE5M3_MAX MUST encode to 0xFF (NaN).""" + x = torch.tensor( + [UE5M3_MAX * 2.0, 1e6, 1e30, float("inf")], + dtype=torch.float32, + ) + codes = f32_to_ue5m3_uint8(x).tolist() + assert codes == [0xFF] * 4, ( + f"Overflow-above-MAX must encode to 0xFF, got {[hex(c) for c in codes]}" + ) + + +def test_quantize_to_ue5m3_propagates_nan(): + """``quantize_to_ue5m3(NaN)`` MUST emit NaN (D2' round-trip).""" + x = torch.tensor( + [float("nan"), 1.0, UE5M3_MAX * 2.0, 0.0], + dtype=torch.float32, + ) + q = quantize_to_ue5m3(x) + assert torch.isnan(q[0]).item(), "NaN input must produce NaN output" + assert q[1].item() == 1.0, "1.0 should round-trip cleanly" + assert torch.isnan(q[2]).item(), "Overflow must produce NaN output (D2')" + assert q[3].item() == 0.0, "0.0 must round-trip to 0.0" + + +# --------------------------------------------------------------------------- +# Round-to-nearest-even spot checks on between-grid inputs +# --------------------------------------------------------------------------- + +def test_rtne_tie_to_even(): + """Inputs exactly between two UE5M3 grid points must round to the + representation with an even least-significant mantissa bit. + + Grid near 1.0: + code 0x77 = exp=14, mant=7 -> (1+7/8)*2^-1 = 0.9375 + code 0x78 = exp=15, mant=0 -> 1.0 + code 0x79 = exp=15, mant=1 -> (1+1/8)*2^0 = 1.125 + + Halfway between 1.0 and 1.125 is 1.0625. Both neighbors are equidistant + -> tie -> round to even mantissa LSB -> code 0x78 (mant=0, even). + """ + x = torch.tensor([0.9375, 1.0, 1.0625, 1.125], dtype=torch.float32) + codes = f32_to_ue5m3_uint8(x) + # 0.9375 and 1.0 and 1.125 are exact grid points -> identity. + # 1.0625 is the tie -> 0x78 (mant=0 is even). + expected = torch.tensor([0x77, 0x78, 0x78, 0x79], dtype=torch.uint8) + assert torch.equal(codes, expected), ( + f"RTNE tie-to-even broken: codes={codes.tolist()}, expected={expected.tolist()}" + ) + + +def test_rtne_no_tie_rounds_to_nearest(): + """1.05 is closer to 1.0 than to 1.125 -> rounds down.""" + x = torch.tensor([1.05, 1.1], dtype=torch.float32) + codes = f32_to_ue5m3_uint8(x) + assert codes[0].item() == 0x78, f"1.05 should round to 1.0, got code 0x{int(codes[0].item()):02X}" + # 1.1 is closer to 1.125 than to 1.0 (diff 0.025 vs 0.1) -> rounds up. + assert codes[1].item() == 0x79, f"1.1 should round to 1.125, got code 0x{int(codes[1].item()):02X}" + + +def test_rtne_subnormal_boundary(): + """Just-below-MIN_NORMAL inputs should land on subnormal grid points, + not silently snap to 0 or to the smallest normal.""" + # smallest subnormal = 2^-17; smallest normal = 2^-14 = 8 * 2^-17 + # Pick an input that's exactly between subnormals: + # 2^-17 (mant=1) and 2 * 2^-17 (mant=2) + # The midpoint is 1.5 * 2^-17; with RTNE tie -> mant=2 (even) + midpoint = 1.5 * math.ldexp(1.0, -17) + just_below = math.ldexp(1.0, -17) * 1.4 # closer to mant=1 + just_above = math.ldexp(1.0, -17) * 1.6 # closer to mant=2 + x = torch.tensor([midpoint, just_below, just_above], dtype=torch.float32) + codes = f32_to_ue5m3_uint8(x) + # midpoint -> 0x02 (mant=2 is even); just_below -> 0x01; just_above -> 0x02 + expected = torch.tensor([0x02, 0x01, 0x02], dtype=torch.uint8) + assert torch.equal(codes, expected), ( + f"Subnormal RTNE broken: codes={codes.tolist()}, expected={expected.tolist()}" + ) + + +def test_rtne_below_smallest_subnormal_rounds_to_zero(): + """Inputs below half the smallest subnormal must round to 0.""" + x = torch.tensor( + [UE5M3_EPS * 0.4, UE5M3_EPS * 0.49, math.ldexp(1.0, -30)], + dtype=torch.float32, + ) + codes = f32_to_ue5m3_uint8(x) + assert torch.all(codes == 0).item(), ( + f"Inputs < 0.5 * UE5M3_EPS should round to 0, got codes={codes.tolist()}" + ) + + +# --------------------------------------------------------------------------- +# Coverage of the full encoding map on a dense sweep +# --------------------------------------------------------------------------- + +def test_encode_grid_values_identity(): + """Every UE5M3 grid value (including 0xFF -> NaN -> 0xFF), when fed + back to the encoder, must produce its own code -- a stronger version + of `roundtrip_all_codes` that exercises the encoder's path selection + (normal vs denormal vs saturate vs special) on real fp32 numerics.""" + codes = torch.arange(256, dtype=torch.uint8) + values = ue5m3_uint8_to_f32(codes) + re_encoded = f32_to_ue5m3_uint8(values) + diffs = (codes != re_encoded).nonzero(as_tuple=False).flatten().tolist() + assert not diffs, ( + f"Encoder is not identity on grid values; offending codes: " + + ", ".join(f"0x{c:02X} (got 0x{int(re_encoded[c].item()):02X})" for c in diffs[:8]) + ) diff --git a/tests/unittest/amdfp4/test_amdfp_grouped_gemm.py b/tests/unittest/amdfp4/test_amdfp_grouped_gemm.py new file mode 100644 index 0000000..b25a565 --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_grouped_gemm.py @@ -0,0 +1,201 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Op-level precision tests for AMD-FP4 Grouped GEMM (UE5M3 inner scale). + +Mirrors the two ``scale_format``-parametrised tests from +:mod:`tests.unittest.nvfp4.test_nvfp_grouped_gemm`: + +* ``test_amdfp4_grouped_gemm_autograd`` — full O / dX / dW SNR + parity vs BF16 reference, K-aware floors; +* ``test_amdfp4_grouped_gemm_recipe_variants_smoke`` — recipe knobs + (2D-w, outer_scale, Hadamard, DGE, combinations) run end-to-end and + produce finite outputs + gradients. + +Recipe-shared regressions (boundary shapes, non-aligned-M guards, +single-expert == linear, native-dispatch parity, padded buffers) are +not duplicated here: they exercise the shared kernel layer that NVFP4's +``test_nvfp_grouped_gemm.py`` already covers, and the AMD-FP4 path +lands on the *exact same* autograd Function and Triton kernels. +""" + +from __future__ import annotations + +import pytest +from tabulate import tabulate +import torch + +from alto.kernels.fp4.amdfp4 import ( + ALIGN_SIZE_M, + amdfp4_grouped_gemm, +) +from alto.kernels.fp4.testing_utils import check_nvfp4_autograd_snr + +from amdfp4.utils import calc_cossim, calc_snr, prepare_data + + +# --------------------------------------------------------------------------- +# Helpers shared across tests (mirrors nvfp4 test helpers) +# --------------------------------------------------------------------------- + +def _make_contiguous_expert_indices( + M_total: int, num_groups: int, num_experts: int, device, +) -> torch.Tensor: + indices = torch.zeros(M_total, dtype=torch.int32, device=device) + for g in range(num_groups): + eid = torch.randint(0, num_experts, (1,), device=device).item() + s = g * ALIGN_SIZE_M + indices[s : s + ALIGN_SIZE_M] = eid + return indices + + +def _bf16_grouped_ref_forward( + inputs, expert_weights, expert_indices, M_total, N, num_groups, trans_weights, +): + dtype = inputs.dtype + device = inputs.device + y = torch.zeros(M_total, N, dtype=dtype, device=device) + for g in range(num_groups): + s, e = g * ALIGN_SIZE_M, (g + 1) * ALIGN_SIZE_M + eid = expert_indices[s].item() + w = expert_weights[eid] + y[s:e] = inputs[s:e] @ (w.T if trans_weights else w) + return y + + +# --------------------------------------------------------------------------- +# Autograd parity (O / dX / dW SNR vs BF16 reference) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("shape", [ + (512, 256, 256, 4), + (1024, 512, 512, 8), + (512, 512, 2048, 4), +]) +@pytest.mark.parametrize("use_2dblock_x", [False, True]) +@pytest.mark.parametrize("use_2dblock_w", [False, True]) +@pytest.mark.parametrize("use_sr_grad", [False, True]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32]) +def test_amdfp4_grouped_gemm_autograd( + shape, use_2dblock_x, use_2dblock_w, use_sr_grad, data_type, +): + """O, dX, dW SNR vs BF16 autograd reference must remain healthy under + the AMD-FP4 (UE5M3) inner-scale path.""" + M_total, N, K, num_experts = shape + M_total = (M_total // ALIGN_SIZE_M) * ALIGN_SIZE_M + num_groups = M_total // ALIGN_SIZE_M + device = torch.device("cuda") + + inputs_ref = prepare_data((M_total, K), data_type).requires_grad_(True) + weights_ref = prepare_data((num_experts, N, K), data_type).requires_grad_(True) + expert_indices = _make_contiguous_expert_indices( + M_total, num_groups, num_experts, device, + ) + target = prepare_data((M_total, N), data_type) + + y_ref = _bf16_grouped_ref_forward( + inputs_ref, weights_ref, expert_indices, M_total, N, num_groups, + trans_weights=True, + ) + loss_ref = torch.nn.functional.mse_loss(y_ref, target) + loss_ref.backward() + dx_ref = inputs_ref.grad.clone() + dw_ref = weights_ref.grad.clone() + + inputs = inputs_ref.detach().clone().requires_grad_(True) + weights = weights_ref.detach().clone().requires_grad_(True) + + y = amdfp4_grouped_gemm( + inputs, weights, expert_indices, + trans_weights=True, + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=use_sr_grad, + ) + loss = torch.nn.functional.mse_loss(y, target) + loss.backward() + + o_snr = calc_snr(y_ref.detach(), y.detach()) + dx_snr = calc_snr(dx_ref, inputs.grad) + dw_snr = calc_snr(dw_ref, weights.grad) + o_sim = calc_cossim(y_ref.detach(), y.detach()) + dx_sim = calc_cossim(dx_ref, inputs.grad) + dw_sim = calc_cossim(dw_ref, weights.grad) + + print() + print(tabulate( + [ + ["O", f"{o_snr:.2f}", f"{o_sim:.6f}"], + ["dX", f"{dx_snr:.2f}", f"{dx_sim:.6f}"], + ["dW", f"{dw_snr:.2f}", f"{dw_sim:.6f}"], + ], + headers=["Tensor", "SNR(dB)", "CosSim"], + tablefmt="github", + )) + + check_nvfp4_autograd_snr( + {"O": o_snr, "dX": dx_snr, "dW": dw_snr}, + K=K, + use_sr_grad=use_sr_grad, + kind="nvfp4_grouped_gemm", # shared K-aware floors + context=( + f"AMDFP4GroupedGEMM shape={shape} dtype={data_type} " + f"x_2d={use_2dblock_x} w_2d={use_2dblock_w}" + ), + ) + + if not use_sr_grad and data_type == torch.bfloat16: + fwd_min_snr = 10 if not (use_2dblock_x or use_2dblock_w) else 5 + assert o_snr > fwd_min_snr, ( + f"Forward SNR too low on non-SR BF16 path: {o_snr:.2f} " + f"(min {fwd_min_snr})" + ) + assert o_sim > 0.95, f"Forward CosSim too low: {o_sim:.6f}" + + +# --------------------------------------------------------------------------- +# Recipe-variants smoke (2D-w, outer_scale, Hadamard, DGE) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("use_2dblock_x,use_2dblock_w,use_outer_scale,use_hadamard,use_dge", [ + (False, False, False, False, False), + (False, True, False, False, False), + (False, True, True, False, False), + (False, True, False, True, False), + (False, True, False, False, True), + (False, True, False, True, True), +]) +def test_amdfp4_grouped_gemm_recipe_variants_smoke( + use_2dblock_x, use_2dblock_w, use_outer_scale, use_hadamard, use_dge, +): + """Grouped recipe variants must run end-to-end on the AMD-FP4 path.""" + M_total, N, K, num_experts = 512, 256, 256, 4 + num_groups = M_total // ALIGN_SIZE_M + device = torch.device("cuda") + dtype = torch.bfloat16 + inputs = prepare_data((M_total, K), dtype).requires_grad_(True) + expert_weights = prepare_data((num_experts, N, K), dtype).requires_grad_(True) + expert_indices = _make_contiguous_expert_indices(M_total, num_groups, num_experts, device) + target = prepare_data((M_total, N), dtype) + + y = amdfp4_grouped_gemm( + inputs, + expert_weights, + expert_indices, + trans_weights=True, + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=True, + use_outer_scale=use_outer_scale, + use_hadamard=use_hadamard, + use_dge=use_dge, + ) + loss = torch.nn.functional.mse_loss(y, target) + loss.backward() + + assert y.shape == (M_total, N) + assert y.dtype == dtype + assert torch.isfinite(y).all() + assert inputs.grad is not None and torch.isfinite(inputs.grad).all() + assert expert_weights.grad is not None and torch.isfinite(expert_weights.grad).all() diff --git a/tests/unittest/amdfp4/test_amdfp_linear.py b/tests/unittest/amdfp4/test_amdfp_linear.py new file mode 100644 index 0000000..a40c3fa --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_linear.py @@ -0,0 +1,208 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Unit tests for the AMD-FP4 linear op (UE5M3 inner scale). + +Mirrors :mod:`tests.unittest.nvfp4.test_nvfp_linear` but pins the inner +grid to UE5M3 throughout and uses the AMD-FP4 op surface +(``AMDFP4LinearFunction`` / ``_to_amdfp4_then_scaled_mm``). + +Two test families: + +* ``test_amdfp4_qdq_roundtrip`` -- per-operand QDQ accuracy (the + building block that ``AMDFP4LinearFunction`` stacks together). +* ``test_amdfp4_linear_autograd_function`` -- forward output and both + gradients (``dX``, ``dW``) track a BF16 reference within an SNR + threshold, identical structure to NVFP4's autograd test. + +The unaligned-axis guard tests and the BF16/FP32 dtype-cross test are +not duplicated here: the underlying autograd function is *the same +class* as NVFP4's, so those guards are already covered in the NVFP4 +suite and any breakage there shows up in both. +""" + +from __future__ import annotations + +import pytest +from tabulate import tabulate +import torch + +from alto.kernels.fp4.amdfp4 import AMDFP4LinearFunction, _to_amdfp4_then_scaled_mm +from alto.kernels.fp4.nvfp4.nvfp_linear import _qdq, _to_nvfp4_then_scaled_mm +from alto.kernels.fp4.testing_utils import check_nvfp4_autograd_snr + +from amdfp4.utils import calc_cossim, calc_snr, prepare_data + + +# --------------------------------------------------------------------------- +# QDQ round-trip: BF16 -> AMD-FP4 -> BF16 (no GEMM) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("shape", [(128, 64), (4, 128, 64)]) +@pytest.mark.parametrize("axis", [-1, -2]) +@pytest.mark.parametrize("is_2d_block", [False, True]) +@pytest.mark.parametrize("use_outer_scale", [False, True]) +@pytest.mark.parametrize("use_sr", [False, True]) +def test_amdfp4_qdq_roundtrip( + shape, axis, is_2d_block, use_outer_scale, use_sr, +): + """Verify precision of a single AMD-FP4 quant -> dequant round-trip.""" + x = prepare_data(shape, torch.bfloat16) + + x_qdq = _qdq( + x, axis=axis, + is_2d_block=is_2d_block, + use_outer_scale=use_outer_scale, + use_sr=use_sr, + scale_format="ue5m3", + ) + + assert x_qdq.shape == x.shape + assert x_qdq.dtype == x.dtype + + snr = calc_snr(x, x_qdq) + cossim = calc_cossim(x, x_qdq) + + print() + print(tabulate( + [["SNR (dB)", f"{snr:.2f}"], ["Cosine Sim", f"{cossim:.6f}"]], + headers=["Metric", "Value"], tablefmt="github", + )) + + # Same SNR / cossim band as NVFP4 -- UE5M3 has wider dynamic range + # but the same mantissa width as E4M3, so single-block round-trip + # error on Gaussian inputs is comparable. Floors picked so any + # encoding-level regression trips them. + min_snr = 8 if not is_2d_block else 5 + min_cossim = 0.99 if not is_2d_block else (0.94 if use_sr else 0.99) + assert snr > min_snr, f"QDQ SNR too low: {snr:.2f}" + assert cossim > min_cossim, f"QDQ cosine similarity too low: {cossim:.6f}" + + +# --------------------------------------------------------------------------- +# Full autograd-function parity test (forward + dX + dW vs BF16 reference) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("shape", [(1, 64, 64, 64), (1, 512, 384, 128), (4, 1024, 1024, 2048)]) +@pytest.mark.parametrize("use_2dblock_x", [False, True]) +@pytest.mark.parametrize("use_2dblock_w", [False, True]) +@pytest.mark.parametrize("use_sr_grad", [False, True]) +@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32]) +def test_amdfp4_linear_autograd_function( + shape, use_2dblock_x, use_2dblock_w, use_sr_grad, data_type, +): + """Forward + dX + dW must match a BF16 nn.Linear reference in SNR. + + Same K-aware threshold scheme as NVFP4; the AMD-FP4 path lands on + the same autograd function and is expected to clear the same + floors. + """ + B, M, N, K = shape + inputs = prepare_data((B, M, K), data_type).requires_grad_(True) + weights = prepare_data((N, K), data_type).requires_grad_(True) + target = prepare_data((B, M, N), data_type) + + outputs_ref = torch.nn.functional.linear(inputs, weights) + loss_ref = torch.nn.functional.mse_loss(outputs_ref, target) + loss_ref.backward() + grad_inputs_ref = inputs.grad.clone() + grad_weights_ref = weights.grad.clone() + inputs.grad.zero_(); weights.grad.zero_() + + outputs = AMDFP4LinearFunction.apply( + inputs, weights, + use_2dblock_x, use_2dblock_w, use_sr_grad, + False, # use_outer_scale + None, # hadamard_transform + False, # use_dge + "ue5m3", # scale_format pinned to AMD-FP4 inner grid + ) + loss = torch.nn.functional.mse_loss(outputs, target) + loss.backward() + + output_snr = calc_snr(outputs, outputs_ref) + output_sim = calc_cossim(outputs, outputs_ref) + dx_snr = calc_snr(inputs.grad, grad_inputs_ref) + dx_sim = calc_cossim(inputs.grad, grad_inputs_ref) + dw_snr = calc_snr(weights.grad, grad_weights_ref) + dw_sim = calc_cossim(weights.grad, grad_weights_ref) + + print() + print(tabulate( + [ + ["O", f"{output_snr:.2f}", f"{output_sim:.6f}"], + ["dX", f"{dx_snr:.2f}", f"{dx_sim:.6f}"], + ["dW", f"{dw_snr:.2f}", f"{dw_sim:.6f}"], + ], + headers=["Tensor", "SNR", "Cosine Sim"], tablefmt="github", + )) + + check_nvfp4_autograd_snr( + {"O": output_snr, "dX": dx_snr, "dW": dw_snr}, + K=K, + use_sr_grad=use_sr_grad, + kind="nvfp4_linear", # same K-aware floor scheme as NVFP4 + context=( + f"AMDFP4Linear shape={shape} dtype={data_type} " + f"x_2d={use_2dblock_x} w_2d={use_2dblock_w}" + ), + ) + + +# --------------------------------------------------------------------------- +# Thin-wrapper surface: ``_to_amdfp4_then_scaled_mm`` must actually be the +# UE5M3-pinned variant of ``_to_nvfp4_then_scaled_mm`` (TEST-1). +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("use_2dblock_x", [False, True]) +@pytest.mark.parametrize("use_2dblock_w", [False, True]) +def test_amdfp4_to_scaled_mm_wrapper_pins_ue5m3(use_2dblock_x, use_2dblock_w): + """Exercise the *real* AMD-FP4 thin wrapper end-to-end. + + The autograd-function test above calls ``AMDFP4LinearFunction`` directly + with an explicit ``scale_format="ue5m3"`` and never touches + ``_to_amdfp4_then_scaled_mm`` — the helper the dispatch layer actually + routes through. This test calls that helper and asserts it is bit-for-bit + identical to the NVFP4 helper invoked with ``scale_format="ue5m3"`` (i.e. + the wrapper's only job — pinning the inner grid to UE5M3 — is real and + correct), while differing from the default E4M3 NVFP4 path. + + NOTE: E4M3 and UE5M3 share the same 3-bit mantissa, so on normal-range + inputs both grids round identically and the outputs are bit-equal. To + *distinguish* the two (and thus prove the pin actually selects UE5M3) we + use wide-dynamic-range data where E4M3's inner scale saturates (amax/6 > + 448) but UE5M3's does not. + """ + torch.manual_seed(0) + M, N, K = 64, 64, 128 + # randn * 8000 -> per-block amax ~ 1.5e4..3e4 -> inner_scale_raw far above + # the E4M3 max (448), so E4M3 and UE5M3 genuinely diverge. + a = (torch.randn((M, K), device="cuda") * 8000.0).to(torch.bfloat16) + w = (torch.randn((N, K), device="cuda") * 8000.0).to(torch.bfloat16) + + common = dict( + use_2dblock_x=use_2dblock_x, + use_2dblock_w=use_2dblock_w, + use_sr_grad=False, + ) + + # Linear semantics: the helper computes ``x_dq @ w_dq.T`` with weight of + # shape (N, K), so pass ``w`` directly (not transposed). + y_amd = _to_amdfp4_then_scaled_mm(a, w, **common) + y_nv_ue5m3 = _to_nvfp4_then_scaled_mm(a, w, scale_format="ue5m3", **common) + y_nv_e4m3 = _to_nvfp4_then_scaled_mm(a, w, scale_format="e4m3", **common) + + assert y_amd.shape == (M, N) + assert torch.isfinite(y_amd).all() + # The AMD-FP4 wrapper is exactly NVFP4 pinned to UE5M3. + assert torch.equal(y_amd, y_nv_ue5m3), ( + "_to_amdfp4_then_scaled_mm must be bit-identical to " + "_to_nvfp4_then_scaled_mm(scale_format='ue5m3')" + ) + # ...and genuinely on a different inner grid than the E4M3 default + # (guards against the wrapper silently falling back to E4M3). + assert not torch.equal(y_amd, y_nv_e4m3), ( + "AMD-FP4 wrapper output must differ from the E4M3 NVFP4 path; " + "identical output suggests scale_format pinning was lost" + ) diff --git a/tests/unittest/amdfp4/test_amdfp_quantization.py b/tests/unittest/amdfp4/test_amdfp_quantization.py new file mode 100644 index 0000000..3098f14 --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_quantization.py @@ -0,0 +1,498 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Op-level regression tests for the AMD-FP4 quant / dequant kernels. + +Mirrors :mod:`tests.unittest.nvfp4.test_nvfp_quantization` but exercises +the AMD-FP4-side ATen ops (``alto::convert_to_amdfp4`` / +``alto::convert_from_amdfp4``) and pins the inner-scale dtype to +UE5M3 throughout. The shared blockwise body is regression-tested by +the NVFP4 suite; this file's job is to lock down the AMD-FP4 surface: + +* the UE5M3 inner-scale path produces the same outputs as the + PyTorch UE5M3 oracle (kernel-vs-oracle bit-equality / mismatch rate); +* outer-scale dynamic refresh stays bit-equal between caller-driven and + kernel-driven paths; +* edge-case patterns (zeros, large saturation) and edge tile shapes + stay finite; +* a NaN spike in the input does not contaminate inner scales or + dequant output (D2' defense layer regression). +""" + +from __future__ import annotations + +import pytest +import torch + +from alto.kernels.fp4.amdfp4 import ( + convert_from_amdfp4, + convert_to_amdfp4, +) +from alto.kernels.fp4.nvfp4.nvfp_quantization import ( + convert_from_nvfp4, + convert_to_nvfp4, +) +from alto.kernels.fp4.outer_scaled_fp4 import ( + _OUTER_SCALE_DIVZERO_FLOOR, + _SCALE_FORMAT_TABLE, + compute_dynamic_outer_scale, + is_cdna4, +) + +from amdfp4.utils import ( # noqa: E402 (see amdfp4/conftest sys.path) + convert_from_amdfp4_pytorch, + convert_to_amdfp4_pytorch, + prepare_data, +) + + +# --------------------------------------------------------------------------- +# Bit-exact kernel-vs-oracle agreement on the full axis / shape / SR / +# outer-scale matrix (UE5M3 only — the E4M3 half lives in NVFP4). +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("tensor_shape", [(128, 64), (2048, 2048), (4, 128, 64)]) +@pytest.mark.parametrize("axis", [-1, -2]) +@pytest.mark.parametrize("is_2d_block", [False, True]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("use_sr", [False, True]) +@pytest.mark.parametrize("use_outer_scale", [False, True]) +@pytest.mark.parametrize("compile", [False]) +def test_amdfp4_quantization(tensor_shape, axis, is_2d_block, data_type, + use_sr, use_outer_scale, compile): + block_size = 16 + device = torch.device("cuda") + + if compile: + quant_func = torch.compile( + torch.ops.alto.convert_to_amdfp4, fullgraph=True) + dequant_func = torch.compile( + torch.ops.alto.convert_from_amdfp4, fullgraph=True) + else: + quant_func = torch.ops.alto.convert_to_amdfp4 + dequant_func = torch.ops.alto.convert_from_amdfp4 + + if use_outer_scale: + outer_scale = torch.tensor([0.5], dtype=torch.float32, device=device) + else: + outer_scale = None + + x = prepare_data(tensor_shape, data_type) + + data_lp_ref, scales_ref = convert_to_amdfp4_pytorch( + x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale, + ) + x_dq_ref = convert_from_amdfp4_pytorch( + data_lp_ref, scales_ref, + output_dtype=data_type, block_size=block_size, axis=axis, + is_2d_block=is_2d_block, outer_scale=outer_scale, + ) + + data_lp, scales = quant_func( + x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale, update_outer_scale=False, + use_sr=use_sr, + ) + # Both reference and kernel apply the same float32->UE5M3->float32 + # rounding in the same order, so their scales are bit-for-bit identical. + assert torch.equal(scales_ref, scales), ( + f"Scale mismatch: max abs diff = {(scales_ref - scales).abs().max().item():.6e}" + ) + + is_hip = hasattr(torch.version, "hip") and torch.version.hip is not None + if not use_sr: + if is_hip: + mismatch_rate = (data_lp_ref != data_lp).float().mean().item() + assert mismatch_rate < 1e-3, ( + f"Quantized data mismatch rate {mismatch_rate:.4%} exceeds 0.1% threshold" + ) + else: + assert torch.all(data_lp_ref == data_lp).item() + else: + data_lp_ref_lo = (data_lp_ref.to(torch.int8) & 0xF) + data_lp_ref_hi = ((data_lp_ref.to(torch.int8) >> 4) & 0xF) + data_lp_lo = (data_lp.to(torch.int8) & 0xF) + data_lp_hi = ((data_lp.to(torch.int8) >> 4) & 0xF) + assert torch.all( + torch.max(torch.abs(data_lp_ref_lo - data_lp_lo)) <= 1).item() + assert torch.all( + torch.max(torch.abs(data_lp_ref_hi - data_lp_hi)) <= 1).item() + + x_dq = dequant_func( + data_lp, scales, + output_dtype=data_type, block_size=block_size, axis=axis, + is_2d_block=is_2d_block, outer_scale=outer_scale, + ) + + if not use_sr: + if is_hip: + x_dq_cross = convert_from_amdfp4_pytorch( + data_lp, scales, + output_dtype=data_type, block_size=block_size, axis=axis, + is_2d_block=is_2d_block, outer_scale=outer_scale, + ) + assert torch.equal(x_dq_cross, x_dq), ( + f"Dequant not bit-exact: " + f"mismatches={int((x_dq_cross != x_dq).sum())}/{x_dq.numel()}, " + f"max diff={float((x_dq_cross - x_dq).abs().max())}" + ) + else: + assert torch.allclose(x_dq_ref, x_dq) + else: + dq_mae = (x_dq_ref - x_dq).abs().mean().item() + assert dq_mae < 1.0 if is_2d_block else 0.5 + + +# --------------------------------------------------------------------------- +# Dynamic outer-scale refresh path +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("tensor_shape", [(128, 64), (2048, 2048)]) +@pytest.mark.parametrize("axis", [-1, -2]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_amdfp4_dynamic_outer_scale(tensor_shape, axis, data_type): + """``update_outer_scale=True`` must refresh the caller's scale buffer + in place and produce the same quantized output as a manual + pre-compute via :func:`compute_dynamic_outer_scale`.""" + block_size = 16 + + x = prepare_data(tensor_shape, data_type) + + # Path 1: manual pre-compute, pass in, no update. + outer_scale_manual = compute_dynamic_outer_scale(x, scale_format="ue5m3") + data_lp_manual, scales_manual = convert_to_amdfp4( + x, block_size=block_size, axis=axis, + outer_scale=outer_scale_manual, update_outer_scale=False, + ) + + # Path 2: caller-owned buffer, dynamically refreshed in-place. + outer_scale_dyn = torch.empty(1, dtype=torch.float32, device=x.device) + data_lp_dyn, scales_dyn = convert_to_amdfp4( + x, block_size=block_size, axis=axis, + outer_scale=outer_scale_dyn, update_outer_scale=True, + ) + + assert torch.equal(outer_scale_manual, outer_scale_dyn), ( + f"Dynamic outer_scale mismatch: " + f"{outer_scale_manual.item()} vs {outer_scale_dyn.item()}" + ) + assert torch.equal(data_lp_manual, data_lp_dyn) + assert torch.equal(scales_manual, scales_dyn) + + x_dq = convert_from_amdfp4( + data_lp_dyn, scales_dyn, + output_dtype=data_type, block_size=block_size, axis=axis, + outer_scale=outer_scale_dyn, + ) + assert x_dq.shape == x.shape + assert x_dq.dtype == data_type + + x_dq_ref = convert_from_amdfp4_pytorch( + data_lp_dyn, scales_dyn, + output_dtype=data_type, block_size=block_size, axis=axis, + outer_scale=outer_scale_dyn, + ) + is_hip = hasattr(torch.version, "hip") and torch.version.hip is not None + if is_hip: + assert torch.equal(x_dq_ref, x_dq) + else: + assert torch.allclose(x_dq_ref, x_dq) + + +# --------------------------------------------------------------------------- +# Edge-case patterns +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("tensor_shape", [(128, 64), (2048, 2048)]) +@pytest.mark.parametrize("axis", [-1, -2]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("pattern", ["zeros", "large"]) +def test_amdfp4_special_values(tensor_shape, axis, data_type, pattern): + """Edge-case input patterns: + + * ``zeros``: all-zero tensor — exercises the UE5M3 lower clamp on + the stored block scale. + * ``large``: 5000.0 everywhere — well below ``UE5M3_MAX * F4_E2M1_MAX`` + (~688k) so for AMD-FP4 the inner-scale upper clamp does *not* + fire on this pattern (it does for NVFP4); regression target is + the kernel emitting the same answer as the oracle. + """ + block_size = 16 + + x = prepare_data(tensor_shape, data_type, pattern=pattern) + + data_lp_ref, scales_ref = convert_to_amdfp4_pytorch( + x, block_size=block_size, axis=axis, + ) + + data_lp, scales = convert_to_amdfp4( + x, block_size=block_size, axis=axis, update_outer_scale=False, + ) + assert torch.equal(scales_ref, scales), ( + f"Scale mismatch for pattern={pattern}: " + f"max diff={float((scales_ref - scales).abs().max())}" + ) + + is_hip = hasattr(torch.version, "hip") and torch.version.hip is not None + if is_hip: + mismatch_rate = (data_lp_ref != data_lp).float().mean().item() + assert mismatch_rate < 1e-3, ( + f"Quantized data mismatch rate {mismatch_rate:.4%} for pattern={pattern}" + ) + else: + assert torch.all(data_lp_ref == data_lp).item(), ( + f"Quantized data mismatch for pattern={pattern}" + ) + + x_dq = convert_from_amdfp4( + data_lp, scales, + output_dtype=data_type, block_size=block_size, axis=axis, + ) + x_dq_ref = convert_from_amdfp4_pytorch( + data_lp, scales, + output_dtype=data_type, block_size=block_size, axis=axis, + ) + assert torch.equal(x_dq_ref, x_dq), ( + f"Dequant mismatch for pattern={pattern}" + ) + + +# --------------------------------------------------------------------------- +# Outer-scale + zero tensor: divzero floor and FP32-normal range invariants +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("is_2d_block", [False, True]) +def test_amdfp4_zero_tensor_with_outer_scale(is_2d_block): + """All-zero input on the outer-scale branch must round-trip to 0 with + finite scales, and the effective per-block divisor + ``inner_scale * outer_scale`` must stay in FP32 normal range + (``UE5M3_EPS * _OUTER_SCALE_DIVZERO_FLOOR`` is the worst case).""" + block_size = 16 + axis = -1 + data_type = torch.bfloat16 + x = prepare_data((128, 64), data_type, pattern="zeros") + + outer_scale_buf = torch.empty(1, dtype=torch.float32, device=x.device) + fmt_eps, _ = _SCALE_FORMAT_TABLE["ue5m3"] + data_lp, scales = convert_to_amdfp4( + x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale_buf, update_outer_scale=True, + ) + + expected_floor_fp32 = torch.tensor( + _OUTER_SCALE_DIVZERO_FLOOR, dtype=torch.float32 + ).item() + outer_scale_value = outer_scale_buf.item() + assert outer_scale_value == expected_floor_fp32, ( + f"Zero-tensor outer_scale must be floored to " + f"_OUTER_SCALE_DIVZERO_FLOOR (FP32-rounded {expected_floor_fp32:.6e}), " + f"got {outer_scale_value:.6e}" + ) + + fp32_min_normal = torch.finfo(torch.float32).tiny + eff_min = scales.min().item() * outer_scale_value + assert eff_min >= fp32_min_normal, ( + f"Effective quant_scale ({eff_min:.6e}) fell into FP32 subnormal range; " + f"_OUTER_SCALE_DIVZERO_FLOOR * ue5m3_EPS ({fmt_eps:.6e}) would be " + f"flushed to zero under FTZ." + ) + + x_dq = convert_from_amdfp4( + data_lp, scales, output_dtype=data_type, + block_size=block_size, axis=axis, is_2d_block=is_2d_block, + outer_scale=outer_scale_buf, + ) + assert torch.isfinite(scales).all(), "stored block scale has NaN/Inf" + assert torch.isfinite(x_dq).all(), "outer_scale+zero dequant produced NaN/Inf" + assert (x_dq == 0).all(), "outer_scale+zero dequant must be exactly 0" + assert (data_lp == 0).all(), "outer_scale+zero packed FP4 must be all zero bins" + + +# --------------------------------------------------------------------------- +# Edge-tile shapes +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_amdfp4_non_aligned_m_no_nan_inf(data_type): + """Regression for the edge-tile bug on non-aligned M. + + Same shared kernel as NVFP4; this test makes sure the AMD-FP4 op + surface inherits the masked load/store edge-tile fix. + """ + x = prepare_data((150, 128), data_type) + data_lp, scales = convert_to_amdfp4( + x, + block_size=16, + axis=-1, + is_2d_block=False, + update_outer_scale=False, + ) + x_dq = convert_from_amdfp4( + data_lp, + scales, + output_dtype=data_type, + block_size=16, + axis=-1, + is_2d_block=False, + ) + assert torch.isfinite(scales).all(), "non-aligned M produced Inf/NaN scales" + assert torch.isfinite(x_dq).all(), "non-aligned M produced Inf/NaN dequant output" + + +# --------------------------------------------------------------------------- +# D2' NaN-input sanitization regression +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_amdfp4_quantization_nan_input_sanitized(data_type): + """A NaN spike in the input MUST NOT propagate to inner scales / dequant. + + Under D2' the UE5M3 cast is spec-aligned (NaN input -> 0xFF NaN code). + The defense layer in the shared ``_calculate_inner_scales`` (Triton) + and ``_quantize_inner_scale`` (PyTorch oracle) sanitises NaN + ``max_abs`` before the cast, so the resulting inner-scale tensor is + finite and downstream GEMM never sees a NaN. + + Mirrors the industry pattern (TransformerEngine "caller-side + sanitize", vLLM "zero-out padding", TRT-LLM "amax clamp"). + """ + block_size = 16 + x = prepare_data((128, 64), data_type) + x_nan = x.clone() + x_nan[0, 7] = float("nan") + + data_lp, scales = convert_to_amdfp4( + x_nan, + block_size=block_size, + axis=-1, + is_2d_block=False, + update_outer_scale=False, + ) + assert torch.isfinite(scales.float()).all(), ( + "NaN input produced non-finite inner scales under AMD-FP4; " + "D2' defense layer should have sanitised the block amax." + ) + + x_dq = convert_from_amdfp4( + data_lp, scales, + output_dtype=data_type, + block_size=block_size, + axis=-1, + is_2d_block=False, + ) + assert torch.isfinite(x_dq).all(), ( + "NaN input contaminated AMD-FP4 dequant output; " + "defense layer must keep downstream GEMM input finite." + ) + + +# --------------------------------------------------------------------------- +# Wide-dynamic-range advantage: the regime AMD-FP4 exists for (TEST-2). +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_amdfp4_beats_e4m3_on_wide_dynamic_range(data_type): + """AMD-FP4's reason to exist: when a block's ideal inner scale exceeds the + E4M3 max (448), NVFP4 saturates the inner scale and clamps the block's + large elements, whereas UE5M3 (max 114688) represents it faithfully. + + We build data whose per-block amax forces ``inner_scale_raw = amax/6`` well + above 448 but below UE5M3_MAX, then round-trip through both inner grids + (no outer scale) and assert AMD-FP4's reconstruction error is dramatically + lower than NVFP4(e4m3) on the *same* data. + """ + torch.manual_seed(0) + block_size = 16 + # randn * 8000 -> block amax ~ 1.5e4..3e4 -> inner_scale_raw ~ 2.5e3..5e3, + # far above E4M3 max 448 (saturates) but far below UE5M3 max 114688. + x = (torch.randn((256, 256), device="cuda") * 8000.0).to(data_type) + + # NVFP4 (E4M3 inner): inner scale saturates -> large elements clamped. + nv_lp, nv_scales = convert_to_nvfp4( + x, block_size=block_size, axis=-1, + update_outer_scale=False, scale_format="e4m3", + ) + x_nv = convert_from_nvfp4( + nv_lp, nv_scales, output_dtype=torch.float32, + block_size=block_size, axis=-1, scale_format="e4m3", + ) + + # AMD-FP4 (UE5M3 inner): inner scale fits -> faithful representation. + amd_lp, amd_scales = convert_to_amdfp4( + x, block_size=block_size, axis=-1, update_outer_scale=False, + ) + x_amd = convert_from_amdfp4( + amd_lp, amd_scales, output_dtype=torch.float32, + block_size=block_size, axis=-1, + ) + + xf = x.float() + err_nv = (x_nv - xf).abs().mean().item() + err_amd = (x_amd - xf).abs().mean().item() + + # Sanity: this data must actually saturate E4M3 (otherwise the test is + # vacuous). E4M3 max-normal inner scale * F4_E2M1_MAX caps any block's + # representable magnitude at 448*6=2688; our amax is >> that. + assert xf.abs().max().item() > 2688.0, ( + "test data does not enter the E4M3-saturation regime" + ) + # AMD-FP4 must be at least ~4x more accurate here; in practice it is + # orders of magnitude better. + assert err_amd < err_nv * 0.25, ( + f"AMD-FP4 should crush NVFP4 on wide-dynamic-range data: " + f"err_amd={err_amd:.2f} vs err_nv={err_nv:.2f}" + ) + + +# --------------------------------------------------------------------------- +# BUG-1 repro: NaN + dynamic outer scale must not poison the whole tensor. +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("is_2d_block", [False, True]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_amdfp4_nan_with_dynamic_outer_scale_is_contained(is_2d_block, data_type): + """Regression for BUG-1. + + With ``update_outer_scale=True`` the per-tensor outer scale is computed + from ``data.abs().max()``. A single NaN element makes that amax NaN, and + because the outer scale is shared, *every* block's ``inner_scale_raw`` + would become NaN — bypassing the per-block NaN defense in + ``_calculate_inner_scales``. ``compute_dynamic_outer_scale`` must + ``nan_to_num`` the reduced amax so the NaN spike degrades to "that block + goes to 0" instead of "the whole tensor is NaN". + """ + block_size = 16 + x = prepare_data((128, 64), data_type) + x_nan = x.clone() + x_nan[0, 7] = float("nan") + + outer_scale = torch.empty(1, dtype=torch.float32, device=x_nan.device) + data_lp, scales = convert_to_amdfp4( + x_nan, block_size=block_size, axis=-1, is_2d_block=is_2d_block, + outer_scale=outer_scale, update_outer_scale=True, + ) + + assert torch.isfinite(outer_scale).all(), ( + "dynamic outer scale became NaN — torch.clamp does not strip NaN; " + "compute_dynamic_outer_scale must nan_to_num the amax" + ) + assert torch.isfinite(scales.float()).all(), ( + "NaN poisoned the inner scales via the per-tensor outer scale" + ) + + x_dq = convert_from_amdfp4( + data_lp, scales, output_dtype=data_type, + block_size=block_size, axis=-1, is_2d_block=is_2d_block, + outer_scale=outer_scale, + ) + assert torch.isfinite(x_dq).all(), ( + "NaN contaminated the dequant output despite the defense layer" + ) + + +# A small placeholder that ensures ``is_cdna4`` import survives — used +# by call sites that detect CDNA4 to enable hardware ASM fast paths. +def test_amdfp4_is_cdna4_helper_callable(): + assert isinstance(is_cdna4(), bool) diff --git a/tests/unittest/amdfp4/test_amdfp_triton_pytorch_parity.py b/tests/unittest/amdfp4/test_amdfp_triton_pytorch_parity.py new file mode 100644 index 0000000..dd4cbc2 --- /dev/null +++ b/tests/unittest/amdfp4/test_amdfp_triton_pytorch_parity.py @@ -0,0 +1,256 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Triton vs. PyTorch parity tests for the UE5M3 quant primitive (AMD-FP4 inner scale). + +The PyTorch reference in :mod:`alto.kernels.fp4.fp4_primitives.ue5m3_ops` is +the oracle. This module verifies that :func:`triton_quantize_to_ue5m3` +produces **bit-identical fp32 snapped values** (with matching NaN masks +under D2') for: + +* all 256 UE5M3 grid points (including 0xFF -> NaN), +* a large random fp32 tensor (1M elems), +* eight pathological input patterns drawn from the AMD-FP4 stress matrix + (zeros, large normals, near-MIN_NORMAL, near-MAX, NaN/Inf, negatives, + hot channels, mixed magnitudes), +* an explicit NaN/Inf parity sweep. + +CUDA / ROCm is required. The whole module is skipped otherwise. +""" + +from __future__ import annotations + +import math + +import pytest +import torch + +cuda_required = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Triton parity tests require a CUDA / ROCm device", +) + +from alto.kernels.fp4.fp4_primitives import ( # noqa: E402 + UE5M3_EPS, + UE5M3_MAX, + UE5M3_MIN_NORMAL, + UE5M3_NAN_CODE, + f32_to_ue5m3_uint8, + quantize_to_ue5m3, + triton_quantize_to_ue5m3, + ue5m3_uint8_to_f32, +) + + +def _assert_nan_aware_equal(a: torch.Tensor, b: torch.Tensor, ctx: str) -> None: + """Bit-equal comparison that treats two NaN positions as equal. + + Required under D2': both PyTorch and Triton paths emit NaN for the + spec-aligned NaN code 0xFF, but ``torch.equal`` returns False on NaN + elements. We therefore separately check the NaN mask and the finite + bit-equality. + """ + nan_a = torch.isnan(a) + nan_b = torch.isnan(b) + assert torch.equal(nan_a, nan_b), ( + f"{ctx}: NaN masks differ; first diff idx=" + f"{int((nan_a ^ nan_b).nonzero()[0].item())}" + ) + finite = ~nan_a + assert torch.equal(a[finite], b[finite]), ( + f"{ctx}: finite snap values differ; " + f"max abs diff={(a[finite] - b[finite]).abs().max().item():.3e}" + ) + + +# --------------------------------------------------------------------------- +# Parity on the 256-code grid +# --------------------------------------------------------------------------- + +@cuda_required +def test_triton_parity_on_grid(): + codes = torch.arange(256, dtype=torch.uint8) + values = ue5m3_uint8_to_f32(codes).cuda() + snapped_triton = triton_quantize_to_ue5m3(values).cpu() + snapped_torch = quantize_to_ue5m3(values.cpu()) + _assert_nan_aware_equal( + snapped_triton, snapped_torch, + ctx="Triton vs PyTorch grid-point parity", + ) + + +@cuda_required +def test_triton_parity_re_encodes_to_same_uint8_on_grid(): + """A stronger statement: the FP32 Triton output, when re-encoded by + the PyTorch oracle, must produce the same UE5M3 code as the input + (including 0xFF -> NaN -> 0xFF under D2').""" + codes = torch.arange(256, dtype=torch.uint8) + values = ue5m3_uint8_to_f32(codes).cuda() + snapped = triton_quantize_to_ue5m3(values).cpu() + re_encoded = f32_to_ue5m3_uint8(snapped) + assert torch.equal(codes, re_encoded), ( + "Triton snapped values do not re-encode to their original UE5M3 codes; " + f"first diff at code 0x{int((codes != re_encoded).nonzero()[0].item()):02X}" + ) + + +# --------------------------------------------------------------------------- +# Parity on dense random data +# --------------------------------------------------------------------------- + +@cuda_required +@pytest.mark.parametrize("seed", [0, 1, 42]) +def test_triton_parity_random_1m(seed): + torch.manual_seed(seed) + # log-uniform magnitudes across the UE5M3 dynamic range + log_lo = math.log(UE5M3_EPS * 0.1) + log_hi = math.log(UE5M3_MAX * 4.0) # 4x oversampling to stress saturation + x_cpu = torch.empty(1024 * 1024, dtype=torch.float32).uniform_(log_lo, log_hi).exp() + # Inject occasional negatives / specials so the saturation branches fire. + n = x_cpu.numel() + x_cpu[torch.randperm(n)[: n // 64]] *= -1 + specials = torch.tensor( + [float("nan"), float("inf"), -float("inf"), 0.0, UE5M3_EPS, UE5M3_MAX], + dtype=torch.float32, + ) + x_cpu[: specials.numel()] = specials + + x_gpu = x_cpu.cuda() + snapped_triton = triton_quantize_to_ue5m3(x_gpu).cpu() + snapped_torch = quantize_to_ue5m3(x_cpu) + _assert_nan_aware_equal( + snapped_triton, snapped_torch, + ctx=f"Triton vs PyTorch UE5M3 random-1M parity (seed={seed})", + ) + + +# --------------------------------------------------------------------------- +# Parity on the AMD-FP4 stress pattern matrix +# --------------------------------------------------------------------------- + +def _make_pattern(name: str, n: int) -> torch.Tensor: + """Build one of the named stress patterns used across AMD-FP4 tests.""" + torch.manual_seed(hash(name) & 0xFFFFFFFF) + if name == "zeros": + return torch.zeros(n, dtype=torch.float32) + if name == "small_random": + return torch.randn(n, dtype=torch.float32) * 0.01 + if name == "large_random": + return torch.randn(n, dtype=torch.float32) * 1e3 + if name == "near_min_normal": + return torch.empty(n, dtype=torch.float32).uniform_( + UE5M3_EPS * 0.1, UE5M3_MIN_NORMAL * 2.0 + ) + if name == "near_max": + return torch.empty(n, dtype=torch.float32).uniform_( + UE5M3_MAX * 0.5, UE5M3_MAX * 1.5 + ) + if name == "specials": + base = torch.randn(n, dtype=torch.float32) + base[0] = float("nan") + base[1] = float("inf") + base[2] = -float("inf") + base[3] = -1.0 + base[4] = 0.0 + return base + if name == "hot_channel": + x = torch.full((n,), 0.1, dtype=torch.float32) + x[: max(1, n // 64)] = UE5M3_MAX * 0.9 + return x + if name == "mixed_magnitudes": + x = torch.empty(n, dtype=torch.float32) + chunk = n // 4 + x[:chunk] = torch.randn(chunk) * UE5M3_EPS * 4 + x[chunk:2 * chunk] = torch.randn(chunk) * 1.0 + x[2 * chunk:3 * chunk] = torch.randn(chunk) * 100.0 + x[3 * chunk:] = torch.randn(n - 3 * chunk) * UE5M3_MAX * 0.5 + return x + raise ValueError(f"Unknown pattern: {name}") + + +_PATTERNS = [ + "zeros", + "small_random", + "large_random", + "near_min_normal", + "near_max", + "specials", + "hot_channel", + "mixed_magnitudes", +] + + +@cuda_required +@pytest.mark.parametrize("pattern", _PATTERNS) +def test_triton_parity_patterns(pattern): + n = 4096 + x_cpu = _make_pattern(pattern, n) + x_gpu = x_cpu.cuda() + snapped_triton = triton_quantize_to_ue5m3(x_gpu).cpu() + snapped_torch = quantize_to_ue5m3(x_cpu) + _assert_nan_aware_equal( + snapped_triton, snapped_torch, + ctx=f"Triton vs PyTorch UE5M3 pattern='{pattern}' parity", + ) + + +# --------------------------------------------------------------------------- +# D2' explicit NaN / Inf parity +# --------------------------------------------------------------------------- + +@cuda_required +def test_triton_pytorch_parity_on_nan_inf(): + """PyTorch and Triton MUST emit identical NaN code 0xFF behaviour + on the canonical 'specials' inputs (NaN, ±Inf, overflow). + """ + x_cpu = torch.tensor( + [float("nan"), float("inf"), -float("inf"), + UE5M3_MAX, UE5M3_MAX * 2.0, 1e30, -1.0, 0.0], + dtype=torch.float32, + ) + snapped_triton = triton_quantize_to_ue5m3(x_cpu.cuda()).cpu() + snapped_torch = quantize_to_ue5m3(x_cpu) + + # NaN positions must agree. + assert torch.equal(torch.isnan(snapped_triton), torch.isnan(snapped_torch)), ( + f"NaN masks differ: triton={torch.isnan(snapped_triton).tolist()}, " + f"torch={torch.isnan(snapped_torch).tolist()}" + ) + # On finite positions the snap values must match bit-for-bit. + finite = ~torch.isnan(snapped_torch) + assert torch.equal(snapped_triton[finite], snapped_torch[finite]), ( + f"Finite snap values differ on specials sweep: " + f"triton={snapped_triton.tolist()}, torch={snapped_torch.tolist()}" + ) + # Specifically: NaN/+Inf/-Inf/overflow -> NaN; -1.0/0.0 -> 0.0; + # UE5M3_MAX -> 114688.0. + assert torch.isnan(snapped_torch[0]).item() # NaN + assert torch.isnan(snapped_torch[1]).item() # +Inf + assert torch.isnan(snapped_torch[2]).item() # -Inf + assert snapped_torch[3].item() == UE5M3_MAX # exact MAX -> 0xFE + assert torch.isnan(snapped_torch[4]).item() # 2*MAX -> overflow NaN + assert torch.isnan(snapped_torch[5]).item() # 1e30 -> overflow NaN + assert snapped_torch[6].item() == 0.0 # -1.0 -> 0 + assert snapped_torch[7].item() == 0.0 # 0.0 -> 0 + + +# --------------------------------------------------------------------------- +# Sanity: idempotency must hold on the Triton path too +# --------------------------------------------------------------------------- + +@cuda_required +def test_triton_idempotent(): + """Triton snap must be idempotent on its own output (NaN-aware under D2').""" + torch.manual_seed(0) + x = torch.randn(8192, dtype=torch.float32).abs() * 10.0 + x[0] = UE5M3_MAX * 2.0 # overflow -> NaN under D2' + x[1] = 0.0 + x[2] = float("nan") # explicit NaN input + x_gpu = x.cuda() + q1 = triton_quantize_to_ue5m3(x_gpu) + q2 = triton_quantize_to_ue5m3(q1) + _assert_nan_aware_equal( + q1.cpu(), q2.cpu(), + ctx="Triton UE5M3 idempotency", + ) diff --git a/tests/unittest/amdfp4/utils.py b/tests/unittest/amdfp4/utils.py new file mode 100644 index 0000000..a8a55bf --- /dev/null +++ b/tests/unittest/amdfp4/utils.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Test helpers for the AMD-FP4 unit suite. + +Reuses the NVFP4 PyTorch oracle (``tests/unittest/nvfp4/utils.py``) +because the recipe-level quant / dequant body is identical between +NVFP4 and AMD-FP4 modulo the inner-grid choice. This module just +exposes thin wrappers that pin ``scale_format='ue5m3'``, so the AMD-FP4 +test files never spell ``ue5m3`` themselves. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from torch import Tensor + +from alto.kernels.fp4.outer_scaled_fp4 import BLOCK_SIZE_DEFAULT +from alto.kernels.fp4.testing_utils import calc_cossim, calc_snr # noqa: F401 + +# Sibling import: the AMD-FP4 PyTorch oracle is the NVFP4 oracle pinned +# to ``scale_format='ue5m3'``; ``conftest.py`` arranges for the parent +# unittest dir on ``sys.path`` so this works. +from nvfp4.utils import ( # noqa: E402 (see amdfp4/conftest.py sys.path) + convert_from_nvfp4_pytorch as _convert_from_nvfp4_pytorch, + convert_to_nvfp4_pytorch as _convert_to_nvfp4_pytorch, + prepare_data, +) + +__all__ = ( + "BLOCK_SIZE_DEFAULT", + "calc_cossim", + "calc_snr", + "convert_from_amdfp4_pytorch", + "convert_to_amdfp4_pytorch", + "prepare_data", +) + + +def convert_to_amdfp4_pytorch( + data_hp: Tensor, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[Tensor] = None, +): + """PyTorch oracle for the AMD-FP4 quant op (UE5M3 inner scale).""" + return _convert_to_nvfp4_pytorch( + data_hp, + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + scale_format="ue5m3", + ) + + +def convert_from_amdfp4_pytorch( + data_lp: Tensor, + scales: Tensor, + output_dtype: torch.dtype = torch.float32, + block_size: int = BLOCK_SIZE_DEFAULT, + axis: int = -1, + is_2d_block: bool = False, + outer_scale: Optional[Tensor] = None, +): + """PyTorch oracle for the AMD-FP4 dequant op (UE5M3 inner scale).""" + return _convert_from_nvfp4_pytorch( + data_lp, + scales, + output_dtype=output_dtype, + block_size=block_size, + axis=axis, + is_2d_block=is_2d_block, + outer_scale=outer_scale, + scale_format="ue5m3", + ) diff --git a/tests/unittest/nvfp4/conftest.py b/tests/unittest/nvfp4/conftest.py new file mode 100644 index 0000000..9f4ddea --- /dev/null +++ b/tests/unittest/nvfp4/conftest.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +"""Triton env for NVFP4 / AMD-FP4 quantization tests. + +``quantize_ue5m3`` is a module-level ``@triton.jit`` helper imported into +``nvfp_quantization``; Triton requires ``TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1``. +""" + +import pytest + + +@pytest.fixture(autouse=True) +def _configure_triton_env(monkeypatch): + monkeypatch.setenv("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "1") + monkeypatch.setenv("TRITON_CACHE_DIR", "/tmp/triton-cache-nvfp4-tests") diff --git a/tests/unittest/nvfp4/test_nvfp_dispatch_guards.py b/tests/unittest/nvfp4/test_nvfp_dispatch_guards.py index 6a1da28..92e55da 100644 --- a/tests/unittest/nvfp4/test_nvfp_dispatch_guards.py +++ b/tests/unittest/nvfp4/test_nvfp_dispatch_guards.py @@ -75,7 +75,11 @@ def test_grouped_mm_routes_to_nvfp4_grouped_kernel(monkeypatch, device): calls = [] def _mock_grouped(A, B, *, offs, use_2dblock_x, use_2dblock_w, - use_sr_grad, use_outer_scale, use_hadamard, use_dge): + use_sr_grad, use_outer_scale, use_hadamard, use_dge, + scale_format): + # ``scale_format`` is REQUIRED (no default) so a future regression that + # drops the kwarg from the dispatch layer fails-fast here as a TypeError + # instead of silently routing every wrapped op to the e4m3 inner-grid. calls.append({ "A": A, "B": B, @@ -86,6 +90,7 @@ def _mock_grouped(A, B, *, offs, use_2dblock_x, use_2dblock_w, "use_outer_scale": use_outer_scale, "use_hadamard": use_hadamard, "use_dge": use_dge, + "scale_format": scale_format, }) return A.new_full((A.shape[0], B.shape[-1]), 7.0) @@ -122,6 +127,10 @@ def _mock_grouped(A, B, *, offs, use_2dblock_x, use_2dblock_w, assert call["use_outer_scale"] is True assert call["use_hadamard"] is True assert call["use_dge"] is True + # Default config (no explicit ``inner_scale_format``) must still route to + # the legacy E4M3 inner-grid; pinning this prevents a future default flip + # from silently re-quantizing every NVFP4-wrapped grouped_mm under UE5M3. + assert call["scale_format"] == "e4m3" @pytest.mark.parametrize("use_hadamard,use_dge", [ @@ -190,7 +199,11 @@ def test_linear_routes_to_nvfp4_linear_kernel(monkeypatch, device): calls = [] def _mock_linear(A, B, *, use_2dblock_x, use_2dblock_w, use_sr_grad, - use_outer_scale, use_hadamard, use_dge): + use_outer_scale, use_hadamard, use_dge, scale_format): + # ``scale_format`` is REQUIRED here for the same reason as the grouped + # mock: a future regression that drops the kwarg fails-fast as a + # TypeError instead of silently routing every NVFP4-wrapped linear + # call to the e4m3 inner-grid. calls.append({ "A": A, "B": B, @@ -200,6 +213,7 @@ def _mock_linear(A, B, *, use_2dblock_x, use_2dblock_w, use_sr_grad, "use_outer_scale": use_outer_scale, "use_hadamard": use_hadamard, "use_dge": use_dge, + "scale_format": scale_format, }) return A.new_full((A.shape[0], B.shape[0]), 5.0) @@ -230,6 +244,70 @@ def _mock_linear(A, B, *, use_2dblock_x, use_2dblock_w, use_sr_grad, assert call["use_outer_scale"] is True assert call["use_hadamard"] is True assert call["use_dge"] is True + # Same default-routing pin as the grouped path above. + assert call["scale_format"] == "e4m3" + + +# --------------------------------------------------------------------------- +# M2 — UE5M3 inner-scale routing +# +# The two tests below verify the *only* code path through which AMD-FP4 can be +# selected at the dispatch layer: setting ``inner_scale_format`` on the +# ``TrainingOpConfig``. They are intentionally separate from the legacy +# ``test_*_routes_to_nvfp4_*_kernel`` tests above to keep the long-standing +# E4M3 routing assertions completely unchanged (no new parametrize cell on the +# original tests), so existing baselines stay byte-for-byte the same. +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("inner_scale_format", ["e4m3", "ue5m3"]) +def test_grouped_mm_forwards_inner_scale_format(monkeypatch, device, inner_scale_format): + """Whatever ``cfg.inner_scale_format`` is set to MUST reach the kernel + entry point as ``scale_format``; otherwise the AMD-FP4 path is a no-op.""" + calls = [] + + def _mock_grouped(A, B, *, offs, use_2dblock_x, use_2dblock_w, + use_sr_grad, use_outer_scale, use_hadamard, use_dge, + scale_format): + calls.append({"scale_format": scale_format}) + return A.new_zeros((A.shape[0], B.shape[-1])) + + monkeypatch.setattr(dispatch_tensor, "_quantize_then_nvfp4_scaled_grouped_mm", _mock_grouped) + + cfg = _make_config(inner_scale_format=inner_scale_format) + num_experts, K, N, M = 2, 16, 16, ALIGN_SIZE_M + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W_wrapped = NVFP4TrainingWeightWrapperTensor( + torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=device), + cfg, + ) + offs = torch.tensor([M // num_experts, M], dtype=torch.int32, device=device) + + torch._grouped_mm(A, W_wrapped, offs=offs) + + assert len(calls) == 1 + assert calls[0]["scale_format"] == inner_scale_format + + +@pytest.mark.parametrize("inner_scale_format", ["e4m3", "ue5m3"]) +def test_linear_forwards_inner_scale_format(monkeypatch, device, inner_scale_format): + """Same routing guard as above for the dense F.linear path.""" + calls = [] + + def _mock_linear(A, B, *, use_2dblock_x, use_2dblock_w, use_sr_grad, + use_outer_scale, use_hadamard, use_dge, scale_format): + calls.append({"scale_format": scale_format}) + return A.new_zeros((A.shape[0], B.shape[0])) + + monkeypatch.setattr(dispatch_tensor, "_to_nvfp4_then_scaled_mm", _mock_linear) + + cfg = _make_config(inner_scale_format=inner_scale_format) + W_wrapped = _make_wrapper(cfg, device=device, shape=(32, 16)) + x = torch.randn(8, 16, dtype=torch.bfloat16, device=device) + + torch.nn.functional.linear(x, W_wrapped) + + assert len(calls) == 1 + assert calls[0]["scale_format"] == inner_scale_format # --------------------------------------------------------------------------- diff --git a/tests/unittest/nvfp4/test_nvfp_grouped_gemm.py b/tests/unittest/nvfp4/test_nvfp_grouped_gemm.py index dae2154..73dc691 100644 --- a/tests/unittest/nvfp4/test_nvfp_grouped_gemm.py +++ b/tests/unittest/nvfp4/test_nvfp_grouped_gemm.py @@ -94,8 +94,10 @@ def _bf16_grouped_ref_forward( @pytest.mark.parametrize("use_sr_grad", [False, True]) @pytest.mark.parametrize("use_outer_scale", [False, True]) @pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_grouped_gemm def test_nvfp4_grouped_gemm_autograd( shape, use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, data_type, + scale_format, ): """Output, dX, and dW SNR vs BF16 autograd reference must remain healthy.""" M_total, N, K, num_experts = shape @@ -131,6 +133,7 @@ def test_nvfp4_grouped_gemm_autograd( use_2dblock_w=use_2dblock_w, use_sr_grad=use_sr_grad, use_outer_scale=use_outer_scale, + scale_format=scale_format, ) loss = torch.nn.functional.mse_loss(y, target) loss.backward() @@ -161,7 +164,8 @@ def test_nvfp4_grouped_gemm_autograd( use_outer_scale=use_outer_scale, context=( f"NVFP4GroupedGEMM shape={shape} dtype={data_type} " - f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale}" + f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale} " + f"scale_format={scale_format}" ), ) @@ -306,8 +310,9 @@ def test_nvfp4_grouped_gemm_boundary(M_multiplier, num_experts): (False, True, False, False, True), # + DGE (False, True, False, True, True), # + Hadamard + DGE ]) +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_grouped_gemm def test_nvfp4_grouped_gemm_recipe_variants_smoke( - use_2dblock_x, use_2dblock_w, use_outer_scale, use_hadamard, use_dge, + use_2dblock_x, use_2dblock_w, use_outer_scale, use_hadamard, use_dge, scale_format, ): """Grouped recipe variants should run end-to-end, produce finite outputs, and backprop finite gradients without silently falling back.""" @@ -331,6 +336,7 @@ def test_nvfp4_grouped_gemm_recipe_variants_smoke( use_outer_scale=use_outer_scale, use_hadamard=use_hadamard, use_dge=use_dge, + scale_format=scale_format, ) loss = torch.nn.functional.mse_loss(y, target) loss.backward() diff --git a/tests/unittest/nvfp4/test_nvfp_linear.py b/tests/unittest/nvfp4/test_nvfp_linear.py index d1ecb6f..9c79620 100644 --- a/tests/unittest/nvfp4/test_nvfp_linear.py +++ b/tests/unittest/nvfp4/test_nvfp_linear.py @@ -47,8 +47,9 @@ @pytest.mark.parametrize("is_2d_block", [False, True]) @pytest.mark.parametrize("use_outer_scale", [False, True]) @pytest.mark.parametrize("use_sr", [False, True]) +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_linear def test_nvfp4_qdq_roundtrip( - shape, axis, is_2d_block, use_outer_scale, use_sr, + shape, axis, is_2d_block, use_outer_scale, use_sr, scale_format, ): """Verify precision of a single NVFP4 quant -> dequant round-trip. @@ -64,6 +65,7 @@ def test_nvfp4_qdq_roundtrip( is_2d_block=is_2d_block, use_outer_scale=use_outer_scale, use_sr=use_sr, + scale_format=scale_format, ) assert x_qdq.shape == x.shape @@ -109,8 +111,10 @@ def test_nvfp4_qdq_roundtrip( @pytest.mark.parametrize("use_sr_grad", [False, True]) @pytest.mark.parametrize("use_outer_scale", [False, True]) @pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_linear def test_nvfp4_linear_autograd_function( shape, use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, data_type, + scale_format, ): """Forward + dX + dW must match a BF16 nn.Linear reference in SNR. @@ -136,6 +140,9 @@ def test_nvfp4_linear_autograd_function( inputs, weights, use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, + None, # hadamard_transform + False, # use_dge + scale_format, ) loss = torch.nn.functional.mse_loss(outputs, target) loss.backward() @@ -165,7 +172,8 @@ def test_nvfp4_linear_autograd_function( use_outer_scale=use_outer_scale, context=( f"NVFP4Linear shape={shape} dtype={data_type} " - f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale}" + f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale} " + f"scale_format={scale_format}" ), ) diff --git a/tests/unittest/nvfp4/test_nvfp_quantization.py b/tests/unittest/nvfp4/test_nvfp_quantization.py index 18a8e99..2fcd344 100644 --- a/tests/unittest/nvfp4/test_nvfp_quantization.py +++ b/tests/unittest/nvfp4/test_nvfp_quantization.py @@ -6,6 +6,7 @@ import torch from alto.kernels.fp4.nvfp4.nvfp_quantization import ( _OUTER_SCALE_DIVZERO_FLOOR, + _SCALE_FORMAT_TABLE, convert_to_nvfp4, convert_from_nvfp4, compute_dynamic_outer_scale, @@ -25,9 +26,13 @@ @pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("use_sr", [False, True]) @pytest.mark.parametrize("use_outer_scale", [False, True]) +# UE5M3 (AMD-FP4) is now covered by ``tests/unittest/amdfp4/test_amdfp_quantization.py``; +# keeping the ``scale_format`` axis here so future inner-grid additions can extend it +# in one place without re-parametrising all four test functions. +@pytest.mark.parametrize("scale_format", ["e4m3"]) @pytest.mark.parametrize("compile", [False]) def test_nvfp4_quantization(tensor_shape, axis, is_2d_block, data_type, - use_sr, use_outer_scale, compile): + use_sr, use_outer_scale, scale_format, compile): block_size = 16 device = torch.device("cuda") @@ -49,18 +54,19 @@ def test_nvfp4_quantization(tensor_shape, axis, is_2d_block, data_type, data_lp_ref, scales_ref = convert_to_nvfp4_pytorch( x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, - outer_scale=outer_scale, + outer_scale=outer_scale, scale_format=scale_format, ) x_dq_ref = convert_from_nvfp4_pytorch( data_lp_ref, scales_ref, output_dtype=data_type, block_size=block_size, axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale, + scale_format=scale_format, ) data_lp, scales = quant_func( x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale, update_outer_scale=False, - use_sr=use_sr, + scale_format=scale_format, use_sr=use_sr, ) # Both reference and kernel apply the same float32->float8_e4m3fn->float32 # rounding in the same order, so their scales are bit-for-bit identical. @@ -91,6 +97,7 @@ def test_nvfp4_quantization(tensor_shape, axis, is_2d_block, data_type, data_lp, scales, output_dtype=data_type, block_size=block_size, axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale, + scale_format=scale_format, ) if not use_sr: @@ -99,6 +106,7 @@ def test_nvfp4_quantization(tensor_shape, axis, is_2d_block, data_type, data_lp, scales, output_dtype=data_type, block_size=block_size, axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale, + scale_format=scale_format, ) assert torch.equal(x_dq_cross, x_dq), ( f"Dequant not bit-exact: " @@ -115,7 +123,8 @@ def test_nvfp4_quantization(tensor_shape, axis, is_2d_block, data_type, @pytest.mark.parametrize("tensor_shape", [(128, 64), (2048, 2048)]) @pytest.mark.parametrize("axis", [-1, -2]) @pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) -def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_quantization +def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type, scale_format): """Verify that ``update_outer_scale=True`` refreshes the caller's scale buffer in place, producing the same quantized output as explicitly pre-computing the scale via :func:`compute_dynamic_outer_scale`.""" @@ -124,10 +133,11 @@ def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): x = prepare_data(tensor_shape, data_type) # Path 1: manual pre-compute, pass in, no update. - outer_scale_manual = compute_dynamic_outer_scale(x) + outer_scale_manual = compute_dynamic_outer_scale(x, scale_format=scale_format) data_lp_manual, scales_manual = convert_to_nvfp4( x, block_size=block_size, axis=axis, outer_scale=outer_scale_manual, update_outer_scale=False, + scale_format=scale_format, ) # Path 2: caller-owned buffer, dynamically refreshed in-place. @@ -135,6 +145,7 @@ def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): data_lp_dyn, scales_dyn = convert_to_nvfp4( x, block_size=block_size, axis=axis, outer_scale=outer_scale_dyn, update_outer_scale=True, + scale_format=scale_format, ) assert torch.equal(outer_scale_manual, outer_scale_dyn), ( @@ -147,7 +158,7 @@ def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): x_dq = convert_from_nvfp4( data_lp_dyn, scales_dyn, output_dtype=data_type, block_size=block_size, axis=axis, - outer_scale=outer_scale_dyn, + outer_scale=outer_scale_dyn, scale_format=scale_format, ) assert x_dq.shape == x.shape assert x_dq.dtype == data_type @@ -155,7 +166,7 @@ def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): x_dq_ref = convert_from_nvfp4_pytorch( data_lp_dyn, scales_dyn, output_dtype=data_type, block_size=block_size, axis=axis, - outer_scale=outer_scale_dyn, + outer_scale=outer_scale_dyn, scale_format=scale_format, ) is_hip = hasattr(torch.version, "hip") and torch.version.hip is not None if is_hip: @@ -168,7 +179,8 @@ def test_nvfp4_dynamic_outer_scale(tensor_shape, axis, data_type): @pytest.mark.parametrize("axis", [-1, -2]) @pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("pattern", ["zeros", "large"]) -def test_nvfp4_special_values(tensor_shape, axis, data_type, pattern): +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_quantization +def test_nvfp4_special_values(tensor_shape, axis, data_type, pattern, scale_format): """Verify quantization correctness on edge-case inputs. - zeros: all-zero tensor, exercises the E4M3 lower clamp on the stored block scale. @@ -180,11 +192,12 @@ def test_nvfp4_special_values(tensor_shape, axis, data_type, pattern): x = prepare_data(tensor_shape, data_type, pattern=pattern) data_lp_ref, scales_ref = convert_to_nvfp4_pytorch( - x, block_size=block_size, axis=axis, + x, block_size=block_size, axis=axis, scale_format=scale_format, ) data_lp, scales = convert_to_nvfp4( x, block_size=block_size, axis=axis, update_outer_scale=False, + scale_format=scale_format, ) assert torch.equal(scales_ref, scales), ( f"Scale mismatch for pattern={pattern}: " @@ -205,10 +218,12 @@ def test_nvfp4_special_values(tensor_shape, axis, data_type, pattern): x_dq = convert_from_nvfp4( data_lp, scales, output_dtype=data_type, block_size=block_size, axis=axis, + scale_format=scale_format, ) x_dq_ref = convert_from_nvfp4_pytorch( data_lp, scales, output_dtype=data_type, block_size=block_size, axis=axis, + scale_format=scale_format, ) assert torch.equal(x_dq_ref, x_dq), ( f"Dequant mismatch for pattern={pattern}" @@ -216,7 +231,8 @@ def test_nvfp4_special_values(tensor_shape, axis, data_type, pattern): @pytest.mark.parametrize("is_2d_block", [False, True]) -def test_nvfp4_zero_tensor_with_outer_scale(is_2d_block): +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_quantization +def test_nvfp4_zero_tensor_with_outer_scale(is_2d_block, scale_format): """All-zero tensor on the outer-scale branch must round-trip to 0 with no NaN/Inf, and the effective per-block divisor ``inner_scale * outer_scale`` must stay in FP32 normal range. @@ -230,9 +246,11 @@ def test_nvfp4_zero_tensor_with_outer_scale(is_2d_block): x = prepare_data((128, 64), data_type, pattern="zeros") outer_scale_buf = torch.empty(1, dtype=torch.float32, device=x.device) + fmt_eps, _ = _SCALE_FORMAT_TABLE[scale_format] data_lp, scales = convert_to_nvfp4( x, block_size=block_size, axis=axis, is_2d_block=is_2d_block, outer_scale=outer_scale_buf, update_outer_scale=True, + scale_format=scale_format, ) expected_floor_fp32 = torch.tensor( @@ -249,13 +267,14 @@ def test_nvfp4_zero_tensor_with_outer_scale(is_2d_block): eff_min = scales.min().item() * outer_scale_value assert eff_min >= fp32_min_normal, ( f"Effective quant_scale ({eff_min:.6e}) fell into FP32 subnormal range; " - f"_OUTER_SCALE_DIVZERO_FLOOR * E4M3_EPS would be flushed to zero under FTZ." + f"_OUTER_SCALE_DIVZERO_FLOOR * {scale_format}_EPS ({fmt_eps:.6e}) would be " + f"flushed to zero under FTZ." ) x_dq = convert_from_nvfp4( data_lp, scales, output_dtype=data_type, block_size=block_size, axis=axis, is_2d_block=is_2d_block, - outer_scale=outer_scale_buf, + outer_scale=outer_scale_buf, scale_format=scale_format, ) assert torch.isfinite(scales).all(), "stored block scale has NaN/Inf" assert torch.isfinite(x_dq).all(), "outer_scale+zero dequant produced NaN/Inf" @@ -291,3 +310,57 @@ def test_nvfp4_non_aligned_m_no_nan_inf(data_type): ) assert torch.isfinite(scales).all(), "non-aligned M produced Inf/NaN scales" assert torch.isfinite(x_dq).all(), "non-aligned M produced Inf/NaN dequant output" + + +# --------------------------------------------------------------------------- +# D2' NaN-input sanitization regression +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("scale_format", ["e4m3"]) # UE5M3 covered by amdfp4/test_amdfp_quantization +@pytest.mark.parametrize("data_type", [torch.float32, torch.bfloat16]) +def test_nvfp4_quantization_nan_input_sanitized(scale_format, data_type): + """A NaN spike in the input MUST NOT propagate to inner scales / dequant. + + Under D2' the UE5M3 cast is spec-aligned: NaN input -> 0xFF NaN code. + The defense layer in ``_calculate_nvfp4_scales`` (Triton) and + ``_quantize_inner_scale`` (PyTorch oracle) sanitises NaN ``max_abs`` + before the cast, so the resulting inner-scale tensor is finite and + downstream GEMM never sees a NaN. The same guard is applied to the + E4M3 path symmetrically (``torch.float8_e4m3fn`` also has a 0xFF / + 0x7F NaN encoding). + + Mirrors the industry pattern (TransformerEngine "caller-side + sanitize", vLLM "zero-out padding", TRT-LLM "amax clamp"). + """ + block_size = 16 + x = prepare_data((128, 64), data_type) + # Inject a NaN spike into one block so the block_amax for that block + # would propagate NaN without the defense layer. + x_nan = x.clone() + x_nan[0, 7] = float("nan") + + data_lp, scales = convert_to_nvfp4( + x_nan, + block_size=block_size, + axis=-1, + is_2d_block=False, + update_outer_scale=False, + scale_format=scale_format, + ) + assert torch.isfinite(scales.float()).all(), ( + f"NaN input produced non-finite inner scales under {scale_format}; " + f"D2' defense layer should have sanitised the block amax." + ) + + x_dq = convert_from_nvfp4( + data_lp, scales, + output_dtype=data_type, + block_size=block_size, + axis=-1, + is_2d_block=False, + scale_format=scale_format, + ) + assert torch.isfinite(x_dq).all(), ( + f"NaN input contaminated dequant output under {scale_format}; " + f"defense layer must keep downstream GEMM input finite." + ) diff --git a/tests/unittest/nvfp4/utils.py b/tests/unittest/nvfp4/utils.py index 39c7192..b70b6ca 100644 --- a/tests/unittest/nvfp4/utils.py +++ b/tests/unittest/nvfp4/utils.py @@ -4,17 +4,55 @@ import torch from torch import Tensor -from alto.kernels.fp4.nvfp4.nvfp_quantization import BLOCK_SIZE_DEFAULT +from alto.kernels.fp4.nvfp4.nvfp_quantization import ( + BLOCK_SIZE_DEFAULT, + SUPPORTED_SCALE_FORMATS, + _SCALE_FORMAT_TABLE, +) # Re-exported so existing ``from .utils import calc_snr, calc_cossim`` # call-sites keep working; the single source of truth lives in # ``alto.kernels.fp4.testing_utils``. +from alto.kernels.fp4.fp4_primitives import ( + E4M3_EPS, + F8E4M3_MAX, + quantize_to_e4m3, + quantize_to_ue5m3, +) from alto.kernels.fp4.testing_utils import calc_snr, calc_cossim # noqa: F401 F4_E2M1_MAX = 6.0 -F8E4M3_MAX = 448.0 -E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny + + +def _quantize_inner_scale(inner_scale_raw: Tensor, scale_format: str) -> Tensor: + """Snap per-block inner scales to the selected inner grid (FP32 out). + + NaN defense: NaN / ±Inf in ``inner_scale_raw`` are sanitised to + ``[fmt_eps, fmt_max]`` before the cast. Under D2' the UE5M3 cast is + spec-aligned (0xFF reserved for NaN), so without this guard a NaN + input would emit a 0xFF code and contaminate downstream GEMM scales. + The same guard is applied symmetrically to the E4M3 path because + ``torch.float8_e4m3fn`` has the same NaN encoding (0xFF / 0x7F) and + PyTorch's ``Tensor.clamp`` does not replace NaN. + + This mirrors the industry pattern (TransformerEngine, vLLM, + TensorRT-LLM) where NaN is never allowed to propagate through a + quant cast into the GEMM input. + """ + if scale_format not in _SCALE_FORMAT_TABLE: + raise ValueError( + f"scale_format={scale_format!r} not supported; " + f"expected one of {SUPPORTED_SCALE_FORMATS}" + ) + fmt_eps, fmt_max = _SCALE_FORMAT_TABLE[scale_format] + inner_scale_raw = torch.nan_to_num( + inner_scale_raw, nan=fmt_eps, posinf=fmt_max, neginf=fmt_eps, + ) + clamped = inner_scale_raw.clamp(min=fmt_eps, max=fmt_max) + if scale_format == "ue5m3": + return quantize_to_ue5m3(clamped.float()) + return quantize_to_e4m3(clamped.float()) def prepare_data(tensor_shape, data_type, pattern="random"): @@ -24,11 +62,14 @@ def prepare_data(tensor_shape, data_type, pattern="random"): tensor_shape: Shape of the output tensor. data_type: Data type (torch.float32 or torch.bfloat16). pattern: Data pattern - - "random" : Gaussian with sparse outliers (default). - "zeros" : All zeros — exercises the E4M3 lower clamp on the - stored block scale. - "large" : All 5000.0 — exceeds F8E4M3_MAX * F4_E2M1_MAX (2688), - tests FP4 saturation and scale clamp to F8E4M3_MAX. + "random" : Gaussian with sparse outliers (default). + "zeros" : All zeros — inner-scale lower clamp + outer floor. + "large" : All 5000.0 — FP4 saturation + inner-scale upper clamp. + "hot_channel" : One column at amplitude 300, rest randn. + "lognormal" : exp(N(0, 1.5)) heavy-tail magnitudes. + "near_overflow" : All 1e4 — between E4M3 max and UE5M3 max. + "near_underflow" : All 1e-4 — between UE5M3_EPS and E4M3_EPS. + "single_spike" : One element ±5000, remainder zero. """ torch.manual_seed(1234) device = torch.device("cuda") @@ -41,6 +82,20 @@ def prepare_data(tensor_shape, data_type, pattern="random"): x = torch.zeros(tensor_shape, dtype=data_type, device=device) elif pattern == "large": x = torch.ones(tensor_shape, dtype=data_type, device=device) * 5000.0 + elif pattern == "hot_channel": + x = torch.randn(tensor_shape, dtype=data_type, device=device) + col = tensor_shape[-1] // 2 + x[..., col] = 300.0 + elif pattern == "lognormal": + x = torch.exp(torch.randn(tensor_shape, dtype=data_type, device=device) * 1.5) + elif pattern == "near_overflow": + x = torch.ones(tensor_shape, dtype=data_type, device=device) * 1.0e4 + elif pattern == "near_underflow": + x = torch.ones(tensor_shape, dtype=data_type, device=device) * 1.0e-4 + elif pattern == "single_spike": + x = torch.zeros(tensor_shape, dtype=data_type, device=device) + flat = x.reshape(-1) + flat[flat.numel() // 2] = 5000.0 else: raise ValueError(f"Unknown pattern: {pattern}") @@ -244,7 +299,12 @@ def convert_to_nvfp4_pytorch( scale_format: str = "e4m3", ): assert data_hp.dtype in [torch.float32, torch.bfloat16] - assert scale_format == "e4m3", f"scale_format={scale_format!r} not yet supported" + if scale_format not in SUPPORTED_SCALE_FORMATS: + raise ValueError( + f"scale_format={scale_format!r} not supported; " + f"expected one of {SUPPORTED_SCALE_FORMATS}" + ) + fmt_eps, fmt_max = _SCALE_FORMAT_TABLE[scale_format] data_hp = data_hp.transpose(axis, -1) ori_shape = data_hp.shape @@ -264,16 +324,18 @@ def convert_to_nvfp4_pytorch( max_abs = grouped.abs().amax(dim=-1) # NVFP4 spec order: outer_scale-normalise first, then derive the inner - # block scale, with clamp + E4M3 round applied exactly once on the + # block scale, with clamp + inner-grid round applied exactly once on the # final stored value. if outer_scale is not None: outer_scale = outer_scale.float().to(data_hp.device) - inner_scale_raw = (max_abs / outer_scale / F4_E2M1_MAX).clamp(min=E4M3_EPS, max=F8E4M3_MAX) - inner_scale = inner_scale_raw.to(torch.float8_e4m3fn).to(torch.float32) + inner_scale_raw = (max_abs / outer_scale / F4_E2M1_MAX).clamp( + min=fmt_eps, max=fmt_max, + ) + inner_scale = _quantize_inner_scale(inner_scale_raw, scale_format) quant_scale = inner_scale * outer_scale else: - inner_scale_raw = (max_abs / F4_E2M1_MAX).clamp(min=E4M3_EPS, max=F8E4M3_MAX) - inner_scale = inner_scale_raw.to(torch.float8_e4m3fn).to(torch.float32) + inner_scale_raw = (max_abs / F4_E2M1_MAX).clamp(min=fmt_eps, max=fmt_max) + inner_scale = _quantize_inner_scale(inner_scale_raw, scale_format) quant_scale = inner_scale if is_2d_block: