Skip to content

[Perf] CUDA graph 2: graph_do_while#406

Merged
hughperkins merged 148 commits intomainfrom
hp/cuda-graph-mvp-2-graph-while
Mar 16, 2026
Merged

[Perf] CUDA graph 2: graph_do_while#406
hughperkins merged 148 commits intomainfrom
hp/cuda-graph-mvp-2-graph-while

Conversation

@hughperkins
Copy link
Collaborator

@hughperkins hughperkins commented Mar 11, 2026

Issue: #

Brief Summary

In this PR, we add graph_do_while for CUDA, and do NOT add fallbacks on other platforms. A later PR will add fallbacks on other platforms.

  • in addition, must always use the same physical ndarray condition object when calling into a do_while graph graph

The do-while is implemented by usng a cuda graph conditional node

  • we need to use PTX for this node, and compile that, using libcudadevrta.a

copilot:summary

Walkthrough

copilot:walkthrough

When QD_CUDA_GRAPH=1, kernels with 2+ top-level for loops (offloaded
tasks) are captured into a CUDA graph on first launch and replayed on
subsequent launches, eliminating per-kernel launch overhead.

Uses the explicit graph node API (cuGraphAddKernelNode) with persistent
device arg/result buffers. Assumes stable ndarray device pointers.

Made-with: Cursor
Replace the global QD_CUDA_GRAPH=1 env var with a per-kernel opt-in.
The flag flows from the Python decorator through LaunchContextBuilder
to the CUDA kernel launcher, avoiding interference with internal
kernels like ndarray_to_ext_arr.

Made-with: Cursor
Verify that cuda_graph=True is a harmless no-op on non-CUDA backends
(tested on x64/CPU). Passes on both x64 and CUDA.

Made-with: Cursor
On each graph replay, re-resolve ndarray device pointers and re-upload
the arg buffer to the persistent device buffer. This ensures correct
results when the kernel is called with different ndarrays after the
graph was first captured.

Refactored ndarray pointer resolution into resolve_ctx_ndarray_ptrs().

Made-with: Cursor
Apply lint formatting fixes (clang-format, ruff) and remove
cuda_graph flag from autodiff adjoint kernel until the interaction
with reverse-mode AD is validated.
Implements @qd.kernel(graph_while='flag_arg') which wraps the kernel
offloaded tasks in a CUDA conditional while node (requires SM 9.0+).
The named argument is a scalar i32 ndarray on device; the loop
continues while its value is non-zero.

Key implementation details:
- Condition kernel compiled as PTX and JIT-linked with libcudadevrt.a
  at runtime to access cudaGraphSetConditional device function
- CU_GRAPH_COND_ASSIGN_DEFAULT flag ensures handle is reset each launch
- Works with both counter-based (decrement to 0) and boolean flag
  (set to 0 when done) patterns
- graph_while implicitly enables cuda_graph=True

Tests: counter, boolean done flag, multiple loops, graph replay.
…allback

The graph_while_arg_id was computed using Python-level parameter indices,
which is wrong when struct parameters are flattened into many C++ args
(e.g. Genesis solver has 40 C++ params from 6 Python params). Now tracks
the flattened C++ arg index during launch context setup and caches it.

Also adds C++ do-while fallback loops for CPU, CUDA (non-graph path), and
AMDGPU backends so graph_while works identically on all platforms.
Falls back to non-graph path with a warning on pre-Hopper GPUs,
instead of failing with an unhelpful JIT link error.
Checks env-var-derived paths before the hardcoded fallbacks, so
custom toolkit installs (e.g. conda, non-default prefix) are found.
Document cuda_graph=True and graph_while API in kernel() docstring,
and add a user guide page covering usage patterns, cross-platform
behavior, and the do-while semantics constraint.
The graph path doesn't copy the result buffer back to the host,
so struct returns would silently return stale data. Error early
instead of producing wrong results.
Verifies that calling a cuda_graph=True kernel first with small
arrays then with larger ones produces correct results for all
elements — catches stale grid dims if the graph were incorrectly
replayed from the first capture.
Re-add documentation comments for |transfers|, |device_ptrs|,
zero-sized array handling, external array logic, and the
host copy-back section in the non-graph launch path.
Verify that a cuda_graph=True kernel works correctly after a
reset/reinit cycle — exercises the full teardown and rebuild
of the KernelLauncher and its graph cache.
@hughperkins hughperkins marked this pull request as draft March 11, 2026 22:38
@hughperkins
Copy link
Collaborator Author

Opus 4.6 review:

What it does

Adds graph_while — a GPU-side iteration primitive built on the CUDA graph infrastructure from MVP-1. Annotating a kernel with @qd.kernel(graph_while="counter") repeats the kernel body while a device-side i32 flag is non-zero:

  • SM 9.0+ (Hopper): Uses CUDA conditional while graph nodes. A JIT-compiled PTX condition kernel calls cudaGraphSetConditional to control the loop entirely on-GPU.
  • Older CUDA / CPU / AMDGPU: Falls back to a host-side do-while loop with stream_synchronize + D2H memcpy to poll the flag each iteration.

The implementation threads through cleanly: Python decorator → graph_while_arg_id on LaunchContextBuilder → backend-specific loop logic.

What's good

  1. Clean layering. The Python side (kernel_impl.py, kernel.py) is minimal — just plumbs the arg name through and resolves the C++ arg index. All the heavy lifting is in the C++ launchers.

  2. Cross-platform fallback. CPU and AMDGPU get do-while fallbacks with identical semantics. cuda_graph=True is implicitly set, so the user writes one annotation and it works everywhere.

  3. Solid documentation. The new cuda_graphs.md page explains both cuda_graph=True and graph_while, provides counter-based and boolean-flag patterns, and clearly warns about do-while semantics.

  4. Good test coverage. Both CUDA-specific tests and cross-backend tests covering counter, boolean flag, multi-loop, replay, and single-iteration edge cases.

  5. ABI guard. The static_assert(sizeof(CudaGraphNodeParams) == 256, ...) is a smart safety net for the hand-rolled struct.

Concerns

1. Graph replay bug with changed ndarray pointers (correctness)

When the CUDA graph is replayed (cache hit), the work kernels are fine because the arg buffer is re-uploaded. But the condition kernel's flag_ptr is baked into the graph at creation time:

// kernel_launcher.cpp, inside launch_llvm_kernel_graph
void *flag_ptr = ctx.graph_while_flag_dev_ptr;
void *cond_args[2] = {&cond_handle, &flag_ptr};
// ... added as a kernel node — pointer captured permanently

If a user creates a new ndarray for the counter between calls, the condition kernel still reads from the old device address. The tests pass because they reuse the same ndarray object — but that's a latent bug.

Fix: Store the condition kernel node in CachedCudaGraph and update its params on replay (via cuGraphExecKernelNodeSetParams or equivalent), or invalidate the cache when the flag pointer changes.

2. Fragile ABI-coupled struct (maintainability)

CudaGraphNodeParams is a hand-rolled mirror of CUgraphNodeParams with manual padding:

// kernel_launcher.h
struct CudaGraphNodeParams {
  unsigned int type;
  int reserved0[3];
  unsigned long long handle;
  unsigned int condType;
  unsigned int size;
  void *phGraph_out;
  void *ctx;
  char _pad[232 - 8 - 4 - 4 - 8 - 8];
  long long reserved2;
};

The static_assert on total size catches gross drift, but internal layout drift (e.g., NVIDIA changing field offsets within the union) would silently corrupt data and still pass the size check. A comment referencing the exact CUDA driver API version this was modeled against would help future maintainers know when to re-verify.

3. No validation on the graph_while parameter type (usability)

There's no check that the named parameter is actually a scalar i32 ndarray. Passing a f32 array, a 1D array, or a misspelled name would produce undefined behavior or a silent wrong result rather than a clear error.

Fix: Validate at kernel definition or first-call time that the named arg exists and has the expected dtype/shape.

4. Do-while footgun with counter=0 (usability)

The docs warn about it, which is good. But there's no runtime guard — counter=0 will infinite-loop on CPU/AMDGPU fallbacks (work, sync, read 0, decrement to −1, −1 ≠ 0, loop forever). A simple QD_ASSERT(initial_value >= 1) or at least a QD_WARN on the fallback paths would be cheap insurance.

5. Host-side fallback performance (docs)

The AMDGPU and CUDA non-graph fallbacks call stream_synchronize + D2H memcpy every iteration. For fine-grained loops this could be slower than just running the iterations without the graph_while machinery. Worth a note in the docs that the fallback is semantically correct but not a performance win.

6. Minor: stream_synchronize(nullptr) in AMDGPU (correctness)

The AMDGPU fallback passes nullptr for the stream, which may synchronize all streams rather than just the relevant one. The CUDA fallback correctly uses CUDAContext::get_instance().get_stream().

Summary

Well-structured feature with clean layering, good docs, and solid tests. #1 (replay pointer staleness) is a correctness bug that should be fixed before merge. #3 (input validation) is a close second for user-facing ergonomics. The rest are hardening items that could be tracked as follow-ups.

The condition kernel's flag pointer was baked into the CUDA graph at
creation time. Passing a different ndarray on replay would cause the
condition kernel to read from a stale device address. Invalidate the
cached graph when the flag pointer changes so it gets rebuilt.
Raise ValueError immediately if the graph_while name doesn't match any
kernel parameter, instead of silently running the kernel once without
looping. Also document the CUDA API version for CudaGraphNodeParams.
@hughperkins
Copy link
Collaborator Author

Update from Opus on concern 4:

Thinking this through more carefully: counter=0 causes an infinite loop on all paths, not just the fallbacks. On the CUDA graph path, defaultLaunchValue=1 forces the body to
execute once, the body decrements 0 to -1, the condition kernel sees -1 != 0, and it loops forever too.
The question is whether a runtime guard is practical. The counter lives on device — to check its initial value at launch time you'd need a D2H memcpy + sync, which adds
latency on every launch and undermines the whole point of keeping things on the GPU. On CPU you could peek cheaply, but adding a check only on one backend is inconsistent.
You could do a post-first-iteration sanity check (if the counter went negative, something is wrong), but that's complex, still requires a sync on the fallback paths, and
doesn't help the CUDA graph path where the loop is entirely on-device.
I think this is adequately handled by the documentation warning. The do-while semantics are inherent to CUDA conditional while nodes — it's not a quirk of this implementation,
it's how the hardware primitive works. Users of this API are writing GPU iteration loops and should be expected to read the docs. The name validation we just added catches
the most common real-world mistake (typos); counter=0 is a logic error that's hard to guard against cheaply.
I wouldn't add anything here.

@hughperkins
Copy link
Collaborator Author

Update from Opus on concern 5:

Concern 5 (host-side fallback performance) is just a documentation suggestion — noting that the fallback is correct but not a performance win. Looking at it again, I think the
existing docs already cover this adequately:

docs/source/user_guide/user_guide/cuda_graphs.md lines 68-70

  • On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement.
  • On older CUDA GPUs and non-CUDA backends, it falls back to a host-side do-while loop.
  • graph_while implicitly enables cuda_graph=True.

The fact that it's a "host-side do-while loop" already implies sync-per-iteration. Anyone who understands enough to care about this overhead will read "host-side" and
understand the implication. Adding a paragraph spelling out "this means a stream synchronize and D2H memcpy every iteration" would be accurate but is probably over-explaining
for the target audience.
I wouldn't change anything.

@hughperkins
Copy link
Collaborator Author

update from Opus for concern 6:

Looking at the code, the existing AMDGPU backend already uses stream_synchronize(nullptr) everywhere — line 95 has the same pattern, predating graph_while. The new code at
line 129 is just following the established convention for this backend.
This isn't something the branch introduced; it's how the AMDGPU backend works. Changing it would be a broader AMDGPU backend concern, not something to fix in the graph_while
PR. I'd drop this from the review.

@hughperkins
Copy link
Collaborator Author

Addressed all 6 of Opus's concerns

…-2-graph-while

Made-with: Cursor

# Conflicts:
#	quadrants/runtime/cuda/cuda_graph_manager.cpp
#	quadrants/runtime/cuda/cuda_graph_manager.h
…loop

Replace `@qd.kernel(graph_do_while="var")` with explicit
`while qd.graph_do_while(var):` inside the kernel body. The AST
transformer recognises the pattern and sets the condition arg without
emitting a while-loop IR node. No C++ changes needed.

Made-with: Cursor
- On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement.
- Older CUDA GPUs, and non-CUDA backends not currently supported.
- `graph_do_while` implicitly enables `cuda_graph=True`.
- Using `qd.graph_do_while()` implicitly enables `cuda_graph=True` if not already set.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be removed

f"Available parameters: {arg_names}"
)
kernel.graph_do_while_arg = graph_do_while_arg
if not kernel.use_cuda_graph:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error if not in cuda graph arlreayd

Test that using qd.graph_do_while() without cuda_graph=True and with a
non-existent parameter name both raise QuadrantsSyntaxError.

Made-with: Cursor
@hughperkins
Copy link
Collaborator Author

I have read every line added in this PR, and reviewed the lines. I take responsibilty for the lines added and removed in this PR, and won't blame any issues on Opus.

The LLVM x64 backend generates extra tasks per ndarray argument for
serialization/setup, so exact equality checks fail. Use >= instead.

Made-with: Cursor
Ndarray kernels can produce additional serial tasks beyond the
user-visible loops, so hardcoding expected node counts breaks.
Use the actual num_offloaded_tasks instead.
Base automatically changed from hp/cuda-graph-mvp-1-graph-build to main March 16, 2026 16:14
Resolve conflicts from squash-merged MVP-1 PR (#405) vs branch's
pre-existing MVP-1 merge commits. Keep all graph_do_while (MVP-2)
additions. Incorporate grad_ptr local variable cleanup from main.
@hughperkins hughperkins enabled auto-merge (squash) March 16, 2026 16:45
@hughperkins hughperkins disabled auto-merge March 16, 2026 17:34
@hughperkins
Copy link
Collaborator Author

env.sh shouldnt be here

env.sh is generated by ./build.py and should not be tracked.
…raph-while

# Conflicts:
#	.github/workflows/test_gpu.yml
@hughperkins hughperkins enabled auto-merge (squash) March 16, 2026 17:45
@hughperkins hughperkins merged commit 79dd8bf into main Mar 16, 2026
46 checks passed
@hughperkins hughperkins deleted the hp/cuda-graph-mvp-2-graph-while branch March 16, 2026 18:32
hughperkins added a commit that referenced this pull request Mar 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants