Skip to content

Commit 79dd8bf

Browse files
authored
[Perf] CUDA graph 2: graph_do_while (#406)
1 parent 9c8f4dd commit 79dd8bf

15 files changed

Lines changed: 640 additions & 9 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,4 @@ imgui.ini
9696
stubs/
9797
CHANGELOG.md
9898
python/quadrants/_version.py
99+
env.sh

docs/source/user_guide/cuda_graph.md

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
CUDA graphs reduce kernel launch overhead by capturing a sequence of GPU operations into a graph, then replaying it in a single launch. On non-CUDA platforms, the cuda graph annotation is simply ignored, and code runs normally.
44

5-
## Usage
5+
## Basic usage
66

77
Add `cuda_graph=True` to a `@qd.kernel` decorator:
88

@@ -52,3 +52,80 @@ my_kernel(x2, y2) # replays graph with new array pointers
5252
### Fields as arguments
5353

5454
When different fields are passed as template arguments, each unique combination of fields produces a separately compiled kernel with its own graph cache entry. There is no interference between them.
55+
56+
57+
## GPU-side iteration with `graph_do_while`
58+
59+
For iterative algorithms (physics solvers, convergence loops), you often want to repeat the kernel body until a condition is met, without returning to the host each iteration. Use `while qd.graph_do_while(flag):` inside a `cuda_graph=True` kernel:
60+
61+
```python
62+
@qd.kernel(cuda_graph=True)
63+
def solve(x: qd.types.ndarray(qd.f32, ndim=1),
64+
counter: qd.types.ndarray(qd.i32, ndim=0)):
65+
while qd.graph_do_while(counter):
66+
for i in range(x.shape[0]):
67+
x[i] = x[i] + 1.0
68+
for i in range(1):
69+
counter[()] = counter[()] - 1
70+
71+
x = qd.ndarray(qd.f32, shape=(N,))
72+
counter = qd.ndarray(qd.i32, shape=())
73+
counter.from_numpy(np.array(10, dtype=np.int32))
74+
solve(x, counter)
75+
# x is now incremented 10 times; counter is 0
76+
```
77+
78+
The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero.
79+
80+
- On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement.
81+
- Older CUDA GPUs, and non-CUDA backends not currently supported.
82+
83+
### Patterns
84+
85+
**Counter-based**: set the counter to N, decrement each iteration. The body runs exactly N times.
86+
87+
```python
88+
@qd.kernel(cuda_graph=True)
89+
def iterate(x: qd.types.ndarray(qd.f32, ndim=1),
90+
counter: qd.types.ndarray(qd.i32, ndim=0)):
91+
while qd.graph_do_while(counter):
92+
for i in range(x.shape[0]):
93+
x[i] = x[i] + 1.0
94+
for i in range(1):
95+
counter[()] = counter[()] - 1
96+
```
97+
98+
**Boolean flag**: set a `keep_going` flag to 1, have the kernel set it to 0 when a convergence criterion is met.
99+
100+
```python
101+
@qd.kernel(cuda_graph=True)
102+
def converge(x: qd.types.ndarray(qd.f32, ndim=1),
103+
keep_going: qd.types.ndarray(qd.i32, ndim=0)):
104+
while qd.graph_do_while(keep_going):
105+
for i in range(x.shape[0]):
106+
# ... do work ...
107+
pass
108+
for i in range(1):
109+
if some_condition(x):
110+
keep_going[()] = 0
111+
```
112+
113+
### Do-while semantics
114+
115+
`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop.
116+
117+
### ndarray vs field
118+
119+
The parameter used by `graph_do_while` MUST be an ndarray.
120+
121+
However, other parameters can be any supported Quadrants kernel parameter type.
122+
123+
### Restrictions
124+
125+
- The same physical ndarray must be used for the counter parameter on every
126+
call. Passing a different ndarray raises an error, because the counter's
127+
device pointer is baked into the CUDA graph at creation time.
128+
129+
### Caveats
130+
131+
Only runs on CUDA. No fallback on non-CUDA platforms currently.

python/quadrants/lang/ast/ast_transformer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,11 +1193,41 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None:
11931193
# Struct for
11941194
return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
11951195

1196+
@staticmethod
1197+
def _is_graph_do_while_call(node: ast.expr) -> str | None:
1198+
"""If *node* is ``qd.graph_do_while(var)`` return the arg name, else None."""
1199+
if not isinstance(node, ast.Call):
1200+
return None
1201+
func = node.func
1202+
if isinstance(func, ast.Attribute) and func.attr == "graph_do_while":
1203+
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
1204+
return node.args[0].id
1205+
if isinstance(func, ast.Name) and func.id == "graph_do_while":
1206+
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
1207+
return node.args[0].id
1208+
return None
1209+
11961210
@staticmethod
11971211
def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None:
11981212
if node.orelse:
11991213
raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels")
12001214

1215+
graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test)
1216+
if graph_do_while_arg is not None:
1217+
kernel = ctx.global_context.current_kernel
1218+
arg_names = [m.name for m in kernel.arg_metas]
1219+
if graph_do_while_arg not in arg_names:
1220+
raise QuadrantsSyntaxError(
1221+
f"qd.graph_do_while({graph_do_while_arg!r}) does not match any "
1222+
f"parameter of kernel {kernel.func.__name__!r}. "
1223+
f"Available parameters: {arg_names}"
1224+
)
1225+
if not kernel.use_cuda_graph:
1226+
raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(cuda_graph=True)")
1227+
kernel.graph_do_while_arg = graph_do_while_arg
1228+
build_stmts(ctx, node.body)
1229+
return None
1230+
12011231
with ctx.loop_scope_guard():
12021232
stmt_dbg_info = _qd_core.DebugInfo(ctx.get_pos_info(node))
12031233
ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)

python/quadrants/lang/kernel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel
292292
self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
293293
self.has_print = False
294294
self.use_cuda_graph: bool = False
295+
self.graph_do_while_arg: str | None = None
295296
self.quadrants_callable: QuadrantsCallable | None = None
296297
self.visited_functions: set[FunctionSourceInfo] = set()
297298
self.kernel_function_info: FunctionSourceInfo | None = None
@@ -444,6 +445,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled
444445
template_num += 1
445446
i_out += 1
446447
continue
448+
if self.graph_do_while_arg is not None and self.arg_metas[i_in].name == self.graph_do_while_arg:
449+
self._graph_do_while_cpp_arg_id = i_out - template_num
447450
num_args_, is_launch_ctx_cacheable_ = self._recursive_set_args(
448451
self.used_py_dataclass_parameters_by_key_enforcing[key],
449452
self.arg_metas[i_in].name,
@@ -505,6 +508,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled
505508
self.src_ll_cache_observations.cache_stored = True
506509
self._last_compiled_kernel_data = compiled_kernel_data
507510
launch_ctx.use_cuda_graph = self.use_cuda_graph
511+
if self.graph_do_while_arg is not None and hasattr(self, "_graph_do_while_cpp_arg_id"):
512+
launch_ctx.graph_do_while_arg_id = self._graph_do_while_cpp_arg_id
508513
prog.launch_kernel(compiled_kernel_data, launch_ctx)
509514
except Exception as e:
510515
e = handle_exception_from_cpp(e)

python/quadrants/lang/kernel_impl.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def _inside_class(level_of_class_stackframe: int) -> bool:
124124

125125

126126
def _kernel_impl(
127-
_func: Callable, level_of_class_stackframe: int, verbose: bool = False, cuda_graph: bool = False
127+
_func: Callable,
128+
level_of_class_stackframe: int,
129+
verbose: bool = False,
130+
cuda_graph: bool = False,
128131
) -> QuadrantsCallable:
129132
# Can decorators determine if a function is being defined inside a class?
130133
# https://stackoverflow.com/a/8793684/12003165
@@ -206,6 +209,12 @@ def kernel(
206209
207210
Kernel's gradient kernel would be generated automatically by the AutoDiff system.
208211
212+
Args:
213+
cuda_graph: If True, kernels with 2+ top-level for loops are captured
214+
into a CUDA graph on first launch and replayed on subsequent
215+
launches, reducing per-kernel launch overhead. Non-CUDA backends
216+
are not supported currently.
217+
209218
Example::
210219
211220
>>> x = qd.field(qd.i32, shape=(4, 8))

python/quadrants/lang/misc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,24 @@ def copy():
701701
_bit_vectorize()
702702

703703

704+
def graph_do_while(condition) -> bool:
705+
"""Marks a while loop as a CUDA graph do-while conditional node.
706+
707+
Used as ``while qd.graph_do_while(flag):`` inside a
708+
``@qd.kernel(cuda_graph=True)`` kernel. The loop body repeats while
709+
``flag`` (a scalar ``qd.i32`` ndarray) is non-zero.
710+
711+
On SM 9.0+ (Hopper) GPUs this compiles to a native CUDA graph
712+
conditional while node. Older CUDA GPUs and non-CUDA backends
713+
are not currently supported.
714+
715+
This function should not be called directly at runtime; it is
716+
recognised and transformed during AST compilation.
717+
Requires ``@qd.kernel(cuda_graph=True)``.
718+
"""
719+
return bool(condition)
720+
721+
704722
def global_thread_idx():
705723
"""Returns the global thread id of this running thread,
706724
only available for cpu and cuda backends.
@@ -837,6 +855,7 @@ def dump_compile_config() -> None:
837855
"python",
838856
"vulkan",
839857
"extension",
858+
"graph_do_while",
840859
"loop_config",
841860
"global_thread_idx",
842861
"assume_in_range",

quadrants/program/launch_context_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class LaunchContextBuilder {
151151
const StructType *args_type{nullptr};
152152
size_t result_buffer_size{0};
153153
bool use_cuda_graph{false};
154+
int graph_do_while_arg_id{-1};
155+
void *graph_do_while_flag_dev_ptr{nullptr};
154156

155157
// Note that I've tried to group `array_runtime_size` and
156158
// `is_device_allocations` into a small struct. However, it caused some test

quadrants/python/export_lang.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,9 @@ void export_lang(py::module &m) {
667667
.def("get_struct_ret_int", &LaunchContextBuilder::get_struct_ret_int)
668668
.def("get_struct_ret_uint", &LaunchContextBuilder::get_struct_ret_uint)
669669
.def("get_struct_ret_float", &LaunchContextBuilder::get_struct_ret_float)
670-
.def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph);
670+
.def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph)
671+
.def_readwrite("graph_do_while_arg_id",
672+
&LaunchContextBuilder::graph_do_while_arg_id);
671673

672674
py::class_<Function>(m, "Function")
673675
.def("insert_scalar_param", &Function::insert_scalar_param)

quadrants/rhi/cuda/cuda_driver_functions.inc.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,18 @@ PER_CUDA_FUNCTION(import_external_semaphore, cuImportExternalSemaphore,CUexterna
7373
// Graph management
7474
PER_CUDA_FUNCTION(graph_create, cuGraphCreate, void **, uint32);
7575
PER_CUDA_FUNCTION(graph_add_kernel_node, cuGraphAddKernelNode, void **, void *, const void *, std::size_t, const void *);
76+
PER_CUDA_FUNCTION(graph_add_node, cuGraphAddNode, void **, void *, const void *, std::size_t, void *);
7677
PER_CUDA_FUNCTION(graph_instantiate, cuGraphInstantiate, void **, void *, void *, char *, std::size_t);
7778
PER_CUDA_FUNCTION(graph_launch, cuGraphLaunch, void *, void *);
7879
PER_CUDA_FUNCTION(graph_destroy, cuGraphDestroy, void *);
7980
PER_CUDA_FUNCTION(graph_exec_destroy, cuGraphExecDestroy, void *);
81+
PER_CUDA_FUNCTION(graph_conditional_handle_create, cuGraphConditionalHandleCreate, void *, void *, void *, uint32, uint32);
82+
83+
// JIT linker (for loading condition kernel with cudadevrt)
84+
PER_CUDA_FUNCTION(link_create, cuLinkCreate_v2, uint32, void *, void *, void **);
85+
PER_CUDA_FUNCTION(link_add_data, cuLinkAddData_v2, void *, uint32, void *, std::size_t, const char *, uint32, void *, void *);
86+
PER_CUDA_FUNCTION(link_add_file, cuLinkAddFile_v2, void *, uint32, const char *, uint32, void *, void *);
87+
PER_CUDA_FUNCTION(link_complete, cuLinkComplete, void *, void **, std::size_t *);
88+
PER_CUDA_FUNCTION(link_destroy, cuLinkDestroy, void *);
89+
PER_CUDA_FUNCTION(module_load_data, cuModuleLoadData, void **, const void *);
8090
// clang-format on

quadrants/runtime/amdgpu/kernel_launcher.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
110110

111111
AMDGPUContext::get_instance().push_back_kernel_arg_pointer(context_pointer);
112112

113+
QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0,
114+
"graph_do_while is only supported on the CUDA backend");
115+
113116
for (auto &task : offloaded_tasks) {
114117
QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
115118
task.block_dim);

0 commit comments

Comments
 (0)