From a40ed4ccd03a1162cf40a5f4fa35ee6ee7979abc Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 16:47:44 -0700 Subject: [PATCH 1/4] Add qd.stream_parallel() context manager for implicit stream parallelism Introduces stream_parallel() for running top-level for-loop blocks on separate GPU streams. The AST transformer maps 'with qd.stream_parallel()' blocks to stream-parallel group IDs, which propagate through IR lowering and offloading to the CUDA/AMDGPU kernel launchers. Each unique group ID gets its own stream at launch time. Includes validation that all top-level kernel statements must be stream_parallel blocks (no mixing), and offline cache key support. --- python/quadrants/lang/ast/ast_transformer.py | 32 +++- .../function_def_transformer.py | 29 +++ python/quadrants/lang/stream.py | 15 +- quadrants/analysis/gen_offline_cache_key.cpp | 1 + quadrants/codegen/amdgpu/codegen_amdgpu.cpp | 1 + quadrants/codegen/cuda/codegen_cuda.cpp | 1 + quadrants/codegen/llvm/llvm_compiled_data.h | 13 +- quadrants/ir/frontend_ir.cpp | 12 +- quadrants/ir/frontend_ir.h | 12 ++ quadrants/ir/statements.cpp | 3 + quadrants/ir/statements.h | 3 + quadrants/python/export_lang.cpp | 4 +- quadrants/runtime/amdgpu/kernel_launcher.cpp | 52 ++++- quadrants/runtime/cuda/kernel_launcher.cpp | 52 ++++- quadrants/transforms/lower_ast.cpp | 3 + quadrants/transforms/offload.cpp | 3 + tests/python/test_api.py | 1 + tests/python/test_streams.py | 178 ++++++++++++++++-- 18 files changed, 377 insertions(+), 38 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 1b13ead0f..f5cfbeef1 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -28,6 +28,7 @@ from quadrants.lang.ast.ast_transformers.function_def_transformer import ( FunctionDefTransformer, ) +from quadrants.lang.ast.symbol_resolver import ASTResolver from quadrants.lang.exception import ( QuadrantsIndexError, QuadrantsRuntimeTypeError, @@ -39,6 +40,7 @@ from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.snode import append, deactivate, length +from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import Struct, StructType from quadrants.types import primitive_types from quadrants.types.utils import is_integral @@ -108,7 +110,11 @@ def build_AnnAssign(ctx: ASTTransformerFuncContext, node: ast.AnnAssign): @staticmethod def build_assign_annotated( - ctx: ASTTransformerFuncContext, target: ast.Name, value, is_static_assign: bool, annotation: Type + ctx: ASTTransformerFuncContext, + target: ast.Name, + value, + is_static_assign: bool, + annotation: Type, ): """Build an annotated assignment like this: target: annotation = value. @@ -156,7 +162,10 @@ def build_Assign(ctx: ASTTransformerFuncContext, node: ast.Assign) -> None: @staticmethod def build_assign_unpack( - ctx: ASTTransformerFuncContext, node_target: list | ast.Tuple, values, is_static_assign: bool + ctx: ASTTransformerFuncContext, + node_target: list | ast.Tuple, + values, + is_static_assign: bool, ): """Build the unpack assignments like this: (target1, target2) = (value1, value2). The function should be called only if the node target is a tuple. @@ -538,7 +547,8 @@ def build_Return(ctx: ASTTransformerFuncContext, node: ast.Return) -> None: else: raise QuadrantsSyntaxError("The return type is not supported now!") ctx.ast_builder.create_kernel_exprgroup_return( - expr.make_expr_group(return_exprs), _qd_core.DebugInfo(ctx.get_pos_info(node)) + expr.make_expr_group(return_exprs), + _qd_core.DebugInfo(ctx.get_pos_info(node)), ) else: ctx.return_data = node.value.ptr @@ -1381,6 +1391,22 @@ def build_Continue(ctx: ASTTransformerFuncContext, node: ast.Continue) -> None: ctx.ast_builder.insert_continue_stmt(_qd_core.DebugInfo(ctx.get_pos_info(node))) return None + @staticmethod + def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + if len(node.items) != 1: + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports a single context manager") + item = node.items[0] + if item.optional_vars is not None: + raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels") + if not isinstance(item.context_expr, ast.Call): + raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression") + if not ASTResolver.resolve_to(item.context_expr.func, stream_parallel, ctx.global_vars): + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()") + ctx.ast_builder.begin_stream_parallel() + build_stmts(ctx, node.body) + ctx.ast_builder.end_stream_parallel() + return None + @staticmethod def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None: return None diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 6d000b69f..dacbac4c9 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -21,10 +21,12 @@ from quadrants.lang.ast.ast_transformer_utils import ( ASTTransformerFuncContext, ) +from quadrants.lang.ast.symbol_resolver import ASTResolver from quadrants.lang.exception import ( QuadrantsSyntaxError, ) from quadrants.lang.matrix import MatrixType +from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType from quadrants.lang.util import to_quadrants_type from quadrants.types import annotations, ndarray_type, primitive_types @@ -295,7 +297,34 @@ def build_FunctionDef( else: FunctionDefTransformer._transform_as_func(ctx, node, args) + if ctx.is_kernel: + FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars) + with ctx.variable_scope_guard(): build_stmts(ctx, node.body) return None + + @staticmethod + def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool: + if not isinstance(stmt, ast.With): + return False + if len(stmt.items) != 1: + return False + item = stmt.items[0] + if not isinstance(item.context_expr, ast.Call): + return False + return ASTResolver.resolve_to(item.context_expr.func, stream_parallel, global_vars) + + @staticmethod + def _validate_stream_parallel_exclusivity(body: list[ast.stmt], global_vars: dict[str, Any]) -> None: + has_sp = any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body) + if not has_sp: + return + for stmt in body: + if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars): + raise QuadrantsSyntaxError( + "When using qd.stream_parallel(), all top-level statements " + "in the kernel must be 'with qd.stream_parallel():' blocks. " + "Move non-parallel code to a separate kernel." + ) diff --git a/python/quadrants/lang/stream.py b/python/quadrants/lang/stream.py index 853098245..77979184d 100644 --- a/python/quadrants/lang/stream.py +++ b/python/quadrants/lang/stream.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + from quadrants.lang import impl @@ -93,4 +95,15 @@ def create_event() -> Event: return Event(handle) -__all__ = ["Stream", "Event", "create_stream", "create_event"] +@contextmanager +def stream_parallel(): + """Run top-level for loops in this block on separate GPU streams. + + Used inside @qd.kernel. At Python runtime (outside kernels), this is a + no-op. During kernel compilation, the AST transformer calls into the C++ + ASTBuilder to tag loops with a stream-parallel group ID. + """ + yield + + +__all__ = ["Stream", "Event", "create_stream", "create_event", "stream_parallel"] diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index f9eb5dc32..9a38eb9ac 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -382,6 +382,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(stmt->strictly_serialized); emit(stmt->mem_access_opt); emit(stmt->block_dim); + emit(stmt->stream_parallel_group_id); emit(stmt->body.get()); } diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index bba1c87f2..e0fcca575 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -396,6 +396,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { current_task->grid_dim = num_SMs * query_max_block_per_sm; } current_task->block_dim = stmt->block_dim; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); offloaded_tasks.push_back(*current_task); diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 8395f7adc..4795db23d 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -692,6 +692,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } current_task->block_dim = stmt->block_dim; current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); offloaded_tasks.push_back(*current_task); diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 16d4978bd..f496e6fa3 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -14,16 +14,23 @@ class OffloadedTask { int block_dim{0}; int grid_dim{0}; int dynamic_shared_array_bytes{0}; + int stream_parallel_group_id{0}; explicit OffloadedTask(const std::string &name = "", int block_dim = 0, int grid_dim = 0, - int dynamic_shared_array_bytes = 0) + int dynamic_shared_array_bytes = 0, + int stream_parallel_group_id = 0) : name(name), block_dim(block_dim), grid_dim(grid_dim), - dynamic_shared_array_bytes(dynamic_shared_array_bytes) {}; - QD_IO_DEF(name, block_dim, grid_dim, dynamic_shared_array_bytes); + dynamic_shared_array_bytes(dynamic_shared_array_bytes), + stream_parallel_group_id(stream_parallel_group_id) {}; + QD_IO_DEF(name, + block_dim, + grid_dim, + dynamic_shared_array_bytes, + stream_parallel_group_id); }; struct LLVMCompiledTask { diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index ae2e3ebe7..6cf308764 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -119,7 +119,8 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o) num_cpu_threads(o.num_cpu_threads), strictly_serialized(o.strictly_serialized), mem_access_opt(o.mem_access_opt), - block_dim(o.block_dim) { + block_dim(o.block_dim), + stream_parallel_group_id(o.stream_parallel_group_id) { } void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { @@ -127,6 +128,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { strictly_serialized = config.strictly_serialized; mem_access_opt = config.mem_access_opt; block_dim = config.block_dim; + stream_parallel_group_id = config.stream_parallel_group_id; if (arch == Arch::cuda || arch == Arch::amdgpu) { num_cpu_threads = 1; QD_ASSERT(block_dim <= quadrants_max_gpu_block_dim); @@ -1542,6 +1544,8 @@ void ASTBuilder::begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e, const DebugInfo &dbg_info) { + for_loop_dec_.config.stream_parallel_group_id = + current_stream_parallel_group_id_; auto stmt_unique = std::make_unique( i, s, e, arch_, for_loop_dec_.config, dbg_info); auto stmt = stmt_unique.get(); @@ -1558,6 +1562,8 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = + current_stream_parallel_group_id_; auto stmt_unique = std::make_unique( loop_vars, snode, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); @@ -1574,6 +1580,8 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor( for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = + current_stream_parallel_group_id_; auto stmt_unique = std::make_unique( loop_vars, external_tensor, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); @@ -1591,6 +1599,8 @@ void ASTBuilder::begin_frontend_mesh_for( for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the mesh for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = + current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(ExprGroup(i), mesh_ptr, element_type, arch_, for_loop_dec_.config, dbg_info); diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index bce009f9e..693a7f461 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -23,6 +23,7 @@ struct ForLoopConfig { MemoryAccessOptions mem_access_opt; int block_dim{0}; bool uniform{false}; + int stream_parallel_group_id{0}; }; #define QD_DEFINE_CLONE_FOR_FRONTEND_IR \ @@ -207,6 +208,7 @@ class FrontendForStmt : public Stmt { bool strictly_serialized; MemoryAccessOptions mem_access_opt; int block_dim; + int stream_parallel_group_id{0}; FrontendForStmt(const ExprGroup &loop_vars, SNode *snode, @@ -961,6 +963,8 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; + int stream_parallel_group_counter_{0}; + int current_stream_parallel_group_id_{0}; public: ASTBuilder(Block *initial, Arch arch, bool is_kernel) @@ -1107,6 +1111,14 @@ class ASTBuilder { for_loop_dec_.reset(); } + void begin_stream_parallel() { + current_stream_parallel_group_id_ = ++stream_parallel_group_counter_; + } + + void end_stream_parallel() { + current_stream_parallel_group_id_ = 0; + } + Identifier get_next_id(const std::string &name = "") { return Identifier(id_counter_++, name); } diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 14c55be85..79b469a22 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -244,6 +244,7 @@ std::unique_ptr RangeForStmt::clone() const { begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim, strictly_serialized); new_stmt->reversed = reversed; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; return new_stmt; } @@ -265,6 +266,7 @@ std::unique_ptr StructForStmt::clone() const { auto new_stmt = std::make_unique( snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; return new_stmt; } @@ -439,6 +441,7 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->tls_size = tls_size; new_stmt->bls_size = bls_size; new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; return new_stmt; } diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index e06bb6d4d..3f440fe4e 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -1016,6 +1016,7 @@ class RangeForStmt : public Stmt { int block_dim; bool strictly_serialized; std::string range_hint; + int stream_parallel_group_id{0}; RangeForStmt(Stmt *begin, Stmt *end, @@ -1061,6 +1062,7 @@ class StructForStmt : public Stmt { int num_cpu_threads; int block_dim; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; StructForStmt(SNode *snode, std::unique_ptr &&body, @@ -1443,6 +1445,7 @@ class OffloadedStmt : public Stmt { std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte std::size_t bls_size{0}; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; OffloadedStmt(TaskType task_type, Arch arch, Kernel *kernel); diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 2f5da8b1b..d134464d4 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -357,7 +357,9 @@ void export_lang(py::module &m) { .def("strictly_serialize", &ASTBuilder::strictly_serialize) .def("block_dim", &ASTBuilder::block_dim) .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag) - .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag); + .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag) + .def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel) + .def("end_stream_parallel", &ASTBuilder::end_stream_parallel); auto device_capability_config = py::class_(m, "DeviceCapabilityConfig") diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index 1d8430d35..1b82b3345 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -1,3 +1,5 @@ +#include + #include "quadrants/runtime/amdgpu/kernel_launcher.h" #include "quadrants/rhi/amdgpu/amdgpu_context.h" #include "quadrants/rhi/amdgpu/amdgpu_driver.h" @@ -108,12 +110,50 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, AMDGPUContext::get_instance().push_back_kernel_arg_pointer(context_pointer); - for (auto &task : offloaded_tasks) { - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, - task.block_dim); - amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, - task.dynamic_shared_array_bytes, - {(void *)&context_pointer}, {arg_size}); + for (size_t i = 0; i < offloaded_tasks.size();) { + auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, + task.block_dim); + amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, + task.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && + offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; + } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + void *s = nullptr; + AMDGPUDriver::get_instance().stream_create(&s, 0); + stream_by_id[sid] = s; + } + } + + for (size_t j = group_start; j < i; j++) { + auto &t = offloaded_tasks[j]; + AMDGPUContext::get_instance().set_stream( + stream_by_id[t.stream_parallel_group_id]); + amdgpu_module->launch(t.name, t.grid_dim, t.block_dim, + t.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + } + + for (auto &[sid, s] : stream_by_id) { + AMDGPUDriver::get_instance().stream_synchronize(s); + } + for (auto &[sid, s] : stream_by_id) { + AMDGPUDriver::get_instance().stream_destroy(s); + } + + AMDGPUContext::get_instance().set_stream(active_stream); + } } QD_TRACE("Launching kernel"); if (ctx.arg_buffer_size > 0) { diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 13845d5a9..94aa786b5 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -1,3 +1,5 @@ +#include + #include "quadrants/runtime/cuda/kernel_launcher.h" #include "quadrants/rhi/cuda/cuda_context.h" #include "quadrants/rhi/cuda/cuda_driver.h" @@ -139,12 +141,50 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.get_context().arg_buffer = device_arg_buffer; } - for (auto task : offloaded_tasks) { - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, - task.block_dim); - cuda_module->launch(task.name, task.grid_dim, task.block_dim, - task.dynamic_shared_array_bytes, {&ctx.get_context()}, - {}); + for (size_t i = 0; i < offloaded_tasks.size();) { + auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, + task.block_dim); + cuda_module->launch(task.name, task.grid_dim, task.block_dim, + task.dynamic_shared_array_bytes, {&ctx.get_context()}, + {}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && + offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; + } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + void *s = nullptr; + CUDADriver::get_instance().stream_create(&s, 0); + stream_by_id[sid] = s; + } + } + + for (size_t j = group_start; j < i; j++) { + auto &t = offloaded_tasks[j]; + CUDAContext::get_instance().set_stream( + stream_by_id[t.stream_parallel_group_id]); + cuda_module->launch(t.name, t.grid_dim, t.block_dim, + t.dynamic_shared_array_bytes, {&ctx.get_context()}, + {}); + } + + for (auto &[sid, s] : stream_by_id) { + CUDADriver::get_instance().stream_synchronize(s); + } + for (auto &[sid, s] : stream_by_id) { + CUDADriver::get_instance().stream_destroy(s); + } + + CUDAContext::get_instance().set_stream(active_stream); + } } if (ctx.arg_buffer_size > 0) { CUDADriver::get_instance().mem_free_async(device_arg_buffer, active_stream); diff --git a/quadrants/transforms/lower_ast.cpp b/quadrants/transforms/lower_ast.cpp index 74b698a9e..ef1bb6f06 100644 --- a/quadrants/transforms/lower_ast.cpp +++ b/quadrants/transforms/lower_ast.cpp @@ -232,6 +232,7 @@ class LowerAST : public IRVisitor { snode, std::move(stmt->body), stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim); new_for->index_offsets = offsets; + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; for (int i = 0; i < (int)stmt->loop_var_ids.size(); i++) { Stmt *loop_index = new_statements.push_back( @@ -270,6 +271,7 @@ class LowerAST : public IRVisitor { begin, end, std::move(stmt->body), stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/fmt::format("arg ({})", fmt::join(arg_id, ", "))); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; Stmt *loop_index = new_statements.push_back(new_for.get(), 0); @@ -311,6 +313,7 @@ class LowerAST : public IRVisitor { begin_stmt, end_stmt, std::move(stmt->body), stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] = diff --git a/quadrants/transforms/offload.cpp b/quadrants/transforms/offload.cpp index 2f2024736..f3e254a88 100644 --- a/quadrants/transforms/offload.cpp +++ b/quadrants/transforms/offload.cpp @@ -134,6 +134,7 @@ class Offloader { offloaded->body->insert(std::move(s->body->statements[j])); } offloaded->range_hint = s->range_hint; + offloaded->stream_parallel_group_id = s->stream_parallel_group_id; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { assemble_serial_statements(); @@ -257,6 +258,8 @@ class Offloader { offloaded_struct_for->num_cpu_threads = std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads); offloaded_struct_for->mem_access_opt = mem_access_opt; + offloaded_struct_for->stream_parallel_group_id = + for_stmt->stream_parallel_group_id; root_block->insert(std::move(offloaded_struct_for)); } diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 002014c96..241f3143d 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -218,6 +218,7 @@ def _get_expected_matrix_apis(): "static_assert", "static_print", "stop_grad", + "stream_parallel", "svd", "sym_eig", "sync", diff --git a/tests/python/test_streams.py b/tests/python/test_streams.py index 073d383c2..4c28b6f58 100644 --- a/tests/python/test_streams.py +++ b/tests/python/test_streams.py @@ -180,23 +180,6 @@ def fill(): e.destroy() -@test_utils.test() -def test_stream_with_ndarray(): - N = 1024 - - @qd.kernel - def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): - for i in range(N): - arr[i] = 99.0 - - arr = qd.ndarray(qd.f32, shape=(N,)) - s = qd.create_stream() - fill(arr, qd_stream=s) - s.synchronize() - assert np.allclose(arr.to_numpy(), 99.0) - s.destroy() - - @test_utils.test() def test_concurrent_streams_with_events(): """Two slow kernels on separate streams run concurrently (~1s on GPU), @@ -275,3 +258,164 @@ def add_first_two(a: qd.types.ndarray(dtype=qd.f32, ndim=1)): s2.destroy() e1.destroy() e2.destroy() + + +@test_utils.test() +def test_stream_parallel_basic(): + """Each with qd.stream_parallel() block runs on its own stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_parallel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 2.0 + + fill_parallel() + qd.sync() + assert np.allclose(a.to_numpy(), 1.0) + assert np.allclose(b.to_numpy(), 2.0) + + +@test_utils.test() +def test_stream_parallel_multiple_loops_per_stream(): + """Multiple for loops inside one stream_parallel block share a stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + c = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def parallel_phase(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = a[i] + 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 10.0 + + @qd.kernel + def combine(): + for i in range(N): + c[i] = a[i] + b[i] + + parallel_phase() + combine() + qd.sync() + assert np.allclose(a.to_numpy(), 2.0) + assert np.allclose(b.to_numpy(), 10.0) + assert np.allclose(c.to_numpy(), 12.0) + + +@test_utils.test() +def test_stream_parallel_timing(): + """stream_parallel achieves speedup on GPU, serial fallback elsewhere.""" + SPIN_ITERS = 5_000_000 + + a = qd.field(qd.i32, shape=(2,)) + b = qd.field(qd.i32, shape=(2,)) + + @qd.kernel + def serial_spin(): + for _ in range(1): + x = a[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[0] = x + for _ in range(1): + x = a[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[1] = x + + @qd.kernel + def parallel_spin(): + with qd.stream_parallel(): + for _ in range(1): + x = b[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[0] = x + with qd.stream_parallel(): + for _ in range(1): + x = b[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[1] = x + + import time + + # Warm up + serial_spin() + parallel_spin() + qd.sync() + + qd.sync() + t0 = time.perf_counter() + serial_spin() + qd.sync() + serial_time = time.perf_counter() - t0 + + qd.sync() + t0 = time.perf_counter() + parallel_spin() + qd.sync() + stream_time = time.perf_counter() - t0 + + speedup = serial_time / stream_time + if qd.lang.impl.current_cfg().arch in (qd.cuda, qd.amdgpu): + assert speedup > 1.5, ( + f"Expected >1.5x speedup, got {speedup:.2f}x " f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + else: + assert speedup > 0.75, ( + f"Expected >=0.75x (serial fallback), got {speedup:.2f}x " + f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + + +@test_utils.test() +def test_stream_parallel_rejects_mixed_top_level(): + """Mixing stream_parallel and non-stream_parallel at top level is an error.""" + import pytest # noqa: I001 + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="all top-level statements"): + + @qd.kernel + def bad_kernel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = 2.0 + + bad_kernel() + + +@test_utils.test() +def test_stream_with_ndarray(): + N = 1024 + + @qd.kernel + def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): + for i in range(N): + arr[i] = 99.0 + + arr = qd.ndarray(qd.f32, shape=(N,)) + s = qd.create_stream() + fill(arr, qd_stream=s) + s.synchronize() + assert np.allclose(arr.to_numpy(), 99.0) + s.destroy() From be7ad924c333a589f13bbbe34f2d9583649007f5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 17:29:24 -0700 Subject: [PATCH 2/4] Clear stream_parallel_group_id in ForLoopDecoratorRecorder::reset() Prevents stale group IDs from leaking if insert_for is called after a path that set a non-zero stream_parallel_group_id, matching the reset pattern of all other ForLoopConfig fields. --- quadrants/ir/frontend_ir.h | 1 + 1 file changed, 1 insertion(+) diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 693a7f461..38226ca1b 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -954,6 +954,7 @@ class ASTBuilder { config.mem_access_opt.clear(); config.block_dim = 0; config.strictly_serialized = false; + config.stream_parallel_group_id = 0; } }; From ce8328102ae0b18f0b29d661b4dc4026edf3c4a8 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 17:29:36 -0700 Subject: [PATCH 3/4] Reject nested stream_parallel blocks Add an error check in begin_stream_parallel() to prevent nesting, which would produce undefined group ID semantics. --- quadrants/ir/frontend_ir.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 38226ca1b..46d7a3ec7 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -1113,6 +1113,8 @@ class ASTBuilder { } void begin_stream_parallel() { + QD_ERROR_IF(current_stream_parallel_group_id_ != 0, + "Nested stream_parallel blocks are not supported"); current_stream_parallel_group_id_ = ++stream_parallel_group_counter_; } From 880abc7e74cc8be0979d54747ff753929f00221d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 17:30:08 -0700 Subject: [PATCH 4/4] Document stream_parallel launcher design: per-launch streams, shared context safety Add comments explaining that streams are created/destroyed per launch (stream pooling as future optimization), and that RuntimeContext sharing across concurrent streams is safe because kernels only read from it. --- quadrants/runtime/amdgpu/kernel_launcher.cpp | 5 +++++ quadrants/runtime/cuda/kernel_launcher.cpp | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index f859bb116..6abd0778e 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -127,6 +127,8 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, i++; } + // Create one stream per unique group ID. Streams are created/destroyed + // per launch; a stream pool could reduce overhead for hot loops. std::map stream_by_id; for (size_t j = group_start; j < i; j++) { int sid = offloaded_tasks[j].stream_parallel_group_id; @@ -137,6 +139,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, } } + // Launch tasks concurrently on their respective streams. The shared + // RuntimeContext is safe here: kernels only read from it (args/runtime + // pointers); result_buffer writes are to disjoint offsets per task. for (size_t j = group_start; j < i; j++) { auto &t = offloaded_tasks[j]; AMDGPUContext::get_instance().set_stream( diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 2e10226a1..9cf24915a 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -159,6 +159,8 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, i++; } + // Create one stream per unique group ID. Streams are created/destroyed + // per launch; a stream pool could reduce overhead for hot loops. std::map stream_by_id; for (size_t j = group_start; j < i; j++) { int sid = offloaded_tasks[j].stream_parallel_group_id; @@ -169,6 +171,9 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, } } + // Launch tasks concurrently on their respective streams. The shared + // RuntimeContext is safe here: kernels only read from it (args/runtime + // pointers); result_buffer writes are to disjoint offsets per task. for (size_t j = group_start; j < i; j++) { auto &t = offloaded_tasks[j]; CUDAContext::get_instance().set_stream(