Skip to content

[Pef] CUDA graph 4: call from multiple locations#420

Draft
hughperkins wants to merge 194 commits intomainfrom
hp/cuda-graph-mvp-4-handle-ndarray-change-2
Draft

[Pef] CUDA graph 4: call from multiple locations#420
hughperkins wants to merge 194 commits intomainfrom
hp/cuda-graph-mvp-4-handle-ndarray-change-2

Conversation

@hughperkins
Copy link
Collaborator

Issue: #

Brief Summary

When multiple locations call the same graph, very likely the counter ndarray will be different physically between each calling site.

Prior to this PR, this causes an exception.

In this PR, we handle making it possible to pass in a different ndarray object, without either triggering a recompile, throwing an exception, or being incorrect (which were the three options before this fix).

copilot:summary

Walkthrough

copilot:walkthrough

When QD_CUDA_GRAPH=1, kernels with 2+ top-level for loops (offloaded
tasks) are captured into a CUDA graph on first launch and replayed on
subsequent launches, eliminating per-kernel launch overhead.

Uses the explicit graph node API (cuGraphAddKernelNode) with persistent
device arg/result buffers. Assumes stable ndarray device pointers.

Made-with: Cursor
Replace the global QD_CUDA_GRAPH=1 env var with a per-kernel opt-in.
The flag flows from the Python decorator through LaunchContextBuilder
to the CUDA kernel launcher, avoiding interference with internal
kernels like ndarray_to_ext_arr.

Made-with: Cursor
Verify that cuda_graph=True is a harmless no-op on non-CUDA backends
(tested on x64/CPU). Passes on both x64 and CUDA.

Made-with: Cursor
On each graph replay, re-resolve ndarray device pointers and re-upload
the arg buffer to the persistent device buffer. This ensures correct
results when the kernel is called with different ndarrays after the
graph was first captured.

Refactored ndarray pointer resolution into resolve_ctx_ndarray_ptrs().

Made-with: Cursor
Apply lint formatting fixes (clang-format, ruff) and remove
cuda_graph flag from autodiff adjoint kernel until the interaction
with reverse-mode AD is validated.
Implements @qd.kernel(graph_while='flag_arg') which wraps the kernel
offloaded tasks in a CUDA conditional while node (requires SM 9.0+).
The named argument is a scalar i32 ndarray on device; the loop
continues while its value is non-zero.

Key implementation details:
- Condition kernel compiled as PTX and JIT-linked with libcudadevrt.a
  at runtime to access cudaGraphSetConditional device function
- CU_GRAPH_COND_ASSIGN_DEFAULT flag ensures handle is reset each launch
- Works with both counter-based (decrement to 0) and boolean flag
  (set to 0 when done) patterns
- graph_while implicitly enables cuda_graph=True

Tests: counter, boolean done flag, multiple loops, graph replay.
…allback

The graph_while_arg_id was computed using Python-level parameter indices,
which is wrong when struct parameters are flattened into many C++ args
(e.g. Genesis solver has 40 C++ params from 6 Python params). Now tracks
the flattened C++ arg index during launch context setup and caches it.

Also adds C++ do-while fallback loops for CPU, CUDA (non-graph path), and
AMDGPU backends so graph_while works identically on all platforms.
Falls back to non-graph path with a warning on pre-Hopper GPUs,
instead of failing with an unhelpful JIT link error.
Checks env-var-derived paths before the hardcoded fallbacks, so
custom toolkit installs (e.g. conda, non-default prefix) are found.
Document cuda_graph=True and graph_while API in kernel() docstring,
and add a user guide page covering usage patterns, cross-platform
behavior, and the do-while semantics constraint.
The graph path doesn't copy the result buffer back to the host,
so struct returns would silently return stale data. Error early
instead of producing wrong results.
Verifies that calling a cuda_graph=True kernel first with small
arrays then with larger ones produces correct results for all
elements — catches stale grid dims if the graph were incorrectly
replayed from the first capture.
Re-add documentation comments for |transfers|, |device_ptrs|,
zero-sized array handling, external array logic, and the
host copy-back section in the non-graph launch path.
Verify that a cuda_graph=True kernel works correctly after a
reset/reinit cycle — exercises the full teardown and rebuild
of the KernelLauncher and its graph cache.
The condition kernel's flag pointer was baked into the CUDA graph at
creation time. Passing a different ndarray on replay would cause the
condition kernel to read from a stale device address. Invalidate the
cached graph when the flag pointer changes so it gets rebuilt.
Raise ValueError immediately if the graph_while name doesn't match any
kernel parameter, instead of silently running the kernel once without
looping. Also document the CUDA API version for CudaGraphNodeParams.
Add get_cuda_graph_cache_size() through the KernelLauncher -> Program ->
pybind chain so tests can verify that graphs are actually being created
(or not) rather than only checking output correctness.

Made-with: Cursor
Tracks whether the CUDA graph cache was used on the most recent kernel
launch, exposed through KernelLauncher -> Program -> pybind so tests
can assert the graph path was (or was not) taken.

Made-with: Cursor
Every test now verifies graph caching behavior, not just output
correctness. Cross-platform test uses platform_supports_graph to
make assertions conditional on the backend.

Made-with: Cursor
@hughperkins hughperkins changed the base branch from main to hp/cuda-graph-mvp-3-add-fallback March 16, 2026 16:57
.reg .b64 %rd<5>;

// Load the two kernel parameters into registers:
// %rd1 = conditional node handle
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets get these commetns back

@hughperkins hughperkins force-pushed the hp/cuda-graph-mvp-4-handle-ndarray-change-2 branch from 36d6d2d to fd78ff6 Compare March 16, 2026 17:08
selp.u32 %r2, 1, 0, %p1;

// Tell the conditional while node whether to loop again or stop.
// cudaGraphSetConditional(handle, should_continue)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets get this comment back

@hughperkins hughperkins force-pushed the hp/cuda-graph-mvp-4-handle-ndarray-change-2 branch from fd78ff6 to 7f03036 Compare March 16, 2026 17:10
"Reuse the same ndarray for the condition parameter across calls.");
if (use_graph_do_while && cached.counter_ptr_slot) {
void *flag_ptr = ctx.graph_do_while_flag_dev_ptr;
CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could/should be async?

Use device-side pointer indirection so the condition kernel reads the
counter address through a persistent slot. Updating the slot via memcpy
before each launch lets different ndarrays be used without rebuilding
the CUDA graph.

Replaces the previous error ("condition ndarray changed between calls")
with transparent support for swapping.
@hughperkins hughperkins force-pushed the hp/cuda-graph-mvp-4-handle-ndarray-change-2 branch from 7f03036 to 9de68d5 Compare March 16, 2026 17:16
// Allocate a persistent device-side pointer slot and write the initial
// counter address into it. The condition kernel reads through this slot,
// so swapping the counter ndarray later only requires updating the slot.
CUDADriver::get_instance().malloc(&cached.counter_ptr_slot, sizeof(void *));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard to tell if this is ok. lets move to the constructor of cahcedcuagraph perahps?

Introduce CudaDeviceBuffer, an RAII wrapper around CUDADriver::malloc/mem_free,
replacing raw void*/char* pointers for persistent_device_arg_buffer,
persistent_device_result_buffer, and counter_ptr_slot. Add a parameterized
CachedCudaGraph constructor that allocates all device buffers upfront,
eliminating scattered malloc calls in try_launch.
Move device buffer allocation (arg, result, counter_ptr_slot) and
RuntimeContext setup into a new constructor, removing scattered malloc
calls from try_launch.
env.sh is generated by ./build.py and should not be tracked.
…p-4-handle-ndarray-change-2

# Conflicts:
#	env.sh
…raph-while

# Conflicts:
#	.github/workflows/test_gpu.yml
c2.from_numpy(np.array(1, dtype=np.int32))
with pytest.raises(RuntimeError, match="condition ndarray changed"):

for iteration in range(3):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check that we arent simply rebuilding the graph each call.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added _cuda_graph_total_builds assert

@hughperkins hughperkins changed the title [Pef] CUDA graph 4: handle multiple locations calling same graph [Pef] CUDA graph 4: call from multiple locations Mar 16, 2026
…dd-fallback

# Conflicts:
#	docs/source/user_guide/cuda_graph.md
#	python/quadrants/lang/misc.py
#	quadrants/runtime/amdgpu/kernel_launcher.cpp
#	quadrants/runtime/cpu/kernel_launcher.cpp
#	quadrants/runtime/cuda/cuda_graph_manager.cpp
#	tests/python/test_cuda_graph_do_while.py
Base automatically changed from hp/cuda-graph-mvp-3-add-fallback to main March 16, 2026 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant