[Pef] CUDA graph 4: call from multiple locations#420
Draft
hughperkins wants to merge 194 commits intomainfrom
Draft
[Pef] CUDA graph 4: call from multiple locations#420hughperkins wants to merge 194 commits intomainfrom
hughperkins wants to merge 194 commits intomainfrom
Conversation
When QD_CUDA_GRAPH=1, kernels with 2+ top-level for loops (offloaded tasks) are captured into a CUDA graph on first launch and replayed on subsequent launches, eliminating per-kernel launch overhead. Uses the explicit graph node API (cuGraphAddKernelNode) with persistent device arg/result buffers. Assumes stable ndarray device pointers. Made-with: Cursor
Replace the global QD_CUDA_GRAPH=1 env var with a per-kernel opt-in. The flag flows from the Python decorator through LaunchContextBuilder to the CUDA kernel launcher, avoiding interference with internal kernels like ndarray_to_ext_arr. Made-with: Cursor
Verify that cuda_graph=True is a harmless no-op on non-CUDA backends (tested on x64/CPU). Passes on both x64 and CUDA. Made-with: Cursor
On each graph replay, re-resolve ndarray device pointers and re-upload the arg buffer to the persistent device buffer. This ensures correct results when the kernel is called with different ndarrays after the graph was first captured. Refactored ndarray pointer resolution into resolve_ctx_ndarray_ptrs(). Made-with: Cursor
Apply lint formatting fixes (clang-format, ruff) and remove cuda_graph flag from autodiff adjoint kernel until the interaction with reverse-mode AD is validated.
Implements @qd.kernel(graph_while='flag_arg') which wraps the kernel offloaded tasks in a CUDA conditional while node (requires SM 9.0+). The named argument is a scalar i32 ndarray on device; the loop continues while its value is non-zero. Key implementation details: - Condition kernel compiled as PTX and JIT-linked with libcudadevrt.a at runtime to access cudaGraphSetConditional device function - CU_GRAPH_COND_ASSIGN_DEFAULT flag ensures handle is reset each launch - Works with both counter-based (decrement to 0) and boolean flag (set to 0 when done) patterns - graph_while implicitly enables cuda_graph=True Tests: counter, boolean done flag, multiple loops, graph replay.
…allback The graph_while_arg_id was computed using Python-level parameter indices, which is wrong when struct parameters are flattened into many C++ args (e.g. Genesis solver has 40 C++ params from 6 Python params). Now tracks the flattened C++ arg index during launch context setup and caches it. Also adds C++ do-while fallback loops for CPU, CUDA (non-graph path), and AMDGPU backends so graph_while works identically on all platforms.
Falls back to non-graph path with a warning on pre-Hopper GPUs, instead of failing with an unhelpful JIT link error.
Checks env-var-derived paths before the hardcoded fallbacks, so custom toolkit installs (e.g. conda, non-default prefix) are found.
Document cuda_graph=True and graph_while API in kernel() docstring, and add a user guide page covering usage patterns, cross-platform behavior, and the do-while semantics constraint.
The graph path doesn't copy the result buffer back to the host, so struct returns would silently return stale data. Error early instead of producing wrong results.
Verifies that calling a cuda_graph=True kernel first with small arrays then with larger ones produces correct results for all elements — catches stale grid dims if the graph were incorrectly replayed from the first capture.
Re-add documentation comments for |transfers|, |device_ptrs|, zero-sized array handling, external array logic, and the host copy-back section in the non-graph launch path.
Verify that a cuda_graph=True kernel works correctly after a reset/reinit cycle — exercises the full teardown and rebuild of the KernelLauncher and its graph cache.
The condition kernel's flag pointer was baked into the CUDA graph at creation time. Passing a different ndarray on replay would cause the condition kernel to read from a stale device address. Invalidate the cached graph when the flag pointer changes so it gets rebuilt.
Raise ValueError immediately if the graph_while name doesn't match any kernel parameter, instead of silently running the kernel once without looping. Also document the CUDA API version for CudaGraphNodeParams.
Made-with: Cursor
Add get_cuda_graph_cache_size() through the KernelLauncher -> Program -> pybind chain so tests can verify that graphs are actually being created (or not) rather than only checking output correctness. Made-with: Cursor
Tracks whether the CUDA graph cache was used on the most recent kernel launch, exposed through KernelLauncher -> Program -> pybind so tests can assert the graph path was (or was not) taken. Made-with: Cursor
Every test now verifies graph caching behavior, not just output correctness. Cross-platform test uses platform_supports_graph to make assertions conditional on the backend. Made-with: Cursor
Made-with: Cursor
hughperkins
commented
Mar 16, 2026
| .reg .b64 %rd<5>; | ||
|
|
||
| // Load the two kernel parameters into registers: | ||
| // %rd1 = conditional node handle |
Collaborator
Author
There was a problem hiding this comment.
lets get these commetns back
36d6d2d to
fd78ff6
Compare
hughperkins
commented
Mar 16, 2026
| selp.u32 %r2, 1, 0, %p1; | ||
|
|
||
| // Tell the conditional while node whether to loop again or stop. | ||
| // cudaGraphSetConditional(handle, should_continue) |
Collaborator
Author
There was a problem hiding this comment.
lets get this comment back
fd78ff6 to
7f03036
Compare
hughperkins
commented
Mar 16, 2026
| "Reuse the same ndarray for the condition parameter across calls."); | ||
| if (use_graph_do_while && cached.counter_ptr_slot) { | ||
| void *flag_ptr = ctx.graph_do_while_flag_dev_ptr; | ||
| CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot, |
Collaborator
Author
There was a problem hiding this comment.
I wonder if this could/should be async?
Use device-side pointer indirection so the condition kernel reads the
counter address through a persistent slot. Updating the slot via memcpy
before each launch lets different ndarrays be used without rebuilding
the CUDA graph.
Replaces the previous error ("condition ndarray changed between calls")
with transparent support for swapping.
7f03036 to
9de68d5
Compare
hughperkins
commented
Mar 16, 2026
| // Allocate a persistent device-side pointer slot and write the initial | ||
| // counter address into it. The condition kernel reads through this slot, | ||
| // so swapping the counter ndarray later only requires updating the slot. | ||
| CUDADriver::get_instance().malloc(&cached.counter_ptr_slot, sizeof(void *)); |
Collaborator
Author
There was a problem hiding this comment.
hard to tell if this is ok. lets move to the constructor of cahcedcuagraph perahps?
Introduce CudaDeviceBuffer, an RAII wrapper around CUDADriver::malloc/mem_free, replacing raw void*/char* pointers for persistent_device_arg_buffer, persistent_device_result_buffer, and counter_ptr_slot. Add a parameterized CachedCudaGraph constructor that allocates all device buffers upfront, eliminating scattered malloc calls in try_launch.
This reverts commit 6418b6f.
Move device buffer allocation (arg, result, counter_ptr_slot) and RuntimeContext setup into a new constructor, removing scattered malloc calls from try_launch.
env.sh is generated by ./build.py and should not be tracked.
…p-4-handle-ndarray-change-2 # Conflicts: # env.sh
…raph-while # Conflicts: # .github/workflows/test_gpu.yml
…p-4-handle-ndarray-change-2
hughperkins
commented
Mar 16, 2026
| c2.from_numpy(np.array(1, dtype=np.int32)) | ||
| with pytest.raises(RuntimeError, match="condition ndarray changed"): | ||
|
|
||
| for iteration in range(3): |
Collaborator
Author
There was a problem hiding this comment.
I think we should check that we arent simply rebuilding the graph each call.
Collaborator
Author
There was a problem hiding this comment.
added _cuda_graph_total_builds assert
…dd-fallback # Conflicts: # docs/source/user_guide/cuda_graph.md # python/quadrants/lang/misc.py # quadrants/runtime/amdgpu/kernel_launcher.cpp # quadrants/runtime/cpu/kernel_launcher.cpp # quadrants/runtime/cuda/cuda_graph_manager.cpp # tests/python/test_cuda_graph_do_while.py
…p-4-handle-ndarray-change-2
…p-4-handle-ndarray-change-2
…andle-ndarray-change-2 # Conflicts: # tests/python/test_cuda_graph_do_while.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Issue: #
Brief Summary
When multiple locations call the same graph, very likely the counter ndarray will be different physically between each calling site.
Prior to this PR, this causes an exception.
In this PR, we handle making it possible to pass in a different ndarray object, without either triggering a recompile, throwing an exception, or being incorrect (which were the three options before this fix).
copilot:summary
Walkthrough
copilot:walkthrough