DSA: fix CuTe DSL guards and add SM90 indexer forward#263
DSA: fix CuTe DSL guards and add SM90 indexer forward#263jiayus-nvidia wants to merge 6 commits into
Conversation
|
@cudnn-ci-bot run |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-263-3a24267 |
|
@jiayus-nvidia . Related issue |
|
@cudnn-ci-bot run |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-263-6f62b9c |
Compile-cache keys across the deepseek_sparse_attention kernels included
runtime-only values (batch/seqlen/seqlen_k, sm_scale, tensor shapes/strides,
num_head, num_threads), forcing spurious recompiles under varlen / changing
batch even though one compiled kernel serves them all. Drop those fields and
keep only params that change generated code.
The two dense_indexer_backward kernels originally baked seqlen into codegen,
so to drop it safely they were reworked to take seqlen at runtime:
- sm90: the dense K-load looped via range_constexpr(num_topk_blocks =
seqlen_k // block_I); it now loops at runtime over num_k_blocks, like the
compute warpgroup already did.
- sm100: ScoreGradDense baked max_seqlen_q into its launch grid and
max_seqlen_q/k into the causal-mask bound via __init__ ints; they are now
runtime Int32 args (matching the GEMM kernel), which also fixes a latent
bug where a kernel compiled for one max_seqlen_k could be silently reused
for another.
Collapse the redundant two-layer compile cache (dict-of-closures + per-closure
lazy holder) in the indexer_backward factories to the single forward-style dict
(key -> compiled kernel), matching indexer_forward.
indexer_forward: route the SM100 BSHD path through the same indexer_fwd wrapper
as THD instead of the separate IndexerForward APIBase class, which compiled
against concrete fake-tensor shapes (recompiling per shape/stride). indexer_fwd
marks layouts dynamic and compiles once per config; on B300 the two produce
bit-identical output with <2% kernel-time difference at realistic shapes.
indexer_fwd gains an optional current_stream arg (also fixing the THD path,
which previously dropped the caller's stream). The public IndexerForward
class/export is retained.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
|
Sorry for another commit. The new commit mainly trims runtime-only values (batch / seqlen / some shape-dependent fields / runtime scalars) from the compile-cache keys where the kernels already take them as runtime args. Otherwise there'll be a lot of recompile when using THD. This should be the final commit for this PR. Please trigger ci again, thanks. |
|
@cudnn-ci-bot run |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-263-2f9b6e0 |
| k, | ||
| w, | ||
| ratio=ratio, | ||
| qhead_per_kv_head=qhead_per_kv_head, |
There was a problem hiding this comment.
I believe we're missing the stream argument here
There was a problem hiding this comment.
Also, we're effectively ignoring the m_block_size, n_block_size, etc parameters for this codepath – should we be performing checks on/rejecting unsupported configurations during check_support?
There was a problem hiding this comment.
Fixed in 9e20e47. The SM90 path now passes current_stream through to the SM90 interface, and rejects non-default m_block_size/n_block_size/q_stage/kv_stage instead of silently ignoring them. I also checked the other DSA kernels for the same stream propagation issue and fixed sparse_attention_backward.
|
Hi @jiayus-nvidia , Can you also run the pre-commit check. Thanks |
📝 WalkthroughWalkthroughThis PR refactors DSA (DeepSeek Sparse Attention) backward and forward kernels to improve compile-time caching efficiency, expands SM90 architecture support, and adds stream parameter support across interfaces. The main changes separate code-generation-affecting parameters from runtime arguments in cache keys, implement a complete SM90 forward kernel with staged TMA loading and optional TMA store, and update weight caching to FP32 for improved accuracy. ChangesBackward Kernel Refactoring
SM90 Forward Kernel Implementation and Dispatcher
Score Recompute and Weight Caching
Architecture Support and Documentation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Hi @Anerudhan , fixed in 7e52a09. Thanks. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py (1)
107-121:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winContradictory comments about
seqlen_kin compile key.The comment at lines 107-110 states "seqlen_k IS kept" because it affects
num_topk_blocks, but the actualcompile_keyat line 156 is(is_varlen, heads, dim, block_I, ratio)which does not includeseqlen_k. The comment at lines 151-156 correctly states "seqlen_k is runtime now."The comment at lines 107-110 appears to be stale and should be updated to reflect the actual implementation.
📝 Suggested fix
- # batch/seqlen are runtime (dynamic grid dim + Int32 args), so they're not - # keyed. seqlen_k IS kept: the dense K-load unrolls `range_constexpr( - # self.num_topk_blocks)` where num_topk_blocks = seqlen_k // block_I, so it - # changes generated code. sm_scale is a runtime Float32 arg (not keyed). + # batch/seqlen/seqlen_k are runtime (dynamic grid dim + Int32 args), so + # they're not keyed. sm_scale is a runtime Float32 arg (not keyed).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py` around lines 107 - 121, Update the stale comment that claims "seqlen_k IS kept" to reflect the actual implementation: seqlen_k is treated as a runtime (non-keyed) argument and therefore omitted from the compile key; change the comment near the call to _build_cute_dsl_dense_kernel and the surrounding explanation so it matches the compile_key tuple (is_varlen, heads, dim, block_I, ratio) and the later note "seqlen_k is runtime now"; reference seqlen_k, _build_cute_dsl_dense_kernel, and compile_key when editing the comment to avoid future confusion.
🧹 Nitpick comments (1)
python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py (1)
26-30: 💤 Low valueUnused dtype mappings in
torch2cute_dtype_map.The map includes
float16andfloat32entries, but validation enforcesbfloat16only. Consider removing unused entries or documenting that they're reserved for future use.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py` around lines 26 - 30, The dictionary torch2cute_dtype_map contains entries for torch.float16 and torch.float32 that are never used because validation currently enforces torch.bfloat16 only; either remove the unused mappings (delete the keys torch.float16 and torch.float32 from torch2cute_dtype_map) or explicitly document the remaining entries as "reserved for future use" next to torch2cute_dtype_map so reviewers know they are intentional; update any related docstring or comment near torch2cute_dtype_map to reflect the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py`:
- Around line 172-199: The load_Weights_packed_f32 routine currently hardcodes
cutlass.BFloat16 when calling sm90_ops.elem_pointer_packed_i64 which causes FP16
payloads to be misinterpreted as BF16; change the call to use the actual weights
dtype (e.g., self.mQ_type or the attribute that represents the packed-weights
dtype) instead of cutlass.BFloat16 so elem_pointer_packed_i64 reads the correct
element type, and keep the existing conversion to cutlass.Float32(gmem_val[0])
after loading; also update the function docstring to reflect that the input
packed type is dynamic rather than always BF16.
---
Outside diff comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`:
- Around line 107-121: Update the stale comment that claims "seqlen_k IS kept"
to reflect the actual implementation: seqlen_k is treated as a runtime
(non-keyed) argument and therefore omitted from the compile key; change the
comment near the call to _build_cute_dsl_dense_kernel and the surrounding
explanation so it matches the compile_key tuple (is_varlen, heads, dim, block_I,
ratio) and the later note "seqlen_k is runtime now"; reference seqlen_k,
_build_cute_dsl_dense_kernel, and compile_key when editing the comment to avoid
future confusion.
---
Nitpick comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py`:
- Around line 26-30: The dictionary torch2cute_dtype_map contains entries for
torch.float16 and torch.float32 that are never used because validation currently
enforces torch.bfloat16 only; either remove the unused mappings (delete the keys
torch.float16 and torch.float32 from torch2cute_dtype_map) or explicitly
document the remaining entries as "reserved for future use" next to
torch2cute_dtype_map so reviewers know they are intentional; update any related
docstring or comment near torch2cute_dtype_map to reflect the chosen approach.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: cdf7bcb7-471d-4068-b420-ed9c1f0cc5fc
📒 Files selected for processing (24)
python/cudnn/deepseek_sparse_attention/README.mdpython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/_interface.pypython/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/api.pypython/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/api.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.pypython/cudnn/deepseek_sparse_attention/indexer_top_k/local_to_global_dsl.pypython/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.pypython/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.pypython/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.pypython/cudnn/deepseek_sparse_attention/score_recompute/sparse_score_recompute_sm100.pypython/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm100.pypython/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/sparse_attention_backward/api.pypython/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm90.pytest/python/fe_api/dsa/dsa_reference.pytest/python/fe_api/dsa/test_DSA_indexer_forward.pytest/python/fe_api/dsa/test_DSA_indexer_top_k.py
💤 Files with no reviewable changes (1)
- python/cudnn/deepseek_sparse_attention/indexer_top_k/local_to_global_dsl.py
| def load_Weights_packed_f32( | ||
| self, | ||
| base_ptr_i64: cutlass.Int64, | ||
| seqlen_q: cutlass.Int32, | ||
| sWeights: cute.Tensor, # (tile_m,) - FP32 SMEM buffer | ||
| m_block: cutlass.Int32, | ||
| tile_m: cutlass.Constexpr[int], | ||
| tidx: cutlass.Int32, # 0-31 for Warp 0 | ||
| ): | ||
| """Load BF16 weights from packed global memory and convert once to FP32 SMEM.""" | ||
| qhpkv = self.qhead_per_kvhead | ||
| rows_per_thread = cute.ceil_div(tile_m, cute.arch.WARP_SIZE) | ||
| for i in cutlass.range_constexpr(rows_per_thread): | ||
| row = i * cute.arch.WARP_SIZE + tidx | ||
| if row < tile_m: | ||
| idx = m_block * tile_m + row | ||
| m_idx = idx // qhpkv | ||
| h_idx = idx - m_idx * qhpkv | ||
| ptr = sm90_ops.elem_pointer_packed_i64( | ||
| base_ptr_i64, | ||
| h_idx, | ||
| m_idx, | ||
| seqlen_q, | ||
| cutlass.BFloat16, | ||
| cute.AddressSpace.gmem, | ||
| ) | ||
| gmem_val = cute.make_tensor(ptr, (1,)) | ||
| sWeights[row] = cutlass.Float32(gmem_val[0]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify dtype contract vs packed-load source dtype
rg -n -C2 "_check_type|Float16|BFloat16|load_Weights_packed_f32|elem_pointer_packed_i64" \
python/cudnn/deepseek_sparse_attention/score_recomputeRepository: NVIDIA/cudnn-frontend
Length of output: 11961
Fix FP16 correctness in packed-weight loads (avoid hardcoded BF16 reads)
python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py hardcodes cutlass.BFloat16 in load_Weights_packed_f32 (elem_pointer_packed_i64(...)), but DenseScoreRecomputeSm90 (and SparseScoreRecomputeSm90) allow index-score mode with cutlass.Float16 weights (and enforce weights dtype matches mQ_type). In that case, FP16 payloads get interpreted as BF16 before converting to FP32, corrupting scores.
💡 Proposed fix
diff --git a/python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py b/python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py
@@
def load_Weights_packed_f32(
self,
base_ptr_i64: cutlass.Int64,
seqlen_q: cutlass.Int32,
sWeights: cute.Tensor, # (tile_m,) - FP32 SMEM buffer
m_block: cutlass.Int32,
tile_m: cutlass.Constexpr[int],
+ src_weight_dtype: cutlass.Constexpr[cutlass.Numeric],
tidx: cutlass.Int32, # 0-31 for Warp 0
):
@@
ptr = sm90_ops.elem_pointer_packed_i64(
base_ptr_i64,
h_idx,
m_idx,
seqlen_q,
- cutlass.BFloat16,
+ src_weight_dtype,
cute.AddressSpace.gmem,
)
gmem_val = cute.make_tensor(ptr, (1,))
sWeights[row] = cutlass.Float32(gmem_val[0])diff --git a/python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py b/python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py
@@
_pack_gqa.load_Weights_packed_f32(
mWeights_cur.iterator.toint(),
seqlen_q_packed,
sWeights,
eff_m_block,
self.tile_m,
+ self.dtype,
wg_tidx,
)
@@
_pack_gqa.load_Weights_packed_f32(
mWeights_cur.iterator.toint(),
seqlen_q_packed,
sWeights,
eff_m_block,
self.tile_m,
+ self.dtype,
wg_tidx,
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py` around
lines 172 - 199, The load_Weights_packed_f32 routine currently hardcodes
cutlass.BFloat16 when calling sm90_ops.elem_pointer_packed_i64 which causes FP16
payloads to be misinterpreted as BF16; change the call to use the actual weights
dtype (e.g., self.mQ_type or the attribute that represents the packed-weights
dtype) instead of cutlass.BFloat16 so elem_pointer_packed_i64 reads the correct
element type, and keep the existing conversion to cutlass.Float32(gmem_val[0])
after loading; also update the function docstring to reflect that the input
packed type is dynamic rather than always BF16.
|
@cudnn-ci-bot run |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-263-cb25733 |
saltyminty
left a comment
There was a problem hiding this comment.
Approved but please check comments before merging.
| ) | ||
| return TupleDict(scores=scores) | ||
|
|
||
| b, s_q, h_q, d = q.shape |
There was a problem hiding this comment.
checking: removing this fallback means that this will no longer run on sm80 (since it's only sm90 or sm100). Likely not an issue, just making sure.
| current_stream = resolve_stream(current_stream) | ||
|
|
||
| has_topk_length = topk_length is not None | ||
| compile_key = (dtype, head_dim, head_dim_v, block_tile, has_topk_length, num_head) |
There was a problem hiding this comment.
I believe this head_dim is still needed right?
There was a problem hiding this comment.
Yes, head_dim is still needed.
|
@cudnn-ci-bot run |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-263-cb25733 |
Summary
const_exprbranches.indexer_forward_wrappercovers both SM90 and SM100 paths.Validation
git diff --check HEAD~1..HEADPYTHONPYCACHEPREFIX=/tmp/codex_cudnn_github_pycache python3 -m py_compile ...for changed DSA Python files and testsnvidia-cutlass-dsl4.5.0 and 4.5.2nvidia-cutlass-dsl4.5.0 and 4.5.2test/python/fe_api/dsa/test_DSA_indexer_forward.pyon GB200 and H100 with 4.5.0 and 4.5.2Summary by CodeRabbit
New Features
Bug Fixes
Refactor