Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions quadrants/rhi/amdgpu/amdgpu_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ void AMDGPUContext::launch(void *func,
}

AMDGPUContext::~AMDGPUContext() {
for (auto *s : stream_pool_) {
driver_.stream_destroy(s);
}
stream_pool_.clear();
if (context_) {
driver_.device_primary_ctx_release(device_);
}
Expand Down
18 changes: 18 additions & 0 deletions quadrants/rhi/amdgpu/amdgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AMDGPUContext {
AMDGPUDriver &driver_;
bool debug_{false};
static thread_local void *stream_;
std::vector<void *> stream_pool_;
std::vector<void *> kernel_arg_pointer_;

public:
Expand Down Expand Up @@ -125,6 +126,23 @@ class AMDGPUContext {
return stream_;
}

void *acquire_stream() {
std::lock_guard<std::mutex> _(lock_);
if (!stream_pool_.empty()) {
auto s = stream_pool_.back();
stream_pool_.pop_back();
return s;
}
void *s = nullptr;
AMDGPUDriver::get_instance().stream_create(&s, 0);
return s;
}

void release_stream(void *s) {
std::lock_guard<std::mutex> _(lock_);
stream_pool_.push_back(s);
}

static AMDGPUContext &get_instance();
};

Expand Down
11 changes: 4 additions & 7 deletions quadrants/rhi/cuda/cuda_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,10 @@ void CUDAContext::launch(void *func,
}

CUDAContext::~CUDAContext() {
// TODO: restore these?
/*
CUDADriver::get_instance().cuMemFree(context_buffer);
for (auto cudaModule: cudaModules)
CUDADriver::get_instance().cuModuleUnload(cudaModule);
CUDADriver::get_instance().cuCtxDestroy(context);
*/
for (auto *s : stream_pool_) {
driver_.stream_destroy(s);
}
stream_pool_.clear();
}

CUDAContext &CUDAContext::get_instance() {
Expand Down
19 changes: 19 additions & 0 deletions quadrants/rhi/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <mutex>
#include <unordered_map>
#include <thread>
#include <vector>

#include "quadrants/program/kernel_profiler.h"
#include "quadrants/rhi/cuda/cuda_driver.h"
Expand Down Expand Up @@ -31,6 +32,7 @@ class CUDAContext {
bool debug_;
bool supports_mem_pool_;
static thread_local void *stream_;
std::vector<void *> stream_pool_;

public:
CUDAContext();
Expand Down Expand Up @@ -120,6 +122,23 @@ class CUDAContext {
void *get_stream() const {
return stream_;
}

void *acquire_stream() {
std::lock_guard<std::mutex> _(lock_);
if (!stream_pool_.empty()) {
auto s = stream_pool_.back();
stream_pool_.pop_back();
return s;
}
void *s = nullptr;
CUDADriver::get_instance().stream_create(&s, 0);
return s;
}

void release_stream(void *s) {
std::lock_guard<std::mutex> _(lock_);
stream_pool_.push_back(s);
}
};

} // namespace quadrants::lang
8 changes: 2 additions & 6 deletions quadrants/runtime/amdgpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,11 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
i++;
}

// Create one stream per unique group ID. Streams are created/destroyed
// per launch; a stream pool could reduce overhead for hot loops.
std::map<int, void *> stream_by_id;
for (size_t j = group_start; j < i; j++) {
int sid = offloaded_tasks[j].stream_parallel_group_id;
if (stream_by_id.find(sid) == stream_by_id.end()) {
void *s = nullptr;
AMDGPUDriver::get_instance().stream_create(&s, 0);
stream_by_id[sid] = s;
stream_by_id[sid] = AMDGPUContext::get_instance().acquire_stream();
}
}

Expand All @@ -155,7 +151,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
AMDGPUDriver::get_instance().stream_synchronize(s);
}
for (auto &[sid, s] : stream_by_id) {
AMDGPUDriver::get_instance().stream_destroy(s);
AMDGPUContext::get_instance().release_stream(s);
}

AMDGPUContext::get_instance().set_stream(active_stream);
Expand Down
8 changes: 2 additions & 6 deletions quadrants/runtime/cuda/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,11 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
i++;
}

// Create one stream per unique group ID. Streams are created/destroyed
// per launch; a stream pool could reduce overhead for hot loops.
std::map<int, void *> stream_by_id;
for (size_t j = group_start; j < i; j++) {
int sid = offloaded_tasks[j].stream_parallel_group_id;
if (stream_by_id.find(sid) == stream_by_id.end()) {
void *s = nullptr;
CUDADriver::get_instance().stream_create(&s, 0);
stream_by_id[sid] = s;
stream_by_id[sid] = CUDAContext::get_instance().acquire_stream();
}
}

Expand All @@ -187,7 +183,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
CUDADriver::get_instance().stream_synchronize(s);
}
for (auto &[sid, s] : stream_by_id) {
CUDADriver::get_instance().stream_destroy(s);
CUDAContext::get_instance().release_stream(s);
}

CUDAContext::get_instance().set_stream(active_stream);
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,31 @@ def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)):
s.synchronize()
assert np.allclose(arr.to_numpy(), 99.0)
s.destroy()


@test_utils.test()
def test_stream_pool_reuse():
"""Repeated stream_parallel invocations reuse pooled streams correctly."""
N = 128
a = qd.ndarray(qd.f32, shape=(N,))
b = qd.ndarray(qd.f32, shape=(N,))

@qd.kernel
def parallel_fill(
x: qd.types.ndarray(dtype=qd.f32, ndim=1),
y: qd.types.ndarray(dtype=qd.f32, ndim=1),
val: qd.f32,
):
with qd.stream_parallel():
for i in range(N):
x[i] = val
with qd.stream_parallel():
for i in range(N):
y[i] = val * 2.0

for iteration in range(5):
v = float(iteration + 1)
parallel_fill(a, b, v)
qd.sync()
assert np.allclose(a.to_numpy(), v), f"iteration {iteration}"
assert np.allclose(b.to_numpy(), v * 2.0), f"iteration {iteration}"
Loading