Skip to content

[Perf] CUDA graph 3: add fallbacks#416

Merged
hughperkins merged 182 commits intomainfrom
hp/cuda-graph-mvp-3-add-fallback
Mar 16, 2026
Merged

[Perf] CUDA graph 3: add fallbacks#416
hughperkins merged 182 commits intomainfrom
hp/cuda-graph-mvp-3-add-fallback

Conversation

@hughperkins
Copy link
Collaborator

  • add fallbacks on non-cuda platforms

Made-with: Cursor

Issue: #

Brief Summary

copilot:summary

Walkthrough

copilot:walkthrough

When QD_CUDA_GRAPH=1, kernels with 2+ top-level for loops (offloaded
tasks) are captured into a CUDA graph on first launch and replayed on
subsequent launches, eliminating per-kernel launch overhead.

Uses the explicit graph node API (cuGraphAddKernelNode) with persistent
device arg/result buffers. Assumes stable ndarray device pointers.

Made-with: Cursor
Replace the global QD_CUDA_GRAPH=1 env var with a per-kernel opt-in.
The flag flows from the Python decorator through LaunchContextBuilder
to the CUDA kernel launcher, avoiding interference with internal
kernels like ndarray_to_ext_arr.

Made-with: Cursor
Verify that cuda_graph=True is a harmless no-op on non-CUDA backends
(tested on x64/CPU). Passes on both x64 and CUDA.

Made-with: Cursor
On each graph replay, re-resolve ndarray device pointers and re-upload
the arg buffer to the persistent device buffer. This ensures correct
results when the kernel is called with different ndarrays after the
graph was first captured.

Refactored ndarray pointer resolution into resolve_ctx_ndarray_ptrs().

Made-with: Cursor
Apply lint formatting fixes (clang-format, ruff) and remove
cuda_graph flag from autodiff adjoint kernel until the interaction
with reverse-mode AD is validated.
Implements @qd.kernel(graph_while='flag_arg') which wraps the kernel
offloaded tasks in a CUDA conditional while node (requires SM 9.0+).
The named argument is a scalar i32 ndarray on device; the loop
continues while its value is non-zero.

Key implementation details:
- Condition kernel compiled as PTX and JIT-linked with libcudadevrt.a
  at runtime to access cudaGraphSetConditional device function
- CU_GRAPH_COND_ASSIGN_DEFAULT flag ensures handle is reset each launch
- Works with both counter-based (decrement to 0) and boolean flag
  (set to 0 when done) patterns
- graph_while implicitly enables cuda_graph=True

Tests: counter, boolean done flag, multiple loops, graph replay.
…allback

The graph_while_arg_id was computed using Python-level parameter indices,
which is wrong when struct parameters are flattened into many C++ args
(e.g. Genesis solver has 40 C++ params from 6 Python params). Now tracks
the flattened C++ arg index during launch context setup and caches it.

Also adds C++ do-while fallback loops for CPU, CUDA (non-graph path), and
AMDGPU backends so graph_while works identically on all platforms.
Falls back to non-graph path with a warning on pre-Hopper GPUs,
instead of failing with an unhelpful JIT link error.
Checks env-var-derived paths before the hardcoded fallbacks, so
custom toolkit installs (e.g. conda, non-default prefix) are found.
Document cuda_graph=True and graph_while API in kernel() docstring,
and add a user guide page covering usage patterns, cross-platform
behavior, and the do-while semantics constraint.
The graph path doesn't copy the result buffer back to the host,
so struct returns would silently return stale data. Error early
instead of producing wrong results.
Verifies that calling a cuda_graph=True kernel first with small
arrays then with larger ones produces correct results for all
elements — catches stale grid dims if the graph were incorrectly
replayed from the first capture.
Re-add documentation comments for |transfers|, |device_ptrs|,
zero-sized array handling, external array logic, and the
host copy-back section in the non-graph launch path.
Verify that a cuda_graph=True kernel works correctly after a
reset/reinit cycle — exercises the full teardown and rebuild
of the KernelLauncher and its graph cache.
The condition kernel's flag pointer was baked into the CUDA graph at
creation time. Passing a different ndarray on replay would cause the
condition kernel to read from a stale device address. Invalidate the
cached graph when the flag pointer changes so it gets rebuilt.
Raise ValueError immediately if the graph_while name doesn't match any
kernel parameter, instead of silently running the kernel once without
looping. Also document the CUDA API version for CudaGraphNodeParams.
Add get_cuda_graph_cache_size() through the KernelLauncher -> Program ->
pybind chain so tests can verify that graphs are actually being created
(or not) rather than only checking output correctness.

Made-with: Cursor
Tracks whether the CUDA graph cache was used on the most recent kernel
launch, exposed through KernelLauncher -> Program -> pybind so tests
can assert the graph path was (or was not) taken.

Made-with: Cursor
Every test now verifies graph caching behavior, not just output
correctness. Cross-platform test uses platform_supports_graph to
make assertions conditional on the backend.

Made-with: Cursor
The LLVM x64 backend generates extra tasks per ndarray argument for
serialization/setup, so exact equality checks fail. Use >= instead.

Made-with: Cursor
…-3-add-fallback

Made-with: Cursor

# Conflicts:
#	docs/source/user_guide/cuda_graph.md
@hughperkins
Copy link
Collaborator Author

I have read every line added in this PR, and reviewed the lines. I take responsibilty for the lines added and removed in this PR, and won't blame any issues on Opus.

Resolve conflicts from squash-merged MVP-1 PR (#405) vs branch's
pre-existing MVP-1 merge commits. Keep all graph_do_while (MVP-2)
additions. Incorporate grad_ptr local variable cleanup from main.
env.sh is generated by ./build.py and should not be tracked.
…raph-while

# Conflicts:
#	.github/workflows/test_gpu.yml
Base automatically changed from hp/cuda-graph-mvp-2-graph-while to main March 16, 2026 18:32
…dd-fallback

# Conflicts:
#	docs/source/user_guide/cuda_graph.md
#	python/quadrants/lang/misc.py
#	quadrants/runtime/amdgpu/kernel_launcher.cpp
#	quadrants/runtime/cpu/kernel_launcher.cpp
#	quadrants/runtime/cuda/cuda_graph_manager.cpp
#	tests/python/test_cuda_graph_do_while.py
@hughperkins hughperkins enabled auto-merge (squash) March 16, 2026 18:48
@hughperkins hughperkins merged commit a346d2d into main Mar 16, 2026
47 checks passed
@hughperkins hughperkins deleted the hp/cuda-graph-mvp-3-add-fallback branch March 16, 2026 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants