Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
dump_ir = os.environ.get("DLC_DUMP_IR", "0") == "1"
replace_ttshared_ir = os.environ.get("DLC_REPLACE_TTSHARED_IR_FILE", None)
replace_linked_ir = os.environ.get("DLC_REPLACE_LINKED_IR_FILE", None)
if dump_ir or (replace_ttshared_ir is not None) or (replace_linked_ir is not None):
replace_commonir_linked_ir = os.environ.get("DLC_REPLACE_COMMONIR_LINKED_IR_FILE", None)
if (
dump_ir
or (replace_ttshared_ir is not None)
or (replace_linked_ir is not None)
or (replace_commonir_linked_ir is not None)
):
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
dump_dir = "./tmp"
os.environ["TRITON_DUMP_DIR"] = os.environ.get("TRITON_DUMP_DIR", dump_dir)
Expand Down Expand Up @@ -441,6 +447,7 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False):
"--scalar-to-1d-tensor",
f"--linalg-to-linked=global-kernel=false named-ops=true",
"--linked-to-hivm",
"--vectorize-parallel-loop",
"-o",
dst_path,
]
Expand Down Expand Up @@ -471,14 +478,15 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False):
"--scalar-to-1d-tensor",
f"--linalg-to-linked=global-kernel=false named-ops=true",
"--linked-to-hivm",
"--vectorize-parallel-loop",
]
dicp_utils._dump_stage_ir(
content, metadata["hash"], "kernel.linkedir.mlir", cmd_list
)

if replace_linked_ir is not None:
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
return Path(replace_linked_ir).read_text()
if replace_commonir_linked_ir is not None:
print(f"[DEBUG] Replace Linkedir with {replace_commonir_linked_ir}")
return Path(replace_commonir_linked_ir).read_text()
return content


Expand All @@ -489,6 +497,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm)
dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm)
dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm)
# dicp_triton.passes.linked_npu.add_vectorize_kernel(pm) # 添加vectorize-kernel pass
dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True)
dicp_triton.passes.linked_npu.add_linked_to_hivm(pm)
pm.run(mod)
Expand Down Expand Up @@ -700,9 +709,9 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
callback_path = os.path.join(tmpdir, "libkernel.so")
multibuffer = metadata["multibuffer"]
_compile_option_list = []
if dump_ir:
_compile_option_list += [f"--mlir-print-ir-before-all"]
_compile_option_list += [f"--mlir-print-ir-after-all"]
# if dump_ir:
# _compile_option_list += [f"--mlir-print-ir-before-all"]
# _compile_option_list += [f"--mlir-print-ir-after-all"]
_compile_option_list += [
f"--enable-auto-multi-buffer={multibuffer}",
]
Expand Down
3 changes: 3 additions & 0 deletions compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ std::unique_ptr<OperationPass<mlir::func::FuncOp>> createScalarTo1DTensorPass();
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createNormalizeSliceOpsPass();

std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createVectorizeParallelLoopPass();

#define GEN_PASS_REGISTRATION
#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc"

Expand Down
53 changes: 33 additions & 20 deletions compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
include "mlir/Pass/PassBase.td"

def LinalgIfToSelect : Pass<"linalg-if-to-select", "mlir::ModuleOp"> {
let summary = "Convert scf.if inside parallel linalg.generic to arith.select and hoist selects.";
let summary = "Convert scf.if inside parallel linalg.generic to arith.select "
"and hoist selects.";
let description = [{
This pass converts conditional logic (`scf.if`) inside a parallel `linalg.generic`
operation into data-flow operations (`arith.select`).
Expand All @@ -23,14 +24,13 @@ def LinalgIfToSelect : Pass<"linalg-if-to-select", "mlir::ModuleOp"> {
simplify the body of the `linalg.generic` and enables further fusion.
}];
let constructor = "mlir::dicp::LinalgExt::createLinalgIfToSelectPass()";
let dependentDialects = ["mlir::linalg::LinalgDialect",
"mlir::arith::ArithDialect",
"mlir::tensor::TensorDialect",
"mlir::scf::SCFDialect",
"mlir::memref::MemRefDialect"];
let dependentDialects = [
"mlir::linalg::LinalgDialect", "mlir::arith::ArithDialect",
"mlir::tensor::TensorDialect", "mlir::scf::SCFDialect",
"mlir::memref::MemRefDialect"
];
}


def LinalgGenericToSCF : Pass<"linalg-generic-to-scf", "mlir::ModuleOp"> {
let summary = "Lower linalg.generic ops to SCF loops";
let description = [{
Expand All @@ -42,22 +42,20 @@ def LinalgGenericToSCF : Pass<"linalg-generic-to-scf", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}


def ScalarTo1DTensor : Pass<"scalar-to-1d-tensor", "mlir::func::FuncOp"> {
let summary = "Convert scalar computations and memref load/store to tensor<1 x T> form";
let description = [{
The ScalarTo1DTensor pass targets scalar computations and
memory access within a function and rewrites them into an explicit
tensor<1 x T>-based form. This enables uniform handling of scalar
values in subsequent bufferization and lowering passes.
}];

let summary =
"Convert scalar computations and memref load/store to tensor<1 x T> form";
let description =
[{The ScalarTo1DTensor pass targets scalar computations and memory access
within a function and rewrites them into an explicit tensor<1 x T> -
based form
.This enables uniform handling of scalar values in subsequent
bufferization and lowering passes.}];

let constructor = "mlir::dicp::LinalgExt::createScalarTo1DTensorPass()";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect",
"mlir::bufferization::BufferizationDialect",
"mlir::arith::ArithDialect", "mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect", "mlir::bufferization::BufferizationDialect",
"mlir::func::FuncDialect"
];
}
Expand All @@ -68,4 +66,19 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> {
let dependentDialects = ["mlir::tensor::TensorDialect"];
}

def VectorizeParallelLoop : Pass<"vectorize-parallel-loop", "func::FuncOp"> {
let summary = "Convert single-element parallel loops to vectorized batch "
"processing with step=size.";
let description = [{This pass transforms parallel loops that process single
elements into vectorized operations that can process
multiple elements simultaneously,
potentially increasing computational throughput.}];
let constructor = "mlir::dicp::LinalgExt::createVectorizeParallelLoopPass()";
let dependentDialects = [
"mlir::arith::ArithDialect", "mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect", "mlir::scf::SCFDialect",
"mlir::bufferization::BufferizationDialect", "mlir::func::FuncDialect"
];
}

#endif
1 change: 1 addition & 0 deletions compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_triton_library(LinalgExtTransforms
ScalarTo1DTensorPass.cpp
RemoveSingleIterationLoop.cpp
TensorTransform.cpp
VectorizeParallelLoopPass.cpp

DEPENDS
LinalgExtTransformsIncGen
Expand Down
Loading