|
2 | 2 |
|
3 | 3 | CUDA graphs reduce kernel launch overhead by capturing a sequence of GPU operations into a graph, then replaying it in a single launch. On non-CUDA platforms, the cuda graph annotation is simply ignored, and code runs normally. |
4 | 4 |
|
5 | | -## Usage |
| 5 | +## Basic usage |
6 | 6 |
|
7 | 7 | Add `cuda_graph=True` to a `@qd.kernel` decorator: |
8 | 8 |
|
@@ -52,3 +52,80 @@ my_kernel(x2, y2) # replays graph with new array pointers |
52 | 52 | ### Fields as arguments |
53 | 53 |
|
54 | 54 | When different fields are passed as template arguments, each unique combination of fields produces a separately compiled kernel with its own graph cache entry. There is no interference between them. |
| 55 | + |
| 56 | + |
| 57 | +## GPU-side iteration with `graph_do_while` |
| 58 | + |
| 59 | +For iterative algorithms (physics solvers, convergence loops), you often want to repeat the kernel body until a condition is met, without returning to the host each iteration. Use `while qd.graph_do_while(flag):` inside a `cuda_graph=True` kernel: |
| 60 | + |
| 61 | +```python |
| 62 | +@qd.kernel(cuda_graph=True) |
| 63 | +def solve(x: qd.types.ndarray(qd.f32, ndim=1), |
| 64 | + counter: qd.types.ndarray(qd.i32, ndim=0)): |
| 65 | + while qd.graph_do_while(counter): |
| 66 | + for i in range(x.shape[0]): |
| 67 | + x[i] = x[i] + 1.0 |
| 68 | + for i in range(1): |
| 69 | + counter[()] = counter[()] - 1 |
| 70 | + |
| 71 | +x = qd.ndarray(qd.f32, shape=(N,)) |
| 72 | +counter = qd.ndarray(qd.i32, shape=()) |
| 73 | +counter.from_numpy(np.array(10, dtype=np.int32)) |
| 74 | +solve(x, counter) |
| 75 | +# x is now incremented 10 times; counter is 0 |
| 76 | +``` |
| 77 | + |
| 78 | +The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero. |
| 79 | + |
| 80 | +- On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement. |
| 81 | +- Older CUDA GPUs, and non-CUDA backends not currently supported. |
| 82 | + |
| 83 | +### Patterns |
| 84 | + |
| 85 | +**Counter-based**: set the counter to N, decrement each iteration. The body runs exactly N times. |
| 86 | + |
| 87 | +```python |
| 88 | +@qd.kernel(cuda_graph=True) |
| 89 | +def iterate(x: qd.types.ndarray(qd.f32, ndim=1), |
| 90 | + counter: qd.types.ndarray(qd.i32, ndim=0)): |
| 91 | + while qd.graph_do_while(counter): |
| 92 | + for i in range(x.shape[0]): |
| 93 | + x[i] = x[i] + 1.0 |
| 94 | + for i in range(1): |
| 95 | + counter[()] = counter[()] - 1 |
| 96 | +``` |
| 97 | + |
| 98 | +**Boolean flag**: set a `keep_going` flag to 1, have the kernel set it to 0 when a convergence criterion is met. |
| 99 | + |
| 100 | +```python |
| 101 | +@qd.kernel(cuda_graph=True) |
| 102 | +def converge(x: qd.types.ndarray(qd.f32, ndim=1), |
| 103 | + keep_going: qd.types.ndarray(qd.i32, ndim=0)): |
| 104 | + while qd.graph_do_while(keep_going): |
| 105 | + for i in range(x.shape[0]): |
| 106 | + # ... do work ... |
| 107 | + pass |
| 108 | + for i in range(1): |
| 109 | + if some_condition(x): |
| 110 | + keep_going[()] = 0 |
| 111 | +``` |
| 112 | + |
| 113 | +### Do-while semantics |
| 114 | + |
| 115 | +`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. |
| 116 | + |
| 117 | +### ndarray vs field |
| 118 | + |
| 119 | +The parameter used by `graph_do_while` MUST be an ndarray. |
| 120 | + |
| 121 | +However, other parameters can be any supported Quadrants kernel parameter type. |
| 122 | + |
| 123 | +### Restrictions |
| 124 | + |
| 125 | +- The same physical ndarray must be used for the counter parameter on every |
| 126 | + call. Passing a different ndarray raises an error, because the counter's |
| 127 | + device pointer is baked into the CUDA graph at creation time. |
| 128 | + |
| 129 | +### Caveats |
| 130 | + |
| 131 | +Only runs on CUDA. No fallback on non-CUDA platforms currently. |
0 commit comments