Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def nvfp4_grouped_gemm(
[M_total, N] output tensor.
"""
if not trans_weights:
# User/dispatch stored weights as [E, K, N]; materialize the canonical
# [E, N, K] layout exactly once, here at the wrapper boundary.
expert_weights = expert_weights.transpose(-2, -1).contiguous()
# Downstream grouped kernels and _qdq both consume strided
# expert_weights, so no .contiguous() needed.
expert_weights = expert_weights.transpose(-2, -1)
return _nvfp4_grouped_gemm_impl(
inputs,
expert_weights,
Expand Down Expand Up @@ -146,7 +146,8 @@ def _quantize_then_nvfp4_scaled_grouped_mm(
in dispatch layout ``[E, K, N]`` and is transposed once to the canonical
``[E, N, K]`` layout shared by the autograd function and its primitives.
"""
B_canonical = B.transpose(-2, -1).contiguous()
# Zero-copy stride view; see nvfp4_grouped_gemm for rationale.
B_canonical = B.transpose(-2, -1)
return _nvfp4_grouped_gemm_impl(
A,
B_canonical,
Expand Down
48 changes: 38 additions & 10 deletions alto/kernels/fp4/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,52 @@ def calc_cossim(x: torch.Tensor, y: torch.Tensor) -> float:
return (torch.dot(x_flat, y_flat) / (x_flat.norm() * y_flat.norm())).item()


# NVFP4 autograd SNR floors are empirical, not derived from a closed form.
# Calibrated on MI355X by running the full op-level autograd matrix, then
# setting each (kind, outer, K bucket, SR) tier to ~0.5–2 dB below observed
# per-tensor mins and median(O,dX,dW), with pytest pass/fail tweaks.
#
# Partitioning: K>=1024 stress (SR noise ~sqrt(K)) gets lower hard floors;
# use_outer_scale=True (production) is much tighter than inner-only; grouped
# is slightly looser than linear on the same bucket.
#
# Used by check_nvfp4_autograd_snr: hard_floor guards every O/dX/dW; aggregate
# floor guards median/mean so a low large-K SR hard floor does not weaken all cases.


def _nvfp4_autograd_snr_thresholds(
*,
K: int,
use_sr_grad: bool,
kind: str,
use_outer_scale: bool = False,
) -> tuple[float, float]:
"""Return ``(hard_floor, aggregate_floor)`` for NVFP4 autograd SNR tests.

The floors are intentionally K-aware. Large reduction dimensions amplify
stochastic-rounding noise roughly as ``sqrt(K)``, so production-scale
stress cases such as K=2048 need a lower per-tensor hard floor while small
and medium cases should keep much tighter checks.
"""
"""Return ``(hard_floor, aggregate_floor)`` for NVFP4 autograd SNR tests."""
if kind == "nvfp4_linear":
if use_outer_scale:
# Production default (outer + block scales). Tighter than the
# historical inner-only large-K SR floor (3 dB) but with margin for
# 2D-block + SR stress (measured dX/dW can sit in the high single
# digits on grouped K=2048 + SR).
if K >= 1024:
return (9.0, 11.0) if use_sr_grad else (12.0, 14.0)
if K >= 256:
return (10.0, 12.0) if use_sr_grad else (14.0, 17.0)
return (12.0, 14.0) if use_sr_grad else (16.0, 17.5)
if K >= 1024:
return (3.0, 4.5) if use_sr_grad else (4.0, 8.0)
if K >= 256:
return (10.0, 12.0) if use_sr_grad else (12.0, 14.0)
return (12.0, 14.0) if use_sr_grad else (14.0, 15.0)

if kind == "nvfp4_grouped_gemm":
if use_outer_scale:
# Grouped paths (especially 2D-block + large G) trail linear outer
# SNR slightly on dX/dW; keep ~1 dB margin below observed mins.
if K >= 1024:
# 2D-x+w at K=2048 can pull median dX/dW to ~12.5 dB (still >> inner-only).
return (9.0, 10.0) if use_sr_grad else (12.0, 12.5)
return (10.0, 12.0) if use_sr_grad else (13.0, 15.5)
if K >= 1024:
return (4.0, 6.0) if use_sr_grad else (7.0, 10.0)
return (7.0, 10.0) if use_sr_grad else (10.0, 12.0)
Expand All @@ -84,6 +109,7 @@ def check_nvfp4_autograd_snr(
K: int,
use_sr_grad: bool,
kind: str,
use_outer_scale: bool = False,
context: str = "",
) -> None:
"""Validate NVFP4 autograd SNR with per-tensor and aggregate checks.
Expand All @@ -103,14 +129,16 @@ def check_nvfp4_autograd_snr(
K=K,
use_sr_grad=use_sr_grad,
kind=kind,
use_outer_scale=use_outer_scale,
)
ctx = f"{context}: " if context else ""
outer_tag = ", use_outer_scale=True" if use_outer_scale else ""

Comment on lines 134 to 136
for name, value in snrs.items():
assert value > hard_floor, (
f"{ctx}{name} SNR too low: {value:.2f} dB "
f"< hard_floor={hard_floor:.2f} dB "
f"(K={K}, use_sr_grad={use_sr_grad}, kind={kind}). "
f"(K={K}, use_sr_grad={use_sr_grad}, kind={kind}{outer_tag}). "
"This likely indicates a real regression in quantization, "
"axis alignment, scale handling, or gradient propagation."
)
Expand All @@ -121,10 +149,10 @@ def check_nvfp4_autograd_snr(
assert median_snr > aggregate_floor, (
f"{ctx}median SNR too low: {median_snr:.2f} dB "
f"< aggregate_floor={aggregate_floor:.2f} dB "
f"(values={snrs}, K={K}, use_sr_grad={use_sr_grad}, kind={kind})."
f"(values={snrs}, K={K}, use_sr_grad={use_sr_grad}, kind={kind}{outer_tag})."
)
assert mean_snr > aggregate_floor, (
f"{ctx}mean SNR too low: {mean_snr:.2f} dB "
f"< aggregate_floor={aggregate_floor:.2f} dB "
f"(values={snrs}, K={K}, use_sr_grad={use_sr_grad}, kind={kind})."
f"(values={snrs}, K={K}, use_sr_grad={use_sr_grad}, kind={kind}{outer_tag})."
)
26 changes: 11 additions & 15 deletions tests/unittest/nvfp4/test_nvfp_grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _bf16_grouped_ref_forward(
@pytest.mark.parametrize("use_2dblock_x", [False, True])
@pytest.mark.parametrize("use_2dblock_w", [False, True])
@pytest.mark.parametrize("use_sr_grad", [False, True])
@pytest.mark.parametrize("use_outer_scale", [False, True])
@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32])
Comment on lines 92 to 96
def test_nvfp4_grouped_gemm_autograd(
shape, use_2dblock_x, use_2dblock_w, use_sr_grad, data_type
shape, use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, data_type,
):
"""Output, dX, and dW SNR vs BF16 autograd reference must remain healthy."""
M_total, N, K, num_experts = shape
Expand Down Expand Up @@ -129,6 +130,7 @@ def test_nvfp4_grouped_gemm_autograd(
use_2dblock_x=use_2dblock_x,
use_2dblock_w=use_2dblock_w,
use_sr_grad=use_sr_grad,
use_outer_scale=use_outer_scale,
)
loss = torch.nn.functional.mse_loss(y, target)
loss.backward()
Expand Down Expand Up @@ -156,9 +158,10 @@ def test_nvfp4_grouped_gemm_autograd(
K=K,
use_sr_grad=use_sr_grad,
kind="nvfp4_grouped_gemm",
use_outer_scale=use_outer_scale,
context=(
f"NVFP4GroupedGEMM shape={shape} dtype={data_type} "
f"x_2d={use_2dblock_x} w_2d={use_2dblock_w}"
f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale}"
),
)

Expand All @@ -182,8 +185,9 @@ def test_nvfp4_grouped_gemm_autograd(
((1024, 512, 512, 8), True, False, False),
((1024, 512, 512, 8), False, False, True),
])
@pytest.mark.parametrize("use_outer_scale", [False, True])
def test_nvfp4_grouped_gemm_forward_compares_with_mxfp4(
shape, trans_weights, use_2dblock_x, use_2dblock_w,
shape, trans_weights, use_2dblock_x, use_2dblock_w, use_outer_scale,
):
"""NVFP4 and MXFP4 grouped forward paths should both track BF16 reference.

Expand Down Expand Up @@ -219,10 +223,9 @@ def test_nvfp4_grouped_gemm_forward_compares_with_mxfp4(
trans_weights=trans_weights,
)

# Use analogous minimal grouped recipes for both format families. In
# particular, do not enable MXFP4 clipping or macro-block scaling here:
# those are valuable recipe knobs, but this test is only a cross-format
# smoke/reference check for the common routed-expert shape.
# Minimal grouped recipes for both formats. Intentional asymmetry: NVFP4
# sweeps ``use_outer_scale`` (two-level FP32 scaling path); MXFP4 stays
# fixed (no clipping, no macro-block). Cross-format smoke/reference check.
y_nv = nvfp4_grouped_gemm(
inputs,
expert_weights,
Expand All @@ -231,6 +234,7 @@ def test_nvfp4_grouped_gemm_forward_compares_with_mxfp4(
use_2dblock_x=use_2dblock_x,
use_2dblock_w=use_2dblock_w,
use_sr_grad=False,
use_outer_scale=use_outer_scale,
)
y_mx = mxfp4_grouped_gemm(
inputs,
Expand All @@ -248,14 +252,6 @@ def test_nvfp4_grouped_gemm_forward_compares_with_mxfp4(

nv_snr = calc_snr(y_ref, y_nv)
mx_snr = calc_snr(y_ref, y_mx)
print()
print(tabulate(
[
["NVFP4", f"{nv_snr:.2f}", f"{calc_cossim(y_ref, y_nv):.6f}"],
["MXFP4", f"{mx_snr:.2f}", f"{calc_cossim(y_ref, y_mx):.6f}"],
],
headers=["Format", "SNR", "CosSim"], tablefmt="github",
))

assert torch.isfinite(y_nv).all()
assert torch.isfinite(y_mx).all()
Expand Down
37 changes: 11 additions & 26 deletions tests/unittest/nvfp4/test_nvfp_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def test_nvfp4_qdq_roundtrip(
@pytest.mark.parametrize("use_2dblock_x", [False, True])
@pytest.mark.parametrize("use_2dblock_w", [False, True])
@pytest.mark.parametrize("use_sr_grad", [False, True])
@pytest.mark.parametrize("use_outer_scale", [False, True])
@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32])
def test_nvfp4_linear_autograd_function(
shape, use_2dblock_x, use_2dblock_w, use_sr_grad, data_type,
shape, use_2dblock_x, use_2dblock_w, use_sr_grad, use_outer_scale, data_type,
):
"""Forward + dX + dW must match a BF16 nn.Linear reference in SNR.

Expand All @@ -131,17 +132,10 @@ def test_nvfp4_linear_autograd_function(
grad_weights_ref = weights.grad.clone()
inputs.grad.zero_(); weights.grad.zero_()

# NVFP4 QDQ path. This is NOT a shared MXFP4/NVFP4 test; it is specific
# to NVFP4's autograd function. We still keep
# ``use_outer_scale=False`` here on purpose so the parity check stays
# focused on the shared 6-QDQ linear math (forward + dX + dW). The
# outer-level scale path is exercised separately in the QDQ round-trip
# and quantization tests, where it can be isolated without broadening
# this matrix or coupling the SNR thresholds to an NVFP4-only knob.
outputs = NVFP4LinearFunction.apply(
inputs, weights,
use_2dblock_x, use_2dblock_w, use_sr_grad,
False, # use_outer_scale
use_outer_scale,
)
loss = torch.nn.functional.mse_loss(outputs, target)
loss.backward()
Expand All @@ -168,21 +162,22 @@ def test_nvfp4_linear_autograd_function(
K=K,
use_sr_grad=use_sr_grad,
kind="nvfp4_linear",
use_outer_scale=use_outer_scale,
context=(
f"NVFP4Linear shape={shape} dtype={data_type} "
f"x_2d={use_2dblock_x} w_2d={use_2dblock_w}"
f"x_2d={use_2dblock_x} w_2d={use_2dblock_w} outer={use_outer_scale}"
),
)


@pytest.mark.parametrize("shape,use_2dblock_x,use_2dblock_w", [
# Shared with both NVFP4 and MXFP4 linear-autograd test matrices.
((1, 512, 384, 128), False, False),
((1, 512, 384, 128), False, True),
((4, 1024, 1024, 2048), False, False),
])
@pytest.mark.parametrize("use_outer_scale", [False, True])
def test_nvfp4_linear_forward_compares_with_mxfp4_on_shared_cases(
shape, use_2dblock_x, use_2dblock_w,
shape, use_2dblock_x, use_2dblock_w, use_outer_scale,
):
"""NVFP4 and MXFP4 forward paths should both stay close to BF16 reference.

Expand All @@ -200,18 +195,16 @@ def test_nvfp4_linear_forward_compares_with_mxfp4_on_shared_cases(
# "NVFP4 vs BF16" and "MXFP4 vs BF16", not "NVFP4 vs MXFP4".
y_ref = torch.nn.functional.linear(x, w)

# Keep both recipes deliberately minimal and analogous: 1D block scales,
# deterministic rounding on the forward path, no clipping, no Hadamard,
# no DGE, no macro/outer scaling. This keeps the test focused on whether
# each format family can run the same dense-linear shape and produce a
# finite output with reasonable forward error.
# Minimal forward config for both formats (1D blocks, RNE, no clip/Hadamard/
# DGE). Intentional asymmetry: NVFP4 sweeps ``use_outer_scale`` to exercise
# its two-level FP32 scaling path; MXFP4 stays fixed (no macro-block).
y_nv = _to_nvfp4_then_scaled_mm(
x,
w,
use_2dblock_x=use_2dblock_x,
use_2dblock_w=use_2dblock_w,
use_sr_grad=False,
use_outer_scale=False,
use_outer_scale=use_outer_scale,
)
y_mx = _to_mxfp4_then_scaled_mm(
x,
Expand All @@ -227,14 +220,6 @@ def test_nvfp4_linear_forward_compares_with_mxfp4_on_shared_cases(

nv_snr = calc_snr(y_ref, y_nv)
mx_snr = calc_snr(y_ref, y_mx)
print()
print(tabulate(
[
["NVFP4", f"{nv_snr:.2f}", f"{calc_cossim(y_ref, y_nv):.6f}"],
["MXFP4", f"{mx_snr:.2f}", f"{calc_cossim(y_ref, y_mx):.6f}"],
],
headers=["Format", "SNR", "Cosine Sim"], tablefmt="github",
))

assert torch.isfinite(y_nv).all()
assert torch.isfinite(y_mx).all()
Expand Down