Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from quadrants.lang.ast.ast_transformers.function_def_transformer import (
FunctionDefTransformer,
)
from quadrants.lang.ast.symbol_resolver import ASTResolver
from quadrants.lang.exception import (
QuadrantsIndexError,
QuadrantsRuntimeTypeError,
Expand All @@ -39,6 +40,7 @@
from quadrants.lang.field import Field
from quadrants.lang.matrix import Matrix, MatrixType
from quadrants.lang.snode import append, deactivate, length
from quadrants.lang.stream import stream_parallel
from quadrants.lang.struct import Struct, StructType
from quadrants.types import primitive_types
from quadrants.types.utils import is_integral
Expand Down Expand Up @@ -108,7 +110,11 @@ def build_AnnAssign(ctx: ASTTransformerFuncContext, node: ast.AnnAssign):

@staticmethod
def build_assign_annotated(
ctx: ASTTransformerFuncContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
ctx: ASTTransformerFuncContext,
target: ast.Name,
value,
is_static_assign: bool,
annotation: Type,
):
"""Build an annotated assignment like this: target: annotation = value.

Expand Down Expand Up @@ -156,7 +162,10 @@ def build_Assign(ctx: ASTTransformerFuncContext, node: ast.Assign) -> None:

@staticmethod
def build_assign_unpack(
ctx: ASTTransformerFuncContext, node_target: list | ast.Tuple, values, is_static_assign: bool
ctx: ASTTransformerFuncContext,
node_target: list | ast.Tuple,
values,
is_static_assign: bool,
):
"""Build the unpack assignments like this: (target1, target2) = (value1, value2).
The function should be called only if the node target is a tuple.
Expand Down Expand Up @@ -538,7 +547,8 @@ def build_Return(ctx: ASTTransformerFuncContext, node: ast.Return) -> None:
else:
raise QuadrantsSyntaxError("The return type is not supported now!")
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group(return_exprs), _qd_core.DebugInfo(ctx.get_pos_info(node))
expr.make_expr_group(return_exprs),
_qd_core.DebugInfo(ctx.get_pos_info(node)),
)
else:
ctx.return_data = node.value.ptr
Expand Down Expand Up @@ -1381,6 +1391,22 @@ def build_Continue(ctx: ASTTransformerFuncContext, node: ast.Continue) -> None:
ctx.ast_builder.insert_continue_stmt(_qd_core.DebugInfo(ctx.get_pos_info(node)))
return None

@staticmethod
def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None:
if len(node.items) != 1:
raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports a single context manager")
item = node.items[0]
Comment on lines +1396 to +1398
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Not clear what items is.

if item.optional_vars is not None:
raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels")
if not isinstance(item.context_expr, ast.Call):
raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression")
if not ASTResolver.resolve_to(item.context_expr.func, stream_parallel, ctx.global_vars):
raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()")
ctx.ast_builder.begin_stream_parallel()
build_stmts(ctx, node.body)
ctx.ast_builder.end_stream_parallel()
return None

@staticmethod
def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from quadrants.lang.ast.ast_transformer_utils import (
ASTTransformerFuncContext,
)
from quadrants.lang.ast.symbol_resolver import ASTResolver
from quadrants.lang.exception import (
QuadrantsSyntaxError,
)
from quadrants.lang.matrix import MatrixType
from quadrants.lang.stream import stream_parallel
from quadrants.lang.struct import StructType
from quadrants.lang.util import to_quadrants_type
from quadrants.types import annotations, ndarray_type, primitive_types
Expand Down Expand Up @@ -295,7 +297,34 @@ def build_FunctionDef(
else:
FunctionDefTransformer._transform_as_func(ctx, node, args)

if ctx.is_kernel:
FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars)

with ctx.variable_scope_guard():
build_stmts(ctx, node.body)

return None

@staticmethod
def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool:
if not isinstance(stmt, ast.With):
return False
if len(stmt.items) != 1:
return False
item = stmt.items[0]
Comment on lines +312 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is items ? Could you document here or somewhere else why the length can be 1 or more, and what does it means in this context?

if not isinstance(item.context_expr, ast.Call):
return False
return ASTResolver.resolve_to(item.context_expr.func, stream_parallel, global_vars)

@staticmethod
def _validate_stream_parallel_exclusivity(body: list[ast.stmt], global_vars: dict[str, Any]) -> None:
has_sp = any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body)
if not has_sp:
Comment on lines +321 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather do

# <Insert fancy comment explaining what this check is doing>
if not any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body):
    return

return
for stmt in body:
if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars):
raise QuadrantsSyntaxError(
"When using qd.stream_parallel(), all top-level statements "
"in the kernel must be 'with qd.stream_parallel():' blocks. "
"Move non-parallel code to a separate kernel."
)
Comment on lines +327 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why you are moving to the next line before you have to. This is weird to me. But I don't care much.

14 changes: 13 additions & 1 deletion python/quadrants/lang/stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import weakref
from contextlib import contextmanager

from quadrants.lang import impl

Expand Down Expand Up @@ -121,4 +122,15 @@ def create_event() -> Event:
return Event(handle, _get_prog_weakref())


__all__ = ["Stream", "Event", "create_stream", "create_event"]
@contextmanager
def stream_parallel():
"""Run top-level for loops in this block on separate GPU streams.

Used inside @qd.kernel. At Python runtime (outside kernels), this is a
no-op. During kernel compilation, the AST transformer calls into the C++
ASTBuilder to tag loops with a stream-parallel group ID.
"""
yield


__all__ = ["Stream", "Event", "create_stream", "create_event", "stream_parallel"]
1 change: 1 addition & 0 deletions quadrants/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(stmt->strictly_serialized);
emit(stmt->mem_access_opt);
emit(stmt->block_dim);
emit(stmt->stream_parallel_group_id);
emit(stmt->body.get());
}

Expand Down
1 change: 1 addition & 0 deletions quadrants/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
current_task->grid_dim = num_SMs * query_max_block_per_sm;
}
current_task->block_dim = stmt->block_dim;
current_task->stream_parallel_group_id = stmt->stream_parallel_group_id;
QD_ASSERT(current_task->grid_dim != 0);
QD_ASSERT(current_task->block_dim != 0);
offloaded_tasks.push_back(*current_task);
Expand Down
1 change: 1 addition & 0 deletions quadrants/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
current_task->block_dim = stmt->block_dim;
current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes;
current_task->stream_parallel_group_id = stmt->stream_parallel_group_id;
QD_ASSERT(current_task->grid_dim != 0);
QD_ASSERT(current_task->block_dim != 0);
offloaded_tasks.push_back(*current_task);
Expand Down
13 changes: 10 additions & 3 deletions quadrants/codegen/llvm/llvm_compiled_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@ class OffloadedTask {
int block_dim{0};
int grid_dim{0};
int dynamic_shared_array_bytes{0};
int stream_parallel_group_id{0};

explicit OffloadedTask(const std::string &name = "",
int block_dim = 0,
int grid_dim = 0,
int dynamic_shared_array_bytes = 0)
int dynamic_shared_array_bytes = 0,
int stream_parallel_group_id = 0)
: name(name),
block_dim(block_dim),
grid_dim(grid_dim),
dynamic_shared_array_bytes(dynamic_shared_array_bytes) {};
QD_IO_DEF(name, block_dim, grid_dim, dynamic_shared_array_bytes);
dynamic_shared_array_bytes(dynamic_shared_array_bytes),
stream_parallel_group_id(stream_parallel_group_id) {};
QD_IO_DEF(name,
block_dim,
grid_dim,
dynamic_shared_array_bytes,
stream_parallel_group_id);
};

struct LLVMCompiledTask {
Expand Down
12 changes: 11 additions & 1 deletion quadrants/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,16 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o)
num_cpu_threads(o.num_cpu_threads),
strictly_serialized(o.strictly_serialized),
mem_access_opt(o.mem_access_opt),
block_dim(o.block_dim) {
block_dim(o.block_dim),
stream_parallel_group_id(o.stream_parallel_group_id) {
}

void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) {
is_bit_vectorized = config.is_bit_vectorized;
strictly_serialized = config.strictly_serialized;
mem_access_opt = config.mem_access_opt;
block_dim = config.block_dim;
stream_parallel_group_id = config.stream_parallel_group_id;
if (arch == Arch::cuda || arch == Arch::amdgpu) {
num_cpu_threads = 1;
QD_ASSERT(block_dim <= quadrants_max_gpu_block_dim);
Expand Down Expand Up @@ -1542,6 +1544,8 @@ void ASTBuilder::begin_frontend_range_for(const Expr &i,
const Expr &s,
const Expr &e,
const DebugInfo &dbg_info) {
for_loop_dec_.config.stream_parallel_group_id =
current_stream_parallel_group_id_;
auto stmt_unique = std::make_unique<FrontendForStmt>(
i, s, e, arch_, for_loop_dec_.config, dbg_info);
auto stmt = stmt_unique.get();
Expand All @@ -1558,6 +1562,8 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars,
for_loop_dec_.config.strictly_serialized,
"ti.loop_config(serialize=True) does not have effect on the struct for. "
"The execution order is not guaranteed.");
for_loop_dec_.config.stream_parallel_group_id =
current_stream_parallel_group_id_;
auto stmt_unique = std::make_unique<FrontendForStmt>(
loop_vars, snode, arch_, for_loop_dec_.config, dbg_info);
for_loop_dec_.reset();
Expand All @@ -1574,6 +1580,8 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor(
for_loop_dec_.config.strictly_serialized,
"ti.loop_config(serialize=True) does not have effect on the struct for. "
"The execution order is not guaranteed.");
for_loop_dec_.config.stream_parallel_group_id =
current_stream_parallel_group_id_;
auto stmt_unique = std::make_unique<FrontendForStmt>(
loop_vars, external_tensor, arch_, for_loop_dec_.config, dbg_info);
for_loop_dec_.reset();
Expand All @@ -1591,6 +1599,8 @@ void ASTBuilder::begin_frontend_mesh_for(
for_loop_dec_.config.strictly_serialized,
"ti.loop_config(serialize=True) does not have effect on the mesh for. "
"The execution order is not guaranteed.");
for_loop_dec_.config.stream_parallel_group_id =
current_stream_parallel_group_id_;
auto stmt_unique =
std::make_unique<FrontendForStmt>(ExprGroup(i), mesh_ptr, element_type,
arch_, for_loop_dec_.config, dbg_info);
Expand Down
15 changes: 15 additions & 0 deletions quadrants/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct ForLoopConfig {
MemoryAccessOptions mem_access_opt;
int block_dim{0};
bool uniform{false};
int stream_parallel_group_id{0};
};

#define QD_DEFINE_CLONE_FOR_FRONTEND_IR \
Expand Down Expand Up @@ -207,6 +208,7 @@ class FrontendForStmt : public Stmt {
bool strictly_serialized;
MemoryAccessOptions mem_access_opt;
int block_dim;
int stream_parallel_group_id{0};

FrontendForStmt(const ExprGroup &loop_vars,
SNode *snode,
Expand Down Expand Up @@ -952,6 +954,7 @@ class ASTBuilder {
config.mem_access_opt.clear();
config.block_dim = 0;
config.strictly_serialized = false;
config.stream_parallel_group_id = 0;
}
};

Expand All @@ -961,6 +964,8 @@ class ASTBuilder {
Arch arch_;
ForLoopDecoratorRecorder for_loop_dec_;
int id_counter_{0};
int stream_parallel_group_counter_{0};
int current_stream_parallel_group_id_{0};

public:
ASTBuilder(Block *initial, Arch arch, bool is_kernel)
Expand Down Expand Up @@ -1107,6 +1112,16 @@ class ASTBuilder {
for_loop_dec_.reset();
}

void begin_stream_parallel() {
QD_ERROR_IF(current_stream_parallel_group_id_ != 0,
"Nested stream_parallel blocks are not supported");
current_stream_parallel_group_id_ = ++stream_parallel_group_counter_;
}

void end_stream_parallel() {
current_stream_parallel_group_id_ = 0;
}

Identifier get_next_id(const std::string &name = "") {
return Identifier(id_counter_++, name);
}
Expand Down
3 changes: 3 additions & 0 deletions quadrants/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ std::unique_ptr<Stmt> RangeForStmt::clone() const {
begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim,
strictly_serialized);
new_stmt->reversed = reversed;
new_stmt->stream_parallel_group_id = stream_parallel_group_id;
return new_stmt;
}

Expand All @@ -265,6 +266,7 @@ std::unique_ptr<Stmt> StructForStmt::clone() const {
auto new_stmt = std::make_unique<StructForStmt>(
snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim);
new_stmt->mem_access_opt = mem_access_opt;
new_stmt->stream_parallel_group_id = stream_parallel_group_id;
return new_stmt;
}

Expand Down Expand Up @@ -439,6 +441,7 @@ std::unique_ptr<Stmt> OffloadedStmt::clone() const {
new_stmt->tls_size = tls_size;
new_stmt->bls_size = bls_size;
new_stmt->mem_access_opt = mem_access_opt;
new_stmt->stream_parallel_group_id = stream_parallel_group_id;
return new_stmt;
}

Expand Down
3 changes: 3 additions & 0 deletions quadrants/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,7 @@ class RangeForStmt : public Stmt {
int block_dim;
bool strictly_serialized;
std::string range_hint;
int stream_parallel_group_id{0};

RangeForStmt(Stmt *begin,
Stmt *end,
Expand Down Expand Up @@ -1061,6 +1062,7 @@ class StructForStmt : public Stmt {
int num_cpu_threads;
int block_dim;
MemoryAccessOptions mem_access_opt;
int stream_parallel_group_id{0};

StructForStmt(SNode *snode,
std::unique_ptr<Block> &&body,
Expand Down Expand Up @@ -1443,6 +1445,7 @@ class OffloadedStmt : public Stmt {
std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte
std::size_t bls_size{0};
MemoryAccessOptions mem_access_opt;
int stream_parallel_group_id{0};

OffloadedStmt(TaskType task_type, Arch arch, Kernel *kernel);

Expand Down
4 changes: 3 additions & 1 deletion quadrants/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ void export_lang(py::module &m) {
.def("strictly_serialize", &ASTBuilder::strictly_serialize)
.def("block_dim", &ASTBuilder::block_dim)
.def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag)
.def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag);
.def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag)
.def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel)
.def("end_stream_parallel", &ASTBuilder::end_stream_parallel);

auto device_capability_config =
py::class_<DeviceCapabilityConfig>(m, "DeviceCapabilityConfig")
Expand Down
Loading
Loading