Skip to content

kernels: nvfp4: support decomposed linear (LoRA outlier compensation)#24

Draft
zhitwang17 wants to merge 2 commits into
mainfrom
zhitao/nvfp4-support-decompose-linear
Draft

kernels: nvfp4: support decomposed linear (LoRA outlier compensation)#24
zhitwang17 wants to merge 2 commits into
mainfrom
zhitao/nvfp4-support-decompose-linear

Conversation

@zhitwang17

Copy link
Copy Markdown
Collaborator

Support DecomposedLinear (LoRA outlier compensation) under NVFP4

Summary

Enables DecomposedLinear low-rank outlier compensation for NVFP4 low-precision
training. Quantizing weights to FP4 corrupts a small set of outlier values; a low-rank
U·Σ·V term (LoRA-style) recovers that lost signal at negligible cost. This path already
existed for MXFP4 but was never exercised end-to-end under NVFP4 or under TorchTitan's
meta-device model construction, where it crashed during conversion.

The dispatch/quantization wrappers are scheme-agnostic, so no kernel changes are required;
the fixes are confined to DecomposedLinear construction and its test coverage.

Changes

  • Meta-device-safe from_linear: allocate u/v/sigma on the source weight's
    device/dtype (incl. meta), fixing a set_data crash during TorchTitan's meta-device
    conversion.
  • lora_rank validation: require a positive multiple of 16 (the FP4 GEMM block size)
    at construction, replacing an opaque deep-kernel torch._check failure with a fast,
    actionable error.
  • Tests: parametrize the suite over NVFP4 + MXFP4, add a meta-device build/swap test,
    and trim the CUDA SNR matrix to [128, 512] for cheaper CI.

Op-level test results

All green; no regressions in the shared dispatch / quantization paths.

Suite Result
tests/unittest/nn/test_decomposed_linear.py (incl. CUDA SNR, NVFP4 + MXFP4) 94 passed
nvfp4/ + mxfp4/ op-level suites (dispatch guards, linear, quantization) 526 passed, 400 skipped

E2E validation (debug model)

Three-way 10,000-step comparison on gpt-oss-debug (~82.8M total / 47.3M active, MoE),
GBS=16, 2 × MI300X. NVFP4 recipes are identical except for the compensation knob
(lora_rank 0 vs 32). All runs were stable: monotonic loss, healthy grad norms, no NaNs.

Recipe Final train loss (step 10k) Δ vs BF16 Δ vs NVFP4 rank0
BF16 baseline 0.0727
NVFP4, compensation off (rank0) 0.0901 +0.0174
NVFP4 + DecomposedLinear (rank32) 0.0833 +0.0106 −0.0068
image

DecomposedLinear closes ~39% of the NVFP4↔BF16 final-loss gap, confirming the
compensation is both correct and effective under NVFP4.

@zhitwang17 zhitwang17 self-assigned this Jun 3, 2026
@zhitwang17 zhitwang17 changed the title kernels: nvfp4: support decomposed linear kernels: nvfp4: support decomposed linear (LoRA outlier compensation) Jun 3, 2026
@zhitwang17 zhitwang17 force-pushed the zhitao/nvfp4-support-decompose-linear branch from 5be4d63 to 4522237 Compare June 12, 2026 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant