diff --git a/quadrants/program/kernel_launcher.h b/quadrants/program/kernel_launcher.h index 12f294611..1c5f59621 100644 --- a/quadrants/program/kernel_launcher.h +++ b/quadrants/program/kernel_launcher.h @@ -24,6 +24,10 @@ class KernelLauncher { return 0; } + virtual std::size_t get_cuda_graph_total_builds() const { + return 0; + } + virtual ~KernelLauncher() = default; }; diff --git a/quadrants/program/program.h b/quadrants/program/program.h index 1312f0441..8091912cb 100644 --- a/quadrants/program/program.h +++ b/quadrants/program/program.h @@ -151,6 +151,10 @@ class QD_DLL_EXPORT Program { .get_cuda_graph_num_nodes_on_last_call(); } + std::size_t get_cuda_graph_total_builds() { + return program_impl_->get_kernel_launcher().get_cuda_graph_total_builds(); + } + DeviceCapabilityConfig get_device_caps() { return program_impl_->get_device_caps(); } diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index abbba729a..64c21add1 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -502,7 +502,9 @@ void export_lang(py::module &m) { .def("get_num_offloaded_tasks_on_last_call", &Program::get_num_offloaded_tasks_on_last_call) .def("get_cuda_graph_num_nodes_on_last_call", - &Program::get_cuda_graph_num_nodes_on_last_call); + &Program::get_cuda_graph_num_nodes_on_last_call) + .def("get_cuda_graph_total_builds", + &Program::get_cuda_graph_total_builds); py::class_(m, "CompileResult") .def_property_readonly( diff --git a/quadrants/runtime/cuda/cuda_graph_manager.cpp b/quadrants/runtime/cuda/cuda_graph_manager.cpp index 203794669..61501c48b 100644 --- a/quadrants/runtime/cuda/cuda_graph_manager.cpp +++ b/quadrants/runtime/cuda/cuda_graph_manager.cpp @@ -11,12 +11,17 @@ namespace quadrants::lang { namespace cuda { // Condition kernel for graph_do_while. Reads the user's i32 loop-control flag -// from GPU memory and tells the CUDA graph's conditional while node whether to -// run another iteration — all without returning to the host. +// from GPU memory via an indirection slot, and tells the CUDA graph's +// conditional while node whether to run another iteration. +// +// The indirection allows swapping the counter ndarray between calls without +// rebuilding the graph: the slot's address is baked into the graph, but the +// pointer it contains can be updated via memcpy before each launch. // // Parameters: // param_0: conditional node handle (passed to cudaGraphSetConditional) -// param_1: pointer to the user's qd.i32 flag ndarray on the GPU +// param_1: pointer to a device-side slot (void**) that holds the address +// of the user's qd.i32 flag ndarray // // Compiled from CUDA C with: nvcc -ptx -arch=sm_90 -rdc=true // Requires SM 9.0+ (Hopper) for cudaGraphSetConditional / conditional nodes. @@ -37,7 +42,7 @@ static const char *kConditionKernelPTX = R"PTX( // Entry point: called by the CUDA graph's conditional while node each iteration. // param_0 (u64): conditional node handle -// param_1 (u64): pointer to the user's qd.i32 flag in GPU global memory +// param_1 (u64): pointer to device-side slot holding address of user's i32 flag .visible .entry _qd_graph_do_while_cond( .param .u64 _qd_graph_do_while_cond_param_0, .param .u64 _qd_graph_do_while_cond_param_1 @@ -45,17 +50,20 @@ static const char *kConditionKernelPTX = R"PTX( { .reg .pred %p<2>; .reg .b32 %r<3>; - .reg .b64 %rd<4>; + .reg .b64 %rd<5>; // Load the two kernel parameters into registers: // %rd1 = conditional node handle - // %rd2 = pointer to user's i32 flag + // %rd2 = pointer to device-side indirection slot ld.param.u64 %rd1, [_qd_graph_do_while_cond_param_0]; ld.param.u64 %rd2, [_qd_graph_do_while_cond_param_1]; - // Convert generic pointer to global address space, then read the flag value + // Dereference the indirection slot to get the actual flag pointer cvta.to.global.u64 %rd3, %rd2; - ld.global.u32 %r1, [%rd3]; + ld.global.u64 %rd4, [%rd3]; + + // Read the flag value from the actual counter ndarray + ld.global.u32 %r1, [%rd4]; // Convert flag to boolean: %r2 = (flag != 0) ? 1 : 0 setp.ne.s32 %p1, %r1, 0; @@ -75,6 +83,30 @@ static const char *kConditionKernelPTX = R"PTX( } )PTX"; +CachedCudaGraph::CachedCudaGraph(std::size_t arg_buf_size, + std::size_t result_buf_size, + bool needs_counter_ptr_slot, + LlvmRuntimeExecutor *executor) + : arg_buffer_size(arg_buf_size), result_buffer_size(result_buf_size) { + CUDADriver::get_instance().malloc( + (void **)&persistent_device_result_buffer, + std::max(result_buffer_size, sizeof(uint64))); + + if (arg_buffer_size > 0) { + CUDADriver::get_instance().malloc((void **)&persistent_device_arg_buffer, + arg_buffer_size); + } + + if (needs_counter_ptr_slot) { + CUDADriver::get_instance().malloc(&counter_ptr_slot, sizeof(void *)); + } + + persistent_ctx.runtime = executor->get_llvm_runtime(); + persistent_ctx.arg_buffer = persistent_device_arg_buffer; + persistent_ctx.result_buffer = (uint64 *)persistent_device_result_buffer; + persistent_ctx.cpu_thread_id = 0; +} + CachedCudaGraph::~CachedCudaGraph() { if (graph_exec) { CUDADriver::get_instance().graph_exec_destroy(graph_exec); @@ -85,6 +117,9 @@ CachedCudaGraph::~CachedCudaGraph() { if (persistent_device_result_buffer) { CUDADriver::get_instance().mem_free(persistent_device_result_buffer); } + if (counter_ptr_slot) { + CUDADriver::get_instance().mem_free(counter_ptr_slot); + } } CachedCudaGraph::CachedCudaGraph(CachedCudaGraph &&other) noexcept @@ -94,11 +129,12 @@ CachedCudaGraph::CachedCudaGraph(CachedCudaGraph &&other) noexcept persistent_ctx(other.persistent_ctx), arg_buffer_size(other.arg_buffer_size), result_buffer_size(other.result_buffer_size), - graph_do_while_flag_dev_ptr(other.graph_do_while_flag_dev_ptr), + counter_ptr_slot(other.counter_ptr_slot), num_nodes(other.num_nodes) { other.graph_exec = nullptr; other.persistent_device_arg_buffer = nullptr; other.persistent_device_result_buffer = nullptr; + other.counter_ptr_slot = nullptr; } CachedCudaGraph &CachedCudaGraph::operator=(CachedCudaGraph &&other) noexcept { @@ -116,12 +152,13 @@ CachedCudaGraph &CachedCudaGraph::operator=(CachedCudaGraph &&other) noexcept { persistent_ctx = other.persistent_ctx; arg_buffer_size = other.arg_buffer_size; result_buffer_size = other.result_buffer_size; - graph_do_while_flag_dev_ptr = other.graph_do_while_flag_dev_ptr; + counter_ptr_slot = other.counter_ptr_slot; num_nodes = other.num_nodes; other.graph_exec = nullptr; other.persistent_device_arg_buffer = nullptr; other.persistent_device_result_buffer = nullptr; + other.counter_ptr_slot = nullptr; } return *this; } @@ -314,11 +351,13 @@ void *CudaGraphManager::add_conditional_while_node( bool CudaGraphManager::launch_cached_graph(CachedCudaGraph &cached, LaunchContextBuilder &ctx, bool use_graph_do_while) { - QD_ERROR_IF( - use_graph_do_while && - cached.graph_do_while_flag_dev_ptr != ctx.graph_do_while_flag_dev_ptr, - "graph_do_while condition ndarray changed between calls. " - "Reuse the same ndarray for the condition parameter across calls."); + // TODO: these two memcpy_host_to_device calls could be async + // (cuMemcpyHtoDAsync) on the launch stream for better CPU-GPU overlap. + if (use_graph_do_while && cached.counter_ptr_slot) { + void *flag_ptr = ctx.graph_do_while_flag_dev_ptr; + CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot, + &flag_ptr, sizeof(void *)); + } if (ctx.arg_buffer_size > 0) { CUDADriver::get_instance().memcpy_host_to_device( @@ -358,30 +397,15 @@ bool CudaGraphManager::try_launch( CUDAContext::get_instance().make_current(); - CachedCudaGraph cached; - - // --- Allocate persistent buffers --- - cached.result_buffer_size = std::max(ctx.result_buffer_size, sizeof(uint64)); - CUDADriver::get_instance().malloc( - (void **)&cached.persistent_device_result_buffer, - cached.result_buffer_size); + CachedCudaGraph cached(ctx.arg_buffer_size, ctx.result_buffer_size, + use_graph_do_while, executor); - cached.arg_buffer_size = ctx.arg_buffer_size; if (cached.arg_buffer_size > 0) { - CUDADriver::get_instance().malloc( - (void **)&cached.persistent_device_arg_buffer, cached.arg_buffer_size); CUDADriver::get_instance().memcpy_host_to_device( cached.persistent_device_arg_buffer, ctx.get_context().arg_buffer, cached.arg_buffer_size); } - // --- Build persistent RuntimeContext --- - cached.persistent_ctx.runtime = executor->get_llvm_runtime(); - cached.persistent_ctx.arg_buffer = cached.persistent_device_arg_buffer; - cached.persistent_ctx.result_buffer = - (uint64 *)cached.persistent_device_result_buffer; - cached.persistent_ctx.cpu_thread_id = 0; - // --- Build CUDA graph --- void *graph = nullptr; CUDADriver::get_instance().graph_create(&graph, 0); @@ -424,8 +448,14 @@ bool CudaGraphManager::try_launch( if (use_graph_do_while) { QD_ASSERT(ctx.graph_do_while_flag_dev_ptr); + // Write the initial counter address into the persistent indirection slot + // (allocated by the constructor). The condition kernel reads through this + // slot, so swapping the counter ndarray later only requires updating it. void *flag_ptr = ctx.graph_do_while_flag_dev_ptr; - void *cond_args[2] = {&cond_handle, &flag_ptr}; + CUDADriver::get_instance().memcpy_host_to_device(cached.counter_ptr_slot, + &flag_ptr, sizeof(void *)); + + void *cond_args[2] = {&cond_handle, &cached.counter_ptr_slot}; add_kernel_node(kernel_target_graph, prev_node, cond_kernel_func_, 1, 1, 0, cond_args); @@ -446,11 +476,8 @@ bool CudaGraphManager::try_launch( cached.num_nodes, launch_id, use_graph_do_while ? " (with graph_do_while)" : ""); - if (use_graph_do_while) { - cached.graph_do_while_flag_dev_ptr = ctx.graph_do_while_flag_dev_ptr; - } - num_nodes_on_last_call_ = cached.num_nodes; + ++total_builds_; cache_.emplace(launch_id, std::move(cached)); used_on_last_call_ = true; return true; diff --git a/quadrants/runtime/cuda/cuda_graph_manager.h b/quadrants/runtime/cuda/cuda_graph_manager.h index 85781a5d9..341c2a3cf 100644 --- a/quadrants/runtime/cuda/cuda_graph_manager.h +++ b/quadrants/runtime/cuda/cuda_graph_manager.h @@ -60,10 +60,17 @@ struct CachedCudaGraph { RuntimeContext persistent_ctx{}; std::size_t arg_buffer_size{0}; std::size_t result_buffer_size{0}; - void *graph_do_while_flag_dev_ptr{nullptr}; + // Device-side pointer slot for graph_do_while indirection. Holds the address + // of the user's counter ndarray. The condition kernel reads through this + // slot, allowing the counter ndarray to change between calls without + // rebuilding. + void *counter_ptr_slot{nullptr}; std::size_t num_nodes{0}; - CachedCudaGraph() = default; + CachedCudaGraph(std::size_t arg_buffer_size, + std::size_t result_buffer_size, + bool needs_counter_ptr_slot, + LlvmRuntimeExecutor *executor); ~CachedCudaGraph(); CachedCudaGraph(const CachedCudaGraph &) = delete; CachedCudaGraph &operator=(const CachedCudaGraph &) = delete; @@ -100,6 +107,9 @@ class CudaGraphManager { std::size_t num_nodes_on_last_call() const { return num_nodes_on_last_call_; } + std::size_t total_builds() const { + return total_builds_; + } private: bool launch_cached_graph(CachedCudaGraph &cached, @@ -125,6 +135,7 @@ class CudaGraphManager { std::unordered_map cache_; bool used_on_last_call_{false}; std::size_t num_nodes_on_last_call_{0}; + std::size_t total_builds_{0}; // JIT-compiled condition kernel for graph_do_while conditional nodes void *cond_kernel_module_{nullptr}; // CUmodule diff --git a/quadrants/runtime/cuda/kernel_launcher.h b/quadrants/runtime/cuda/kernel_launcher.h index 8c54aaaea..669809881 100644 --- a/quadrants/runtime/cuda/kernel_launcher.h +++ b/quadrants/runtime/cuda/kernel_launcher.h @@ -34,6 +34,9 @@ class KernelLauncher : public LLVM::KernelLauncher { std::size_t get_cuda_graph_num_nodes_on_last_call() const override { return graph_manager_.num_nodes_on_last_call(); } + std::size_t get_cuda_graph_total_builds() const override { + return graph_manager_.total_builds(); + } private: void launch_offloaded_tasks( diff --git a/tests/python/test_cuda_graph_do_while.py b/tests/python/test_cuda_graph_do_while.py index 202b52600..82ec19803 100644 --- a/tests/python/test_cuda_graph_do_while.py +++ b/tests/python/test_cuda_graph_do_while.py @@ -15,6 +15,10 @@ def _cuda_graph_used(): return impl.get_runtime().prog.get_cuda_graph_cache_used_on_last_call() +def _cuda_graph_total_builds(): + return impl.get_runtime().prog.get_cuda_graph_total_builds() + + def _on_cuda(): return impl.current_cfg().arch == qd.cuda @@ -160,10 +164,18 @@ def multi_loop( np.testing.assert_allclose(y.to_numpy(), np.full(N, 10.0)) -@test_utils.test(arch=[qd.cuda]) -def test_graph_do_while_changed_condition_ndarray_raises(): - """Passing a different ndarray for the condition parameter should raise.""" +@test_utils.test() +def test_graph_do_while_swap_counter_ndarray(): + """Swapping the counter ndarray between calls should work correctly. + + Creates one counter c1, runs the kernel with counter=3, verifies x is all + 3s. Then creates a new ndarray c2 (different device pointer), runs the same + kernel with counter=7, verifies x is all 7s. Confirms cache size stays 1 -- + the graph wasn't rebuilt, it just updated the indirection slot with c2's + pointer. + """ _xfail_if_cuda_without_hopper() + N = 32 @qd.kernel(cuda_graph=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): @@ -173,15 +185,75 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): for i in range(1): c[()] = c[()] - 1 - x = qd.ndarray(qd.i32, shape=(4,)) + x = qd.ndarray(qd.i32, shape=(N,)) c1 = qd.ndarray(qd.i32, shape=()) - c1.from_numpy(np.array(1, dtype=np.int32)) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + c1.from_numpy(np.array(3, dtype=np.int32)) k(x, c1) + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + assert c1.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 3, dtype=np.int32)) + + c2 = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros(N, dtype=np.int32)) + c2.from_numpy(np.array(7, dtype=np.int32)) + k(x, c2) + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 + assert _cuda_graph_total_builds() == 1 + assert c2.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 7, dtype=np.int32)) + +@test_utils.test() +def test_graph_do_while_alternate_counter_ndarrays(): + """Alternating between two counter ndarrays should work correctly. + + Creates c1 and c2 upfront, then alternates between them for 3 rounds (6 + kernel calls). Each call uses a different iteration count (count and + count+10). Confirms the slot update works back and forth, not just as a + one-time swap. Cache size is checked once at the end -- still 1. + """ + _xfail_if_cuda_without_hopper() + N = 16 + + @qd.kernel(cuda_graph=True) + def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(c): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + for i in range(1): + c[()] = c[()] - 1 + + x = qd.ndarray(qd.i32, shape=(N,)) + c1 = qd.ndarray(qd.i32, shape=()) c2 = qd.ndarray(qd.i32, shape=()) - c2.from_numpy(np.array(1, dtype=np.int32)) - with pytest.raises(RuntimeError, match="condition ndarray changed"): + + for iteration in range(3): + count = iteration + 2 + x.from_numpy(np.zeros(N, dtype=np.int32)) + c1.from_numpy(np.array(count, dtype=np.int32)) + k(x, c1) + if _on_cuda(): + assert _cuda_graph_used() + assert c1.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, count, dtype=np.int32)) + + x.from_numpy(np.zeros(N, dtype=np.int32)) + c2.from_numpy(np.array(count + 10, dtype=np.int32)) k(x, c2) + if _on_cuda(): + assert _cuda_graph_used() + assert c2.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(N, count + 10, dtype=np.int32)) + + if _on_cuda(): + assert _cuda_graph_cache_size() == 1 + assert _cuda_graph_total_builds() == 1 @test_utils.test()