From ab15b1b82c1cc2ef2d0029db9faf913ce4ef2145 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 16:40:10 -0700 Subject: [PATCH 1/2] Add CUDA stream and event API for concurrent kernel execution Introduces qd.create_stream() and qd.create_event() for launching kernels on separate CUDA streams with event-based synchronization. The qd_stream kwarg on kernel calls routes the launch to a specific stream. Non-CUDA backends return no-op handles (0). Routes kernel launcher memory ops through the active stream. --- python/quadrants/lang/__init__.py | 2 + python/quadrants/lang/kernel.py | 16 +- python/quadrants/lang/stream.py | 96 +++++++++ quadrants/program/program.cpp | 93 +++++++++ quadrants/program/program.h | 10 + quadrants/python/export_lang.cpp | 11 +- .../rhi/cuda/cuda_driver_functions.inc.h | 2 + quadrants/runtime/cuda/kernel_launcher.cpp | 20 +- tests/python/test_api.py | 4 + tests/python/test_cache.py | 8 +- tests/python/test_streams.py | 197 ++++++++++++++++++ 11 files changed, 443 insertions(+), 16 deletions(-) create mode 100644 python/quadrants/lang/stream.py create mode 100644 tests/python/test_streams.py diff --git a/python/quadrants/lang/__init__.py b/python/quadrants/lang/__init__.py index dc4fb2cf19..43a4b44b89 100644 --- a/python/quadrants/lang/__init__.py +++ b/python/quadrants/lang/__init__.py @@ -15,6 +15,7 @@ from quadrants.lang.runtime_ops import * from quadrants.lang.snode import * from quadrants.lang.source_builder import * +from quadrants.lang.stream import * from quadrants.lang.struct import * from quadrants.types.enums import DeviceCapability, Format, Layout # noqa: F401 @@ -45,6 +46,7 @@ "shell", "snode", "source_builder", + "stream", "struct", "util", ] diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index af6dbdacb5..4b1578ac4b 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -424,7 +424,9 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ] runtime._current_global_context = None - def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any: + def launch_kernel( + self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None + ) -> Any: assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided" callbacks: list[Callable[[], None]] = [] @@ -503,7 +505,14 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled ) self.src_ll_cache_observations.cache_stored = True self._last_compiled_kernel_data = compiled_kernel_data - prog.launch_kernel(compiled_kernel_data, launch_ctx) + stream_handle = qd_stream.handle if qd_stream is not None else 0 + if stream_handle: + prog.set_current_cuda_stream(stream_handle) + try: + prog.launch_kernel(compiled_kernel_data, launch_ctx) + finally: + if stream_handle: + prog.set_current_cuda_stream(0) except Exception as e: e = handle_exception_from_cpp(e) if impl.get_runtime().print_full_traceback: @@ -547,6 +556,7 @@ def ensure_compiled(self, *py_args: tuple[Any, ...]) -> tuple[Callable, int, Aut # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU) @_shell_pop_print def __call__(self, *py_args, **kwargs) -> Any: + qd_stream = kwargs.pop("qd_stream", None) if impl.get_runtime()._arch == _ARCH_PYTHON: return self.func(*py_args, **kwargs) config = impl.current_cfg() @@ -578,7 +588,7 @@ def __call__(self, *py_args, **kwargs) -> Any: kernel_cpp = self.materialized_kernels[key] compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None) self.launch_observations.found_kernel_in_materialize_cache = compiled_kernel_data is not None - ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args) + ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args, qd_stream=qd_stream) if compiled_kernel_data is None: assert self._last_compiled_kernel_data is not None self.compiled_kernel_data_by_key[key] = self._last_compiled_kernel_data diff --git a/python/quadrants/lang/stream.py b/python/quadrants/lang/stream.py new file mode 100644 index 0000000000..8530982455 --- /dev/null +++ b/python/quadrants/lang/stream.py @@ -0,0 +1,96 @@ +from quadrants.lang import impl + + +class Stream: + """Wraps a backend-specific GPU stream for concurrent kernel execution. + + On backends without native streams (e.g. CPU), this is a no-op object. + """ + + def __init__(self, handle: int): + self._handle = handle + + @property + def handle(self) -> int: + return self._handle + + def synchronize(self): + """Block until all operations on this stream complete.""" + prog = impl.get_runtime().prog + prog.stream_synchronize(self._handle) + + def destroy(self): + """Explicitly destroy the stream. Safe to call multiple times.""" + if self._handle != 0: + prog = impl.get_runtime().prog + prog.stream_destroy(self._handle) + self._handle = 0 + + def __del__(self): + if self._handle != 0: + try: + self.destroy() + except Exception: + pass + + +class Event: + """Wraps a backend-specific GPU event for stream synchronization. + + On backends without native events (e.g. CPU), this is a no-op object. + """ + + def __init__(self, handle: int): + self._handle = handle + + @property + def handle(self) -> int: + return self._handle + + def record(self, stream: Stream | None = None): + """Record this event on a stream. None means the default stream.""" + prog = impl.get_runtime().prog + stream_handle = stream.handle if stream is not None else 0 + prog.event_record(self._handle, stream_handle) + + def wait(self, qd_stream: Stream | None = None): + """Make a stream wait for this event. None means the default stream.""" + prog = impl.get_runtime().prog + stream_handle = qd_stream.handle if qd_stream is not None else 0 + prog.stream_wait_event(stream_handle, self._handle) + + def synchronize(self): + """Block the host until this event has been reached.""" + prog = impl.get_runtime().prog + prog.event_synchronize(self._handle) + + def destroy(self): + """Explicitly destroy the event. Safe to call multiple times.""" + if self._handle != 0: + prog = impl.get_runtime().prog + prog.event_destroy(self._handle) + self._handle = 0 + + def __del__(self): + if self._handle != 0: + try: + self.destroy() + except Exception: + pass + + +def create_stream() -> Stream: + """Create a new GPU stream for concurrent kernel execution.""" + prog = impl.get_runtime().prog + handle = prog.stream_create() + return Stream(handle) + + +def create_event() -> Event: + """Create a new GPU event for stream synchronization.""" + prog = impl.get_runtime().prog + handle = prog.event_create() + return Event(handle) + + +__all__ = ["Stream", "Event", "create_stream", "create_event"] diff --git a/quadrants/program/program.cpp b/quadrants/program/program.cpp index 7f5dfef2d8..9b2ff0886b 100644 --- a/quadrants/program/program.cpp +++ b/quadrants/program/program.cpp @@ -20,6 +20,11 @@ #include "quadrants/codegen/llvm/struct_llvm.h" #endif +#ifdef QD_WITH_CUDA +#include "quadrants/rhi/cuda/cuda_driver.h" +#include "quadrants/rhi/cuda/cuda_context.h" +#endif + #ifdef QD_WITH_VULKAN #include "quadrants/runtime/program_impls/vulkan/vulkan_program.h" #include "quadrants/rhi/vulkan/vulkan_loader.h" @@ -481,4 +486,92 @@ void Program::enqueue_compute_op_lambda( program_impl_->enqueue_compute_op_lambda(op, image_refs); } +uint64 Program::stream_create() { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda) { + void *stream = nullptr; + CUDADriver::get_instance().stream_create(&stream, 0 /*flags*/); + return reinterpret_cast(stream); + } +#endif + return 0; +} + +void Program::stream_destroy(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda && stream_handle != 0) { + CUDADriver::get_instance().stream_destroy( + reinterpret_cast(stream_handle)); + } +#endif +} + +void Program::stream_synchronize(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda) { + CUDADriver::get_instance().stream_synchronize( + reinterpret_cast(stream_handle)); + } +#endif +} + +void Program::set_current_cuda_stream(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda) { + CUDAContext::get_instance().set_stream( + reinterpret_cast(stream_handle)); + } +#endif +} + +uint64 Program::event_create() { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda) { + void *event = nullptr; + CUDADriver::get_instance().event_create(&event, + 0x02 /*CU_EVENT_DISABLE_TIMING*/); + return reinterpret_cast(event); + } +#endif + return 0; +} + +void Program::event_destroy(uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda && event_handle != 0) { + CUDADriver::get_instance().event_destroy( + reinterpret_cast(event_handle)); + } +#endif +} + +void Program::event_record(uint64 event_handle, uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda && event_handle != 0) { + CUDADriver::get_instance().event_record( + reinterpret_cast(event_handle), + reinterpret_cast(stream_handle)); + } +#endif +} + +void Program::event_synchronize(uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda && event_handle != 0) { + CUDADriver::get_instance().event_synchronize( + reinterpret_cast(event_handle)); + } +#endif +} + +void Program::stream_wait_event(uint64 stream_handle, uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (compile_config().arch == Arch::cuda && event_handle != 0) { + CUDADriver::get_instance().stream_wait_event( + reinterpret_cast(stream_handle), + reinterpret_cast(event_handle), 0 /*flags*/); + } +#endif +} + } // namespace quadrants::lang diff --git a/quadrants/program/program.h b/quadrants/program/program.h index 1fa2c2ac57..9568c371c8 100644 --- a/quadrants/program/program.h +++ b/quadrants/program/program.h @@ -300,6 +300,16 @@ class QD_DLL_EXPORT Program { return ndarrays_.size(); } + uint64 stream_create(); + void stream_destroy(uint64 stream_handle); + void stream_synchronize(uint64 stream_handle); + void set_current_cuda_stream(uint64 stream_handle); + uint64 event_create(); + void event_destroy(uint64 event_handle); + void event_record(uint64 event_handle, uint64 stream_handle); + void event_synchronize(uint64 event_handle); + void stream_wait_event(uint64 stream_handle, uint64 event_handle); + // TODO(zhanlue): Move these members and corresponding interfaces to // ProgramImpl Ideally, Program should serve as a pure interface class and all // the implementations should fall inside ProgramImpl diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index b3d23c0037..2f5da8b1b4 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -495,7 +495,16 @@ void export_lang(py::module &m) { .def("compile_kernel", &Program::compile_kernel, py::return_value_policy::reference) .def("launch_kernel", &Program::launch_kernel) - .def("get_device_caps", &Program::get_device_caps); + .def("get_device_caps", &Program::get_device_caps) + .def("stream_create", &Program::stream_create) + .def("stream_destroy", &Program::stream_destroy) + .def("stream_synchronize", &Program::stream_synchronize) + .def("set_current_cuda_stream", &Program::set_current_cuda_stream) + .def("event_create", &Program::event_create) + .def("event_destroy", &Program::event_destroy) + .def("event_record", &Program::event_record) + .def("event_synchronize", &Program::event_synchronize) + .def("stream_wait_event", &Program::stream_wait_event); py::class_(m, "CompileResult") .def_property_readonly( diff --git a/quadrants/rhi/cuda/cuda_driver_functions.inc.h b/quadrants/rhi/cuda/cuda_driver_functions.inc.h index 25b3c7958e..a9690ca10b 100644 --- a/quadrants/rhi/cuda/cuda_driver_functions.inc.h +++ b/quadrants/rhi/cuda/cuda_driver_functions.inc.h @@ -20,6 +20,7 @@ PER_CUDA_FUNCTION(context_set_limit, cuCtxSetLimit, int, std::size_t); // Stream management PER_CUDA_FUNCTION(stream_create, cuStreamCreate, void **, uint32); +PER_CUDA_FUNCTION(stream_destroy, cuStreamDestroy_v2, void *); // Memory management PER_CUDA_FUNCTION(memcpy_host_to_device, cuMemcpyHtoD_v2, void *, void *, std::size_t); @@ -52,6 +53,7 @@ PER_CUDA_FUNCTION(kernel_set_attribute, cuFuncSetAttribute, void *, CUfunction_a // Stream management PER_CUDA_FUNCTION(stream_synchronize, cuStreamSynchronize, void *); +PER_CUDA_FUNCTION(stream_wait_event, cuStreamWaitEvent, void *, void *, uint32); // Event management PER_CUDA_FUNCTION(event_create, cuEventCreate, void **, uint32) diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 5eae5e747d..13845d5a9b 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -1,5 +1,6 @@ #include "quadrants/runtime/cuda/kernel_launcher.h" #include "quadrants/rhi/cuda/cuda_context.h" +#include "quadrants/rhi/cuda/cuda_driver.h" namespace quadrants::lang { namespace cuda { @@ -43,10 +44,12 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, // kernels. std::unordered_map device_ptrs; + auto *active_stream = CUDAContext::get_instance().get_stream(); + char *device_result_buffer{nullptr}; CUDADriver::get_instance().malloc_async( (void **)&device_result_buffer, - std::max(ctx.result_buffer_size, sizeof(uint64)), nullptr); + std::max(ctx.result_buffer_size, sizeof(uint64)), active_stream); ctx.get_context().runtime = executor->get_llvm_runtime(); for (int i = 0; i < (int)parameters.size(); i++) { @@ -120,7 +123,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, } } if (transfers.size() > 0) { - CUDADriver::get_instance().stream_synchronize(nullptr); + CUDADriver::get_instance().stream_synchronize(active_stream); } char *host_result_buffer = (char *)ctx.get_context().result_buffer; if (ctx.result_buffer_size > 0) { @@ -129,10 +132,10 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, char *device_arg_buffer = nullptr; if (ctx.arg_buffer_size > 0) { CUDADriver::get_instance().malloc_async((void **)&device_arg_buffer, - ctx.arg_buffer_size, nullptr); + ctx.arg_buffer_size, active_stream); CUDADriver::get_instance().memcpy_host_to_device_async( device_arg_buffer, ctx.get_context().arg_buffer, ctx.arg_buffer_size, - nullptr); + active_stream); ctx.get_context().arg_buffer = device_arg_buffer; } @@ -144,17 +147,18 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, {}); } if (ctx.arg_buffer_size > 0) { - CUDADriver::get_instance().mem_free_async(device_arg_buffer, nullptr); + CUDADriver::get_instance().mem_free_async(device_arg_buffer, active_stream); } if (ctx.result_buffer_size > 0) { CUDADriver::get_instance().memcpy_device_to_host_async( host_result_buffer, device_result_buffer, ctx.result_buffer_size, - nullptr); + active_stream); } - CUDADriver::get_instance().mem_free_async(device_result_buffer, nullptr); + CUDADriver::get_instance().mem_free_async(device_result_buffer, + active_stream); // copy data back to host if (transfers.size() > 0) { - CUDADriver::get_instance().stream_synchronize(nullptr); + CUDADriver::get_instance().stream_synchronize(active_stream); for (auto itr = transfers.begin(); itr != transfers.end(); itr++) { auto &idx = itr->first; CUDADriver::get_instance().memcpy_device_to_host( diff --git a/tests/python/test_api.py b/tests/python/test_api.py index cf12abc393..002014c960 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -59,6 +59,7 @@ def _get_expected_matrix_apis(): "DEBUG", "DeviceCapability", "ERROR", + "Event", "Field", "FieldsBuilder", "Format", @@ -73,6 +74,7 @@ def _get_expected_matrix_apis(): "SNode", "ScalarField", "ScalarNdarray", + "Stream", "Struct", "StructField", "TRACE", @@ -117,6 +119,8 @@ def _get_expected_matrix_apis(): "clock_freq_hz", "cos", "cpu", + "create_event", + "create_stream", "cuda", "data_oriented", "dataclass", diff --git a/tests/python/test_cache.py b/tests/python/test_cache.py index c3821e44c5..e31daf61e7 100644 --- a/tests/python/test_cache.py +++ b/tests/python/test_cache.py @@ -216,11 +216,11 @@ def test_fastcache(tmp_path: pathlib.Path, monkeypatch): qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) is_valid = False - def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args): + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): nonlocal is_valid is_valid = True assert compiled_kernel_data is None - return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) @@ -242,11 +242,11 @@ def fun(value: qd.types.ndarray(), offset: qd.template()): qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) is_valid = False - def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args): + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): nonlocal is_valid is_valid = True assert compiled_kernel_data is not None - return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) diff --git a/tests/python/test_streams.py b/tests/python/test_streams.py new file mode 100644 index 0000000000..fabc217e96 --- /dev/null +++ b/tests/python/test_streams.py @@ -0,0 +1,197 @@ +"""Tests for GPU stream and event support.""" + +import numpy as np + +import quadrants as qd +from quadrants.lang.stream import Event, Stream + +from tests import test_utils + + +@test_utils.test(arch=[qd.cuda]) +def test_create_and_destroy_stream(): + s = qd.create_stream() + assert isinstance(s, Stream) + assert s.handle != 0 + s.destroy() + assert s.handle == 0 + + +@test_utils.test(arch=[qd.cuda]) +def test_create_and_destroy_event(): + e = qd.create_event() + assert isinstance(e, Event) + assert e.handle != 0 + e.destroy() + assert e.handle == 0 + + +@test_utils.test() +def test_kernel_on_stream(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 42.0 + + s = qd.create_stream() + fill(qd_stream=s) + s.synchronize() + assert np.allclose(x.to_numpy(), 42.0) + s.destroy() + + +@test_utils.test() +def test_two_streams(): + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_a(): + for i in range(N): + a[i] = 1.0 + + @qd.kernel + def fill_b(): + for i in range(N): + b[i] = 2.0 + + s1 = qd.create_stream() + s2 = qd.create_stream() + fill_a(qd_stream=s1) + fill_b(qd_stream=s2) + s1.synchronize() + s2.synchronize() + assert np.allclose(a.to_numpy(), 1.0) + assert np.allclose(b.to_numpy(), 2.0) + s1.destroy() + s2.destroy() + + +@test_utils.test() +def test_event_synchronization(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + y = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_x(): + for i in range(N): + x[i] = 10.0 + + @qd.kernel + def copy_x_to_y(): + for i in range(N): + y[i] = x[i] + + s1 = qd.create_stream() + fill_x(qd_stream=s1) + + e = qd.create_event() + e.record(s1) + + # Default stream waits for s1 to finish fill_x + e.wait() + copy_x_to_y() + qd.sync() + + assert np.allclose(y.to_numpy(), 10.0) + + e.destroy() + s1.destroy() + + +@test_utils.test() +def test_event_wait_on_stream(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + y = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_x(): + for i in range(N): + x[i] = 5.0 + + @qd.kernel + def copy_x_to_y(): + for i in range(N): + y[i] = x[i] + + s1 = qd.create_stream() + s2 = qd.create_stream() + + fill_x(qd_stream=s1) + + e = qd.create_event() + e.record(s1) + + # s2 waits for s1's event before running + e.wait(qd_stream=s2) + copy_x_to_y(qd_stream=s2) + s2.synchronize() + + assert np.allclose(y.to_numpy(), 5.0) + + e.destroy() + s1.destroy() + s2.destroy() + + +@test_utils.test() +def test_default_stream_kernel(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 7.0 + + fill() + qd.sync() + assert np.allclose(x.to_numpy(), 7.0) + + +@test_utils.test(arch=[qd.cpu]) +def test_stream_noop_on_cpu(): + """Streams should be no-ops on CPU without errors.""" + N = 64 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 3.0 + + s = qd.create_stream() + assert s.handle == 0 + fill(qd_stream=s) + qd.sync() + assert np.allclose(x.to_numpy(), 3.0) + + e = qd.create_event() + assert e.handle == 0 + e.record(s) + e.wait() + s.destroy() + e.destroy() + + +@test_utils.test() +def test_stream_with_ndarray(): + N = 1024 + + @qd.kernel + def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): + for i in range(N): + arr[i] = 99.0 + + arr = qd.ndarray(qd.f32, shape=(N,)) + s = qd.create_stream() + fill(arr, qd_stream=s) + s.synchronize() + assert np.allclose(arr.to_numpy(), 99.0) + s.destroy() From b856b33247dfbb55ca5f781e788fc50d5e32c9e9 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 17:25:18 -0700 Subject: [PATCH 2/2] Address review feedback for CUDA streams PR - Make CUDAContext::stream_ thread_local for thread-safety - Convert sync memcpy_host_to_device to async on active_stream - Use weakref in Stream/Event __del__ to safely handle interpreter shutdown - Add __enter__/__exit__ context manager support for Stream and Event - Use consistent qd_stream parameter naming in Event.record and Event.wait - Add handle==0 guard to stream_synchronize --- python/quadrants/lang/stream.py | 60 ++++++++++++++++------ quadrants/program/program.cpp | 2 +- quadrants/rhi/cuda/cuda_context.cpp | 6 +-- quadrants/rhi/cuda/cuda_context.h | 2 +- quadrants/runtime/cuda/kernel_launcher.cpp | 10 ++-- 5 files changed, 55 insertions(+), 25 deletions(-) diff --git a/python/quadrants/lang/stream.py b/python/quadrants/lang/stream.py index 8530982455..8f6cfab3d6 100644 --- a/python/quadrants/lang/stream.py +++ b/python/quadrants/lang/stream.py @@ -1,14 +1,22 @@ +import weakref + from quadrants.lang import impl +def _get_prog_weakref(): + return weakref.ref(impl.get_runtime().prog) + + class Stream: """Wraps a backend-specific GPU stream for concurrent kernel execution. On backends without native streams (e.g. CPU), this is a no-op object. + Call destroy() explicitly or use as a context manager to ensure cleanup. """ - def __init__(self, handle: int): + def __init__(self, handle: int, prog_ref: weakref.ref | None = None): self._handle = handle + self._prog_ref = prog_ref @property def handle(self) -> int: @@ -27,30 +35,41 @@ def destroy(self): self._handle = 0 def __del__(self): - if self._handle != 0: - try: - self.destroy() - except Exception: - pass + if self._handle != 0 and self._prog_ref is not None: + prog = self._prog_ref() + if prog is not None: + try: + prog.stream_destroy(self._handle) + self._handle = 0 + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() class Event: """Wraps a backend-specific GPU event for stream synchronization. On backends without native events (e.g. CPU), this is a no-op object. + Call destroy() explicitly or use as a context manager to ensure cleanup. """ - def __init__(self, handle: int): + def __init__(self, handle: int, prog_ref: weakref.ref | None = None): self._handle = handle + self._prog_ref = prog_ref @property def handle(self) -> int: return self._handle - def record(self, stream: Stream | None = None): + def record(self, qd_stream: Stream | None = None): """Record this event on a stream. None means the default stream.""" prog = impl.get_runtime().prog - stream_handle = stream.handle if stream is not None else 0 + stream_handle = qd_stream.handle if qd_stream is not None else 0 prog.event_record(self._handle, stream_handle) def wait(self, qd_stream: Stream | None = None): @@ -72,25 +91,34 @@ def destroy(self): self._handle = 0 def __del__(self): - if self._handle != 0: - try: - self.destroy() - except Exception: - pass + if self._handle != 0 and self._prog_ref is not None: + prog = self._prog_ref() + if prog is not None: + try: + prog.event_destroy(self._handle) + self._handle = 0 + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() def create_stream() -> Stream: """Create a new GPU stream for concurrent kernel execution.""" prog = impl.get_runtime().prog handle = prog.stream_create() - return Stream(handle) + return Stream(handle, _get_prog_weakref()) def create_event() -> Event: """Create a new GPU event for stream synchronization.""" prog = impl.get_runtime().prog handle = prog.event_create() - return Event(handle) + return Event(handle, _get_prog_weakref()) __all__ = ["Stream", "Event", "create_stream", "create_event"] diff --git a/quadrants/program/program.cpp b/quadrants/program/program.cpp index 9b2ff0886b..be152d02da 100644 --- a/quadrants/program/program.cpp +++ b/quadrants/program/program.cpp @@ -508,7 +508,7 @@ void Program::stream_destroy(uint64 stream_handle) { void Program::stream_synchronize(uint64 stream_handle) { #ifdef QD_WITH_CUDA - if (compile_config().arch == Arch::cuda) { + if (compile_config().arch == Arch::cuda && stream_handle != 0) { CUDADriver::get_instance().stream_synchronize( reinterpret_cast(stream_handle)); } diff --git a/quadrants/rhi/cuda/cuda_context.cpp b/quadrants/rhi/cuda/cuda_context.cpp index 89c16135a2..23399649a9 100644 --- a/quadrants/rhi/cuda/cuda_context.cpp +++ b/quadrants/rhi/cuda/cuda_context.cpp @@ -11,10 +11,10 @@ namespace quadrants::lang { +thread_local void *CUDAContext::stream_ = nullptr; + CUDAContext::CUDAContext() - : profiler_(nullptr), - driver_(CUDADriver::get_instance_without_context()), - stream_(nullptr) { + : profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) { // CUDA initialization dev_count_ = 0; driver_.init(0); diff --git a/quadrants/rhi/cuda/cuda_context.h b/quadrants/rhi/cuda/cuda_context.h index c57baa3d92..ba891644a7 100644 --- a/quadrants/rhi/cuda/cuda_context.h +++ b/quadrants/rhi/cuda/cuda_context.h @@ -30,7 +30,7 @@ class CUDAContext { int max_shared_memory_bytes_; bool debug_; bool supports_mem_pool_; - void *stream_; + static thread_local void *stream_; public: CUDAContext(); diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 13845d5a9b..9bbf75044e 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -85,8 +85,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, executor->get_device_alloc_info_ptr(devalloc); transfers[data_ptr_idx] = {data_ptr, devalloc}; - CUDADriver::get_instance().memcpy_host_to_device( - (void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz); + CUDADriver::get_instance().memcpy_host_to_device_async( + (void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz, + active_stream); if (grad_ptr != nullptr) { DeviceAllocation grad_devalloc = executor->allocate_memory_on_device( @@ -95,8 +96,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, executor->get_device_alloc_info_ptr(grad_devalloc); transfers[grad_ptr_idx] = {grad_ptr, grad_devalloc}; - CUDADriver::get_instance().memcpy_host_to_device( - (void *)device_ptrs[grad_ptr_idx], grad_ptr, arr_sz); + CUDADriver::get_instance().memcpy_host_to_device_async( + (void *)device_ptrs[grad_ptr_idx], grad_ptr, arr_sz, + active_stream); } else { device_ptrs[grad_ptr_idx] = nullptr; }