Skip to content

[Pallas] Skip trivial reduction mask when RDIM size equals actual dim#1993

Merged
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-skip-trivial-mask
Apr 13, 2026
Merged

[Pallas] Skip trivial reduction mask when RDIM size equals actual dim#1993
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-skip-trivial-mask

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 9, 2026

Summary

  • When the Pallas backend sets RDIM_SIZE equal to the actual dimension size (no rounding), the reduction mask is always true and can be skipped entirely
  • Previously, the mask was only skipped for power-of-2 dimensions — this worked for Triton (which rounds RDIM up to the next power of 2), but Pallas uses exact RDIM sizes, so the mask was never needed yet always generated for non-power-of-2 dims
  • This change generalizes the check by comparing backend.static_rdim_size(numel) against numel, which naturally handles both backends
  • Guards numel > 0 to preserve correct behavior for zero-size reductions (static_rdim_size(0) returns 1 via next_power_of_2, which would incorrectly skip the mask)

Before (unnecessary mask for Pallas on non-power-of-2 dim like 1000):

def _helion_pallas_reduce_non_pow2(x, out, _RDIM_SIZE_1: int):
    indices_1 = jnp.arange(0, _RDIM_SIZE_1, dtype=jnp.int32)
    mask_1 = indices_1 < 1000
    row = x[:, :]
    _mask_to = jnp.where(jnp.broadcast_to(mask_1[None, :], ...),
                         row, jnp.full([], float('-inf'), jnp.float32))
    max_val = ...jnp.max(_mask_to, axis=1)...
    v_1 = jnp.exp(v_0)
    _mask_to_1 = jnp.where(jnp.broadcast_to(mask_1[None, :], ...),
                           v_1, jnp.full([], 0, jnp.float32))
    sum_1 = ...jnp.sum(_mask_to_1, axis=1)...

After (no mask — clean and correct):

def _helion_pallas_reduce_non_pow2(x, out):
    row = x[:, :]
    max_val = ...jnp.max(row, axis=1)...
    v_1 = jnp.exp(v_0)
    sum_1 = ...jnp.sum(v_1, axis=1)...

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-skip-trivial-mask branch 11 times, most recently from ad8be77 to 9b9ac7e Compare April 9, 2026 22:13
@norx1991 norx1991 marked this pull request as ready for review April 9, 2026 22:44
@norx1991 norx1991 force-pushed the yifeixu/pallas-skip-trivial-mask branch from 9b9ac7e to 292b3c5 Compare April 13, 2026 17:16
@norx1991 norx1991 merged commit debd2d8 into main Apr 13, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants