Commit f2a82ba
committed
[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap
Several performance improvements to the Triton paged attention kernel:
1. HEAD_DIM_PADDED support for non-power-of-2 head dimensions (e.g. Gemma 4's
head_dim=176=128+64 and head_dim=512). The decode kernel handles the non-pow2
head_dim by storing partial results in a HEAD_DIM_PADDED staging area and
gathering them at stage2 combine time, eliminating zero-padding overhead.
2. Two-chunk gather kernel for non-pow2 head_dim: processes the two chunks
in separate vectorised passes to improve memory bandwidth utilisation.
3. WRITE_DIRECT optimisation: when num_splits==1 the stage-2 combine is
short-circuited and the single-split result is written directly to the
output tensor, saving one kernel launch per decode step.
4. KV offset hoisting: loop-invariant K/V tensor offsets and block pointers
are hoisted out of the inner attention loop — 9-22% decode speedup on
H100/B200 benchmarks.
5. Sliding-window-aware split computation: the number of KV splits is capped
by the sliding window size rather than the full sequence length, reducing
wasted compute for SW attention models.
6. Logit soft-cap support via a logit_cap parameter for models that apply
logit capping (e.g. Gemma 4 with logit_softcapping=30.0).
7. SLIDING_WINDOW added to autotune key; num_stages/num_warps search space
extended for better occupancy on newer hardware.
8. Always-Triton threshold: raised _TRITON_T_THRESHOLD so Triton paged
attention is always used regardless of batch size, avoiding extra
FlashInfer dispatch overhead in piecewise CUDA graph workloads.
Test: updated test_triton_paged_attention.py to cover new configurations.
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>1 parent 2fe39c1 commit f2a82ba
2 files changed
Lines changed: 1230 additions & 230 deletions
File tree
- tensorrt_llm/_torch/auto_deploy/custom_ops/attention
- tests/unittest/auto_deploy/singlegpu/custom_ops/attention
0 commit comments