diff --git a/docs/source/user_guide/cuda_graph.md b/docs/source/user_guide/cuda_graph.md index cd93aaa1a..d32ea0baa 100644 --- a/docs/source/user_guide/cuda_graph.md +++ b/docs/source/user_guide/cuda_graph.md @@ -78,7 +78,7 @@ solve(x, counter) The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero. - On SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement. -- Older CUDA GPUs, and non-CUDA backends not currently supported. +- On older CUDA GPUs and non-CUDA backends, it falls back to a host-side do-while loop. ### Patterns @@ -128,4 +128,12 @@ However, other parameters can be any supported Quadrants kernel parameter type. ### Caveats -Only runs on CUDA. No fallback on non-CUDA platforms currently. +On currently unsupported GPU platforms, such as AMDGPU at the time of writing, the value of the `graph_do_while` parameter will be copied from the GPU to the host each iteration, in order to check whether we should continue iterating. This causes a GPU pipeline stall. At the end of each loop iteration: +- wait for GPU async queue to finish processing +- copy condition value to hostside +- evaluate condition value on hostside +- launch new kernels for next loop iteration, if not finished yet + +Therefore on unsupported platforms, you might consider creating a second implementation, which works differently. e.g.: +- fixed number of loop iterations, so no dependency on gpu data for kernel launch; combined perhaps with: +- make each kernel 'short-circuit', exit quickly, if the task has already been completed; to avoid running the GPU more than necessary diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index a40c2dfe0..91e48e407 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -709,8 +709,8 @@ def graph_do_while(condition) -> bool: ``flag`` (a scalar ``qd.i32`` ndarray) is non-zero. On SM 9.0+ (Hopper) GPUs this compiles to a native CUDA graph - conditional while node. Older CUDA GPUs and non-CUDA backends - are not currently supported. + conditional while node. On older CUDA GPUs and non-CUDA backends + it falls back to a host-side do-while loop. This function should not be called directly at runtime; it is recognised and transformed during AST compilation. diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index 62a64a65f..eac751e37 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -5,6 +5,37 @@ namespace quadrants::lang { namespace amdgpu { +void KernelLauncher::launch_offloaded_tasks( + JITModule *amdgpu_module, + const std::vector &offloaded_tasks, + void *context_pointer, + int arg_size) { + for (const auto &task : offloaded_tasks) { + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, + task.block_dim); + amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, + task.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + } +} + +void KernelLauncher::launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + JITModule *amdgpu_module, + const std::vector &offloaded_tasks, + void *context_pointer, + int arg_size) { + int32_t counter_val; + do { + launch_offloaded_tasks(amdgpu_module, offloaded_tasks, context_pointer, + arg_size); + counter_val = 0; + AMDGPUDriver::get_instance().stream_synchronize(nullptr); + AMDGPUDriver::get_instance().memcpy_device_to_host( + &counter_val, ctx.graph_do_while_flag_dev_ptr, sizeof(int32_t)); + } while (counter_val != 0); +} + bool KernelLauncher::on_amdgpu_device(void *ptr) { unsigned int attr_val[8]; // mem_get_attribute doesn't work well on ROCm @@ -74,6 +105,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, } ctx.set_ndarray_ptrs(arg_id, (uint64)device_ptrs[data_ptr_idx], (uint64)ctx.array_ptrs[grad_ptr_idx]); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = device_ptrs[data_ptr_idx]; + } } else if (arr_sz > 0) { // why use arr_sz constrain? // Ndarray DeviceAllocation *ptr = static_cast(data_ptr); @@ -82,6 +116,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.set_ndarray_ptrs(arg_id, (uint64)device_ptrs[data_ptr_idx], (uint64)ctx.array_ptrs[grad_ptr_idx]); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = device_ptrs[data_ptr_idx]; + } } } } @@ -110,15 +147,13 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, AMDGPUContext::get_instance().push_back_kernel_arg_pointer(context_pointer); - QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0, - "graph_do_while is only supported on the CUDA backend"); - - for (auto &task : offloaded_tasks) { - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, - task.block_dim); - amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, - task.dynamic_shared_array_bytes, - {(void *)&context_pointer}, {arg_size}); + if (ctx.graph_do_while_arg_id >= 0) { + QD_ASSERT(ctx.graph_do_while_flag_dev_ptr); + launch_offloaded_tasks_with_do_while(ctx, amdgpu_module, offloaded_tasks, + context_pointer, arg_size); + } else { + launch_offloaded_tasks(amdgpu_module, offloaded_tasks, context_pointer, + arg_size); } QD_TRACE("Launching kernel"); if (ctx.arg_buffer_size > 0) { diff --git a/quadrants/runtime/amdgpu/kernel_launcher.h b/quadrants/runtime/amdgpu/kernel_launcher.h index 0d3924dd3..be4ca6c25 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.h +++ b/quadrants/runtime/amdgpu/kernel_launcher.h @@ -23,6 +23,16 @@ class KernelLauncher : public LLVM::KernelLauncher { const LLVM::CompiledKernelData &compiled) override; private: + void launch_offloaded_tasks(JITModule *amdgpu_module, + const std::vector &offloaded_tasks, + void *context_pointer, + int arg_size); + void launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + JITModule *amdgpu_module, + const std::vector &offloaded_tasks, + void *context_pointer, + int arg_size); bool on_amdgpu_device(void *ptr); std::vector contexts_; }; diff --git a/quadrants/runtime/cpu/kernel_launcher.cpp b/quadrants/runtime/cpu/kernel_launcher.cpp index 1f34dced1..9499ef97f 100644 --- a/quadrants/runtime/cpu/kernel_launcher.cpp +++ b/quadrants/runtime/cpu/kernel_launcher.cpp @@ -4,6 +4,22 @@ namespace quadrants::lang { namespace cpu { +void KernelLauncher::launch_offloaded_tasks( + LaunchContextBuilder &ctx, + const std::vector &task_funcs) { + for (auto task : task_funcs) { + task(&ctx.get_context()); + } +} + +void KernelLauncher::launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + const std::vector &task_funcs) { + do { + launch_offloaded_tasks(ctx, task_funcs); + } while (*static_cast(ctx.graph_do_while_flag_dev_ptr) != 0); +} + void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx) { QD_ASSERT(handle.get_launch_id() < contexts_.size()); @@ -27,6 +43,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, if (ctx.device_allocation_type[arg_id] == LaunchContextBuilder::DevAllocType::kNone) { ctx.set_ndarray_ptrs(arg_id, (uint64)data_ptr, (uint64)grad_ptr); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = data_ptr; + } } else if (ctx.array_runtime_sizes[arg_id] > 0) { uint64 host_ptr = (uint64)executor->get_device_alloc_info_ptr( *static_cast(data_ptr)); @@ -38,14 +57,17 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, : (uint64)executor->get_device_alloc_info_ptr( *static_cast(grad_ptr)); ctx.set_ndarray_ptrs(arg_id, host_ptr, host_ptr_grad); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = (void *)host_ptr; + } } } } - QD_ERROR_IF(ctx.graph_do_while_arg_id >= 0, - "graph_do_while is only supported on the CUDA backend"); - - for (auto task : launcher_ctx.task_funcs) { - task(&ctx.get_context()); + if (ctx.graph_do_while_arg_id >= 0) { + QD_ASSERT(ctx.graph_do_while_flag_dev_ptr); + launch_offloaded_tasks_with_do_while(ctx, launcher_ctx.task_funcs); + } else { + launch_offloaded_tasks(ctx, launcher_ctx.task_funcs); } } @@ -64,8 +86,6 @@ KernelLauncher::Handle KernelLauncher::register_llvm_kernel( auto data = compiled.get_internal_data().compiled_data.clone(); auto *jit_module = executor->create_jit_module(std::move(data.module)); - // Construct task_funcs - using TaskFunc = int32 (*)(void *); std::vector task_funcs; task_funcs.reserve(data.tasks.size()); for (auto &task : data.tasks) { diff --git a/quadrants/runtime/cpu/kernel_launcher.h b/quadrants/runtime/cpu/kernel_launcher.h index c69b546f4..478219ef4 100644 --- a/quadrants/runtime/cpu/kernel_launcher.h +++ b/quadrants/runtime/cpu/kernel_launcher.h @@ -9,8 +9,9 @@ namespace cpu { class KernelLauncher : public LLVM::KernelLauncher { using Base = LLVM::KernelLauncher; + using TaskFunc = int32 (*)(void *); + struct Context { - using TaskFunc = int32 (*)(void *); std::vector task_funcs; const std::vector> *parameters; }; @@ -23,6 +24,12 @@ class KernelLauncher : public LLVM::KernelLauncher { const LLVM::CompiledKernelData &compiled) override; private: + void launch_offloaded_tasks(LaunchContextBuilder &ctx, + const std::vector &task_funcs); + void launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + const std::vector &task_funcs); + std::vector contexts_; }; diff --git a/quadrants/runtime/cuda/cuda_graph_manager.cpp b/quadrants/runtime/cuda/cuda_graph_manager.cpp index 8466b52e3..203794669 100644 --- a/quadrants/runtime/cuda/cuda_graph_manager.cpp +++ b/quadrants/runtime/cuda/cuda_graph_manager.cpp @@ -195,10 +195,13 @@ void CudaGraphManager::ensure_condition_kernel_loaded() { return; int cc = CUDAContext::get_instance().get_compute_capability(); - QD_ERROR_IF(cc < 90, - "graph_do_while requires SM 9.0+ (Hopper), but this device is " - "SM {}.", - cc); + if (cc < 90) { + QD_WARN( + "graph_do_while requires SM 9.0+ (Hopper), but this device is SM {}. " + "Falling back to non-graph path.", + cc); + return; + } auto &driver = CUDADriver::get_instance(); @@ -403,6 +406,9 @@ bool CudaGraphManager::try_launch( unsigned long long cond_handle = 0; if (use_graph_do_while) { + ensure_condition_kernel_loaded(); + QD_ERROR_IF(!cond_kernel_func_, + "Condition kernel not available; cannot build graph_do_while"); kernel_target_graph = add_conditional_while_node(graph, &cond_handle); } diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index ad19d607a..392b7ff6d 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -7,6 +7,34 @@ namespace quadrants::lang { namespace cuda { +void KernelLauncher::launch_offloaded_tasks( + LaunchContextBuilder &ctx, + JITModule *cuda_module, + const std::vector &offloaded_tasks) { + for (const auto &task : offloaded_tasks) { + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, + task.block_dim); + cuda_module->launch(task.name, task.grid_dim, task.block_dim, + task.dynamic_shared_array_bytes, {&ctx.get_context()}, + {}); + } +} + +void KernelLauncher::launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + JITModule *cuda_module, + const std::vector &offloaded_tasks) { + int32_t counter_val; + do { + launch_offloaded_tasks(ctx, cuda_module, offloaded_tasks); + counter_val = 0; + auto *stream = CUDAContext::get_instance().get_stream(); + CUDADriver::get_instance().stream_synchronize(stream); + CUDADriver::get_instance().memcpy_device_to_host( + &counter_val, ctx.graph_do_while_flag_dev_ptr, sizeof(int32_t)); + } while (counter_val != 0); +} + void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx) { QD_ASSERT(handle.get_launch_id() < contexts_.size()); @@ -107,6 +135,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.set_ndarray_ptrs(arg_id, (uint64)device_ptrs[data_ptr_idx], (uint64)device_ptrs[grad_ptr_idx]); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = device_ptrs[data_ptr_idx]; + } } else if (arr_sz > 0) { // Ndarray DeviceAllocation *ptr = static_cast(data_ptr); @@ -122,6 +153,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.set_ndarray_ptrs(arg_id, (uint64)device_ptrs[data_ptr_idx], (uint64)device_ptrs[grad_ptr_idx]); + if (arg_id == ctx.graph_do_while_arg_id) { + ctx.graph_do_while_flag_dev_ptr = device_ptrs[data_ptr_idx]; + } } } } @@ -142,12 +176,11 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.get_context().arg_buffer = device_arg_buffer; } - for (auto task : offloaded_tasks) { - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, - task.block_dim); - cuda_module->launch(task.name, task.grid_dim, task.block_dim, - task.dynamic_shared_array_bytes, {&ctx.get_context()}, - {}); + if (ctx.graph_do_while_arg_id >= 0) { + QD_ASSERT(ctx.graph_do_while_flag_dev_ptr); + launch_offloaded_tasks_with_do_while(ctx, cuda_module, offloaded_tasks); + } else { + launch_offloaded_tasks(ctx, cuda_module, offloaded_tasks); } if (ctx.arg_buffer_size > 0) { CUDADriver::get_instance().mem_free_async(device_arg_buffer, nullptr); diff --git a/quadrants/runtime/cuda/kernel_launcher.h b/quadrants/runtime/cuda/kernel_launcher.h index cd060fccc..8c54aaaea 100644 --- a/quadrants/runtime/cuda/kernel_launcher.h +++ b/quadrants/runtime/cuda/kernel_launcher.h @@ -36,6 +36,15 @@ class KernelLauncher : public LLVM::KernelLauncher { } private: + void launch_offloaded_tasks( + LaunchContextBuilder &ctx, + JITModule *cuda_module, + const std::vector &offloaded_tasks); + void launch_offloaded_tasks_with_do_while( + LaunchContextBuilder &ctx, + JITModule *cuda_module, + const std::vector &offloaded_tasks); + std::vector contexts_; CudaGraphManager graph_manager_; }; diff --git a/quadrants/runtime/gfx/kernel_launcher.cpp b/quadrants/runtime/gfx/kernel_launcher.cpp index 9b51c5c5a..63e953c3b 100644 --- a/quadrants/runtime/gfx/kernel_launcher.cpp +++ b/quadrants/runtime/gfx/kernel_launcher.cpp @@ -7,11 +7,39 @@ namespace gfx { KernelLauncher::KernelLauncher(Config config) : config_(std::move(config)) { } +void KernelLauncher::launch_offloaded_tasks_with_do_while( + Handle handle, + LaunchContextBuilder &ctx) { + const ArgArrayPtrKey key{ctx.graph_do_while_arg_id, + TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx.array_ptrs.find(key); + QD_ASSERT(it != ctx.array_ptrs.end()); + + auto *device = config_.gfx_runtime_->get_ti_device(); + DeviceAllocation alloc = *(static_cast(it->second)); + DevicePtr dev_ptr = alloc.get_ptr(0); + + int32_t flag_val; + do { + config_.gfx_runtime_->launch_kernel(handle, ctx); + config_.gfx_runtime_->synchronize(); + void *host_ptr = &flag_val; + size_t sz = sizeof(int32_t); + QD_ASSERT(device->readback_data(&dev_ptr, &host_ptr, &sz, 1) == + RhiResult::success); + } while (flag_val != 0); +} + void KernelLauncher::launch_kernel( const lang::CompiledKernelData &compiled_kernel_data, LaunchContextBuilder &ctx) { auto handle = register_kernel(compiled_kernel_data); - config_.gfx_runtime_->launch_kernel(handle, ctx); + + if (ctx.graph_do_while_arg_id >= 0) { + launch_offloaded_tasks_with_do_while(handle, ctx); + } else { + config_.gfx_runtime_->launch_kernel(handle, ctx); + } } KernelLauncher::Handle KernelLauncher::register_kernel( diff --git a/quadrants/runtime/gfx/kernel_launcher.h b/quadrants/runtime/gfx/kernel_launcher.h index 0efe47907..ca38513a0 100644 --- a/quadrants/runtime/gfx/kernel_launcher.h +++ b/quadrants/runtime/gfx/kernel_launcher.h @@ -18,6 +18,8 @@ class KernelLauncher : public lang::KernelLauncher { LaunchContextBuilder &ctx) override; private: + void launch_offloaded_tasks_with_do_while(Handle handle, + LaunchContextBuilder &ctx); Handle register_kernel(const lang::CompiledKernelData &compiled_kernel_data); Config config_; diff --git a/tests/python/test_cuda_graph_do_while.py b/tests/python/test_cuda_graph_do_while.py index 59cc30a29..202b52600 100644 --- a/tests/python/test_cuda_graph_do_while.py +++ b/tests/python/test_cuda_graph_do_while.py @@ -15,15 +15,19 @@ def _cuda_graph_used(): return impl.get_runtime().prog.get_cuda_graph_cache_used_on_last_call() -def _on_hopper(): - return qd.lang.impl.get_cuda_compute_capability() >= 90 +def _on_cuda(): + return impl.current_cfg().arch == qd.cuda -@test_utils.test(arch=[qd.cuda]) +def _xfail_if_cuda_without_hopper(): + if _on_cuda() and qd.lang.impl.get_cuda_compute_capability() < 90: + pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + + +@test_utils.test() def test_graph_do_while_counter(): """Test graph_do_while with a counter that decrements each iteration.""" - if not _on_hopper(): - pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + _xfail_if_cuda_without_hopper() N = 64 @qd.kernel(cuda_graph=True) @@ -41,8 +45,9 @@ def graph_loop(x: qd.types.ndarray(qd.i32, ndim=1), counter: qd.types.ndarray(qd counter.from_numpy(np.array(5, dtype=np.int32)) graph_loop(x, counter) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert counter.to_numpy() == 0 np.testing.assert_array_equal(x.to_numpy(), np.full(N, 5, dtype=np.int32)) @@ -51,18 +56,18 @@ def graph_loop(x: qd.types.ndarray(qd.i32, ndim=1), counter: qd.types.ndarray(qd counter.from_numpy(np.array(10, dtype=np.int32)) graph_loop(x, counter) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert counter.to_numpy() == 0 np.testing.assert_array_equal(x.to_numpy(), np.full(N, 10, dtype=np.int32)) -@test_utils.test(arch=[qd.cuda]) +@test_utils.test() def test_graph_do_while_boolean_done(): """Test graph_do_while with a boolean 'continue' flag (non-zero = keep going).""" - if not _on_hopper(): - pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + _xfail_if_cuda_without_hopper() N = 64 @qd.kernel(cuda_graph=True) @@ -85,8 +90,9 @@ def increment_until_threshold( keep_going.from_numpy(np.array(1, dtype=np.int32)) increment_until_threshold(x, 7, keep_going) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert keep_going.to_numpy() == 0 np.testing.assert_array_equal(x.to_numpy(), np.full(N, 7, dtype=np.int32)) @@ -95,18 +101,18 @@ def increment_until_threshold( keep_going.from_numpy(np.array(1, dtype=np.int32)) increment_until_threshold(x, 12, keep_going) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert keep_going.to_numpy() == 0 np.testing.assert_array_equal(x.to_numpy(), np.full(N, 12, dtype=np.int32)) -@test_utils.test(arch=[qd.cuda]) +@test_utils.test() def test_graph_do_while_multiple_loops(): """Test graph_do_while with multiple top-level loops in the kernel body.""" - if not _on_hopper(): - pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + _xfail_if_cuda_without_hopper() N = 32 @qd.kernel(cuda_graph=True) @@ -132,8 +138,9 @@ def multi_loop( counter.from_numpy(np.array(10, dtype=np.int32)) multi_loop(x, y, counter) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert counter.to_numpy() == 0 np.testing.assert_allclose(x.to_numpy(), np.full(N, 10.0)) @@ -144,8 +151,9 @@ def multi_loop( counter.from_numpy(np.array(5, dtype=np.int32)) multi_loop(x, y, counter) - assert _cuda_graph_used() - assert _cuda_graph_cache_size() == 1 + if _on_cuda(): + assert _cuda_graph_used() + assert _cuda_graph_cache_size() == 1 assert counter.to_numpy() == 0 np.testing.assert_allclose(x.to_numpy(), np.full(N, 5.0)) @@ -155,8 +163,7 @@ def multi_loop( @test_utils.test(arch=[qd.cuda]) def test_graph_do_while_changed_condition_ndarray_raises(): """Passing a different ndarray for the condition parameter should raise.""" - if not _on_hopper(): - pytest.xfail("graph_do_while requires SM 9.0+ (Hopper)") + _xfail_if_cuda_without_hopper() @qd.kernel(cuda_graph=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)):