Skip to content

[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap#13010

Open
nvchenghaoz wants to merge 4 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/triton-paged-attention-0413
Open

[None][perf] triton paged attention: non-pow2 head_dim, decode speedup, logit cap#13010
nvchenghaoz wants to merge 4 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/triton-paged-attention-0413

Conversation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator

@nvchenghaoz nvchenghaoz commented Apr 13, 2026

Summary

Performance improvements to the AutoDeploy Triton paged attention kernel:

  • HEAD_DIM_PADDED support for non-power-of-2 head dimensions (e.g. Gemma 4's head_dim=176, head_dim=512). The decode kernel pads to the next power-of-2 in registers and gathers the partial result at stage-2 combine time, eliminating zero-padding overhead.
  • Two-chunk gather kernel for non-pow2 head_dim: processes the two chunks (e.g. 128+64) in separate vectorised passes for better memory bandwidth utilisation.
  • WRITE_DIRECT: when num_splits==1 the stage-2 combine is short-circuited and the single-split result is written directly to the output tensor — saves one kernel launch per decode step for short contexts.
  • KV offset hoisting: loop-invariant K/V tensor offsets and block pointers are hoisted out of the inner attention loop — 9-22% decode speedup on H100/B200 benchmarks.
  • Sliding-window-aware split computation: num_splits is capped by the sliding window size rather than the full sequence length, reducing wasted compute for SW attention models.
  • Logit soft-cap support via a logit_cap parameter (used by Gemma 4's logit_softcapping=30.0).
  • Always-Triton threshold: raise _TRITON_T_THRESHOLD so the Triton kernel is always selected for Gemma 4's head dimensions, avoiding extra FlashInfer dispatch steps.
  • SLIDING_WINDOW added to autotune key; num_stages/num_warps search space extended.

Test plan

  • pytest tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py — covers new configurations
  • AutoDeploy integration test with Gemma 4 26B-A4B-it: verify no accuracy regression

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added configurable logit soft-capping for attention operations
    • Direct-write decode path to skip intermediate buffers when possible
  • Improvements

    • Better handling of non-power-of-2 head dimensions and padded head widths
    • Tiled/two-pass reductions and split heuristics for improved decode/context performance
  • Tests

    • New correctness tests for logit soft-capping and inactive-split/write-direct decode behavior

@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 13, 2026 18:33
@nvchenghaoz nvchenghaoz requested a review from lucaslie April 13, 2026 18:33
@nvchenghaoz nvchenghaoz marked this pull request as draft April 13, 2026 18:38
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

Added logit soft-capping, padded head-dimension handling, a two-chunk stage1 kernel for non-power-of-2 head dims, a WRITE_DIRECT decode path, tiled two-pass stage2 reduction, and corresponding API and test updates.

Changes

Cohort / File(s) Summary
Core Attention Kernels & Dispatch
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Added LOGIT_CAP (tanh-based soft-cap) across decode/context paths; introduced HEAD_DIM_PADDED with masked loads/stores; new _flash_decode_stage1_two_chunk_kernel for non-power-of-2 head_dim; WRITE_DIRECT path when num_splits==1; stage2 reworked to tiled, two-pass reduction; _get_num_splits(...) heuristic adjusted (sliding_window param, final cap 32). SDPA fast-gather updated to support two-chunk gather; SDPA fastpath disabled when logit_cap!=0.
APIs & Wrappers
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Function signatures changed: triton_paged_decode(..., logit_cap: Optional[float]=None, max_decode_seq_len: Optional[int]=None), triton_paged_context(..., logit_cap: Optional[float]=None), triton_paged_mha_with_cache(..., logit_cap: Optional[float]=None), and fake variant; TritonPagedAttention.get_constants(...) now returns [scale, sliding_window, logit_cap].
Paged Context Path
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (same file, paged-context sections)
Paged-context kernels updated for HEAD_DIM_PADDED and LOGIT_CAP; masked head-dim loads, logit soft-capping applied in both phase1/phase2; q_positions_2d hoisted and per-page causal-skip if removed by tightening loop bounds.
Tests — Added / Removed
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
Removed TestSDPADispatch (SDPA-dispatch and FP8 KV-cache tests). Added _reference_attention_with_logit_cap and new tests: TestLogitSoftCap (decode/context correctness vs reference; logit_cap None equivalence) and TestInactiveSplits (inactive-split and WRITE_DIRECT regression tests, including empty-sequence slot).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main performance improvements: non-power-of-2 head dimension support, decode speedup, and logit soft-cap support for the Triton paged attention kernel.
Description check ✅ Passed The PR description is comprehensive, covering all major changes with clear explanations of features, performance improvements, and test coverage including unit tests and integration testing with Gemma 4.
Docstring Coverage ✅ Passed Docstring coverage is 90.48% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)

189-193: Use ASCII "x" instead of Unicode multiplication sign "×".

The comment uses the Unicode multiplication sign "×" which Ruff flags (RUF003). Consider using the ASCII letter "x" for consistency.

✏️ Suggested fix
-    # fewer longer splits are fine. SW shapes keep the higher 2×num_sms threshold
+    # fewer longer splits are fine. SW shapes keep the higher 2x num_sms threshold
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 189 - 193, The inline comment above the splits1_threshold
assignment uses a Unicode multiplication sign "×" which triggers lint RUF003;
update the comment text to use the ASCII letter "x" instead (e.g., change
"2×num_sms" to "2x num_sms" or "2 * num_sms") in the comment near the
splits1_threshold = num_sms if sliding_window == 0 else num_sms * 2 line,
leaving the logic and variable names (splits1_threshold, num_sms,
sliding_window) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 526-532: The WRITE_DIRECT path hardcodes tl.bfloat16 when storing
out_val which forces unnecessary dtype conversions for float16 models; update
the store to use the same output dtype as the V tensor (e.g. derive a triton
dtype from v.dtype) or expose the output dtype as a constexpr parameter used
when computing out_val and in the tl.store call (affecting variables out_val,
acc, l_i_safe, and the tl.store invocation with direct_o_ptr/do_offsets and
masks head_mask/head_dim_mask) so the write path matches the attention weight
dtype logic that already uses v.dtype.

---

Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 189-193: The inline comment above the splits1_threshold assignment
uses a Unicode multiplication sign "×" which triggers lint RUF003; update the
comment text to use the ASCII letter "x" instead (e.g., change "2×num_sms" to
"2x num_sms" or "2 * num_sms") in the comment near the splits1_threshold =
num_sms if sliding_window == 0 else num_sms * 2 line, leaving the logic and
variable names (splits1_threshold, num_sms, sliding_window) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 48c37878-20f2-4fe7-afea-9943b030d2d8

📥 Commits

Reviewing files that changed from the base of the PR and between fc2ba32 and f2a82ba.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py

Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py Outdated
…p, logit cap, FP8 KV cache support

- HEAD_DIM_PADDED: pad non-power-of-2 head dims to next power of 2 for Triton efficiency
- WRITE_DIRECT optimization: skip intermediate buffer for direct output writes
- KV offset hoisting: precompute KV offsets outside inner loop
- SW-aware splits: sliding window-aware chunking for context attention
- Logit cap support: per-head logit softcapping (Gemma4 alt-attention)
- Always-Triton threshold: route to Triton for all seq_lens above threshold
- Two-chunk gather: efficient gather for non-pow2 head dims in stage2
- FP8 KV cache casting: cast K/V to query dtype at all load sites (decode
  SW path, context block_ptr loads, stage2, fallback SDPA path)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz force-pushed the chenghao/triton-paged-attention-0413 branch from f2a82ba to 1c1e5ec Compare April 13, 2026 20:57
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py Outdated
…, tests, cleanup)

- OUT_DTYPE constexpr: WRITE_DIRECT stores now use OUT_DTYPE (passed from Python
  as tl.bfloat16 or tl.float16 based on q.dtype) instead of hardcoded tl.bfloat16.
  Fixes potential wrong-dtype writes for float16 queries (addresses CodeRabbit review).

- SW loop dedup: merge the sliding-window fast path (first_valid_pos_in_page <= 0)
  into the slow path by using a single loop that always computes window_mask.
  When first_valid_pos_in_page <= 0, window_mask is all-True, giving the same
  result as no masking. Reduces 3 loop bodies to 2 in both decode kernels.

- Add tests for inactive split path (TestInactiveSplits):
  - test_inactive_split_correctness: mixed seq-length batch where long seqs force
    num_splits > 1 while a short seq has inactive splits (page_split_start >= num_pages).
    Verifies each element's output matches single-element decode reference.
  - test_write_direct_inactive_split_with_empty_sequence: WRITE_DIRECT path
    (num_splits==1) with an empty sequence (num_pages=0), which triggers
    page_split_start >= num_pages for split_id=0. Verifies finite, zero output.

- Remove all "Iter N:" development-history prefixes from inline comments.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review again

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

@nvchenghaoz Sure! I'll re-review the PR now.

✅ Actions performed

Full review triggered.

@nvchenghaoz nvchenghaoz marked this pull request as ready for review April 13, 2026 22:49
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)

1621-1636: Minor optimization: consider allocating gather buffer with output dtype directly.

When kv_cache.dtype != q.dtype (e.g., FP8 KV cache with FP16 queries):

  1. kv_buf is allocated with kv_cache.dtype
  2. .to(q.dtype) creates new tensors, orphaning kv_buf
  3. Gather kernel writes to the new tensors (Triton handles dtype conversion during store)

The current code is correct but allocates an unused buffer. Since SDPA expects q.dtype anyway, consider:

♻️ Suggested optimization
+        # Allocate with q.dtype directly since SDPA requires matching dtypes
+        # and Triton store handles dtype conversion from kv_cache
         kv_buf = torch.empty(
             2,
             num_seq,
             n_kv_heads,
             max_kv_len,
             head_dim,
-            dtype=kv_cache.dtype,
+            dtype=q.dtype,
             device=kv_cache.device,
         )
         k_sdpa = kv_buf[0]
         v_sdpa = kv_buf[1]
-        # Cast k/v to query dtype if kv cache uses a different dtype (e.g., fp8)
-        if kv_cache.dtype != q.dtype:
-            k_sdpa = k_sdpa.to(q.dtype)
-            v_sdpa = v_sdpa.to(q.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 1621 - 1636, The allocation currently creates kv_buf with
kv_cache.dtype and then casts k_sdpa/v_sdpa to q.dtype which orphanes the
original buffer; change the allocation logic in the block that sets
kv_buf/k_sdpa/v_sdpa so that if kv_cache.dtype != q.dtype you allocate kv_buf
with dtype=q (using device=kv_cache.device and same shape), otherwise allocate
with kv_cache.dtype, then proceed to assign k_sdpa=kv_buf[0] and
v_sdpa=kv_buf[1] and remove the .to(q.dtype) casts; update references to kv_buf,
k_sdpa, v_sdpa, kv_cache, and q accordingly so SDPA works with q.dtype without
creating an unused buffer.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1621-1636: The allocation currently creates kv_buf with
kv_cache.dtype and then casts k_sdpa/v_sdpa to q.dtype which orphanes the
original buffer; change the allocation logic in the block that sets
kv_buf/k_sdpa/v_sdpa so that if kv_cache.dtype != q.dtype you allocate kv_buf
with dtype=q (using device=kv_cache.device and same shape), otherwise allocate
with kv_cache.dtype, then proceed to assign k_sdpa=kv_buf[0] and
v_sdpa=kv_buf[1] and remove the .to(q.dtype) casts; update references to kv_buf,
k_sdpa, v_sdpa, kv_cache, and q accordingly so SDPA works with q.dtype without
creating an unused buffer.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 704d0db4-6bd3-469e-8ab3-f0edb7e00db7

📥 Commits

Reviewing files that changed from the base of the PR and between 293b1c9 and af20896.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43113 [ run ] triggered by Bot. Commit: af20896 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43113 [ run ] completed with state SUCCESS. Commit: af20896
/LLM/main/L0_MergeRequest_PR pipeline #33747 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…len in tests

triton_paged_decode() parameter is max_decode_seq_len; the new
TestInactiveSplits tests were passing the old name max_seq_len.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43253 [ run ] triggered by Bot. Commit: cd24415 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43253 [ run ] completed with state DISABLED
Freeze main and open the PR merge only after CI is back to healthy https://nvidia.slack.com/archives/C059LSY62BT/p1776141760843319?thread_ts=1775985925.442509&cid=C059LSY62BT

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43524 [ run ] triggered by Bot. Commit: cd24415 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43524 [ run ] completed with state SUCCESS. Commit: cd24415
/LLM/main/L0_MergeRequest_PR pipeline #34037 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43816 [ run ] triggered by Bot. Commit: 01e5db0 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants