Skip to content

Vedaanta/sdpa d256 cudnn 9 23 bypass oss#272

Open
vedaanta wants to merge 3 commits into
NVIDIA:developfrom
vedaanta:vedaanta/sdpa-d256-cudnn-9-23-bypass-oss
Open

Vedaanta/sdpa d256 cudnn 9 23 bypass oss#272
vedaanta wants to merge 3 commits into
NVIDIA:developfrom
vedaanta:vedaanta/sdpa-d256-cudnn-9-23-bypass-oss

Conversation

@vedaanta
Copy link
Copy Markdown
Collaborator

@vedaanta vedaanta commented Jun 2, 2026

cuDNN 9.23.0 added native d=256 SDPA fprop and bprop support in the graph backend, so the OSS (cuteDSL) kernels at
cudnn.experimental.ops.sdpa are no longer required when the linked backend is recent enough.

Add _cudnn_supports_native_d256() gated on
cudnn.backend_version() >= 92300 and require it to be False before routing fprop/bprop through the SM100 OSS wrappers. The pre-existing SM100+ device check is kept so older cuDNN versions still light up the OSS path on Blackwell.

The test_d256_uses_oss_forward_path test now skips on cuDNN 9.23+ since the OSS bypass is intentional, and a new
test_d256_uses_graph_path_on_cudnn_9_23_plus asserts that fprop/bprop populate the cuDNN graph cache (proving the OSS path is bypassed).

Also: _skip_if_unsupported_d256 and test_d256_uses_oss_forward_path used import cudnn.sdpa inside the function body, which made cudnn a local variable and shadowed the module-level import as soon as any earlier line referenced cudnn (e.g. the new cudnn.backend_version() check). Switch to importlib.import_module("cudnn.sdpa") to avoid the binding.

Summary by CodeRabbit

  • New Features

    • Improved routing logic for SDPA operations with dimension 256 to leverage native cuDNN backend support on compatible versions.
  • Tests

    • Added test coverage to verify SDPA correctly uses the cuDNN backend path with supported versions.

vedaanta and others added 3 commits May 28, 2026 09:58
cuDNN 9.23.0 added native d=256 SDPA fprop and bprop support in the
graph backend, so the OSS (cuteDSL) kernels at
`cudnn.experimental.ops.sdpa` are no longer required when the linked
backend is recent enough.

Add `_cudnn_supports_native_d256()` gated on
`cudnn.backend_version() >= 92300` and require it to be `False` before
routing fprop/bprop through the SM100 OSS wrappers. The pre-existing
SM100+ device check is kept so older cuDNN versions still light up the
OSS path on Blackwell.

The `test_d256_uses_oss_forward_path` test now skips on cuDNN 9.23+
since the OSS bypass is intentional, and a new
`test_d256_uses_graph_path_on_cudnn_9_23_plus` asserts that fprop/bprop
populate the cuDNN graph cache (proving the OSS path is bypassed).

Also: `_skip_if_unsupported_d256` and `test_d256_uses_oss_forward_path`
used `import cudnn.sdpa` inside the function body, which made `cudnn`
a local variable and shadowed the module-level import as soon as any
earlier line referenced `cudnn` (e.g. the new `cudnn.backend_version()`
check). Switch to `importlib.import_module("cudnn.sdpa")` to avoid the
binding.
- Rename `_CUDNN_NATIVE_D256_VERSION` → `_CUDNN_BACKEND_D256_VERSION`
  and `_cudnn_supports_native_d256()` → `_cudnn_backend_supports_d256()`
  per @Anerudhan's request that we say "cuDNN backend" instead of
  "cuDNN native". Update the surrounding log messages and skip strings
  to match.

- Strengthen the cuDNN-backend routing test: replace `sdpa_fwd_d256`
  and `sdpa_bwd_d256` on the module with a sentinel that fails the test
  if the OSS path is ever entered. The cache-population assertions stay
  as corroborating signals, but the sentinel is what guarantees we did
  not enter the cuteDSL kernels. Rename the test to
  `test_d256_uses_cudnn_backend_on_cudnn_9_23_plus`.
@vedaanta vedaanta requested a review from Anerudhan June 2, 2026 20:47
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

The PR adds backend version gating to prefer native cuDNN d=256 support over OSS routing in scaled dot product attention. A new version threshold constant and helper function determine backend capability, which now gates both forward and backward paths. Tests updated to reflect version-aware behavior with new coverage for cuDNN backend path execution.

Changes

cuDNN d=256 Backend Version Gating for SDPA

Layer / File(s) Summary
Backend version support infrastructure
python/cudnn/experimental/ops/sdpa.py
_CUDNN_BACKEND_D256_VERSION constant and _cudnn_backend_supports_d256() helper enable version-aware routing decisions.
Forward path d=256 routing with version gating
python/cudnn/experimental/ops/sdpa.py
Forward path now computes use_d256_oss_fwd only when backend does not support d=256; updated debug logging and OSS dependency fallback warning.
Backward path d=256 routing with version gating
python/cudnn/experimental/ops/sdpa.py
Backward path applies same backend-version gating for use_d256_oss_bwd and updates corresponding fallback warning.
Test infrastructure and skip helper updates
test/python/test_cudnn_sdpa_op.py
Imports now include backend version constant and caches; _skip_if_unsupported_d256 short-circuits for supported backends or probes OSS availability via dynamic import.
Existing OSS forward test update
test/python/test_cudnn_sdpa_op.py
test_d256_uses_oss_forward_path updated to use importlib.import_module() instead of direct import.
New cuDNN backend d=256 verification test
test/python/test_cudnn_sdpa_op.py
test_d256_uses_cudnn_backend_on_cudnn_9_23_plus monkeypatches OSS functions, verifies non-invocation, confirms cuDNN cache population, and validates correctness.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 A cuDNN update hops along the trail,

D=256 now routes with version details,

Backend support decides the path with care,

Tests verify that cuDNN gets its fair share! 🚀

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title describes the main change: bypassing OSS (cuteDSL) SDPA kernels for d=256 when cuDNN 9.23+ is available, which matches the core objective.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@Anerudhan Anerudhan added mod-frontend cuDNN frontend APIs, operation graph construction, plans, and user-facing wrappers. orig-nv-eng Reported or requested by NVIDIA engineering. cat-enhancements labels Jun 2, 2026
@Anerudhan
Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

@Anerudhan
Copy link
Copy Markdown
Collaborator

Source looks good. Can merge once CI passes

@cudnn-ci-bot
Copy link
Copy Markdown

🚀 Running mirror pipeline

Branch: cudnn-gh/pr-272-31c8558
Pipeline: 53454746

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
python/cudnn/experimental/ops/sdpa.py (1)

651-657: ⚡ Quick win

Misleading debug message when routing to the cuDNN backend.

This branch fires whenever can_use_d256_oss_fwd is true but use_d256_oss_fwd is false — which now happens on SM100+ devices precisely because _cudnn_backend_supports_d256() is true. In that case the device does satisfy SM100+, yet the log says OSS path requires SM100+, got <device>, which points a debugging reader at the wrong cause.

♻️ Disambiguate the reason in the log
     if can_use_d256_oss_fwd and not use_d256_oss_fwd:
         _logger.debug(
-            "Routing d=256 forward through the cuDNN backend " "(cuDNN backend version %d, OSS path requires SM100+, got %s)",
+            "Routing d=256 forward through the cuDNN backend "
+            "(cuDNN backend version %d; native d=256 support=%s, device=%s)",
             cudnn.backend_version(),
+            _cudnn_backend_supports_d256(),
             q.device,
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudnn/experimental/ops/sdpa.py` around lines 651 - 657, The debug
message is misleading when can_use_d256_oss_fwd is true but use_d256_oss_fwd is
false; update the _logger.debug call in the block that checks
can_use_d256_oss_fwd and not use_d256_oss_fwd to disambiguate the cause: inspect
_device_supports_d256(q.device) and _cudnn_backend_supports_d256() and log
whether routing to the cuDNN backend is happening because the device lacks
SM100+ (i.e., _device_supports_d256 is false) or because the cuDNN backend
already supports d=256 (i.e., _cudnn_backend_supports_d256 is true), still
including context like cudnn.backend_version() and q.device in the message;
reference the variables use_d256_oss_fwd, can_use_d256_oss_fwd,
_device_supports_d256, _cudnn_backend_supports_d256, and the existing
_logger.debug call to locate where to change the text.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@test/python/test_cudnn_sdpa_op.py`:
- Around line 154-156: The variable `minor` from the call to
torch.cuda.get_device_capability() is unused and triggers a lint warning; change
the tuple unpack to capture it as a disposable name (e.g., use `_minor` or `_`)
so only `major` is used (leave the existing check if major < 10 and the
pytest.skip call intact), updating the unpack in the test_cudnn_sdpa_op.py code
where torch.cuda.get_device_capability() is invoked.

---

Nitpick comments:
In `@python/cudnn/experimental/ops/sdpa.py`:
- Around line 651-657: The debug message is misleading when can_use_d256_oss_fwd
is true but use_d256_oss_fwd is false; update the _logger.debug call in the
block that checks can_use_d256_oss_fwd and not use_d256_oss_fwd to disambiguate
the cause: inspect _device_supports_d256(q.device) and
_cudnn_backend_supports_d256() and log whether routing to the cuDNN backend is
happening because the device lacks SM100+ (i.e., _device_supports_d256 is false)
or because the cuDNN backend already supports d=256 (i.e.,
_cudnn_backend_supports_d256 is true), still including context like
cudnn.backend_version() and q.device in the message; reference the variables
use_d256_oss_fwd, can_use_d256_oss_fwd, _device_supports_d256,
_cudnn_backend_supports_d256, and the existing _logger.debug call to locate
where to change the text.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: abf7efa5-5691-4713-b7f5-436438c85c50

📥 Commits

Reviewing files that changed from the base of the PR and between e851cc8 and 31c8558.

📒 Files selected for processing (2)
  • python/cudnn/experimental/ops/sdpa.py
  • test/python/test_cudnn_sdpa_op.py

Comment on lines 154 to 156
major, minor = torch.cuda.get_device_capability()
if major < 10:
pytest.skip("d=256 backward path requires SM100+")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Discard the unused minor to satisfy lint.

minor is never used; Ruff flags this (RUF059). Prefix with _ to keep the lint clean.

🧹 Proposed fix
-    major, minor = torch.cuda.get_device_capability()
+    major, _minor = torch.cuda.get_device_capability()
     if major < 10:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
major, minor = torch.cuda.get_device_capability()
if major < 10:
pytest.skip("d=256 backward path requires SM100+")
major, _minor = torch.cuda.get_device_capability()
if major < 10:
pytest.skip("d=256 backward path requires SM100+")
🧰 Tools
🪛 Ruff (0.15.15)

[warning] 154-154: Unpacked variable minor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/test_cudnn_sdpa_op.py` around lines 154 - 156, The variable
`minor` from the call to torch.cuda.get_device_capability() is unused and
triggers a lint warning; change the tuple unpack to capture it as a disposable
name (e.g., use `_minor` or `_`) so only `major` is used (leave the existing
check if major < 10 and the pytest.skip call intact), updating the unpack in the
test_cudnn_sdpa_op.py code where torch.cuda.get_device_capability() is invoked.

@Anerudhan
Copy link
Copy Markdown
Collaborator

The CI failed. @vedaanta can you help update the tests for Ampere as discussed

Copy link
Copy Markdown
Collaborator

@Anerudhan Anerudhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI failed. @vedaanta can you help update the tests for Ampere as discussed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cat-enhancements mod-frontend cuDNN frontend APIs, operation graph construction, plans, and user-facing wrappers. orig-nv-eng Reported or requested by NVIDIA engineering.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants