[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap#13010
[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap#13010nvchenghaoz wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughAdded logit soft-capping, padded head-dimension handling, a two-chunk stage1 kernel for non-power-of-2 head dims, a WRITE_DIRECT decode path, tiled two-pass stage2 reduction, and corresponding API and test updates. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)
189-193: Use ASCII "x" instead of Unicode multiplication sign "×".The comment uses the Unicode multiplication sign "×" which Ruff flags (RUF003). Consider using the ASCII letter "x" for consistency.
✏️ Suggested fix
- # fewer longer splits are fine. SW shapes keep the higher 2×num_sms threshold + # fewer longer splits are fine. SW shapes keep the higher 2x num_sms threshold🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 189 - 193, The inline comment above the splits1_threshold assignment uses a Unicode multiplication sign "×" which triggers lint RUF003; update the comment text to use the ASCII letter "x" instead (e.g., change "2×num_sms" to "2x num_sms" or "2 * num_sms") in the comment near the splits1_threshold = num_sms if sliding_window == 0 else num_sms * 2 line, leaving the logic and variable names (splits1_threshold, num_sms, sliding_window) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 526-532: The WRITE_DIRECT path hardcodes tl.bfloat16 when storing
out_val which forces unnecessary dtype conversions for float16 models; update
the store to use the same output dtype as the V tensor (e.g. derive a triton
dtype from v.dtype) or expose the output dtype as a constexpr parameter used
when computing out_val and in the tl.store call (affecting variables out_val,
acc, l_i_safe, and the tl.store invocation with direct_o_ptr/do_offsets and
masks head_mask/head_dim_mask) so the write path matches the attention weight
dtype logic that already uses v.dtype.
---
Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 189-193: The inline comment above the splits1_threshold assignment
uses a Unicode multiplication sign "×" which triggers lint RUF003; update the
comment text to use the ASCII letter "x" instead (e.g., change "2×num_sms" to
"2x num_sms" or "2 * num_sms") in the comment near the splits1_threshold =
num_sms if sliding_window == 0 else num_sms * 2 line, leaving the logic and
variable names (splits1_threshold, num_sms, sliding_window) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 48c37878-20f2-4fe7-afea-9943b030d2d8
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
…p, logit cap, FP8 KV cache support - HEAD_DIM_PADDED: pad non-power-of-2 head dims to next power of 2 for Triton efficiency - WRITE_DIRECT optimization: skip intermediate buffer for direct output writes - KV offset hoisting: precompute KV offsets outside inner loop - SW-aware splits: sliding window-aware chunking for context attention - Logit cap support: per-head logit softcapping (Gemma4 alt-attention) - Always-Triton threshold: route to Triton for all seq_lens above threshold - Two-chunk gather: efficient gather for non-pow2 head dims in stage2 - FP8 KV cache casting: cast K/V to query dtype at all load sites (decode SW path, context block_ptr loads, stage2, fallback SDPA path) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
f2a82ba to
1c1e5ec
Compare
…, tests, cleanup)
- OUT_DTYPE constexpr: WRITE_DIRECT stores now use OUT_DTYPE (passed from Python
as tl.bfloat16 or tl.float16 based on q.dtype) instead of hardcoded tl.bfloat16.
Fixes potential wrong-dtype writes for float16 queries (addresses CodeRabbit review).
- SW loop dedup: merge the sliding-window fast path (first_valid_pos_in_page <= 0)
into the slow path by using a single loop that always computes window_mask.
When first_valid_pos_in_page <= 0, window_mask is all-True, giving the same
result as no masking. Reduces 3 loop bodies to 2 in both decode kernels.
- Add tests for inactive split path (TestInactiveSplits):
- test_inactive_split_correctness: mixed seq-length batch where long seqs force
num_splits > 1 while a short seq has inactive splits (page_split_start >= num_pages).
Verifies each element's output matches single-element decode reference.
- test_write_direct_inactive_split_with_empty_sequence: WRITE_DIRECT path
(num_splits==1) with an empty sequence (num_pages=0), which triggers
page_split_start >= num_pages for split_id=0. Verifies finite, zero output.
- Remove all "Iter N:" development-history prefixes from inline comments.
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
@coderabbitai review again |
|
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)
1621-1636: Minor optimization: consider allocating gather buffer with output dtype directly.When
kv_cache.dtype != q.dtype(e.g., FP8 KV cache with FP16 queries):
kv_bufis allocated withkv_cache.dtype.to(q.dtype)creates new tensors, orphaningkv_buf- Gather kernel writes to the new tensors (Triton handles dtype conversion during store)
The current code is correct but allocates an unused buffer. Since SDPA expects
q.dtypeanyway, consider:♻️ Suggested optimization
+ # Allocate with q.dtype directly since SDPA requires matching dtypes + # and Triton store handles dtype conversion from kv_cache kv_buf = torch.empty( 2, num_seq, n_kv_heads, max_kv_len, head_dim, - dtype=kv_cache.dtype, + dtype=q.dtype, device=kv_cache.device, ) k_sdpa = kv_buf[0] v_sdpa = kv_buf[1] - # Cast k/v to query dtype if kv cache uses a different dtype (e.g., fp8) - if kv_cache.dtype != q.dtype: - k_sdpa = k_sdpa.to(q.dtype) - v_sdpa = v_sdpa.to(q.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 1621 - 1636, The allocation currently creates kv_buf with kv_cache.dtype and then casts k_sdpa/v_sdpa to q.dtype which orphanes the original buffer; change the allocation logic in the block that sets kv_buf/k_sdpa/v_sdpa so that if kv_cache.dtype != q.dtype you allocate kv_buf with dtype=q (using device=kv_cache.device and same shape), otherwise allocate with kv_cache.dtype, then proceed to assign k_sdpa=kv_buf[0] and v_sdpa=kv_buf[1] and remove the .to(q.dtype) casts; update references to kv_buf, k_sdpa, v_sdpa, kv_cache, and q accordingly so SDPA works with q.dtype without creating an unused buffer.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1621-1636: The allocation currently creates kv_buf with
kv_cache.dtype and then casts k_sdpa/v_sdpa to q.dtype which orphanes the
original buffer; change the allocation logic in the block that sets
kv_buf/k_sdpa/v_sdpa so that if kv_cache.dtype != q.dtype you allocate kv_buf
with dtype=q (using device=kv_cache.device and same shape), otherwise allocate
with kv_cache.dtype, then proceed to assign k_sdpa=kv_buf[0] and
v_sdpa=kv_buf[1] and remove the .to(q.dtype) casts; update references to kv_buf,
k_sdpa, v_sdpa, kv_cache, and q accordingly so SDPA works with q.dtype without
creating an unused buffer.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 704d0db4-6bd3-469e-8ab3-f0edb7e00db7
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43113 [ run ] triggered by Bot. Commit: |
|
PR_Github #43113 [ run ] completed with state
|
…len in tests triton_paged_decode() parameter is max_decode_seq_len; the new TestInactiveSplits tests were passing the old name max_seq_len. Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43253 [ run ] triggered by Bot. Commit: |
|
PR_Github #43253 [ run ] completed with state |
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43524 [ run ] triggered by Bot. Commit: |
|
PR_Github #43524 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43816 [ run ] triggered by Bot. Commit: |
Summary
Performance improvements to the AutoDeploy Triton paged attention kernel:
head_dim=176,head_dim=512). The decode kernel pads to the next power-of-2 in registers and gathers the partial result at stage-2 combine time, eliminating zero-padding overhead.num_splits==1the stage-2 combine is short-circuited and the single-split result is written directly to the output tensor — saves one kernel launch per decode step for short contexts.num_splitsis capped by the sliding window size rather than the full sequence length, reducing wasted compute for SW attention models.logit_capparameter (used by Gemma 4'slogit_softcapping=30.0)._TRITON_T_THRESHOLDso the Triton kernel is always selected for Gemma 4's head dimensions, avoiding extra FlashInfer dispatch steps.SLIDING_WINDOWadded to autotune key;num_stages/num_warpssearch space extended.Test plan
pytest tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py— covers new configurations🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Improvements
Tests