Skip to content

merge recent changes in NVFP4#26

Merged
hann-wang merged 6 commits into
mainfrom
dev
Jun 4, 2026
Merged

merge recent changes in NVFP4#26
hann-wang merged 6 commits into
mainfrom
dev

Conversation

@hann-wang

Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI review requested due to automatic review settings June 4, 2026 01:47
@hann-wang hann-wang merged commit 6773d13 into main Jun 4, 2026

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates NVFP4 unit tests and test utilities to cover the two-level (outer/global) scaling path, and reduces unnecessary tensor materialization in the NVFP4 grouped GEMM wrapper by avoiding .contiguous() after a transpose.

Changes:

  • Expand NVFP4 linear and grouped GEMM test matrices to parametrize use_outer_scale and propagate that flag into SNR checks and test context strings.
  • Extend NVFP4 autograd SNR threshold logic to be aware of use_outer_scale, with new calibrated threshold tiers.
  • Avoid .contiguous() when transposing expert weights in NVFP4 grouped GEMM entrypoints (use strided views instead).

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
tests/unittest/nvfp4/test_nvfp_linear.py Adds use_outer_scale coverage to NVFP4 linear autograd and cross-format forward tests.
tests/unittest/nvfp4/test_nvfp_grouped_gemm.py Adds use_outer_scale coverage to NVFP4 grouped GEMM autograd and cross-format forward tests.
alto/kernels/fp4/testing_utils.py Makes NVFP4 autograd SNR thresholds depend on use_outer_scale and threads the flag through error messages.
alto/kernels/fp4/nvfp4/nvfp_grouped_gemm/functional.py Removes .contiguous() after transposing expert weights (keeps zero-copy strided view).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 134 to 136
ctx = f"{context}: " if context else ""
outer_tag = ", use_outer_scale=True" if use_outer_scale else ""

Comment on lines 106 to +110
@pytest.mark.parametrize("shape", [(1, 64, 64, 64), (1, 512, 384, 128), (4, 1024, 1024, 2048)])
@pytest.mark.parametrize("use_2dblock_x", [False, True])
@pytest.mark.parametrize("use_2dblock_w", [False, True])
@pytest.mark.parametrize("use_sr_grad", [False, True])
@pytest.mark.parametrize("use_outer_scale", [False, True])
Comment on lines 92 to 96
@pytest.mark.parametrize("use_2dblock_x", [False, True])
@pytest.mark.parametrize("use_2dblock_w", [False, True])
@pytest.mark.parametrize("use_sr_grad", [False, True])
@pytest.mark.parametrize("use_outer_scale", [False, True])
@pytest.mark.parametrize("data_type", [torch.bfloat16, torch.float32])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants