Vedaanta/sdpa d256 cudnn 9 23 bypass oss#272
Conversation
cuDNN 9.23.0 added native d=256 SDPA fprop and bprop support in the
graph backend, so the OSS (cuteDSL) kernels at
`cudnn.experimental.ops.sdpa` are no longer required when the linked
backend is recent enough.
Add `_cudnn_supports_native_d256()` gated on
`cudnn.backend_version() >= 92300` and require it to be `False` before
routing fprop/bprop through the SM100 OSS wrappers. The pre-existing
SM100+ device check is kept so older cuDNN versions still light up the
OSS path on Blackwell.
The `test_d256_uses_oss_forward_path` test now skips on cuDNN 9.23+
since the OSS bypass is intentional, and a new
`test_d256_uses_graph_path_on_cudnn_9_23_plus` asserts that fprop/bprop
populate the cuDNN graph cache (proving the OSS path is bypassed).
Also: `_skip_if_unsupported_d256` and `test_d256_uses_oss_forward_path`
used `import cudnn.sdpa` inside the function body, which made `cudnn`
a local variable and shadowed the module-level import as soon as any
earlier line referenced `cudnn` (e.g. the new `cudnn.backend_version()`
check). Switch to `importlib.import_module("cudnn.sdpa")` to avoid the
binding.
- Rename `_CUDNN_NATIVE_D256_VERSION` → `_CUDNN_BACKEND_D256_VERSION` and `_cudnn_supports_native_d256()` → `_cudnn_backend_supports_d256()` per @Anerudhan's request that we say "cuDNN backend" instead of "cuDNN native". Update the surrounding log messages and skip strings to match. - Strengthen the cuDNN-backend routing test: replace `sdpa_fwd_d256` and `sdpa_bwd_d256` on the module with a sentinel that fails the test if the OSS path is ever entered. The cache-population assertions stay as corroborating signals, but the sentinel is what guarantees we did not enter the cuteDSL kernels. Rename the test to `test_d256_uses_cudnn_backend_on_cudnn_9_23_plus`.
📝 WalkthroughWalkthroughThe PR adds backend version gating to prefer native cuDNN d=256 support over OSS routing in scaled dot product attention. A new version threshold constant and helper function determine backend capability, which now gates both forward and backward paths. Tests updated to reflect version-aware behavior with new coverage for cuDNN backend path execution. ChangescuDNN d=256 Backend Version Gating for SDPA
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
@cudnn-ci-bot run |
|
Source looks good. Can merge once CI passes |
|
🚀 Running mirror pipeline Branch: cudnn-gh/pr-272-31c8558 |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
python/cudnn/experimental/ops/sdpa.py (1)
651-657: ⚡ Quick winMisleading debug message when routing to the cuDNN backend.
This branch fires whenever
can_use_d256_oss_fwdis true butuse_d256_oss_fwdis false — which now happens on SM100+ devices precisely because_cudnn_backend_supports_d256()is true. In that case the device does satisfy SM100+, yet the log saysOSS path requires SM100+, got <device>, which points a debugging reader at the wrong cause.♻️ Disambiguate the reason in the log
if can_use_d256_oss_fwd and not use_d256_oss_fwd: _logger.debug( - "Routing d=256 forward through the cuDNN backend " "(cuDNN backend version %d, OSS path requires SM100+, got %s)", + "Routing d=256 forward through the cuDNN backend " + "(cuDNN backend version %d; native d=256 support=%s, device=%s)", cudnn.backend_version(), + _cudnn_backend_supports_d256(), q.device, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/cudnn/experimental/ops/sdpa.py` around lines 651 - 657, The debug message is misleading when can_use_d256_oss_fwd is true but use_d256_oss_fwd is false; update the _logger.debug call in the block that checks can_use_d256_oss_fwd and not use_d256_oss_fwd to disambiguate the cause: inspect _device_supports_d256(q.device) and _cudnn_backend_supports_d256() and log whether routing to the cuDNN backend is happening because the device lacks SM100+ (i.e., _device_supports_d256 is false) or because the cuDNN backend already supports d=256 (i.e., _cudnn_backend_supports_d256 is true), still including context like cudnn.backend_version() and q.device in the message; reference the variables use_d256_oss_fwd, can_use_d256_oss_fwd, _device_supports_d256, _cudnn_backend_supports_d256, and the existing _logger.debug call to locate where to change the text.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@test/python/test_cudnn_sdpa_op.py`:
- Around line 154-156: The variable `minor` from the call to
torch.cuda.get_device_capability() is unused and triggers a lint warning; change
the tuple unpack to capture it as a disposable name (e.g., use `_minor` or `_`)
so only `major` is used (leave the existing check if major < 10 and the
pytest.skip call intact), updating the unpack in the test_cudnn_sdpa_op.py code
where torch.cuda.get_device_capability() is invoked.
---
Nitpick comments:
In `@python/cudnn/experimental/ops/sdpa.py`:
- Around line 651-657: The debug message is misleading when can_use_d256_oss_fwd
is true but use_d256_oss_fwd is false; update the _logger.debug call in the
block that checks can_use_d256_oss_fwd and not use_d256_oss_fwd to disambiguate
the cause: inspect _device_supports_d256(q.device) and
_cudnn_backend_supports_d256() and log whether routing to the cuDNN backend is
happening because the device lacks SM100+ (i.e., _device_supports_d256 is false)
or because the cuDNN backend already supports d=256 (i.e.,
_cudnn_backend_supports_d256 is true), still including context like
cudnn.backend_version() and q.device in the message; reference the variables
use_d256_oss_fwd, can_use_d256_oss_fwd, _device_supports_d256,
_cudnn_backend_supports_d256, and the existing _logger.debug call to locate
where to change the text.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: abf7efa5-5691-4713-b7f5-436438c85c50
📒 Files selected for processing (2)
python/cudnn/experimental/ops/sdpa.pytest/python/test_cudnn_sdpa_op.py
| major, minor = torch.cuda.get_device_capability() | ||
| if major < 10: | ||
| pytest.skip("d=256 backward path requires SM100+") |
There was a problem hiding this comment.
Discard the unused minor to satisfy lint.
minor is never used; Ruff flags this (RUF059). Prefix with _ to keep the lint clean.
🧹 Proposed fix
- major, minor = torch.cuda.get_device_capability()
+ major, _minor = torch.cuda.get_device_capability()
if major < 10:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| major, minor = torch.cuda.get_device_capability() | |
| if major < 10: | |
| pytest.skip("d=256 backward path requires SM100+") | |
| major, _minor = torch.cuda.get_device_capability() | |
| if major < 10: | |
| pytest.skip("d=256 backward path requires SM100+") |
🧰 Tools
🪛 Ruff (0.15.15)
[warning] 154-154: Unpacked variable minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@test/python/test_cudnn_sdpa_op.py` around lines 154 - 156, The variable
`minor` from the call to torch.cuda.get_device_capability() is unused and
triggers a lint warning; change the tuple unpack to capture it as a disposable
name (e.g., use `_minor` or `_`) so only `major` is used (leave the existing
check if major < 10 and the pytest.skip call intact), updating the unpack in the
test_cudnn_sdpa_op.py code where torch.cuda.get_device_capability() is invoked.
|
The CI failed. @vedaanta can you help update the tests for Ampere as discussed |
cuDNN 9.23.0 added native d=256 SDPA fprop and bprop support in the graph backend, so the OSS (cuteDSL) kernels at
cudnn.experimental.ops.sdpa are no longer required when the linked backend is recent enough.
Add _cudnn_supports_native_d256() gated on
cudnn.backend_version() >= 92300 and require it to be False before routing fprop/bprop through the SM100 OSS wrappers. The pre-existing SM100+ device check is kept so older cuDNN versions still light up the OSS path on Blackwell.
The test_d256_uses_oss_forward_path test now skips on cuDNN 9.23+ since the OSS bypass is intentional, and a new
test_d256_uses_graph_path_on_cudnn_9_23_plus asserts that fprop/bprop populate the cuDNN graph cache (proving the OSS path is bypassed).
Also: _skip_if_unsupported_d256 and test_d256_uses_oss_forward_path used import cudnn.sdpa inside the function body, which made cudnn a local variable and shadowed the module-level import as soon as any earlier line referenced cudnn (e.g. the new cudnn.backend_version() check). Switch to importlib.import_module("cudnn.sdpa") to avoid the binding.
Summary by CodeRabbit
New Features
Tests