Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/quadrants/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from quadrants.lang.runtime_ops import *
from quadrants.lang.snode import *
from quadrants.lang.source_builder import *
from quadrants.lang.stream import *
from quadrants.lang.struct import *
from quadrants.types.enums import DeviceCapability, Format, Layout # noqa: F401

Expand Down Expand Up @@ -45,6 +46,7 @@
"shell",
"snode",
"source_builder",
"stream",
"struct",
"util",
]
Expand Down
16 changes: 13 additions & 3 deletions python/quadrants/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,9 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, .
]
runtime._current_global_context = None

def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any:
def launch_kernel(
self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None
) -> Any:
assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"

callbacks: list[Callable[[], None]] = []
Expand Down Expand Up @@ -503,7 +505,14 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled
)
self.src_ll_cache_observations.cache_stored = True
self._last_compiled_kernel_data = compiled_kernel_data
prog.launch_kernel(compiled_kernel_data, launch_ctx)
stream_handle = qd_stream.handle if qd_stream is not None else 0
if stream_handle:
prog.set_current_cuda_stream(stream_handle)
try:
prog.launch_kernel(compiled_kernel_data, launch_ctx)
finally:
if stream_handle:
prog.set_current_cuda_stream(0)
except Exception as e:
e = handle_exception_from_cpp(e)
if impl.get_runtime().print_full_traceback:
Expand Down Expand Up @@ -547,6 +556,7 @@ def ensure_compiled(self, *py_args: tuple[Any, ...]) -> tuple[Callable, int, Aut
# Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
@_shell_pop_print
def __call__(self, *py_args, **kwargs) -> Any:
qd_stream = kwargs.pop("qd_stream", None)
if impl.get_runtime()._arch == _ARCH_PYTHON:
return self.func(*py_args, **kwargs)
config = impl.current_cfg()
Expand Down Expand Up @@ -578,7 +588,7 @@ def __call__(self, *py_args, **kwargs) -> Any:
kernel_cpp = self.materialized_kernels[key]
compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None)
self.launch_observations.found_kernel_in_materialize_cache = compiled_kernel_data is not None
ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args)
ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args, qd_stream=qd_stream)
if compiled_kernel_data is None:
assert self._last_compiled_kernel_data is not None
self.compiled_kernel_data_by_key[key] = self._last_compiled_kernel_data
Expand Down
124 changes: 124 additions & 0 deletions python/quadrants/lang/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import weakref

from quadrants.lang import impl


def _get_prog_weakref():
return weakref.ref(impl.get_runtime().prog)


class Stream:
"""Wraps a backend-specific GPU stream for concurrent kernel execution.

On backends without native streams (e.g. CPU), this is a no-op object.
Call destroy() explicitly or use as a context manager to ensure cleanup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather pretend it can only be used as context manager, aligning with the API for torch.profiler. Because managing streams manually without context sounds a bad practice and should be made easy.

"""

def __init__(self, handle: int, prog_ref: weakref.ref | None = None):
self._handle = handle
self._prog_ref = prog_ref

@property
def handle(self) -> int:
return self._handle

def synchronize(self):
"""Block until all operations on this stream complete."""
prog = impl.get_runtime().prog
prog.stream_synchronize(self._handle)

def destroy(self):
"""Explicitly destroy the stream. Safe to call multiple times."""
if self._handle != 0:
prog = impl.get_runtime().prog
prog.stream_destroy(self._handle)
self._handle = 0

def __del__(self):
if self._handle != 0 and self._prog_ref is not None:
prog = self._prog_ref()
if prog is not None:
try:
prog.stream_destroy(self._handle)
self._handle = 0
except Exception:
pass

def __enter__(self):
return self

def __exit__(self, *args):
self.destroy()


class Event:
"""Wraps a backend-specific GPU event for stream synchronization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what is an "event" in the documentation? I have no idea what it is.


On backends without native events (e.g. CPU), this is a no-op object.
Call destroy() explicitly or use as a context manager to ensure cleanup.
"""

def __init__(self, handle: int, prog_ref: weakref.ref | None = None):
self._handle = handle
self._prog_ref = prog_ref

@property
def handle(self) -> int:
return self._handle

def record(self, qd_stream: Stream | None = None):
"""Record this event on a stream. None means the default stream."""
prog = impl.get_runtime().prog
stream_handle = qd_stream.handle if qd_stream is not None else 0
prog.event_record(self._handle, stream_handle)

def wait(self, qd_stream: Stream | None = None):
"""Make a stream wait for this event. None means the default stream."""
prog = impl.get_runtime().prog
stream_handle = qd_stream.handle if qd_stream is not None else 0
prog.stream_wait_event(stream_handle, self._handle)

def synchronize(self):
"""Block the host until this event has been reached."""
prog = impl.get_runtime().prog
prog.event_synchronize(self._handle)

def destroy(self):
"""Explicitly destroy the event. Safe to call multiple times."""
if self._handle != 0:
prog = impl.get_runtime().prog
prog.event_destroy(self._handle)
self._handle = 0

def __del__(self):
if self._handle != 0 and self._prog_ref is not None:
Comment on lines +88 to +94
Copy link
Contributor

@duburcqa duburcqa Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer if self._handle:. It is more clear semantically. Whether it is an int or some more complex object does not matter much.

prog = self._prog_ref()
if prog is not None:
try:
prog.event_destroy(self._handle)
self._handle = 0
except Exception:
pass

def __enter__(self):
return self

def __exit__(self, *args):
self.destroy()


def create_stream() -> Stream:
"""Create a new GPU stream for concurrent kernel execution."""
prog = impl.get_runtime().prog
handle = prog.stream_create()
return Stream(handle, _get_prog_weakref())


def create_event() -> Event:
"""Create a new GPU event for stream synchronization."""
prog = impl.get_runtime().prog
handle = prog.event_create()
return Event(handle, _get_prog_weakref())


__all__ = ["Stream", "Event", "create_stream", "create_event"]
93 changes: 93 additions & 0 deletions quadrants/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
#include "quadrants/codegen/llvm/struct_llvm.h"
#endif

#ifdef QD_WITH_CUDA
#include "quadrants/rhi/cuda/cuda_driver.h"
#include "quadrants/rhi/cuda/cuda_context.h"
#endif

#ifdef QD_WITH_VULKAN
#include "quadrants/runtime/program_impls/vulkan/vulkan_program.h"
#include "quadrants/rhi/vulkan/vulkan_loader.h"
Expand Down Expand Up @@ -481,4 +486,92 @@ void Program::enqueue_compute_op_lambda(
program_impl_->enqueue_compute_op_lambda(op, image_refs);
}

uint64 Program::stream_create() {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda) {
void *stream = nullptr;
CUDADriver::get_instance().stream_create(&stream, 0 /*flags*/);
return reinterpret_cast<uint64>(stream);
}
#endif
return 0;
}

void Program::stream_destroy(uint64 stream_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && stream_handle != 0) {
CUDADriver::get_instance().stream_destroy(
reinterpret_cast<void *>(stream_handle));
}
#endif
}

void Program::stream_synchronize(uint64 stream_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && stream_handle != 0) {
CUDADriver::get_instance().stream_synchronize(
reinterpret_cast<void *>(stream_handle));
}
#endif
}

void Program::set_current_cuda_stream(uint64 stream_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda) {
CUDAContext::get_instance().set_stream(
reinterpret_cast<void *>(stream_handle));
}
#endif
}

uint64 Program::event_create() {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda) {
void *event = nullptr;
CUDADriver::get_instance().event_create(&event,
0x02 /*CU_EVENT_DISABLE_TIMING*/);
return reinterpret_cast<uint64>(event);
}
#endif
return 0;
}

void Program::event_destroy(uint64 event_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && event_handle != 0) {
CUDADriver::get_instance().event_destroy(
reinterpret_cast<void *>(event_handle));
}
#endif
}

void Program::event_record(uint64 event_handle, uint64 stream_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && event_handle != 0) {
CUDADriver::get_instance().event_record(
reinterpret_cast<void *>(event_handle),
reinterpret_cast<void *>(stream_handle));
}
#endif
}

void Program::event_synchronize(uint64 event_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && event_handle != 0) {
CUDADriver::get_instance().event_synchronize(
reinterpret_cast<void *>(event_handle));
}
#endif
}

void Program::stream_wait_event(uint64 stream_handle, uint64 event_handle) {
#ifdef QD_WITH_CUDA
if (compile_config().arch == Arch::cuda && event_handle != 0) {
CUDADriver::get_instance().stream_wait_event(
reinterpret_cast<void *>(stream_handle),
reinterpret_cast<void *>(event_handle), 0 /*flags*/);
}
#endif
}

} // namespace quadrants::lang
10 changes: 10 additions & 0 deletions quadrants/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ class QD_DLL_EXPORT Program {
return ndarrays_.size();
}

uint64 stream_create();
void stream_destroy(uint64 stream_handle);
void stream_synchronize(uint64 stream_handle);
void set_current_cuda_stream(uint64 stream_handle);
uint64 event_create();
void event_destroy(uint64 event_handle);
void event_record(uint64 event_handle, uint64 stream_handle);
void event_synchronize(uint64 event_handle);
void stream_wait_event(uint64 stream_handle, uint64 event_handle);

// TODO(zhanlue): Move these members and corresponding interfaces to
// ProgramImpl Ideally, Program should serve as a pure interface class and all
// the implementations should fall inside ProgramImpl
Expand Down
11 changes: 10 additions & 1 deletion quadrants/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,16 @@ void export_lang(py::module &m) {
.def("compile_kernel", &Program::compile_kernel,
py::return_value_policy::reference)
.def("launch_kernel", &Program::launch_kernel)
.def("get_device_caps", &Program::get_device_caps);
.def("get_device_caps", &Program::get_device_caps)
.def("stream_create", &Program::stream_create)
.def("stream_destroy", &Program::stream_destroy)
.def("stream_synchronize", &Program::stream_synchronize)
.def("set_current_cuda_stream", &Program::set_current_cuda_stream)
.def("event_create", &Program::event_create)
.def("event_destroy", &Program::event_destroy)
.def("event_record", &Program::event_record)
.def("event_synchronize", &Program::event_synchronize)
.def("stream_wait_event", &Program::stream_wait_event);
Comment on lines +499 to +507
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is cuda-specific and what is not? Only 'set_current_cuda_stream' is cuda specific? if so, stream are still usable on other backend or this function is necessary to make it useful?


py::class_<CompileResult>(m, "CompileResult")
.def_property_readonly(
Expand Down
6 changes: 3 additions & 3 deletions quadrants/rhi/cuda/cuda_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

namespace quadrants::lang {

thread_local void *CUDAContext::stream_ = nullptr;

CUDAContext::CUDAContext()
: profiler_(nullptr),
driver_(CUDADriver::get_instance_without_context()),
stream_(nullptr) {
: profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) {
// CUDA initialization
dev_count_ = 0;
driver_.init(0);
Expand Down
2 changes: 1 addition & 1 deletion quadrants/rhi/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CUDAContext {
int max_shared_memory_bytes_;
bool debug_;
bool supports_mem_pool_;
void *stream_;
static thread_local void *stream_;

public:
CUDAContext();
Expand Down
2 changes: 2 additions & 0 deletions quadrants/rhi/cuda/cuda_driver_functions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ PER_CUDA_FUNCTION(context_set_limit, cuCtxSetLimit, int, std::size_t);

// Stream management
PER_CUDA_FUNCTION(stream_create, cuStreamCreate, void **, uint32);
PER_CUDA_FUNCTION(stream_destroy, cuStreamDestroy_v2, void *);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is 'cuStreamDestroy_v2' ? very weird name.

Why do we have functions with '_v2' suffix at multiple places?


// Memory management
PER_CUDA_FUNCTION(memcpy_host_to_device, cuMemcpyHtoD_v2, void *, void *, std::size_t);
Expand Down Expand Up @@ -52,6 +53,7 @@ PER_CUDA_FUNCTION(kernel_set_attribute, cuFuncSetAttribute, void *, CUfunction_a

// Stream management
PER_CUDA_FUNCTION(stream_synchronize, cuStreamSynchronize, void *);
PER_CUDA_FUNCTION(stream_wait_event, cuStreamWaitEvent, void *, void *, uint32);

// Event management
PER_CUDA_FUNCTION(event_create, cuEventCreate, void **, uint32)
Expand Down
Loading
Loading