Skip to content

DSA: fix CuTe DSL guards and add SM90 indexer forward#263

Open
jiayus-nvidia wants to merge 6 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/dsa-cutedsl-guard-sm90-indexer
Open

DSA: fix CuTe DSL guards and add SM90 indexer forward#263
jiayus-nvidia wants to merge 6 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/dsa-cutedsl-guard-sm90-indexer

Conversation

@jiayus-nvidia
Copy link
Copy Markdown

@jiayus-nvidia jiayus-nvidia commented Jun 1, 2026

Summary

  • Split CuTe DSL compile-time and runtime checks in DSA kernels to avoid mixing runtime values inside const_expr branches.
  • Add Hopper/SM90 indexer forward CuTe DSL dispatch and tests so indexer_forward_wrapper covers both SM90 and SM100 paths.
  • Sync recent indexer kernel optimizations for dense backward and dense score recompute, including ratio-aware bounds and fp32 packed weight staging.

Validation

  • git diff --check HEAD~1..HEAD
  • PYTHONPYCACHEPREFIX=/tmp/codex_cudnn_github_pycache python3 -m py_compile ... for changed DSA Python files and tests
  • GB200/SM100 score recompute smoke with nvidia-cutlass-dsl 4.5.0 and 4.5.2
  • H100/SM90 indexer forward smoke with nvidia-cutlass-dsl 4.5.0 and 4.5.2
  • test/python/fe_api/dsa/test_DSA_indexer_forward.py on GB200 and H100 with 4.5.0 and 4.5.2

Summary by CodeRabbit

  • New Features

    • Extended GPU platform support to SM90 and beyond for additional operations.
    • Added optional stream parameter to backward operations for explicit CUDA stream control.
  • Bug Fixes

    • Improved causal masking correctness for skipped/masked output positions.
  • Refactor

    • Optimized compilation caching strategy to reduce recompilations across varying runtime parameters.

@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-263-3a24267
Pipeline: 53347842

@Anerudhan Anerudhan added mod-cutedsl CuTeDSL kernels, generated kernels, examples, or related integration work. cat-enhancements orig-nv-eng Reported or requested by NVIDIA engineering. labels Jun 2, 2026
@Anerudhan
Copy link
Copy Markdown
Collaborator

@jiayus-nvidia . Related issue
#264

@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-263-6f62b9c
Pipeline: 53425860

Compile-cache keys across the deepseek_sparse_attention kernels included
runtime-only values (batch/seqlen/seqlen_k, sm_scale, tensor shapes/strides,
num_head, num_threads), forcing spurious recompiles under varlen / changing
batch even though one compiled kernel serves them all. Drop those fields and
keep only params that change generated code.

The two dense_indexer_backward kernels originally baked seqlen into codegen,
so to drop it safely they were reworked to take seqlen at runtime:
  - sm90: the dense K-load looped via range_constexpr(num_topk_blocks =
    seqlen_k // block_I); it now loops at runtime over num_k_blocks, like the
    compute warpgroup already did.
  - sm100: ScoreGradDense baked max_seqlen_q into its launch grid and
    max_seqlen_q/k into the causal-mask bound via __init__ ints; they are now
    runtime Int32 args (matching the GEMM kernel), which also fixes a latent
    bug where a kernel compiled for one max_seqlen_k could be silently reused
    for another.

Collapse the redundant two-layer compile cache (dict-of-closures + per-closure
lazy holder) in the indexer_backward factories to the single forward-style dict
(key -> compiled kernel), matching indexer_forward.

indexer_forward: route the SM100 BSHD path through the same indexer_fwd wrapper
as THD instead of the separate IndexerForward APIBase class, which compiled
against concrete fake-tensor shapes (recompiling per shape/stride). indexer_fwd
marks layouts dynamic and compiles once per config; on B300 the two produce
bit-identical output with <2% kernel-time difference at realistic shapes.
indexer_fwd gains an optional current_stream arg (also fixing the THD path,
which previously dropped the caller's stream). The public IndexerForward
class/export is retained.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@jiayus-nvidia
Copy link
Copy Markdown
Author

Sorry for another commit.

The new commit mainly trims runtime-only values (batch / seqlen / some shape-dependent fields / runtime scalars) from the compile-cache keys where the kernels already take them as runtime args. Otherwise there'll be a lot of recompile when using THD.

This should be the final commit for this PR. Please trigger ci again, thanks.

@Anerudhan Anerudhan requested review from Anerudhan and saltyminty June 3, 2026 18:04
@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-263-2f9b6e0
Pipeline: 53564443

k,
w,
ratio=ratio,
qhead_per_kv_head=qhead_per_kv_head,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we're missing the stream argument here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we're effectively ignoring the m_block_size, n_block_size, etc parameters for this codepath – should we be performing checks on/rejecting unsupported configurations during check_support?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9e20e47. The SM90 path now passes current_stream through to the SM90 interface, and rejects non-default m_block_size/n_block_size/q_stage/kv_stage instead of silently ignoring them. I also checked the other DSA kernels for the same stream propagation issue and fixed sparse_attention_backward.

@Anerudhan
Copy link
Copy Markdown
Collaborator

Hi @jiayus-nvidia ,

Can you also run the pre-commit check.

Warning: Python 3.12 cannot parse code formatted for Python 3.14. To fix this: run Black with Python 3.14, set --target-version to py312, or use --fast to skip the safety check. Black's safety check verifies equivalence by parsing the AST, which fails when the running Python is older than the target version.
would reformat python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py
would reformat python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
would reformat python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
would reformat python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py

Thanks

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 4, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors DSA (DeepSeek Sparse Attention) backward and forward kernels to improve compile-time caching efficiency, expands SM90 architecture support, and adds stream parameter support across interfaces. The main changes separate code-generation-affecting parameters from runtime arguments in cache keys, implement a complete SM90 forward kernel with staged TMA loading and optional TMA store, and update weight caching to FP32 for improved accuracy.

Changes

Backward Kernel Refactoring

Layer / File(s) Summary
Dense SM100 backward runtime arguments and ratio support
indexer_backward/dense_indexer_backward_sm100.py
ScoreGradDense refactored to accept max_seqlen_q/max_seqlen_k as runtime Int32 arguments instead of baking them into the constructor, enabling seqlen-flexible kernels. DenseIndexerBackward2QGemmSm100 adds ratio parameter for ratio-based causal KV-block limiting. Compilation split into _score_grad_compile_cache and _gemm_compile_cache, and 2Q token scheduling uses guarded has_q0/has_q1 checks with offset-based scheme.
SM90/SM100 backward compile caching
indexer_backward/dense_indexer_backward_sm90.py, indexer_backward/indexer_backward_sm100.py
Compile caching moved into inner kernel builders with cache keys containing only code-generation parameters (heads, dim, topk, block_I, ratio, score_input_is_log), excluding batch/seqlen/sm_scale. Lazy _ensure_compiled closures replace per-invocation compiled_holder pattern, enabling kernel reuse across varying runtime values.
SM90 backward kernel ratio and synchronization
indexer_backward/indexer_backward_sm90.py
IndexerBackwardSm90 adds ratio parameter and _dense_num_k_blocks helper to limit dense K-blocks by causal constraints. SharedStorage refactored with explicit sGradSignal, per-stage sIndices buffers, and aligned SMEM layout. Barrier synchronization rewritten with explicit cute.arch.barrier_arrive/wait calls for ping-pong WG coordination. Pointer construction switches to explicit cute.make_ptr with assumed_align.
Backward stream parameter support
sparse_attention_backward/_interface_sm{90,100}.py, sparse_attention_backward/api.py
Flash_attn_bwd_sm90/sm100 accept optional current_stream parameter, propagate it to kernel compilation/launch. Compile cache keys updated: SM100 drops num_head dependency; both maintain code-gen params.

SM90 Forward Kernel Implementation and Dispatcher

Layer / File(s) Summary
SM90 forward kernel core
indexer_forward/indexer_fwd_sm90.py
Complete IndexerForwardSm90 kernel with producer (staged TMA Q/K loads), consumer (WGMMA score GEMM), and score_store (ratio-causal masking, weight reduction, optional TMA S2G store). Implements mbarrier synchronization, optional TMA load/store descriptors, and varlen/non-varlen tensor layout handling.
SM90 forward interface
indexer_forward/_interface_sm90.py
indexer_fwd validates bfloat16 CUDA tensors, enforces head_dim=128 and n_heads_kv=1, selects TMA store by stride alignment, builds compile_key with dtype/ratio/varlen, lazily compiles IndexerForwardSm90, pre-fills output with -inf, executes with NVTX, returns result.
Forward interface stream and cache refactoring
indexer_forward/_interface.py
indexer_fwd adds optional current_stream parameter, validates input dtypes/threads, refactors compile cache key to exclude sm_scale and seqlen dimensions while adding q/k/w/out dtypes, resolves stream for both compilation and launch.
Forward dispatcher SM90/SM100 routing
indexer_forward/api.py
indexer_forward_wrapper branches by device_major(): SM90 enforces default tuning params and dispatches to indexer_fwd_sm90; other devices dispatch to indexer_fwd_sm100. Removes prior output allocation, TMA padding, cache lookup, and IndexerForward instantiation.

Score Recompute and Weight Caching

Layer / File(s) Summary
FP32 weight caching and K-block computation
score_recompute/dense_score_recompute_sm90.py, score_recompute/pack_gqa.py
DenseScoreRecomputeSm90 unconditionally sets weights_or_lse_dtype to FP32, adds _dense_compute_n_blocks helper for ratio-causal K-block limits. PackGQA adds new load_Weights_packed_f32 method that loads BF16 weights and converts to FP32 in SMEM.
Score recompute cache key refactoring
score_recompute/_interface_sm{90,100}.py
Compile cache keys refactored to exclude sm_scale and per-call seqlen values; kernels receive them as runtime arguments (max_q_arg, max_k_arg, or max_seqlen). Enables cache reuse across varying sm_scale and sequence-length sizes.
Output zeroing for masked positions
score_recompute/_interface_sm90.py
Dense score recompute (both direct and varlen) explicitly zeroes output before launch to ensure masked/skipped positions contain zeros even when caller provides existing out tensor.
Epilogue logic separation
score_recompute/sparse_score_recompute_sm100.py
SM100 sparse epilogues introduce separate should_copy_tmem and should_accumulate_score booleans, decoupling TMEM copy decisions from score accumulation decisions (replacing combined conditional).

Architecture Support and Documentation

Layer / File(s) Summary
SM90+ support expansion
README.md, indexer_top_k/api.py, indexer_top_k/indexer_top_k_decode_varlen.py, indexer_top_k/local_to_global_dsl.py
README expanded to distinguish SM90 (via indexer_forward_wrapper) and SM100+ (via IndexerForward) forward paths. IndexerTopK and top-K decode modules updated to target SM90+ instead of SM100+.
Test configuration and validation
test/python/fe_api/dsa/*
Test functions pass min_compute_capability=90 to dsa_init, and exception handlers narrowed to catch only ValueError and NotImplementedError. Reference implementation uses shared _bottom_right_causal_mask helper and adds explicit -inf location equality checks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 A rabbit hops through cache keys so lean,
Where SM90 now thrives on the scene,
Weights turn to FP32, barriers align,
TMA flows staged in a kernel so fine,
Runtime streams flow, compile keys shrink—
Sparse attention bounds in a blink!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: fixing CuTe DSL guards and adding SM90 indexer forward support, which aligns with the primary objectives of the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@jiayus-nvidia
Copy link
Copy Markdown
Author

Hi @Anerudhan , fixed in 7e52a09.

Thanks.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py (1)

107-121: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Contradictory comments about seqlen_k in compile key.

The comment at lines 107-110 states "seqlen_k IS kept" because it affects num_topk_blocks, but the actual compile_key at line 156 is (is_varlen, heads, dim, block_I, ratio) which does not include seqlen_k. The comment at lines 151-156 correctly states "seqlen_k is runtime now."

The comment at lines 107-110 appears to be stale and should be updated to reflect the actual implementation.

📝 Suggested fix
-    # batch/seqlen are runtime (dynamic grid dim + Int32 args), so they're not
-    # keyed. seqlen_k IS kept: the dense K-load unrolls `range_constexpr(
-    # self.num_topk_blocks)` where num_topk_blocks = seqlen_k // block_I, so it
-    # changes generated code. sm_scale is a runtime Float32 arg (not keyed).
+    # batch/seqlen/seqlen_k are runtime (dynamic grid dim + Int32 args), so
+    # they're not keyed. sm_scale is a runtime Float32 arg (not keyed).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`
around lines 107 - 121, Update the stale comment that claims "seqlen_k IS kept"
to reflect the actual implementation: seqlen_k is treated as a runtime
(non-keyed) argument and therefore omitted from the compile key; change the
comment near the call to _build_cute_dsl_dense_kernel and the surrounding
explanation so it matches the compile_key tuple (is_varlen, heads, dim, block_I,
ratio) and the later note "seqlen_k is runtime now"; reference seqlen_k,
_build_cute_dsl_dense_kernel, and compile_key when editing the comment to avoid
future confusion.
🧹 Nitpick comments (1)
python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py (1)

26-30: 💤 Low value

Unused dtype mappings in torch2cute_dtype_map.

The map includes float16 and float32 entries, but validation enforces bfloat16 only. Consider removing unused entries or documenting that they're reserved for future use.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py`
around lines 26 - 30, The dictionary torch2cute_dtype_map contains entries for
torch.float16 and torch.float32 that are never used because validation currently
enforces torch.bfloat16 only; either remove the unused mappings (delete the keys
torch.float16 and torch.float32 from torch2cute_dtype_map) or explicitly
document the remaining entries as "reserved for future use" next to
torch2cute_dtype_map so reviewers know they are intentional; update any related
docstring or comment near torch2cute_dtype_map to reflect the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py`:
- Around line 172-199: The load_Weights_packed_f32 routine currently hardcodes
cutlass.BFloat16 when calling sm90_ops.elem_pointer_packed_i64 which causes FP16
payloads to be misinterpreted as BF16; change the call to use the actual weights
dtype (e.g., self.mQ_type or the attribute that represents the packed-weights
dtype) instead of cutlass.BFloat16 so elem_pointer_packed_i64 reads the correct
element type, and keep the existing conversion to cutlass.Float32(gmem_val[0])
after loading; also update the function docstring to reflect that the input
packed type is dynamic rather than always BF16.

---

Outside diff comments:
In
`@python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py`:
- Around line 107-121: Update the stale comment that claims "seqlen_k IS kept"
to reflect the actual implementation: seqlen_k is treated as a runtime
(non-keyed) argument and therefore omitted from the compile key; change the
comment near the call to _build_cute_dsl_dense_kernel and the surrounding
explanation so it matches the compile_key tuple (is_varlen, heads, dim, block_I,
ratio) and the later note "seqlen_k is runtime now"; reference seqlen_k,
_build_cute_dsl_dense_kernel, and compile_key when editing the comment to avoid
future confusion.

---

Nitpick comments:
In `@python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py`:
- Around line 26-30: The dictionary torch2cute_dtype_map contains entries for
torch.float16 and torch.float32 that are never used because validation currently
enforces torch.bfloat16 only; either remove the unused mappings (delete the keys
torch.float16 and torch.float32 from torch2cute_dtype_map) or explicitly
document the remaining entries as "reserved for future use" next to
torch2cute_dtype_map so reviewers know they are intentional; update any related
docstring or comment near torch2cute_dtype_map to reflect the chosen approach.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: cdf7bcb7-471d-4068-b420-ed9c1f0cc5fc

📥 Commits

Reviewing files that changed from the base of the PR and between 1a2799b and cb25733.

📒 Files selected for processing (24)
  • python/cudnn/deepseek_sparse_attention/README.md
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/api.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/api.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/indexer_top_k_decode_varlen.py
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/local_to_global_dsl.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm100.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py
  • python/cudnn/deepseek_sparse_attention/score_recompute/sparse_score_recompute_sm100.py
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm100.py
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/api.py
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm90.py
  • test/python/fe_api/dsa/dsa_reference.py
  • test/python/fe_api/dsa/test_DSA_indexer_forward.py
  • test/python/fe_api/dsa/test_DSA_indexer_top_k.py
💤 Files with no reviewable changes (1)
  • python/cudnn/deepseek_sparse_attention/indexer_top_k/local_to_global_dsl.py

Comment on lines +172 to +199
def load_Weights_packed_f32(
self,
base_ptr_i64: cutlass.Int64,
seqlen_q: cutlass.Int32,
sWeights: cute.Tensor, # (tile_m,) - FP32 SMEM buffer
m_block: cutlass.Int32,
tile_m: cutlass.Constexpr[int],
tidx: cutlass.Int32, # 0-31 for Warp 0
):
"""Load BF16 weights from packed global memory and convert once to FP32 SMEM."""
qhpkv = self.qhead_per_kvhead
rows_per_thread = cute.ceil_div(tile_m, cute.arch.WARP_SIZE)
for i in cutlass.range_constexpr(rows_per_thread):
row = i * cute.arch.WARP_SIZE + tidx
if row < tile_m:
idx = m_block * tile_m + row
m_idx = idx // qhpkv
h_idx = idx - m_idx * qhpkv
ptr = sm90_ops.elem_pointer_packed_i64(
base_ptr_i64,
h_idx,
m_idx,
seqlen_q,
cutlass.BFloat16,
cute.AddressSpace.gmem,
)
gmem_val = cute.make_tensor(ptr, (1,))
sWeights[row] = cutlass.Float32(gmem_val[0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify dtype contract vs packed-load source dtype
rg -n -C2 "_check_type|Float16|BFloat16|load_Weights_packed_f32|elem_pointer_packed_i64" \
  python/cudnn/deepseek_sparse_attention/score_recompute

Repository: NVIDIA/cudnn-frontend

Length of output: 11961


Fix FP16 correctness in packed-weight loads (avoid hardcoded BF16 reads)

python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py hardcodes cutlass.BFloat16 in load_Weights_packed_f32 (elem_pointer_packed_i64(...)), but DenseScoreRecomputeSm90 (and SparseScoreRecomputeSm90) allow index-score mode with cutlass.Float16 weights (and enforce weights dtype matches mQ_type). In that case, FP16 payloads get interpreted as BF16 before converting to FP32, corrupting scores.

💡 Proposed fix
diff --git a/python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py b/python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py
@@
     def load_Weights_packed_f32(
         self,
         base_ptr_i64: cutlass.Int64,
         seqlen_q: cutlass.Int32,
         sWeights: cute.Tensor,  # (tile_m,) - FP32 SMEM buffer
         m_block: cutlass.Int32,
         tile_m: cutlass.Constexpr[int],
+        src_weight_dtype: cutlass.Constexpr[cutlass.Numeric],
         tidx: cutlass.Int32,  # 0-31 for Warp 0
     ):
@@
                 ptr = sm90_ops.elem_pointer_packed_i64(
                     base_ptr_i64,
                     h_idx,
                     m_idx,
                     seqlen_q,
-                    cutlass.BFloat16,
+                    src_weight_dtype,
                     cute.AddressSpace.gmem,
                 )
                 gmem_val = cute.make_tensor(ptr, (1,))
                 sWeights[row] = cutlass.Float32(gmem_val[0])
diff --git a/python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py b/python/cudnn/deepseek_sparse_attention/score_recompute/dense_score_recompute_sm90.py
@@
                     _pack_gqa.load_Weights_packed_f32(
                         mWeights_cur.iterator.toint(),
                         seqlen_q_packed,
                         sWeights,
                         eff_m_block,
                         self.tile_m,
+                        self.dtype,
                         wg_tidx,
                     )
@@
                     _pack_gqa.load_Weights_packed_f32(
                         mWeights_cur.iterator.toint(),
                         seqlen_q_packed,
                         sWeights,
                         eff_m_block,
                         self.tile_m,
+                        self.dtype,
                         wg_tidx,
                     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/deepseek_sparse_attention/score_recompute/pack_gqa.py` around
lines 172 - 199, The load_Weights_packed_f32 routine currently hardcodes
cutlass.BFloat16 when calling sm90_ops.elem_pointer_packed_i64 which causes FP16
payloads to be misinterpreted as BF16; change the call to use the actual weights
dtype (e.g., self.mQ_type or the attribute that represents the packed-weights
dtype) instead of cutlass.BFloat16 so elem_pointer_packed_i64 reads the correct
element type, and keep the existing conversion to cutlass.Float32(gmem_val[0])
after loading; also update the function docstring to reflect that the input
packed type is dynamic rather than always BF16.

@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-263-cb25733
Pipeline: 53680579

Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved but please check comments before merging.

)
return TupleDict(scores=scores)

b, s_q, h_q, d = q.shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking: removing this fallback means that this will no longer run on sm80 (since it's only sm90 or sm100). Likely not an issue, just making sure.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's intentional.

current_stream = resolve_stream(current_stream)

has_topk_length = topk_length is not None
compile_key = (dtype, head_dim, head_dim_v, block_tile, has_topk_length, num_head)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this head_dim is still needed right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, head_dim is still needed.

@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-263-cb25733
Pipeline: 53769862

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cat-enhancements mod-cutedsl CuTeDSL kernels, generated kernels, examples, or related integration work. orig-nv-eng Reported or requested by NVIDIA engineering.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants