diff --git a/.gitignore b/.gitignore index 6d05d1ed4..0aaf63e31 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,4 @@ imgui.ini stubs/ CHANGELOG.md python/quadrants/_version.py +env.sh diff --git a/docs/source/user_guide/cuda_graph.md b/docs/source/user_guide/cuda_graph.md index 4855d7090..cd93aaa1a 100644 --- a/docs/source/user_guide/cuda_graph.md +++ b/docs/source/user_guide/cuda_graph.md @@ -2,7 +2,7 @@ CUDA graphs reduce kernel launch overhead by capturing a sequence of GPU operations into a graph, then replaying it in a single launch. On non-CUDA platforms, the cuda graph annotation is simply ignored, and code runs normally. -## Usage +## Basic usage Add `cuda_graph=True` to a `@qd.kernel` decorator: @@ -52,3 +52,80 @@ my_kernel(x2, y2) # replays graph with new array pointers ### Fields as arguments When different fields are passed as template arguments, each unique combination of fields produces a separately compiled kernel with its own graph cache entry. There is no interference between them. + + +## GPU-side iteration with `graph_do_while` + +For iterative algorithms (physics solvers, convergence loops), you often want to repeat the kernel body until a condition is met, without returning to the host each iteration. Use `while qd.graph_do_while(flag):` inside a `cuda_graph=True` kernel: + +```python +@qd.kernel(cuda_graph=True) +def solve(x: qd.types.ndarray(qd.f32, ndim=1), + counter: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(counter): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(1): + counter[()] = counter[()] - 1 + +x = qd.ndarray(qd.f32, shape=(N,)) +counter = qd.ndarray(qd.i32, shape=()) +counter.from_numpy(np.array(10, dtype=np.int32)) +solve(x, counter) +# x is now incremented 10 times; counter is 0 +``` + +The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero. + +- On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement. +- Older CUDA GPUs, and non-CUDA backends not currently supported. + +### Patterns + +**Counter-based**: set the counter to N, decrement each iteration. The body runs exactly N times. + +```python +@qd.kernel(cuda_graph=True) +def iterate(x: qd.types.ndarray(qd.f32, ndim=1), + counter: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(counter): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(1): + counter[()] = counter[()] - 1 +``` + +**Boolean flag**: set a `keep_going` flag to 1, have the kernel set it to 0 when a convergence criterion is met. + +```python +@qd.kernel(cuda_graph=True) +def converge(x: qd.types.ndarray(qd.f32, ndim=1), + keep_going: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(keep_going): + for i in range(x.shape[0]): + # ... do work ... + pass + for i in range(1): + if some_condition(x): + keep_going[()] = 0 +``` + +### Do-while semantics + +`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. + +### ndarray vs field + +The parameter used by `graph_do_while` MUST be an ndarray. + +However, other parameters can be any supported Quadrants kernel parameter type. + +### Restrictions + +- The same physical ndarray must be used for the counter parameter on every + call. Passing a different ndarray raises an error, because the counter's + device pointer is baked into the CUDA graph at creation time. + +### Caveats + +Only runs on CUDA. No fallback on non-CUDA platforms currently. diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 1b13ead0f..79c639761 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1193,11 +1193,41 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None: # Struct for return ASTTransformer.build_struct_for(ctx, node, is_grouped=False) + @staticmethod + def _is_graph_do_while_call(node: ast.expr) -> str | None: + """If *node* is ``qd.graph_do_while(var)`` return the arg name, else None.""" + if not isinstance(node, ast.Call): + return None + func = node.func + if isinstance(func, ast.Attribute) and func.attr == "graph_do_while": + if len(node.args) == 1 and isinstance(node.args[0], ast.Name): + return node.args[0].id + if isinstance(func, ast.Name) and func.id == "graph_do_while": + if len(node.args) == 1 and isinstance(node.args[0], ast.Name): + return node.args[0].id + return None + @staticmethod def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: if node.orelse: raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels") + graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test) + if graph_do_while_arg is not None: + kernel = ctx.global_context.current_kernel + arg_names = [m.name for m in kernel.arg_metas] + if graph_do_while_arg not in arg_names: + raise QuadrantsSyntaxError( + f"qd.graph_do_while({graph_do_while_arg!r}) does not match any " + f"parameter of kernel {kernel.func.__name__!r}. " + f"Available parameters: {arg_names}" + ) + if not kernel.use_cuda_graph: + raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(cuda_graph=True)") + kernel.graph_do_while_arg = graph_do_while_arg + build_stmts(ctx, node.body) + return None + with ctx.loop_scope_guard(): stmt_dbg_info = _qd_core.DebugInfo(ctx.get_pos_info(node)) ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 2c95dc474..d2102c354 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -292,6 +292,7 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {} self.has_print = False self.use_cuda_graph: bool = False + self.graph_do_while_arg: str | None = None self.quadrants_callable: QuadrantsCallable | None = None self.visited_functions: set[FunctionSourceInfo] = set() self.kernel_function_info: FunctionSourceInfo | None = None @@ -444,6 +445,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled template_num += 1 i_out += 1 continue + if self.graph_do_while_arg is not None and self.arg_metas[i_in].name == self.graph_do_while_arg: + self._graph_do_while_cpp_arg_id = i_out - template_num num_args_, is_launch_ctx_cacheable_ = self._recursive_set_args( self.used_py_dataclass_parameters_by_key_enforcing[key], self.arg_metas[i_in].name, @@ -505,6 +508,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled self.src_ll_cache_observations.cache_stored = True self._last_compiled_kernel_data = compiled_kernel_data launch_ctx.use_cuda_graph = self.use_cuda_graph + if self.graph_do_while_arg is not None and hasattr(self, "_graph_do_while_cpp_arg_id"): + launch_ctx.graph_do_while_arg_id = self._graph_do_while_cpp_arg_id prog.launch_kernel(compiled_kernel_data, launch_ctx) except Exception as e: e = handle_exception_from_cpp(e) diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 8002a5a34..b61eb24de 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -124,7 +124,10 @@ def _inside_class(level_of_class_stackframe: int) -> bool: def _kernel_impl( - _func: Callable, level_of_class_stackframe: int, verbose: bool = False, cuda_graph: bool = False + _func: Callable, + level_of_class_stackframe: int, + verbose: bool = False, + cuda_graph: bool = False, ) -> QuadrantsCallable: # Can decorators determine if a function is being defined inside a class? # https://stackoverflow.com/a/8793684/12003165 @@ -206,6 +209,12 @@ def kernel( Kernel's gradient kernel would be generated automatically by the AutoDiff system. + Args: + cuda_graph: If True, kernels with 2+ top-level for loops are captured + into a CUDA graph on first launch and replayed on subsequent + launches, reducing per-kernel launch overhead. Non-CUDA backends + are not supported currently. + Example:: >>> x = qd.field(qd.i32, shape=(4, 8)) diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index fa08aad77..a40c2dfe0 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -701,6 +701,24 @@ def copy(): _bit_vectorize() +def graph_do_while(condition) -> bool: + """Marks a while loop as a CUDA graph do-while conditional node. + + Used as ``while qd.graph_do_while(flag):`` inside a + ``@qd.kernel(cuda_graph=True)`` kernel. The loop body repeats while + ``flag`` (a scalar ``qd.i32`` ndarray) is non-zero. + + On SM 9.0+ (Hopper) GPUs this compiles to a native CUDA graph + conditional while node. Older CUDA GPUs and non-CUDA backends + are not currently supported. + + This function should not be called directly at runtime; it is + recognised and transformed during AST compilation. + Requires ``@qd.kernel(cuda_graph=True)``. + """ + return bool(condition) + + def global_thread_idx(): """Returns the global thread id of this running thread, only available for cpu and cuda backends. @@ -837,6 +855,7 @@ def dump_compile_config() -> None: "python", "vulkan", "extension", + "graph_do_while", "loop_config", "global_thread_idx", "assume_in_range", diff --git a/quadrants/program/launch_context_builder.h b/quadrants/program/launch_context_builder.h index 91a2590b0..d0abd782f 100644 --- a/quadrants/program/launch_context_builder.h +++ b/quadrants/program/launch_context_builder.h @@ -151,6 +151,8 @@ class LaunchContextBuilder { const StructType *args_type{nullptr}; size_t result_buffer_size{0}; bool use_cuda_graph{false}; + int graph_do_while_arg_id{-1}; + void *graph_do_while_flag_dev_ptr{nullptr}; // Note that I've tried to group `array_runtime_size` and // `is_device_allocations` into a small struct. However, it caused some test diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 88fd7a09e..abbba729a 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -667,7 +667,9 @@ void export_lang(py::module &m) { .def("get_struct_ret_int", &LaunchContextBuilder::get_struct_ret_int) .def("get_struct_ret_uint", &LaunchContextBuilder::get_struct_ret_uint) .def("get_struct_ret_float", &LaunchContextBuilder::get_struct_ret_float) - .def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph); + .def_readwrite("use_cuda_graph", &LaunchContextBuilder::use_cuda_graph) + .def_readwrite("graph_do_while_arg_id", + &LaunchContextBuilder::graph_do_while_arg_id); py::class_(m, "Function") .def("insert_scalar_param", &Function::insert_scalar_param) diff --git a/quadrants/rhi/cuda/cuda_driver_functions.inc.h b/quadrants/rhi/cuda/cuda_driver_functions.inc.h index 2da4799b9..9fe0e543d 100644 --- a/quadrants/rhi/cuda/cuda_driver_functions.inc.h +++ b/quadrants/rhi/cuda/cuda_driver_functions.inc.h @@ -73,8 +73,18 @@ PER_CUDA_FUNCTION(import_external_semaphore, cuImportExternalSemaphore,CUexterna // Graph management PER_CUDA_FUNCTION(graph_create, cuGraphCreate, void **, uint32); PER_CUDA_FUNCTION(graph_add_kernel_node, cuGraphAddKernelNode, void **, void *, const void *, std::size_t, const void *); +PER_CUDA_FUNCTION(graph_add_node, cuGraphAddNode, void **, void *, const void *, std::size_t, void *); PER_CUDA_FUNCTION(graph_instantiate, cuGraphInstantiate, void **, void *, void *, char *, std::size_t); PER_CUDA_FUNCTION(graph_launch, cuGraphLaunch, void *, void *); PER_CUDA_FUNCTION(graph_destroy, cuGraphDestroy, void *); PER_CUDA_FUNCTION(graph_exec_destroy, cuGraphExecDestroy, void *); +PER_CUDA_FUNCTION(graph_conditional_handle_create, cuGraphConditionalHandleCreate, void *, void *, void *, uint32, uint32); + +// JIT linker (for loading condition kernel with cudadevrt) +PER_CUDA_FUNCTION(link_create, cuLinkCreate_v2, uint32, void *, void *, void **); +PER_CUDA_FUNCTION(link_add_data, cuLinkAddData_v2, void *, uint32, void *, std::size_t, const char *, uint32, void *, void *); +PER_CUDA_FUNCTION(link_add_file, cuLinkAddFile_v2, void *, uint32, const char *, uint32, void *, void *); +PER_CUDA_FUNCTION(link_complete, cuLinkComplete, void *, void **, std::size_t *); +PER_CUDA_FUNCTION(link_destroy, cuLinkDestroy, void *); +PER_CUDA_FUNCTION(module_load_data, cuModuleLoadData, void **, const void *); // clang-format on diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index 6ef0b0e0e..62a64a65f 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -110,6 +110,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, AMDGPUContext::get_instance().push_back_kernel_arg_pointer(context_pointer); + QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0, + "graph_do_while is only supported on the CUDA backend"); + for (auto &task : offloaded_tasks) { QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, task.block_dim); diff --git a/quadrants/runtime/cpu/kernel_launcher.cpp b/quadrants/runtime/cpu/kernel_launcher.cpp index d7dd8df25..1f34dced1 100644 --- a/quadrants/runtime/cpu/kernel_launcher.cpp +++ b/quadrants/runtime/cpu/kernel_launcher.cpp @@ -41,6 +41,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, } } } + QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0, + "graph_do_while is only supported on the CUDA backend"); + for (auto task : launcher_ctx.task_funcs) { task(&ctx.get_context()); } diff --git a/quadrants/runtime/cuda/cuda_graph_manager.cpp b/quadrants/runtime/cuda/cuda_graph_manager.cpp index 3de86de5c..8466b52e3 100644 --- a/quadrants/runtime/cuda/cuda_graph_manager.cpp +++ b/quadrants/runtime/cuda/cuda_graph_manager.cpp @@ -2,9 +2,79 @@ #include "quadrants/runtime/cuda/cuda_utils.h" #include "quadrants/rhi/cuda/cuda_context.h" +#include +#include +#include +#include + namespace quadrants::lang { namespace cuda { +// Condition kernel for graph_do_while. Reads the user's i32 loop-control flag +// from GPU memory and tells the CUDA graph's conditional while node whether to +// run another iteration — all without returning to the host. +// +// Parameters: +// param_0: conditional node handle (passed to cudaGraphSetConditional) +// param_1: pointer to the user's qd.i32 flag ndarray on the GPU +// +// Compiled from CUDA C with: nvcc -ptx -arch=sm_90 -rdc=true +// Requires SM 9.0+ (Hopper) for cudaGraphSetConditional / conditional nodes. +// Requires JIT linking with libcudadevrt.a at runtime. +static const char *kConditionKernelPTX = R"PTX( +.version 8.8 +.target sm_90 +.address_size 64 + +// Declare the device-side cudaGraphSetConditional function (from libcudadevrt). +// Takes a conditional node handle (u64) and a boolean (u32: 1=continue, 0=stop). +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) +; + +// Entry point: called by the CUDA graph's conditional while node each iteration. +// param_0 (u64): conditional node handle +// param_1 (u64): pointer to the user's qd.i32 flag in GPU global memory +.visible .entry _qd_graph_do_while_cond( + .param .u64 _qd_graph_do_while_cond_param_0, + .param .u64 _qd_graph_do_while_cond_param_1 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<3>; + .reg .b64 %rd<4>; + + // Load the two kernel parameters into registers: + // %rd1 = conditional node handle + // %rd2 = pointer to user's i32 flag + ld.param.u64 %rd1, [_qd_graph_do_while_cond_param_0]; + ld.param.u64 %rd2, [_qd_graph_do_while_cond_param_1]; + + // Convert generic pointer to global address space, then read the flag value + cvta.to.global.u64 %rd3, %rd2; + ld.global.u32 %r1, [%rd3]; + + // Convert flag to boolean: %r2 = (flag != 0) ? 1 : 0 + setp.ne.s32 %p1, %r1, 0; + selp.u32 %r2, 1, 0, %p1; + + // Tell the conditional while node whether to loop again or stop. + // cudaGraphSetConditional(handle, should_continue) + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], %r2; + call.uni cudaGraphSetConditional, (param0, param1); + } // callseq 0 + ret; +} +)PTX"; + CachedCudaGraph::~CachedCudaGraph() { if (graph_exec) { CUDADriver::get_instance().graph_exec_destroy(graph_exec); @@ -24,6 +94,7 @@ CachedCudaGraph::CachedCudaGraph(CachedCudaGraph &&other) noexcept persistent_ctx(other.persistent_ctx), arg_buffer_size(other.arg_buffer_size), result_buffer_size(other.result_buffer_size), + graph_do_while_flag_dev_ptr(other.graph_do_while_flag_dev_ptr), num_nodes(other.num_nodes) { other.graph_exec = nullptr; other.persistent_device_arg_buffer = nullptr; @@ -45,6 +116,7 @@ CachedCudaGraph &CachedCudaGraph::operator=(CachedCudaGraph &&other) noexcept { persistent_ctx = other.persistent_ctx; arg_buffer_size = other.arg_buffer_size; result_buffer_size = other.result_buffer_size; + graph_do_while_flag_dev_ptr = other.graph_do_while_flag_dev_ptr; num_nodes = other.num_nodes; other.graph_exec = nullptr; @@ -106,11 +178,76 @@ void CudaGraphManager::resolve_ctx_ndarray_ptrs( if (resolved_data) { ctx.set_ndarray_ptrs(arg_id, (uint64)resolved_data, (uint64) nullptr); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = resolved_data; + } } } } } +// Lazily JIT-compiles and loads the graph_do_while condition kernel. +// Links the PTX (kConditionKernelPTX) with libcudadevrt.a to produce a cubin, +// then loads the _qd_graph_do_while_cond function for use in conditional +// while nodes. Only called once; subsequent calls are no-ops. +void CudaGraphManager::ensure_condition_kernel_loaded() { + if (cond_kernel_func_) + return; + + int cc = CUDAContext::get_instance().get_compute_capability(); + QD_ERROR_IF(cc < 90, + "graph_do_while requires SM 9.0+ (Hopper), but this device is " + "SM {}.", + cc); + + auto &driver = CUDADriver::get_instance(); + + std::string cudadevrt_path; + std::vector candidates; + for (const char *env_name : {"CUDA_HOME", "CUDA_PATH"}) { + if (const char *env_val = std::getenv(env_name)) { + candidates.push_back(std::string(env_val) + "/lib64/libcudadevrt.a"); + candidates.push_back(std::string(env_val) + "/lib/libcudadevrt.a"); + } + } + candidates.push_back("/usr/local/cuda/lib64/libcudadevrt.a"); + candidates.push_back("/usr/lib/x86_64-linux-gnu/libcudadevrt.a"); + for (const auto &candidate : candidates) { + if (std::filesystem::exists(candidate)) { + cudadevrt_path = candidate; + break; + } + } + QD_ERROR_IF(cudadevrt_path.empty(), + "Cannot find libcudadevrt.a — required for graph_do_while. " + "Install the CUDA toolkit and set CUDA_HOME."); + + // CUlinkState handle for the JIT linker session that combines our PTX + // with libcudadevrt.a to resolve the cudaGraphSetConditional extern. + void *link_state = nullptr; + driver.link_create(0, nullptr, nullptr, &link_state); + + std::size_t ptx_len = std::strlen(kConditionKernelPTX) + 1; + driver.link_add_data(link_state, /*CU_JIT_INPUT_PTX=*/1, + const_cast(kConditionKernelPTX), ptx_len, + /*name=*/"qd_cond", 0, nullptr, nullptr); + + driver.link_add_file(link_state, /*CU_JIT_INPUT_LIBRARY=*/4, + cudadevrt_path.c_str(), 0, nullptr, nullptr); + + void *cubin = nullptr; + std::size_t cubin_size = 0; + driver.link_complete(link_state, &cubin, &cubin_size); + + driver.module_load_data(&cond_kernel_module_, cubin); + driver.module_get_function(&cond_kernel_func_, cond_kernel_module_, + "_qd_graph_do_while_cond"); + driver.link_destroy(link_state); + + QD_TRACE("Loaded graph_do_while condition kernel ({} bytes cubin)", + cubin_size); +} + void *CudaGraphManager::add_kernel_node(void *graph, void *prev_node, void *func, @@ -137,8 +274,49 @@ void *CudaGraphManager::add_kernel_node(void *graph, return node; } +void *CudaGraphManager::add_conditional_while_node( + void *graph, + unsigned long long *cond_handle_out) { + ensure_condition_kernel_loaded(); + QD_ASSERT(cond_kernel_func_); + + void *cu_ctx = CUDAContext::get_instance().get_context(); + + CUDADriver::get_instance().graph_conditional_handle_create( + cond_handle_out, graph, cu_ctx, + /*defaultLaunchValue=*/1, + /*flags=CU_GRAPH_COND_ASSIGN_DEFAULT=*/1); + + CudaGraphNodeParams cond_node_params{}; + cond_node_params.type = 13; // CU_GRAPH_NODE_TYPE_CONDITIONAL + cond_node_params.handle = *cond_handle_out; + cond_node_params.condType = 1; // CU_GRAPH_COND_TYPE_WHILE + cond_node_params.size = 1; + cond_node_params.phGraph_out = nullptr; // CUDA will populate this + cond_node_params.ctx = cu_ctx; + + void *cond_node = nullptr; + CUDADriver::get_instance().graph_add_node(&cond_node, graph, nullptr, 0, + &cond_node_params); + + // CUDA replaces phGraph_out with a pointer to its owned array + void **body_graphs = (void **)cond_node_params.phGraph_out; + QD_ASSERT(body_graphs && body_graphs[0]); + + QD_TRACE("CUDA graph_do_while: conditional node created, body graph={}", + body_graphs[0]); + return body_graphs[0]; +} + bool CudaGraphManager::launch_cached_graph(CachedCudaGraph &cached, - LaunchContextBuilder &ctx) { + LaunchContextBuilder &ctx, + bool use_graph_do_while) { + QD_ERROR_IF( + use_graph_do_while && + cached.graph_do_while_flag_dev_ptr != ctx.graph_do_while_flag_dev_ptr, + "graph_do_while condition ndarray changed between calls. " + "Reuse the same ndarray for the condition parameter across calls."); + if (ctx.arg_buffer_size > 0) { CUDADriver::get_instance().memcpy_host_to_device( cached.persistent_device_arg_buffer, ctx.get_context().arg_buffer, @@ -162,6 +340,8 @@ bool CudaGraphManager::try_launch( return false; } + const bool use_graph_do_while = ctx.graph_do_while_arg_id >= 0; + QD_ERROR_IF(ctx.result_buffer_size > 0, "cuda_graph=True is not supported for kernels with struct return " "values; remove cuda_graph=True or avoid returning values"); @@ -170,7 +350,7 @@ bool CudaGraphManager::try_launch( auto it = cache_.find(launch_id); if (it != cache_.end()) { - return launch_cached_graph(it->second, ctx); + return launch_cached_graph(it->second, ctx, use_graph_do_while); } CUDAContext::get_instance().make_current(); @@ -203,15 +383,48 @@ bool CudaGraphManager::try_launch( void *graph = nullptr; CUDADriver::get_instance().graph_create(&graph, 0); + // Target graph for kernel nodes. Without graph_do_while, work kernels go + // directly into the top-level graph. With graph_do_while, they go into + // a body graph inside a conditional while node: + // + // Top-level graph + // └── Conditional while node (repeats while flag != 0) + // └── Body graph + // ├── Work kernel 1 + // ├── Work kernel 2 + // └── Condition kernel (reads flag, calls + // cudaGraphSetConditional) + // + // The condition kernel must be the last node in the body graph. It reads the + // flag after the work kernels have updated it, so the loop-continue decision + // reflects this iteration's result. Putting it first would cause an extra + // iteration: the condition would see the flag from before the work ran. + void *kernel_target_graph = graph; + unsigned long long cond_handle = 0; + + if (use_graph_do_while) { + kernel_target_graph = add_conditional_while_node(graph, &cond_handle); + } + void *prev_node = nullptr; for (const auto &task : offloaded_tasks) { void *ctx_ptr = &cached.persistent_ctx; prev_node = add_kernel_node( - graph, prev_node, cuda_module->lookup_function(task.name), + kernel_target_graph, prev_node, cuda_module->lookup_function(task.name), (unsigned int)task.grid_dim, (unsigned int)task.block_dim, (unsigned int)task.dynamic_shared_array_bytes, &ctx_ptr); } + if (use_graph_do_while) { + QD_ASSERT(ctx.graph_do_while_flag_dev_ptr); + + void *flag_ptr = ctx.graph_do_while_flag_dev_ptr; + void *cond_args[2] = {&cond_handle, &flag_ptr}; + + add_kernel_node(kernel_target_graph, prev_node, cond_kernel_func_, 1, 1, 0, + cond_args); + } + // --- Instantiate and launch --- CUDADriver::get_instance().graph_instantiate(&cached.graph_exec, graph, nullptr, nullptr, 0); @@ -223,8 +436,13 @@ bool CudaGraphManager::try_launch( cached.num_nodes = offloaded_tasks.size(); - QD_TRACE("CUDA graph created with {} kernel nodes for launch_id={}", - cached.num_nodes, launch_id); + QD_TRACE("CUDA graph created with {} kernel nodes for launch_id={}{}", + cached.num_nodes, launch_id, + use_graph_do_while ? " (with graph_do_while)" : ""); + + if (use_graph_do_while) { + cached.graph_do_while_flag_dev_ptr = ctx.graph_do_while_flag_dev_ptr; + } num_nodes_on_last_call_ = cached.num_nodes; cache_.emplace(launch_id, std::move(cached)); diff --git a/quadrants/runtime/cuda/cuda_graph_manager.h b/quadrants/runtime/cuda/cuda_graph_manager.h index 1df10bb7b..85781a5d9 100644 --- a/quadrants/runtime/cuda/cuda_graph_manager.h +++ b/quadrants/runtime/cuda/cuda_graph_manager.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -23,6 +24,33 @@ struct CudaKernelNodeParams { void **extra; }; +// Mirrors CUDA driver API CUgraphNodeParams / CUDA_CONDITIONAL_NODE_PARAMS. +// We define our own copy because Quadrants loads the CUDA driver dynamically +// rather than linking against it, so we don't have access to those headers. +// Field order verified against cuda-python bindings (handle, type, size, +// phGraph_out, ctx). Introduced in CUDA 12.4; layout stable through 13.2+. +// +// Used to add the conditional while node via cuGraphAddNode. Normal kernel +// nodes have a dedicated cuGraphAddKernelNode API with CudaKernelNodeParams, +// but conditional nodes use the generic cuGraphAddNode which takes this +// catch-all 256-byte union. The type field selects the variant; we only use +// the conditional node variant, so most of the bytes are padding. +struct CudaGraphNodeParams { + unsigned int type; // CU_GRAPH_NODE_TYPE_CONDITIONAL = 13 + int reserved0[3]; + // Union starts at offset 16 (232 bytes total) + unsigned long long handle; // CUgraphConditionalHandle + unsigned int condType; // CU_GRAPH_COND_TYPE_WHILE = 1 + unsigned int size; // 1 for while + void *phGraph_out; // CUgraph* output array + void *ctx; // CUcontext + char _pad[232 - 8 - 4 - 4 - 8 - 8]; + long long reserved2; +}; +static_assert( + sizeof(CudaGraphNodeParams) == 256, + "CudaGraphNodeParams layout must match CUgraphNodeParams (256 bytes)"); + struct CachedCudaGraph { // CUgraphExec handle (typed as void* since driver API is loaded dynamically). // This is the instantiated, launchable form of the captured CUDA graph. @@ -32,6 +60,7 @@ struct CachedCudaGraph { RuntimeContext persistent_ctx{}; std::size_t arg_buffer_size{0}; std::size_t result_buffer_size{0}; + void *graph_do_while_flag_dev_ptr{nullptr}; std::size_t num_nodes{0}; CachedCudaGraph() = default; @@ -47,6 +76,8 @@ class CudaGraphManager { // Attempts to launch the kernel via a cached or newly built CUDA graph. // Returns true on success; false if the graph path can't be used (e.g. // host-resident ndarrays) and the caller should fall back to normal launch. + // Internally tracks whether the graph was used, queryable via + // used_on_last_call(). bool try_launch( int launch_id, LaunchContextBuilder &ctx, @@ -71,11 +102,16 @@ class CudaGraphManager { } private: - bool launch_cached_graph(CachedCudaGraph &cached, LaunchContextBuilder &ctx); + bool launch_cached_graph(CachedCudaGraph &cached, + LaunchContextBuilder &ctx, + bool use_graph_do_while); void resolve_ctx_ndarray_ptrs( LaunchContextBuilder &ctx, const std::vector> ¶meters, LlvmRuntimeExecutor *executor); + void ensure_condition_kernel_loaded(); + void *add_conditional_while_node(void *graph, + unsigned long long *cond_handle_out); void *add_kernel_node(void *graph, void *prev_node, void *func, @@ -89,6 +125,10 @@ class CudaGraphManager { std::unordered_map cache_; bool used_on_last_call_{false}; std::size_t num_nodes_on_last_call_{0}; + + // JIT-compiled condition kernel for graph_do_while conditional nodes + void *cond_kernel_module_{nullptr}; // CUmodule + void *cond_kernel_func_{nullptr}; // CUfunction }; } // namespace cuda diff --git a/tests/python/test_api.py b/tests/python/test_api.py index cf12abc39..9b931488e 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -140,6 +140,7 @@ def _get_expected_matrix_apis(): "get_addr", "global_thread_idx", "gpu", + "graph_do_while", "grouped", "i", "i16", diff --git a/tests/python/test_cuda_graph_do_while.py b/tests/python/test_cuda_graph_do_while.py new file mode 100644 index 000000000..59cc30a29 --- /dev/null +++ b/tests/python/test_cuda_graph_do_while.py @@ -0,0 +1,211 @@ +import numpy as np +import pytest + +import quadrants as qd +from quadrants.lang import impl + +from tests import test_utils + + +def _cuda_graph_cache_size(): + return impl.get_runtime().prog.get_cuda_graph_cache_size() + + +def _cuda_graph_used(): + return impl.get_runtime().prog.get_cuda_graph_cache_used_on_last_call() + + +def _on_hopper(): + return qd.lang.impl.get_cuda_compute_capability() >= 90 + + +@test_utils.test(arch=[qd.cuda]) +def test_graph_do_while_counter(): + """Test graph_do_while with a counter that decrements each iteration.""" + if not _on_hopper(): + pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + N = 64 + + @qd.kernel(cuda_graph=True) + def graph_loop(x: qd.types.ndarray(qd.i32, ndim=1), counter: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(counter): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + for i in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.i32, shape=(N,)) + counter = qd.ndarray(qd.i32, shape=()) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + counter.from_numpy(np.array(5, dtype=np.int32)) + + graph_loop(x, counter) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert counter.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 5, dtype=np.int32)) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + counter.from_numpy(np.array(10, dtype=np.int32)) + + graph_loop(x, counter) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert counter.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 10, dtype=np.int32)) + + +@test_utils.test(arch=[qd.cuda]) +def test_graph_do_while_boolean_done(): + """Test graph_do_while with a boolean 'continue' flag (non-zero = keep going).""" + if not _on_hopper(): + pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + N = 64 + + @qd.kernel(cuda_graph=True) + def increment_until_threshold( + x: qd.types.ndarray(qd.i32, ndim=1), + threshold: qd.i32, + keep_going: qd.types.ndarray(qd.i32, ndim=0), + ): + while qd.graph_do_while(keep_going): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + for i in range(1): + if x[0] >= threshold: + keep_going[()] = 0 + + x = qd.ndarray(qd.i32, shape=(N,)) + keep_going = qd.ndarray(qd.i32, shape=()) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + keep_going.from_numpy(np.array(1, dtype=np.int32)) + + increment_until_threshold(x, 7, keep_going) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert keep_going.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 7, dtype=np.int32)) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + keep_going.from_numpy(np.array(1, dtype=np.int32)) + + increment_until_threshold(x, 12, keep_going) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert keep_going.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 12, dtype=np.int32)) + + +@test_utils.test(arch=[qd.cuda]) +def test_graph_do_while_multiple_loops(): + """Test graph_do_while with multiple top-level loops in the kernel body.""" + if not _on_hopper(): + pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + N = 32 + + @qd.kernel(cuda_graph=True) + def multi_loop( + x: qd.types.ndarray(qd.f32, ndim=1), + y: qd.types.ndarray(qd.f32, ndim=1), + counter: qd.types.ndarray(qd.i32, ndim=0), + ): + while qd.graph_do_while(counter): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(y.shape[0]): + y[i] = y[i] + 2.0 + for i in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.f32, shape=(N,)) + y = qd.ndarray(qd.f32, shape=(N,)) + counter = qd.ndarray(qd.i32, shape=()) + + x.from_numpy(np.zeros(N, dtype=np.float32)) + y.from_numpy(np.zeros(N, dtype=np.float32)) + counter.from_numpy(np.array(10, dtype=np.int32)) + + multi_loop(x, y, counter) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert counter.to_numpy() == 0 + np.testing.assert_allclose(x.to_numpy(), np.full(N, 10.0)) + np.testing.assert_allclose(y.to_numpy(), np.full(N, 20.0)) + + x.from_numpy(np.zeros(N, dtype=np.float32)) + y.from_numpy(np.zeros(N, dtype=np.float32)) + counter.from_numpy(np.array(5, dtype=np.int32)) + + multi_loop(x, y, counter) + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + + assert counter.to_numpy() == 0 + np.testing.assert_allclose(x.to_numpy(), np.full(N, 5.0)) + np.testing.assert_allclose(y.to_numpy(), np.full(N, 10.0)) + + +@test_utils.test(arch=[qd.cuda]) +def test_graph_do_while_changed_condition_ndarray_raises(): + """Passing a different ndarray for the condition parameter should raise.""" + if not _on_hopper(): + pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + + @qd.kernel(cuda_graph=True) + def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(c): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + for i in range(1): + c[()] = c[()] - 1 + + x = qd.ndarray(qd.i32, shape=(4,)) + c1 = qd.ndarray(qd.i32, shape=()) + c1.from_numpy(np.array(1, dtype=np.int32)) + k(x, c1) + + c2 = qd.ndarray(qd.i32, shape=()) + c2.from_numpy(np.array(1, dtype=np.int32)) + with pytest.raises(RuntimeError, match="condition ndarray changed"): + k(x, c2) + + +@test_utils.test() +def test_graph_do_while_without_cuda_graph_raises(): + """Using qd.graph_do_while without cuda_graph=True should raise.""" + + @qd.kernel + def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(c): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + + x = qd.ndarray(qd.i32, shape=(4,)) + c = qd.ndarray(qd.i32, shape=()) + c.from_numpy(np.array(1, dtype=np.int32)) + with pytest.raises(qd.QuadrantsSyntaxError, match="requires @qd.kernel\\(cuda_graph=True\\)"): + k(x, c) + + +@test_utils.test() +def test_graph_do_while_nonexistent_arg_raises(): + """Using a variable name that isn't a kernel parameter should raise.""" + + @qd.kernel(cuda_graph=True) + def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(nonexistent): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + + x = qd.ndarray(qd.i32, shape=(4,)) + c = qd.ndarray(qd.i32, shape=()) + c.from_numpy(np.array(1, dtype=np.int32)) + with pytest.raises(qd.QuadrantsSyntaxError, match="does not match any parameter"): + k(x, c)