Skip to content

Commit f2a82ba

Browse files
committed
[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap
Several performance improvements to the Triton paged attention kernel: 1. HEAD_DIM_PADDED support for non-power-of-2 head dimensions (e.g. Gemma 4's head_dim=176=128+64 and head_dim=512). The decode kernel handles the non-pow2 head_dim by storing partial results in a HEAD_DIM_PADDED staging area and gathering them at stage2 combine time, eliminating zero-padding overhead. 2. Two-chunk gather kernel for non-pow2 head_dim: processes the two chunks in separate vectorised passes to improve memory bandwidth utilisation. 3. WRITE_DIRECT optimisation: when num_splits==1 the stage-2 combine is short-circuited and the single-split result is written directly to the output tensor, saving one kernel launch per decode step. 4. KV offset hoisting: loop-invariant K/V tensor offsets and block pointers are hoisted out of the inner attention loop — 9-22% decode speedup on H100/B200 benchmarks. 5. Sliding-window-aware split computation: the number of KV splits is capped by the sliding window size rather than the full sequence length, reducing wasted compute for SW attention models. 6. Logit soft-cap support via a logit_cap parameter for models that apply logit capping (e.g. Gemma 4 with logit_softcapping=30.0). 7. SLIDING_WINDOW added to autotune key; num_stages/num_warps search space extended for better occupancy on newer hardware. 8. Always-Triton threshold: raised _TRITON_T_THRESHOLD so Triton paged attention is always used regardless of batch size, avoiding extra FlashInfer dispatch overhead in piecewise CUDA graph workloads. Test: updated test_triton_paged_attention.py to cover new configurations. Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
1 parent 2fe39c1 commit f2a82ba

2 files changed

Lines changed: 1230 additions & 230 deletions

File tree

0 commit comments

Comments
 (0)