kernels: nvfp4: make scale constants tl.constexpr for triton>=3.6#29
Open
zhitwang17 wants to merge 1 commit into
Open
kernels: nvfp4: make scale constants tl.constexpr for triton>=3.6#29zhitwang17 wants to merge 1 commit into
zhitwang17 wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
NVFP4 quantization kernels failed to compile on triton >= 3.6 with
NameError: Cannot access global variable F4_E2M1_MAX from within @jit'ed function.triton >= 3.6 forbids
@triton.jitkernels from reading plain (non-constexpr)module globals, which broke the module-level scale constants
(
F4_E2M1_MAX,F8E4M3_MAX,E4M3_EPS) used inside_calculate_nvfp4_scales.Previously this required the
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1env var as aworkaround.
Fix
Split the constants into two clearly-scoped definitions instead of overloading a
single name:
F4_E2M1_MAX/F8E4M3_MAX/E4M3_EPSstay plain Pythonfloats, so eager/host code (e.g.
compute_dynamic_outer_scale) uses themdirectly with no accessor.
_F4_E2M1_MAX/_F8E4M3_MAX/_E4M3_EPSaretl.constexprmirrors, referenced only inside@triton.jitkernels.This removes the need for
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1and eliminates aclass of bugs where host code touching a
tl.constexprobject would raise anAttributeError(it avoids the fragile.value-everywhere pattern).No behavioral change: the inlined constant values are identical, and the
generated AMDGCN is byte-for-byte equal (only DWARF debug line numbers shift),
so there is no performance impact.
Test status
test_decomposed_linear.py: 512 passed,run without
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS.floats and work directlyin arithmetic /
.clamp(...)calls.