Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
61b5b36
Add CUDA graph MVP for multi-task kernels
hughperkins Mar 1, 2026
49ce3c1
bug fixes for cuda graph
hughperkins Mar 1, 2026
9c32a28
Add per-kernel @qd.kernel(cuda_graph=True) API
hughperkins Mar 1, 2026
cffb9ae
Add cross-platform test for cuda_graph=True annotation
hughperkins Mar 1, 2026
ed1cff9
Handle argument changes in CUDA graph replay
hughperkins Mar 1, 2026
85dc8db
Fix formatting and disable cuda_graph on adjoint kernels
hughperkins Mar 11, 2026
d9ca32a
Add graph_while conditional nodes for GPU-side iteration loops
hughperkins Mar 1, 2026
d6cbd15
Fix graph_while arg_id for struct parameters and add cross-platform f…
hughperkins Mar 1, 2026
0573c12
Add static_assert on CudaGraphNodeParams size to catch ABI drift
hughperkins Mar 5, 2026
7fd81d3
Add compute capability check for graph_while (requires SM 9.0+)
hughperkins Mar 5, 2026
9c75cee
Use CUDA_HOME/CUDA_PATH env vars to find libcudadevrt.a
hughperkins Mar 5, 2026
7f80b72
Restore documentation comments removed during cuda-graph refactor
hughperkins Mar 5, 2026
7762fd9
Add CUDA graph documentation and do-while semantics warning
hughperkins Mar 5, 2026
47d59dc
Apply clang-format to kernel_launcher.h static_assert
hughperkins Mar 5, 2026
ad4eab6
Fix lint: formatting (black, clang-format, ruff)
hughperkins Mar 11, 2026
e00fc15
Fix clang-format whitespace in kernel_launcher.cpp
hughperkins Mar 11, 2026
9bcc487
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 11, 2026
0031619
Reject cuda_graph=True on kernels with struct return values
hughperkins Mar 11, 2026
792ff34
Add test for cuda_graph with different-sized arrays
hughperkins Mar 11, 2026
334c2e8
Restore comments removed during cuda graph refactor
hughperkins Mar 11, 2026
8f56ffd
Add test for cuda_graph after qd.reset()
hughperkins Mar 11, 2026
5dd2d66
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 11, 2026
f8ff3ee
Fix graph_while cache staleness when counter ndarray changes
hughperkins Mar 11, 2026
96b43de
Validate graph_while parameter name at decoration time
hughperkins Mar 11, 2026
8caa42c
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-1-g…
hughperkins Mar 13, 2026
501362f
Add CUDA graph documentation page
hughperkins Mar 13, 2026
517d3db
Expose CUDA graph cache size for test observability
hughperkins Mar 13, 2026
da3ff27
Add get_cuda_graph_cache_used_on_last_call() for test observability
hughperkins Mar 13, 2026
a2abceb
Add cache size and cache used assertions to all CUDA graph tests
hughperkins Mar 13, 2026
720f5d8
Inline expected cache size in cross-platform test assertion
hughperkins Mar 13, 2026
f158fd4
Run all CUDA graph tests on all platforms
hughperkins Mar 13, 2026
a8e6b8f
update doc
hughperkins Mar 13, 2026
dd4f48b
Add comment documenting resolve_ctx_ndarray_ptrs contract
hughperkins Mar 13, 2026
aa08442
Add comment explaining contexts_ population in graph path
hughperkins Mar 13, 2026
98bf081
Add comment explaining single-task graph fallback guard
hughperkins Mar 13, 2026
7b18674
Add comment explaining resolve_ctx_ndarray_ptrs fallback check
hughperkins Mar 13, 2026
9907333
Add comment explaining kernelParams vs extra in graph node params
hughperkins Mar 13, 2026
6ff327e
Add comment explaining graph_exec field in CachedCudaGraph
hughperkins Mar 13, 2026
6796baf
Add comment explaining cuda_graph_cache_ key
hughperkins Mar 13, 2026
1d4ebef
Parametrize test_cuda_graph_changed_args over ndarray and field
hughperkins Mar 13, 2026
a4cfdc3
Parametrize all CUDA graph tests over ndarray and field
hughperkins Mar 13, 2026
2db1b05
Parametrize test_cuda_graph_different_sizes over ndarray and field
hughperkins Mar 13, 2026
7775a4e
Merge branch 'main' into hp/cuda-graph-mvp-1-graph-build
hughperkins Mar 13, 2026
f1d397a
Merge remote-tracking branch 'origin/hp/cuda-graph-mvp-1-graph-build'…
hughperkins Mar 13, 2026
f5ff0af
Rename cuda_graphs.md to cuda_graph.md
hughperkins Mar 13, 2026
a91878b
merge doc from pr 1
hughperkins Mar 13, 2026
cd23e79
Use index.md from cuda-graph-mvp-1 branch
hughperkins Mar 13, 2026
aad0dd9
fix up merge
hughperkins Mar 13, 2026
d559a92
Use [()] instead of [None] in CUDA graph docs
hughperkins Mar 13, 2026
712846f
ndarray vs field
hughperkins Mar 14, 2026
a14a072
Rename graph_while to graph_do_while
hughperkins Mar 14, 2026
aa4dd52
add caveats to doc
hughperkins Mar 14, 2026
44781b8
Add comments to AMDGPU graph_do_while fallback code
hughperkins Mar 14, 2026
c63e201
Remove graph_do_while fallback, require CUDA with SM 9.0+
hughperkins Mar 14, 2026
1605188
Remove cross-backend graph_do_while tests
hughperkins Mar 14, 2026
9fbb433
Allow cuda_graph=True for single-task kernels
hughperkins Mar 14, 2026
ae8d5a9
Remove test_cuda_graph_single_loop test
hughperkins Mar 14, 2026
398a101
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
aa82bcb
Fix cuda-cudart-dev package name in GPU workflow
hughperkins Mar 14, 2026
4de7a64
Fix formatting (black + clang-format)
hughperkins Mar 14, 2026
b737cce
Simplify cuda_graph doc caveats section
hughperkins Mar 14, 2026
81f580f
Fix wording: "Older CUDA GPUs"
hughperkins Mar 14, 2026
8333f14
Update kernel_impl.py docstring: non-CUDA not supported
hughperkins Mar 14, 2026
aedac28
Add comments to JIT linker function declarations
hughperkins Mar 14, 2026
a7a3ad9
Add comments to graph_do_while condition kernel PTX
hughperkins Mar 14, 2026
cc64157
Improve comments in graph_do_while condition kernel PTX
hughperkins Mar 14, 2026
e235b23
Assert no gradient pointers in cuda_graph path
hughperkins Mar 14, 2026
ce09e5e
Add /*name=*/ comment to link_add_data call
hughperkins Mar 14, 2026
ea83519
Add description comment to ensure_condition_kernel_loaded
hughperkins Mar 14, 2026
0301b76
Throw error if graph_do_while condition ndarray changes between calls
hughperkins Mar 14, 2026
3402fae
Extract add_conditional_while_node from launch_llvm_kernel_graph
hughperkins Mar 14, 2026
0603fd5
Add comments to link_state and conditional graph structure
hughperkins Mar 14, 2026
03e8142
Error instead of fallback when graph_do_while has host-resident ndarrays
hughperkins Mar 14, 2026
0c481e8
Extract add_kernel_node helper to deduplicate graph kernel node creation
hughperkins Mar 14, 2026
e223689
Add comment explaining why condition kernel must be last in body graph
hughperkins Mar 14, 2026
ff2d2ab
Add comment for conditional node in body graph
hughperkins Mar 14, 2026
b11fe26
Add comment explaining cached graph_do_while_flag_dev_ptr
hughperkins Mar 14, 2026
ee5b0b8
Extract CudaGraphManager from KernelLauncher into separate class
hughperkins Mar 14, 2026
54c8bf0
Fix awkward string literal split in QD_TRACE
hughperkins Mar 14, 2026
cf86442
Extract launch_cached_graph from try_launch
hughperkins Mar 14, 2026
ddb552e
Extract CudaGraphManager from KernelLauncher into separate class
hughperkins Mar 14, 2026
76181bf
Make on_cuda_device a free function shared by both launch paths
hughperkins Mar 14, 2026
f6d531e
Make on_cuda_device a free function shared by both launch paths
hughperkins Mar 14, 2026
3fc7a9a
Move on_cuda_device to cuda_context where it belongs
hughperkins Mar 14, 2026
683194d
Move on_cuda_device to cuda_context where it belongs
hughperkins Mar 14, 2026
c17e59c
Move on_cuda_device to runtime/cuda/cuda_utils
hughperkins Mar 14, 2026
ccf9a37
Move on_cuda_device to runtime/cuda/cuda_utils
hughperkins Mar 14, 2026
3411acb
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
2e42c12
Extract resolve_device_alloc_ptr helper to deduplicate DeviceAllocati…
hughperkins Mar 14, 2026
c94b41c
Extract resolve_device_alloc_ptr helper to deduplicate DeviceAllocati…
hughperkins Mar 14, 2026
55f03c1
Revert "Extract resolve_device_alloc_ptr helper to deduplicate Device…
hughperkins Mar 14, 2026
680c7dc
Revert "Extract resolve_device_alloc_ptr helper to deduplicate Device…
hughperkins Mar 14, 2026
4f78c21
Error on gradient pointers in cuda_graph path instead of silently res…
hughperkins Mar 14, 2026
e0200e5
Add comment explaining scalar parameter skip in resolve_ctx_ndarray_ptrs
hughperkins Mar 14, 2026
2f411a5
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
05a7e4f
Clarify that fields are template parameters and not handled here
hughperkins Mar 14, 2026
ff5d021
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
a55c234
Re-add comments lost during merge conflict resolution
hughperkins Mar 14, 2026
b73dfb8
Add comment explaining resolved_data variable
hughperkins Mar 14, 2026
34f685c
Add comment noting cache_size and used_on_last_call are for tests
hughperkins Mar 14, 2026
bf41337
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
e88fad4
Apply clang-format
hughperkins Mar 14, 2026
51f898f
Apply clang-format
hughperkins Mar 14, 2026
caefc2b
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
90639dc
Add comment explaining why CudaGraphNodeParams is defined locally
hughperkins Mar 14, 2026
e9d4af4
Add comment explaining CudaGraphNodeParams vs CudaKernelNodeParams
hughperkins Mar 14, 2026
359c7d8
Rename increment_loop to graph_loop in test_graph_do_while_counter
hughperkins Mar 14, 2026
99889ee
Remove unnecessary qd.sync() calls from do-while tests
hughperkins Mar 14, 2026
b57582a
Add second call with different counter to test_graph_do_while_counter
hughperkins Mar 14, 2026
844a454
Add second call to all do-while tests to verify graph reuse
hughperkins Mar 14, 2026
4a62b03
Use different values on second call in do-while tests
hughperkins Mar 14, 2026
bf9e9ba
Make threshold a runtime ndarray parameter in boolean done test
hughperkins Mar 14, 2026
ec0daca
Pass threshold as scalar int instead of ndarray
hughperkins Mar 14, 2026
5a2a41a
Remove comment
hughperkins Mar 14, 2026
2bdb112
Remove redundant test_graph_do_while_replay
hughperkins Mar 14, 2026
955413d
Simplify changed-condition-ndarray test
hughperkins Mar 14, 2026
3dfcf8a
Replace [None] with [()] in do-while tests
hughperkins Mar 14, 2026
f175378
Error instead of fallback when cuda_graph gets host-resident arrays
hughperkins Mar 14, 2026
857b59f
Merge branch 1: error on host-resident arrays in cuda_graph
hughperkins Mar 14, 2026
b222dd1
Align autograd check and libcudadevrt error message with branch 3
hughperkins Mar 14, 2026
339b084
Reorder use_graph_do_while declaration to match branch 3
hughperkins Mar 14, 2026
dd480d9
Fix clang-format indentation in QD_ERROR_IF
hughperkins Mar 14, 2026
7206fcd
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
7964d82
Fix macro parse error: avoid brace-init-list inside QD_ERROR_IF
hughperkins Mar 14, 2026
d2563b9
Skip graph_do_while tests on SM < 90
hughperkins Mar 14, 2026
0c33d05
Revert "Skip graph_do_while tests on SM < 90"
hughperkins Mar 14, 2026
a5ebc33
xfail graph_do_while tests on SM < 90
hughperkins Mar 14, 2026
32b1341
Add num_offloaded_tasks query for compiled kernels
hughperkins Mar 14, 2026
2c1464b
Expose CUDA graph node count for test assertions
hughperkins Mar 14, 2026
470b073
Add multi-func cuda graph test with 9 offloaded tasks
hughperkins Mar 14, 2026
3079912
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
70aac93
Change graph_do_while syntax from decorator param to in-kernel while …
hughperkins Mar 14, 2026
791bab4
Remove implicit cuda_graph=True note from docs
hughperkins Mar 14, 2026
ee9178c
Require cuda_graph=True for graph_do_while instead of implicitly enab…
hughperkins Mar 14, 2026
6623375
Update graph_do_while docstring to reflect SM 9.0+ only support
hughperkins Mar 14, 2026
b2d375e
Add tests for graph_do_while syntax errors
hughperkins Mar 14, 2026
6748e17
Apply black formatting to ast_transformer.py
hughperkins Mar 14, 2026
2996cb9
Fix offloaded tasks assertions to use >= for x64 ndarray compatibility
hughperkins Mar 14, 2026
e6d5adc
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
b39c3a9
Fix cuda graph tests: derive expected node count from offloaded tasks
hughperkins Mar 14, 2026
d750b1d
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
3497560
Add graph_do_while to public API test list
hughperkins Mar 15, 2026
31b73a6
Merge branch 'main' into hp/cuda-graph-mvp-1-graph-build
hughperkins Mar 15, 2026
34de2b4
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 15, 2026
a32efba
Merge origin/main into hp/cuda-graph-mvp-2-graph-while
hughperkins Mar 16, 2026
1bdd202
Fix end-of-file newline in env.sh
hughperkins Mar 16, 2026
df0f753
Remove env.sh from git and add to .gitignore
hughperkins Mar 16, 2026
8d10a35
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-2-g…
hughperkins Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@ imgui.ini
stubs/
CHANGELOG.md
python/quadrants/_version.py
env.sh
79 changes: 78 additions & 1 deletion docs/source/user_guide/cuda_graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

CUDA graphs reduce kernel launch overhead by capturing a sequence of GPU operations into a graph, then replaying it in a single launch. On non-CUDA platforms, the cuda graph annotation is simply ignored, and code runs normally.

## Usage
## Basic usage

Add `cuda_graph=True` to a `@qd.kernel` decorator:

Expand Down Expand Up @@ -52,3 +52,80 @@ my_kernel(x2, y2) # replays graph with new array pointers
### Fields as arguments

When different fields are passed as template arguments, each unique combination of fields produces a separately compiled kernel with its own graph cache entry. There is no interference between them.


## GPU-side iteration with `graph_do_while`

For iterative algorithms (physics solvers, convergence loops), you often want to repeat the kernel body until a condition is met, without returning to the host each iteration. Use `while qd.graph_do_while(flag):` inside a `cuda_graph=True` kernel:

```python
@qd.kernel(cuda_graph=True)
def solve(x: qd.types.ndarray(qd.f32, ndim=1),
counter: qd.types.ndarray(qd.i32, ndim=0)):
while qd.graph_do_while(counter):
for i in range(x.shape[0]):
x[i] = x[i] + 1.0
for i in range(1):
counter[()] = counter[()] - 1

x = qd.ndarray(qd.f32, shape=(N,))
counter = qd.ndarray(qd.i32, shape=())
counter.from_numpy(np.array(10, dtype=np.int32))
solve(x, counter)
# x is now incremented 10 times; counter is 0
```

The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero.

- On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement.
- Older CUDA GPUs, and non-CUDA backends not currently supported.

### Patterns

**Counter-based**: set the counter to N, decrement each iteration. The body runs exactly N times.

```python
@qd.kernel(cuda_graph=True)
def iterate(x: qd.types.ndarray(qd.f32, ndim=1),
counter: qd.types.ndarray(qd.i32, ndim=0)):
while qd.graph_do_while(counter):
for i in range(x.shape[0]):
x[i] = x[i] + 1.0
for i in range(1):
counter[()] = counter[()] - 1
```

**Boolean flag**: set a `keep_going` flag to 1, have the kernel set it to 0 when a convergence criterion is met.

```python
@qd.kernel(cuda_graph=True)
def converge(x: qd.types.ndarray(qd.f32, ndim=1),
keep_going: qd.types.ndarray(qd.i32, ndim=0)):
while qd.graph_do_while(keep_going):
for i in range(x.shape[0]):
# ... do work ...
pass
for i in range(1):
if some_condition(x):
keep_going[()] = 0
```

### Do-while semantics

`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop.

### ndarray vs field

The parameter used by `graph_do_while` MUST be an ndarray.

However, other parameters can be any supported Quadrants kernel parameter type.

### Restrictions

- The same physical ndarray must be used for the counter parameter on every
call. Passing a different ndarray raises an error, because the counter's
device pointer is baked into the CUDA graph at creation time.

### Caveats

Only runs on CUDA. No fallback on non-CUDA platforms currently.
30 changes: 30 additions & 0 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,11 +1193,41 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None:
# Struct for
return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)

@staticmethod
def _is_graph_do_while_call(node: ast.expr) -> str | None:
"""If *node* is ``qd.graph_do_while(var)`` return the arg name, else None."""
if not isinstance(node, ast.Call):
return None
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "graph_do_while":
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
return node.args[0].id
if isinstance(func, ast.Name) and func.id == "graph_do_while":
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
return node.args[0].id
return None

@staticmethod
def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None:
if node.orelse:
raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels")

graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test)
if graph_do_while_arg is not None:
kernel = ctx.global_context.current_kernel
arg_names = [m.name for m in kernel.arg_metas]
if graph_do_while_arg not in arg_names:
raise QuadrantsSyntaxError(
f"qd.graph_do_while({graph_do_while_arg!r}) does not match any "
f"parameter of kernel {kernel.func.__name__!r}. "
f"Available parameters: {arg_names}"
)
if not kernel.use_cuda_graph:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error if not in cuda graph arlreayd

raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(cuda_graph=True)")
kernel.graph_do_while_arg = graph_do_while_arg
build_stmts(ctx, node.body)
return None

with ctx.loop_scope_guard():
stmt_dbg_info = _qd_core.DebugInfo(ctx.get_pos_info(node))
ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
Expand Down
5 changes: 5 additions & 0 deletions python/quadrants/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel
self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
self.has_print = False
self.use_cuda_graph: bool = False
self.graph_do_while_arg: str | None = None
self.quadrants_callable: QuadrantsCallable | None = None
self.visited_functions: set[FunctionSourceInfo] = set()
self.kernel_function_info: FunctionSourceInfo | None = None
Expand Down Expand Up @@ -444,6 +445,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled
template_num += 1
i_out += 1
continue
if self.graph_do_while_arg is not None and self.arg_metas[i_in].name == self.graph_do_while_arg:
self._graph_do_while_cpp_arg_id = i_out - template_num
num_args_, is_launch_ctx_cacheable_ = self._recursive_set_args(
self.used_py_dataclass_parameters_by_key_enforcing[key],
self.arg_metas[i_in].name,
Expand Down Expand Up @@ -505,6 +508,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled
self.src_ll_cache_observations.cache_stored = True
self._last_compiled_kernel_data = compiled_kernel_data
launch_ctx.use_cuda_graph = self.use_cuda_graph
if self.graph_do_while_arg is not None and hasattr(self, "_graph_do_while_cpp_arg_id"):
launch_ctx.graph_do_while_arg_id = self._graph_do_while_cpp_arg_id
prog.launch_kernel(compiled_kernel_data, launch_ctx)
except Exception as e:
e = handle_exception_from_cpp(e)
Expand Down
11 changes: 10 additions & 1 deletion python/quadrants/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _inside_class(level_of_class_stackframe: int) -> bool:


def _kernel_impl(
_func: Callable, level_of_class_stackframe: int, verbose: bool = False, cuda_graph: bool = False
_func: Callable,
level_of_class_stackframe: int,
verbose: bool = False,
cuda_graph: bool = False,
) -> QuadrantsCallable:
# Can decorators determine if a function is being defined inside a class?
# https://stackoverflow.com/a/8793684/12003165
Expand Down Expand Up @@ -206,6 +209,12 @@ def kernel(

Kernel's gradient kernel would be generated automatically by the AutoDiff system.

Args:
cuda_graph: If True, kernels with 2+ top-level for loops are captured
into a CUDA graph on first launch and replayed on subsequent
launches, reducing per-kernel launch overhead. Non-CUDA backends
are not supported currently.

Example::

>>> x = qd.field(qd.i32, shape=(4, 8))
Expand Down
19 changes: 19 additions & 0 deletions python/quadrants/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,24 @@ def copy():
_bit_vectorize()


def graph_do_while(condition) -> bool:
"""Marks a while loop as a CUDA graph do-while conditional node.

Used as ``while qd.graph_do_while(flag):`` inside a
``@qd.kernel(cuda_graph=True)`` kernel. The loop body repeats while
``flag`` (a scalar ``qd.i32`` ndarray) is non-zero.

On SM 9.0+ (Hopper) GPUs this compiles to a native CUDA graph
conditional while node. Older CUDA GPUs and non-CUDA backends
are not currently supported.

This function should not be called directly at runtime; it is
recognised and transformed during AST compilation.
Requires ``@qd.kernel(cuda_graph=True)``.
"""
return bool(condition)


def global_thread_idx():
"""Returns the global thread id of this running thread,
only available for cpu and cuda backends.
Expand Down Expand Up @@ -837,6 +855,7 @@ def dump_compile_config() -> None:
"python",
"vulkan",
"extension",
"graph_do_while",
"loop_config",
"global_thread_idx",
"assume_in_range",
Expand Down
2 changes: 2 additions & 0 deletions quadrants/program/launch_context_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class LaunchContextBuilder {
const StructType *args_type{nullptr};
size_t result_buffer_size{0};
bool use_cuda_graph{false};
int graph_do_while_arg_id{-1};
void *graph_do_while_flag_dev_ptr{nullptr};

// Note that I've tried to group `array_runtime_size` and
// `is_device_allocations` into a small struct. However, it caused some test
Expand Down
4 changes: 3 additions & 1 deletion quadrants/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,9 @@ void export_lang(py::module &m) {
.def("get_struct_ret_int", &LaunchContextBuilder::get_struct_ret_int)
.def("get_struct_ret_uint", &LaunchContextBuilder::get_struct_ret_uint)
.def("get_struct_ret_float", &LaunchContextBuilder::get_struct_ret_float)
.def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph);
.def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph)
.def_readwrite("graph_do_while_arg_id",
&LaunchContextBuilder::graph_do_while_arg_id);

py::class_<Function>(m, "Function")
.def("insert_scalar_param", &Function::insert_scalar_param)
Expand Down
10 changes: 10 additions & 0 deletions quadrants/rhi/cuda/cuda_driver_functions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,18 @@ PER_CUDA_FUNCTION(import_external_semaphore, cuImportExternalSemaphore,CUexterna
// Graph management
PER_CUDA_FUNCTION(graph_create, cuGraphCreate, void **, uint32);
PER_CUDA_FUNCTION(graph_add_kernel_node, cuGraphAddKernelNode, void **, void *, const void *, std::size_t, const void *);
PER_CUDA_FUNCTION(graph_add_node, cuGraphAddNode, void **, void *, const void *, std::size_t, void *);
PER_CUDA_FUNCTION(graph_instantiate, cuGraphInstantiate, void **, void *, void *, char *, std::size_t);
PER_CUDA_FUNCTION(graph_launch, cuGraphLaunch, void *, void *);
PER_CUDA_FUNCTION(graph_destroy, cuGraphDestroy, void *);
PER_CUDA_FUNCTION(graph_exec_destroy, cuGraphExecDestroy, void *);
PER_CUDA_FUNCTION(graph_conditional_handle_create, cuGraphConditionalHandleCreate, void *, void *, void *, uint32, uint32);

// JIT linker (for loading condition kernel with cudadevrt)
PER_CUDA_FUNCTION(link_create, cuLinkCreate_v2, uint32, void *, void *, void **);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have some comments describing eacch fo these

PER_CUDA_FUNCTION(link_add_data, cuLinkAddData_v2, void *, uint32, void *, std::size_t, const char *, uint32, void *, void *);
PER_CUDA_FUNCTION(link_add_file, cuLinkAddFile_v2, void *, uint32, const char *, uint32, void *, void *);
PER_CUDA_FUNCTION(link_complete, cuLinkComplete, void *, void **, std::size_t *);
PER_CUDA_FUNCTION(link_destroy, cuLinkDestroy, void *);
PER_CUDA_FUNCTION(module_load_data, cuModuleLoadData, void **, const void *);
// clang-format on
3 changes: 3 additions & 0 deletions quadrants/runtime/amdgpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,

AMDGPUContext::get_instance().push_back_kernel_arg_pointer(context_pointer);

QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0,
"graph_do_while is only supported on the CUDA backend");

for (auto &task : offloaded_tasks) {
QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
Expand Down
3 changes: 3 additions & 0 deletions quadrants/runtime/cpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
}
}
}
QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0,
"graph_do_while is only supported on the CUDA backend");

for (auto task : launcher_ctx.task_funcs) {
task(&ctx.get_context());
}
Expand Down
Loading
Loading