Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
194 commits
Select commit Hold shift + click to select a range
61b5b36
Add CUDA graph MVP for multi-task kernels
hughperkins Mar 1, 2026
49ce3c1
bug fixes for cuda graph
hughperkins Mar 1, 2026
9c32a28
Add per-kernel @qd.kernel(cuda_graph=True) API
hughperkins Mar 1, 2026
cffb9ae
Add cross-platform test for cuda_graph=True annotation
hughperkins Mar 1, 2026
ed1cff9
Handle argument changes in CUDA graph replay
hughperkins Mar 1, 2026
85dc8db
Fix formatting and disable cuda_graph on adjoint kernels
hughperkins Mar 11, 2026
d9ca32a
Add graph_while conditional nodes for GPU-side iteration loops
hughperkins Mar 1, 2026
d6cbd15
Fix graph_while arg_id for struct parameters and add cross-platform f…
hughperkins Mar 1, 2026
0573c12
Add static_assert on CudaGraphNodeParams size to catch ABI drift
hughperkins Mar 5, 2026
7fd81d3
Add compute capability check for graph_while (requires SM 9.0+)
hughperkins Mar 5, 2026
9c75cee
Use CUDA_HOME/CUDA_PATH env vars to find libcudadevrt.a
hughperkins Mar 5, 2026
7f80b72
Restore documentation comments removed during cuda-graph refactor
hughperkins Mar 5, 2026
7762fd9
Add CUDA graph documentation and do-while semantics warning
hughperkins Mar 5, 2026
47d59dc
Apply clang-format to kernel_launcher.h static_assert
hughperkins Mar 5, 2026
ad4eab6
Fix lint: formatting (black, clang-format, ruff)
hughperkins Mar 11, 2026
e00fc15
Fix clang-format whitespace in kernel_launcher.cpp
hughperkins Mar 11, 2026
9bcc487
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 11, 2026
0031619
Reject cuda_graph=True on kernels with struct return values
hughperkins Mar 11, 2026
792ff34
Add test for cuda_graph with different-sized arrays
hughperkins Mar 11, 2026
334c2e8
Restore comments removed during cuda graph refactor
hughperkins Mar 11, 2026
8f56ffd
Add test for cuda_graph after qd.reset()
hughperkins Mar 11, 2026
5dd2d66
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 11, 2026
f8ff3ee
Fix graph_while cache staleness when counter ndarray changes
hughperkins Mar 11, 2026
96b43de
Validate graph_while parameter name at decoration time
hughperkins Mar 11, 2026
8caa42c
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-1-g…
hughperkins Mar 13, 2026
501362f
Add CUDA graph documentation page
hughperkins Mar 13, 2026
517d3db
Expose CUDA graph cache size for test observability
hughperkins Mar 13, 2026
da3ff27
Add get_cuda_graph_cache_used_on_last_call() for test observability
hughperkins Mar 13, 2026
a2abceb
Add cache size and cache used assertions to all CUDA graph tests
hughperkins Mar 13, 2026
720f5d8
Inline expected cache size in cross-platform test assertion
hughperkins Mar 13, 2026
f158fd4
Run all CUDA graph tests on all platforms
hughperkins Mar 13, 2026
a8e6b8f
update doc
hughperkins Mar 13, 2026
dd4f48b
Add comment documenting resolve_ctx_ndarray_ptrs contract
hughperkins Mar 13, 2026
aa08442
Add comment explaining contexts_ population in graph path
hughperkins Mar 13, 2026
98bf081
Add comment explaining single-task graph fallback guard
hughperkins Mar 13, 2026
7b18674
Add comment explaining resolve_ctx_ndarray_ptrs fallback check
hughperkins Mar 13, 2026
9907333
Add comment explaining kernelParams vs extra in graph node params
hughperkins Mar 13, 2026
6ff327e
Add comment explaining graph_exec field in CachedCudaGraph
hughperkins Mar 13, 2026
6796baf
Add comment explaining cuda_graph_cache_ key
hughperkins Mar 13, 2026
1d4ebef
Parametrize test_cuda_graph_changed_args over ndarray and field
hughperkins Mar 13, 2026
a4cfdc3
Parametrize all CUDA graph tests over ndarray and field
hughperkins Mar 13, 2026
2db1b05
Parametrize test_cuda_graph_different_sizes over ndarray and field
hughperkins Mar 13, 2026
7775a4e
Merge branch 'main' into hp/cuda-graph-mvp-1-graph-build
hughperkins Mar 13, 2026
f1d397a
Merge remote-tracking branch 'origin/hp/cuda-graph-mvp-1-graph-build'…
hughperkins Mar 13, 2026
f5ff0af
Rename cuda_graphs.md to cuda_graph.md
hughperkins Mar 13, 2026
a91878b
merge doc from pr 1
hughperkins Mar 13, 2026
cd23e79
Use index.md from cuda-graph-mvp-1 branch
hughperkins Mar 13, 2026
aad0dd9
fix up merge
hughperkins Mar 13, 2026
d559a92
Use [()] instead of [None] in CUDA graph docs
hughperkins Mar 13, 2026
712846f
ndarray vs field
hughperkins Mar 14, 2026
a14a072
Rename graph_while to graph_do_while
hughperkins Mar 14, 2026
aa4dd52
add caveats to doc
hughperkins Mar 14, 2026
44781b8
Add comments to AMDGPU graph_do_while fallback code
hughperkins Mar 14, 2026
c63e201
Remove graph_do_while fallback, require CUDA with SM 9.0+
hughperkins Mar 14, 2026
1605188
Remove cross-backend graph_do_while tests
hughperkins Mar 14, 2026
9fbb433
Allow cuda_graph=True for single-task kernels
hughperkins Mar 14, 2026
ae8d5a9
Remove test_cuda_graph_single_loop test
hughperkins Mar 14, 2026
398a101
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
aa82bcb
Fix cuda-cudart-dev package name in GPU workflow
hughperkins Mar 14, 2026
4de7a64
Fix formatting (black + clang-format)
hughperkins Mar 14, 2026
b737cce
Simplify cuda_graph doc caveats section
hughperkins Mar 14, 2026
81f580f
Fix wording: "Older CUDA GPUs"
hughperkins Mar 14, 2026
8333f14
Update kernel_impl.py docstring: non-CUDA not supported
hughperkins Mar 14, 2026
aedac28
Add comments to JIT linker function declarations
hughperkins Mar 14, 2026
a7a3ad9
Add comments to graph_do_while condition kernel PTX
hughperkins Mar 14, 2026
cc64157
Improve comments in graph_do_while condition kernel PTX
hughperkins Mar 14, 2026
e235b23
Assert no gradient pointers in cuda_graph path
hughperkins Mar 14, 2026
ce09e5e
Add /*name=*/ comment to link_add_data call
hughperkins Mar 14, 2026
ea83519
Add description comment to ensure_condition_kernel_loaded
hughperkins Mar 14, 2026
0301b76
Throw error if graph_do_while condition ndarray changes between calls
hughperkins Mar 14, 2026
3402fae
Extract add_conditional_while_node from launch_llvm_kernel_graph
hughperkins Mar 14, 2026
0603fd5
Add comments to link_state and conditional graph structure
hughperkins Mar 14, 2026
03e8142
Error instead of fallback when graph_do_while has host-resident ndarrays
hughperkins Mar 14, 2026
0c481e8
Extract add_kernel_node helper to deduplicate graph kernel node creation
hughperkins Mar 14, 2026
e223689
Add comment explaining why condition kernel must be last in body graph
hughperkins Mar 14, 2026
ff2d2ab
Add comment for conditional node in body graph
hughperkins Mar 14, 2026
b11fe26
Add comment explaining cached graph_do_while_flag_dev_ptr
hughperkins Mar 14, 2026
ee5b0b8
Extract CudaGraphManager from KernelLauncher into separate class
hughperkins Mar 14, 2026
54c8bf0
Fix awkward string literal split in QD_TRACE
hughperkins Mar 14, 2026
cf86442
Extract launch_cached_graph from try_launch
hughperkins Mar 14, 2026
ddb552e
Extract CudaGraphManager from KernelLauncher into separate class
hughperkins Mar 14, 2026
76181bf
Make on_cuda_device a free function shared by both launch paths
hughperkins Mar 14, 2026
f6d531e
Make on_cuda_device a free function shared by both launch paths
hughperkins Mar 14, 2026
3fc7a9a
Move on_cuda_device to cuda_context where it belongs
hughperkins Mar 14, 2026
683194d
Move on_cuda_device to cuda_context where it belongs
hughperkins Mar 14, 2026
c17e59c
Move on_cuda_device to runtime/cuda/cuda_utils
hughperkins Mar 14, 2026
ccf9a37
Move on_cuda_device to runtime/cuda/cuda_utils
hughperkins Mar 14, 2026
3411acb
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
2e42c12
Extract resolve_device_alloc_ptr helper to deduplicate DeviceAllocati…
hughperkins Mar 14, 2026
c94b41c
Extract resolve_device_alloc_ptr helper to deduplicate DeviceAllocati…
hughperkins Mar 14, 2026
55f03c1
Revert "Extract resolve_device_alloc_ptr helper to deduplicate Device…
hughperkins Mar 14, 2026
680c7dc
Revert "Extract resolve_device_alloc_ptr helper to deduplicate Device…
hughperkins Mar 14, 2026
4f78c21
Error on gradient pointers in cuda_graph path instead of silently res…
hughperkins Mar 14, 2026
e0200e5
Add comment explaining scalar parameter skip in resolve_ctx_ndarray_ptrs
hughperkins Mar 14, 2026
2f411a5
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
05a7e4f
Clarify that fields are template parameters and not handled here
hughperkins Mar 14, 2026
ff5d021
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
a55c234
Re-add comments lost during merge conflict resolution
hughperkins Mar 14, 2026
b73dfb8
Add comment explaining resolved_data variable
hughperkins Mar 14, 2026
34f685c
Add comment noting cache_size and used_on_last_call are for tests
hughperkins Mar 14, 2026
bf41337
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
e88fad4
Apply clang-format
hughperkins Mar 14, 2026
51f898f
Apply clang-format
hughperkins Mar 14, 2026
caefc2b
Merge hp/cuda-graph-mvp-1-graph-build into hp/cuda-graph-mvp-2-graph-…
hughperkins Mar 14, 2026
90639dc
Add comment explaining why CudaGraphNodeParams is defined locally
hughperkins Mar 14, 2026
e9d4af4
Add comment explaining CudaGraphNodeParams vs CudaKernelNodeParams
hughperkins Mar 14, 2026
359c7d8
Rename increment_loop to graph_loop in test_graph_do_while_counter
hughperkins Mar 14, 2026
99889ee
Remove unnecessary qd.sync() calls from do-while tests
hughperkins Mar 14, 2026
b57582a
Add second call with different counter to test_graph_do_while_counter
hughperkins Mar 14, 2026
844a454
Add second call to all do-while tests to verify graph reuse
hughperkins Mar 14, 2026
4a62b03
Use different values on second call in do-while tests
hughperkins Mar 14, 2026
bf9e9ba
Make threshold a runtime ndarray parameter in boolean done test
hughperkins Mar 14, 2026
ec0daca
Pass threshold as scalar int instead of ndarray
hughperkins Mar 14, 2026
5a2a41a
Remove comment
hughperkins Mar 14, 2026
2bdb112
Remove redundant test_graph_do_while_replay
hughperkins Mar 14, 2026
955413d
Simplify changed-condition-ndarray test
hughperkins Mar 14, 2026
3dfcf8a
Replace [None] with [()] in do-while tests
hughperkins Mar 14, 2026
a17f845
Add fallback behavior for graph_do_while
hughperkins Mar 14, 2026
f175378
Error instead of fallback when cuda_graph gets host-resident arrays
hughperkins Mar 14, 2026
857b59f
Merge branch 1: error on host-resident arrays in cuda_graph
hughperkins Mar 14, 2026
4d94c95
Merge branch 2: error on host-resident arrays in cuda_graph
hughperkins Mar 14, 2026
684c448
Update cuda_graph docs for cross-platform graph_do_while fallback
hughperkins Mar 14, 2026
02f959a
Revert cuda_graph docs to original version
hughperkins Mar 14, 2026
6d8bc5b
Factor out graph_do_while fallback loops into separate functions
hughperkins Mar 14, 2026
e4d224f
Extract kernel launch loop into shared function called by both paths
hughperkins Mar 14, 2026
3d997a0
Rename to launch_offloaded_tasks and launch_offloaded_tasks_with_do_w…
hughperkins Mar 14, 2026
72b052e
Replace redundant flag_dev_ptr check with QD_ASSERT
hughperkins Mar 14, 2026
a9607bd
Fix clang-format
hughperkins Mar 14, 2026
18913d1
Error instead of rebuild when graph_do_while condition ndarray changes
hughperkins Mar 14, 2026
d0372f6
Add Restrictions section for graph_do_while condition ndarray
hughperkins Mar 14, 2026
0088af7
Run graph_do_while tests on all platforms
hughperkins Mar 14, 2026
afad5ab
Use proper do-while condition instead of break
hughperkins Mar 14, 2026
4b091aa
Fix KernelHandle -> Handle in gfx kernel_launcher.cpp
hughperkins Mar 14, 2026
a710fc2
Inline grad_ptr check in resolve_ctx_ndarray_ptrs
hughperkins Mar 14, 2026
3ef5ec4
Error instead of warn when libcudadevrt.a is not found
hughperkins Mar 14, 2026
b222dd1
Align autograd check and libcudadevrt error message with branch 3
hughperkins Mar 14, 2026
339b084
Reorder use_graph_do_while declaration to match branch 3
hughperkins Mar 14, 2026
dd480d9
Fix clang-format indentation in QD_ERROR_IF
hughperkins Mar 14, 2026
7206fcd
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
779b660
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
794f723
Fix macro parse error: avoid brace-init-list inside QD_ERROR_IF
hughperkins Mar 14, 2026
7964d82
Fix macro parse error: avoid brace-init-list inside QD_ERROR_IF
hughperkins Mar 14, 2026
11a2206
Fix gfx graph_do_while: use readback_data instead of map
hughperkins Mar 14, 2026
d2563b9
Skip graph_do_while tests on SM < 90
hughperkins Mar 14, 2026
0c33d05
Revert "Skip graph_do_while tests on SM < 90"
hughperkins Mar 14, 2026
a5ebc33
xfail graph_do_while tests on SM < 90
hughperkins Mar 14, 2026
c7ae995
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
0b3fdc3
xfail graph_do_while tests on CUDA SM < 90
hughperkins Mar 14, 2026
32b1341
Add num_offloaded_tasks query for compiled kernels
hughperkins Mar 14, 2026
2c1464b
Expose CUDA graph node count for test assertions
hughperkins Mar 14, 2026
470b073
Add multi-func cuda graph test with 9 offloaded tasks
hughperkins Mar 14, 2026
3079912
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
39a5dc5
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
70aac93
Change graph_do_while syntax from decorator param to in-kernel while …
hughperkins Mar 14, 2026
791bab4
Remove implicit cuda_graph=True note from docs
hughperkins Mar 14, 2026
ee9178c
Require cuda_graph=True for graph_do_while instead of implicitly enab…
hughperkins Mar 14, 2026
6623375
Update graph_do_while docstring to reflect SM 9.0+ only support
hughperkins Mar 14, 2026
b2d375e
Add tests for graph_do_while syntax errors
hughperkins Mar 14, 2026
6748e17
Apply black formatting to ast_transformer.py
hughperkins Mar 14, 2026
2996cb9
Fix offloaded tasks assertions to use >= for x64 ndarray compatibility
hughperkins Mar 14, 2026
e6d5adc
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
e5f39e5
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
9932ca2
Update graph_do_while docstring to reflect host-side fallback support
hughperkins Mar 14, 2026
05cb7ec
Error instead of warn when condition kernel is unavailable
hughperkins Mar 14, 2026
b39c3a9
Fix cuda graph tests: derive expected node count from offloaded tasks
hughperkins Mar 14, 2026
d750b1d
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
c278455
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 14, 2026
3497560
Add graph_do_while to public API test list
hughperkins Mar 15, 2026
2cc4fd0
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 15, 2026
31b73a6
Merge branch 'main' into hp/cuda-graph-mvp-1-graph-build
hughperkins Mar 15, 2026
34de2b4
Merge branch 'hp/cuda-graph-mvp-1-graph-build' into hp/cuda-graph-mvp…
hughperkins Mar 15, 2026
e9e5235
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 15, 2026
a32efba
Merge origin/main into hp/cuda-graph-mvp-2-graph-while
hughperkins Mar 16, 2026
000dda9
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 16, 2026
1bdd202
Fix end-of-file newline in env.sh
hughperkins Mar 16, 2026
aa0cad8
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 16, 2026
9de68d5
Allow swapping graph_do_while counter ndarray between calls
hughperkins Mar 16, 2026
6418b6f
Refactor CachedCudaGraph to use RAII for device memory
hughperkins Mar 16, 2026
f9fbf19
Revert "Refactor CachedCudaGraph to use RAII for device memory"
hughperkins Mar 16, 2026
7154ca2
Add parameterized CachedCudaGraph constructor
hughperkins Mar 16, 2026
df0f753
Remove env.sh from git and add to .gitignore
hughperkins Mar 16, 2026
5947e54
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 16, 2026
27355d7
Merge branch 'hp/cuda-graph-mvp-3-add-fallback' into hp/cuda-graph-mv…
hughperkins Mar 16, 2026
8d10a35
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-2-g…
hughperkins Mar 16, 2026
438499a
Merge branch 'hp/cuda-graph-mvp-2-graph-while' into hp/cuda-graph-mvp…
hughperkins Mar 16, 2026
c3d9cb8
Merge branch 'hp/cuda-graph-mvp-3-add-fallback' into hp/cuda-graph-mv…
hughperkins Mar 16, 2026
2a0dc75
Improve test docstrings for counter ndarray swap tests
hughperkins Mar 16, 2026
0f82db8
Add total_builds counter to verify graph reuse in tests
hughperkins Mar 16, 2026
cd7ef61
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-3-a…
hughperkins Mar 16, 2026
19099e6
Merge branch 'hp/cuda-graph-mvp-3-add-fallback' into hp/cuda-graph-mv…
hughperkins Mar 16, 2026
5c6fe38
Fix stale conflict marker in cuda_graph_manager.cpp
hughperkins Mar 16, 2026
4dc369e
Merge branch 'hp/cuda-graph-mvp-3-add-fallback' into hp/cuda-graph-mv…
hughperkins Mar 16, 2026
8663cc6
Merge remote-tracking branch 'origin/main' into hp/cuda-graph-mvp-4-h…
hughperkins Mar 16, 2026
c8dce56
Fix black formatting: remove extra blank line
hughperkins Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions quadrants/program/kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class KernelLauncher {
return 0;
}

virtual std::size_t get_cuda_graph_total_builds() const {
return 0;
}

virtual ~KernelLauncher() = default;
};

Expand Down
4 changes: 4 additions & 0 deletions quadrants/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ class QD_DLL_EXPORT Program {
.get_cuda_graph_num_nodes_on_last_call();
}

std::size_t get_cuda_graph_total_builds() {
return program_impl_->get_kernel_launcher().get_cuda_graph_total_builds();
}

DeviceCapabilityConfig get_device_caps() {
return program_impl_->get_device_caps();
}
Expand Down
4 changes: 3 additions & 1 deletion quadrants/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@ void export_lang(py::module &m) {
.def("get_num_offloaded_tasks_on_last_call",
&Program::get_num_offloaded_tasks_on_last_call)
.def("get_cuda_graph_num_nodes_on_last_call",
&Program::get_cuda_graph_num_nodes_on_last_call);
&Program::get_cuda_graph_num_nodes_on_last_call)
.def("get_cuda_graph_total_builds",
&Program::get_cuda_graph_total_builds);

py::class_<CompileResult>(m, "CompileResult")
.def_property_readonly(
Expand Down
101 changes: 64 additions & 37 deletions quadrants/runtime/cuda/cuda_graph_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ namespace quadrants::lang {
namespace cuda {

// Condition kernel for graph_do_while. Reads the user's i32 loop-control flag
// from GPU memory and tells the CUDA graph's conditional while node whether to
// run another iteration — all without returning to the host.
// from GPU memory via an indirection slot, and tells the CUDA graph's
// conditional while node whether to run another iteration.
//
// The indirection allows swapping the counter ndarray between calls without
// rebuilding the graph: the slot's address is baked into the graph, but the
// pointer it contains can be updated via memcpy before each launch.
//
// Parameters:
// param_0: conditional node handle (passed to cudaGraphSetConditional)
// param_1: pointer to the user's qd.i32 flag ndarray on the GPU
// param_1: pointer to a device-side slot (void**) that holds the address
// of the user's qd.i32 flag ndarray
//
// Compiled from CUDA C with: nvcc -ptx -arch=sm_90 -rdc=true
// Requires SM 9.0+ (Hopper) for cudaGraphSetConditional / conditional nodes.
Expand All @@ -37,25 +42,28 @@ static const char *kConditionKernelPTX = R"PTX(

// Entry point: called by the CUDA graph's conditional while node each iteration.
// param_0 (u64): conditional node handle
// param_1 (u64): pointer to the user's qd.i32 flag in GPU global memory
// param_1 (u64): pointer to device-side slot holding address of user's i32 flag
.visible .entry _qd_graph_do_while_cond(
.param .u64 _qd_graph_do_while_cond_param_0,
.param .u64 _qd_graph_do_while_cond_param_1
)
{
.reg .pred %p<2>;
.reg .b32 %r<3>;
.reg .b64 %rd<4>;
.reg .b64 %rd<5>;

// Load the two kernel parameters into registers:
// %rd1 = conditional node handle
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets get these commetns back

// %rd2 = pointer to user's i32 flag
// %rd2 = pointer to device-side indirection slot
ld.param.u64 %rd1, [_qd_graph_do_while_cond_param_0];
ld.param.u64 %rd2, [_qd_graph_do_while_cond_param_1];

// Convert generic pointer to global address space, then read the flag value
// Dereference the indirection slot to get the actual flag pointer
cvta.to.global.u64 %rd3, %rd2;
ld.global.u32 %r1, [%rd3];
ld.global.u64 %rd4, [%rd3];

// Read the flag value from the actual counter ndarray
ld.global.u32 %r1, [%rd4];

// Convert flag to boolean: %r2 = (flag != 0) ? 1 : 0
setp.ne.s32 %p1, %r1, 0;
Expand All @@ -75,6 +83,30 @@ static const char *kConditionKernelPTX = R"PTX(
}
)PTX";

CachedCudaGraph::CachedCudaGraph(std::size_t arg_buf_size,
std::size_t result_buf_size,
bool needs_counter_ptr_slot,
LlvmRuntimeExecutor *executor)
: arg_buffer_size(arg_buf_size), result_buffer_size(result_buf_size) {
CUDADriver::get_instance().malloc(
(void **)&persistent_device_result_buffer,
std::max(result_buffer_size, sizeof(uint64)));

if (arg_buffer_size > 0) {
CUDADriver::get_instance().malloc((void **)&persistent_device_arg_buffer,
arg_buffer_size);
}

if (needs_counter_ptr_slot) {
CUDADriver::get_instance().malloc(&counter_ptr_slot, sizeof(void *));
}

persistent_ctx.runtime = executor->get_llvm_runtime();
persistent_ctx.arg_buffer = persistent_device_arg_buffer;
persistent_ctx.result_buffer = (uint64 *)persistent_device_result_buffer;
persistent_ctx.cpu_thread_id = 0;
}

CachedCudaGraph::~CachedCudaGraph() {
if (graph_exec) {
CUDADriver::get_instance().graph_exec_destroy(graph_exec);
Expand All @@ -85,6 +117,9 @@ CachedCudaGraph::~CachedCudaGraph() {
if (persistent_device_result_buffer) {
CUDADriver::get_instance().mem_free(persistent_device_result_buffer);
}
if (counter_ptr_slot) {
CUDADriver::get_instance().mem_free(counter_ptr_slot);
}
}

CachedCudaGraph::CachedCudaGraph(CachedCudaGraph &&other) noexcept
Expand All @@ -94,11 +129,12 @@ CachedCudaGraph::CachedCudaGraph(CachedCudaGraph &&other) noexcept
persistent_ctx(other.persistent_ctx),
arg_buffer_size(other.arg_buffer_size),
result_buffer_size(other.result_buffer_size),
graph_do_while_flag_dev_ptr(other.graph_do_while_flag_dev_ptr),
counter_ptr_slot(other.counter_ptr_slot),
num_nodes(other.num_nodes) {
other.graph_exec = nullptr;
other.persistent_device_arg_buffer = nullptr;
other.persistent_device_result_buffer = nullptr;
other.counter_ptr_slot = nullptr;
}

CachedCudaGraph &CachedCudaGraph::operator=(CachedCudaGraph &&other) noexcept {
Expand All @@ -116,12 +152,13 @@ CachedCudaGraph &CachedCudaGraph::operator=(CachedCudaGraph &&other) noexcept {
persistent_ctx = other.persistent_ctx;
arg_buffer_size = other.arg_buffer_size;
result_buffer_size = other.result_buffer_size;
graph_do_while_flag_dev_ptr = other.graph_do_while_flag_dev_ptr;
counter_ptr_slot = other.counter_ptr_slot;
num_nodes = other.num_nodes;

other.graph_exec = nullptr;
other.persistent_device_arg_buffer = nullptr;
other.persistent_device_result_buffer = nullptr;
other.counter_ptr_slot = nullptr;
}
return *this;
}
Expand Down Expand Up @@ -314,11 +351,13 @@ void *CudaGraphManager::add_conditional_while_node(
bool CudaGraphManager::launch_cached_graph(CachedCudaGraph &cached,
LaunchContextBuilder &ctx,
bool use_graph_do_while) {
QD_ERROR_IF(
use_graph_do_while &&
cached.graph_do_while_flag_dev_ptr != ctx.graph_do_while_flag_dev_ptr,
"graph_do_while condition ndarray changed between calls. "
"Reuse the same ndarray for the condition parameter across calls.");
// TODO: these two memcpy_host_to_device calls could be async
// (cuMemcpyHtoDAsync) on the launch stream for better CPU-GPU overlap.
if (use_graph_do_while && cached.counter_ptr_slot) {
void *flag_ptr = ctx.graph_do_while_flag_dev_ptr;
CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot,
&flag_ptr, sizeof(void *));
}

if (ctx.arg_buffer_size > 0) {
CUDADriver::get_instance().memcpy_host_to_device(
Expand Down Expand Up @@ -358,30 +397,15 @@ bool CudaGraphManager::try_launch(

CUDAContext::get_instance().make_current();

CachedCudaGraph cached;

// --- Allocate persistent buffers ---
cached.result_buffer_size = std::max(ctx.result_buffer_size, sizeof(uint64));
CUDADriver::get_instance().malloc(
(void **)&cached.persistent_device_result_buffer,
cached.result_buffer_size);
CachedCudaGraph cached(ctx.arg_buffer_size, ctx.result_buffer_size,
use_graph_do_while, executor);

cached.arg_buffer_size = ctx.arg_buffer_size;
if (cached.arg_buffer_size > 0) {
CUDADriver::get_instance().malloc(
(void **)&cached.persistent_device_arg_buffer, cached.arg_buffer_size);
CUDADriver::get_instance().memcpy_host_to_device(
cached.persistent_device_arg_buffer, ctx.get_context().arg_buffer,
cached.arg_buffer_size);
}

// --- Build persistent RuntimeContext ---
cached.persistent_ctx.runtime = executor->get_llvm_runtime();
cached.persistent_ctx.arg_buffer = cached.persistent_device_arg_buffer;
cached.persistent_ctx.result_buffer =
(uint64 *)cached.persistent_device_result_buffer;
cached.persistent_ctx.cpu_thread_id = 0;

// --- Build CUDA graph ---
void *graph = nullptr;
CUDADriver::get_instance().graph_create(&graph, 0);
Expand Down Expand Up @@ -424,8 +448,14 @@ bool CudaGraphManager::try_launch(
if (use_graph_do_while) {
QD_ASSERT(ctx.graph_do_while_flag_dev_ptr);

// Write the initial counter address into the persistent indirection slot
// (allocated by the constructor). The condition kernel reads through this
// slot, so swapping the counter ndarray later only requires updating it.
void *flag_ptr = ctx.graph_do_while_flag_dev_ptr;
void *cond_args[2] = {&cond_handle, &flag_ptr};
CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot,
&flag_ptr, sizeof(void *));

void *cond_args[2] = {&cond_handle, &cached.counter_ptr_slot};

add_kernel_node(kernel_target_graph, prev_node, cond_kernel_func_, 1, 1, 0,
cond_args);
Expand All @@ -446,11 +476,8 @@ bool CudaGraphManager::try_launch(
cached.num_nodes, launch_id,
use_graph_do_while ? " (with graph_do_while)" : "");

if (use_graph_do_while) {
cached.graph_do_while_flag_dev_ptr = ctx.graph_do_while_flag_dev_ptr;
}

num_nodes_on_last_call_ = cached.num_nodes;
++total_builds_;
cache_.emplace(launch_id, std::move(cached));
used_on_last_call_ = true;
return true;
Expand Down
15 changes: 13 additions & 2 deletions quadrants/runtime/cuda/cuda_graph_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ struct CachedCudaGraph {
RuntimeContext persistent_ctx{};
std::size_t arg_buffer_size{0};
std::size_t result_buffer_size{0};
void *graph_do_while_flag_dev_ptr{nullptr};
// Device-side pointer slot for graph_do_while indirection. Holds the address
// of the user's counter ndarray. The condition kernel reads through this
// slot, allowing the counter ndarray to change between calls without
// rebuilding.
void *counter_ptr_slot{nullptr};
std::size_t num_nodes{0};

CachedCudaGraph() = default;
CachedCudaGraph(std::size_t arg_buffer_size,
std::size_t result_buffer_size,
bool needs_counter_ptr_slot,
LlvmRuntimeExecutor *executor);
~CachedCudaGraph();
CachedCudaGraph(const CachedCudaGraph &) = delete;
CachedCudaGraph &operator=(const CachedCudaGraph &) = delete;
Expand Down Expand Up @@ -100,6 +107,9 @@ class CudaGraphManager {
std::size_t num_nodes_on_last_call() const {
return num_nodes_on_last_call_;
}
std::size_t total_builds() const {
return total_builds_;
}

private:
bool launch_cached_graph(CachedCudaGraph &cached,
Expand All @@ -125,6 +135,7 @@ class CudaGraphManager {
std::unordered_map<int, CachedCudaGraph> cache_;
bool used_on_last_call_{false};
std::size_t num_nodes_on_last_call_{0};
std::size_t total_builds_{0};

// JIT-compiled condition kernel for graph_do_while conditional nodes
void *cond_kernel_module_{nullptr}; // CUmodule
Expand Down
3 changes: 3 additions & 0 deletions quadrants/runtime/cuda/kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class KernelLauncher : public LLVM::KernelLauncher {
std::size_t get_cuda_graph_num_nodes_on_last_call() const override {
return graph_manager_.num_nodes_on_last_call();
}
std::size_t get_cuda_graph_total_builds() const override {
return graph_manager_.total_builds();
}

private:
void launch_offloaded_tasks(
Expand Down
86 changes: 79 additions & 7 deletions tests/python/test_cuda_graph_do_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def _cuda_graph_used():
return impl.get_runtime().prog.get_cuda_graph_cache_used_on_last_call()


def _cuda_graph_total_builds():
return impl.get_runtime().prog.get_cuda_graph_total_builds()


def _on_cuda():
return impl.current_cfg().arch == qd.cuda

Expand Down Expand Up @@ -160,10 +164,18 @@ def multi_loop(
np.testing.assert_allclose(y.to_numpy(), np.full(N, 10.0))


@test_utils.test(arch=[qd.cuda])
def test_graph_do_while_changed_condition_ndarray_raises():
"""Passing a different ndarray for the condition parameter should raise."""
@test_utils.test()
def test_graph_do_while_swap_counter_ndarray():
"""Swapping the counter ndarray between calls should work correctly.

Creates one counter c1, runs the kernel with counter=3, verifies x is all
3s. Then creates a new ndarray c2 (different device pointer), runs the same
kernel with counter=7, verifies x is all 7s. Confirms cache size stays 1 --
the graph wasn't rebuilt, it just updated the indirection slot with c2's
pointer.
"""
_xfail_if_cuda_without_hopper()
N = 32

@qd.kernel(cuda_graph=True)
def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)):
Expand All @@ -173,15 +185,75 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)):
for i in range(1):
c[()] = c[()] - 1

x = qd.ndarray(qd.i32, shape=(4,))
x = qd.ndarray(qd.i32, shape=(N,))
c1 = qd.ndarray(qd.i32, shape=())
c1.from_numpy(np.array(1, dtype=np.int32))

x.from_numpy(np.zeros(N, dtype=np.int32))
c1.from_numpy(np.array(3, dtype=np.int32))
k(x, c1)
if _on_cuda():
assert _cuda_graph_used()
assert _cuda_graph_cache_size() == 1
assert c1.to_numpy() == 0
np.testing.assert_array_equal(x.to_numpy(), np.full(N, 3, dtype=np.int32))

c2 = qd.ndarray(qd.i32, shape=())
x.from_numpy(np.zeros(N, dtype=np.int32))
c2.from_numpy(np.array(7, dtype=np.int32))
k(x, c2)
if _on_cuda():
assert _cuda_graph_used()
assert _cuda_graph_cache_size() == 1
assert _cuda_graph_total_builds() == 1
assert c2.to_numpy() == 0
np.testing.assert_array_equal(x.to_numpy(), np.full(N, 7, dtype=np.int32))


@test_utils.test()
def test_graph_do_while_alternate_counter_ndarrays():
"""Alternating between two counter ndarrays should work correctly.

Creates c1 and c2 upfront, then alternates between them for 3 rounds (6
kernel calls). Each call uses a different iteration count (count and
count+10). Confirms the slot update works back and forth, not just as a
one-time swap. Cache size is checked once at the end -- still 1.
"""
_xfail_if_cuda_without_hopper()
N = 16

@qd.kernel(cuda_graph=True)
def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)):
while qd.graph_do_while(c):
for i in range(x.shape[0]):
x[i] = x[i] + 1
for i in range(1):
c[()] = c[()] - 1

x = qd.ndarray(qd.i32, shape=(N,))
c1 = qd.ndarray(qd.i32, shape=())
c2 = qd.ndarray(qd.i32, shape=())
c2.from_numpy(np.array(1, dtype=np.int32))
with pytest.raises(RuntimeError, match="condition ndarray changed"):

for iteration in range(3):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check that we arent simply rebuilding the graph each call.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added _cuda_graph_total_builds assert

count = iteration + 2
x.from_numpy(np.zeros(N, dtype=np.int32))
c1.from_numpy(np.array(count, dtype=np.int32))
k(x, c1)
if _on_cuda():
assert _cuda_graph_used()
assert c1.to_numpy() == 0
np.testing.assert_array_equal(x.to_numpy(), np.full(N, count, dtype=np.int32))

x.from_numpy(np.zeros(N, dtype=np.int32))
c2.from_numpy(np.array(count + 10, dtype=np.int32))
k(x, c2)
if _on_cuda():
assert _cuda_graph_used()
assert c2.to_numpy() == 0
np.testing.assert_array_equal(x.to_numpy(), np.full(N, count + 10, dtype=np.int32))

if _on_cuda():
assert _cuda_graph_cache_size() == 1
assert _cuda_graph_total_builds() == 1


@test_utils.test()
Expand Down
Loading