diff --git a/backend/npu.py b/backend/npu.py index fdac4bb4..21afec22 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -23,7 +23,13 @@ dump_ir = os.environ.get("DLC_DUMP_IR", "0") == "1" replace_ttshared_ir = os.environ.get("DLC_REPLACE_TTSHARED_IR_FILE", None) replace_linked_ir = os.environ.get("DLC_REPLACE_LINKED_IR_FILE", None) -if dump_ir or (replace_ttshared_ir is not None) or (replace_linked_ir is not None): +replace_commonir_linked_ir = os.environ.get("DLC_REPLACE_COMMONIR_LINKED_IR_FILE", None) +if ( + dump_ir + or (replace_ttshared_ir is not None) + or (replace_linked_ir is not None) + or (replace_commonir_linked_ir is not None) +): os.environ["TRITON_ALWAYS_COMPILE"] = "1" dump_dir = "./tmp" os.environ["TRITON_DUMP_DIR"] = os.environ.get("TRITON_DUMP_DIR", dump_dir) @@ -441,6 +447,7 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): "--scalar-to-1d-tensor", f"--linalg-to-linked=global-kernel=false named-ops=true", "--linked-to-hivm", + "--vectorize-parallel-loop", "-o", dst_path, ] @@ -471,14 +478,15 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): "--scalar-to-1d-tensor", f"--linalg-to-linked=global-kernel=false named-ops=true", "--linked-to-hivm", + "--vectorize-parallel-loop", ] dicp_utils._dump_stage_ir( content, metadata["hash"], "kernel.linkedir.mlir", cmd_list ) - if replace_linked_ir is not None: - print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}") - return Path(replace_linked_ir).read_text() + if replace_commonir_linked_ir is not None: + print(f"[DEBUG] Replace Linkedir with {replace_commonir_linked_ir}") + return Path(replace_commonir_linked_ir).read_text() return content @@ -489,6 +497,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm) dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) + # dicp_triton.passes.linked_npu.add_vectorize_kernel(pm) # 添加vectorize-kernel pass dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) pm.run(mod) @@ -700,9 +709,9 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): callback_path = os.path.join(tmpdir, "libkernel.so") multibuffer = metadata["multibuffer"] _compile_option_list = [] - if dump_ir: - _compile_option_list += [f"--mlir-print-ir-before-all"] - _compile_option_list += [f"--mlir-print-ir-after-all"] + # if dump_ir: + # _compile_option_list += [f"--mlir-print-ir-before-all"] + # _compile_option_list += [f"--mlir-print-ir-after-all"] _compile_option_list += [ f"--enable-auto-multi-buffer={multibuffer}", ] diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index 7ae43b6c..ca63827d 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -23,6 +23,9 @@ std::unique_ptr> createScalarTo1DTensorPass(); std::unique_ptr> createNormalizeSliceOpsPass(); +std::unique_ptr> +createVectorizeParallelLoopPass(); + #define GEN_PASS_REGISTRATION #include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index c486210a..04041fea 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -4,7 +4,8 @@ include "mlir/Pass/PassBase.td" def LinalgIfToSelect : Pass<"linalg-if-to-select", "mlir::ModuleOp"> { - let summary = "Convert scf.if inside parallel linalg.generic to arith.select and hoist selects."; + let summary = "Convert scf.if inside parallel linalg.generic to arith.select " + "and hoist selects."; let description = [{ This pass converts conditional logic (`scf.if`) inside a parallel `linalg.generic` operation into data-flow operations (`arith.select`). @@ -23,14 +24,13 @@ def LinalgIfToSelect : Pass<"linalg-if-to-select", "mlir::ModuleOp"> { simplify the body of the `linalg.generic` and enables further fusion. }]; let constructor = "mlir::dicp::LinalgExt::createLinalgIfToSelectPass()"; - let dependentDialects = ["mlir::linalg::LinalgDialect", - "mlir::arith::ArithDialect", - "mlir::tensor::TensorDialect", - "mlir::scf::SCFDialect", - "mlir::memref::MemRefDialect"]; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", "mlir::arith::ArithDialect", + "mlir::tensor::TensorDialect", "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect" + ]; } - def LinalgGenericToSCF : Pass<"linalg-generic-to-scf", "mlir::ModuleOp"> { let summary = "Lower linalg.generic ops to SCF loops"; let description = [{ @@ -42,22 +42,20 @@ def LinalgGenericToSCF : Pass<"linalg-generic-to-scf", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect"]; } - def ScalarTo1DTensor : Pass<"scalar-to-1d-tensor", "mlir::func::FuncOp"> { - let summary = "Convert scalar computations and memref load/store to tensor<1 x T> form"; - let description = [{ - The ScalarTo1DTensor pass targets scalar computations and - memory access within a function and rewrites them into an explicit - tensor<1 x T>-based form. This enables uniform handling of scalar - values in subsequent bufferization and lowering passes. - }]; - + let summary = + "Convert scalar computations and memref load/store to tensor<1 x T> form"; + let description = + [{The ScalarTo1DTensor pass targets scalar computations and memory access + within a function and rewrites them into an explicit tensor<1 x T> - + based form + .This enables uniform handling of scalar values in subsequent + bufferization and lowering passes.}]; + let constructor = "mlir::dicp::LinalgExt::createScalarTo1DTensorPass()"; let dependentDialects = [ - "mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::tensor::TensorDialect", - "mlir::bufferization::BufferizationDialect", + "mlir::arith::ArithDialect", "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect", "mlir::bufferization::BufferizationDialect", "mlir::func::FuncDialect" ]; } @@ -68,4 +66,19 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } +def VectorizeParallelLoop : Pass<"vectorize-parallel-loop", "func::FuncOp"> { + let summary = "Convert single-element parallel loops to vectorized batch " + "processing with step=size."; + let description = [{This pass transforms parallel loops that process single + elements into vectorized operations that can process + multiple elements simultaneously, + potentially increasing computational throughput.}]; + let constructor = "mlir::dicp::LinalgExt::createVectorizeParallelLoopPass()"; + let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect", "mlir::scf::SCFDialect", + "mlir::bufferization::BufferizationDialect", "mlir::func::FuncDialect" + ]; +} + #endif diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 0b28548a..f5d52f63 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(LinalgExtTransforms ScalarTo1DTensorPass.cpp RemoveSingleIterationLoop.cpp TensorTransform.cpp + VectorizeParallelLoopPass.cpp DEPENDS LinalgExtTransformsIncGen diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp new file mode 100644 index 00000000..e73ae778 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp @@ -0,0 +1,413 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_VECTORIZEPARALLELLOOPPASS +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +#define DEBUG_TYPE "vectorize-parallel-loop-pass" + +namespace { + +// 核心 Pattern:将标量并行循环展开为向量化的顺序操作 +struct VectorizeParallelLoopPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ParallelOp op, + PatternRewriter &rewriter) const override { + LLVM_DEBUG( + llvm::dbgs() + << "\n[VectorizeParallelLoop] >>> Start matching scf.parallel at " + << op.getLoc() << "\n"); + + // 1. 检查循环结构 + if (op.getNumLoops() != 1) { + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] Skip: Multi-dimensional loop.\n"); + return failure(); + } + + Value lowerBound = op.getLowerBound()[0]; + Value upperBound = op.getUpperBound()[0]; + + auto lowerOp = lowerBound.getDefiningOp(); + auto upperOp = upperBound.getDefiningOp(); + + if (!lowerOp || !upperOp) { + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] Skip: Bounds are not constant.\n"); + return failure(); + } + + int64_t lowerVal = lowerOp.value(); + int64_t upperVal = upperOp.value(); + int64_t size = upperVal - lowerVal; + + LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Loop Bounds: [" + << lowerVal << ", " << upperVal << ")\n"); + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] Calculated Vector Size: " << size + << "\n"); + + // 只有当有实际计算量时才处理 + if (size <= 0) { + LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Skip: Size <= 0.\n"); + return failure(); + } + + // 2. 准备映射表 + // mapper: 用于处理索引计算 (将 Loop IV 映射为常数 LowerBound) + IRMapping mapper; + Block *body = op.getBody(); + Value iv = body->getArgument(0); + + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] Mapping Induction Variable " << iv + << " -> Constant " << lowerBound << "\n"); + mapper.map(iv, lowerBound); // 关键修复:将 IV 替换为 Loop 起始值 + + // scalarToTensorMap: 用于数据流向量化 (标量 Value -> 向量 Tensor Value) + DenseMap scalarToTensorMap; + + LLVM_DEBUG( + llvm::dbgs() + << "[VectorizeParallelLoop] Starting to process body operations...\n"); + + // 3. 遍历原循环体,按顺序生成向量化代码 + for (Operation &inst : body->getOperations()) { + LLVM_DEBUG(llvm::dbgs() + << " -> Visiting Op: " << inst.getName() << "\n"); + + // 跳过 terminator + if (isa(inst) || isa(inst)) { + LLVM_DEBUG(llvm::dbgs() << " Skipping terminator.\n"); + continue; + } + + // --- Case A: 索引计算 (Index Cast, Add, Mul 等) --- + // 直接克隆,但使用 mapper 将 IV 替换为常数 + if (isa(inst)) { + LLVM_DEBUG(llvm::dbgs() + << " [Action] Cloning index calculation.\n"); + Operation *newOp = rewriter.clone(inst, mapper); + LLVM_DEBUG(llvm::dbgs() + << " New Op result: " << newOp->getResult(0) << "\n"); + continue; + } + + // --- Case B: 读取内存 (Load -> Vectorize) --- + if (auto loadOp = dyn_cast(inst)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Vectorizing LoadOp.\n"); + Value memref = loadOp.getMemRef(); + // 获取计算好的索引 (通过 mapper 查找) + Value index = mapper.lookup(loadOp.getIndices()[0]); + LLVM_DEBUG(llvm::dbgs() << " Base MemRef: " << memref << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Mapped Index: " << index << "\n"); + + // 1. Alloc Local Buffer + auto memrefType = dyn_cast(memref.getType()); + if (!memrefType) { + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] ERROR: MemRef type expected " + "but not found.\n"); + return failure(); + } + auto localType = MemRefType::get({size}, memrefType.getElementType()); + Value localAlloc = + rewriter.create(op.getLoc(), localType); + LLVM_DEBUG(llvm::dbgs() << " Created Local Alloc: " + << localAlloc.getType() << "\n"); + + // 2. Subview Global Memory + SmallVector offsets = {index}; + SmallVector sizes = {rewriter.getIndexAttr(size)}; + SmallVector strides = {rewriter.getIndexAttr(1)}; + Value subview = rewriter.create( + op.getLoc(), memref, offsets, sizes, strides); + LLVM_DEBUG(llvm::dbgs() << " Created Subview.\n"); + + // 3. Copy Global -> Local + rewriter.create(op.getLoc(), subview, localAlloc); + LLVM_DEBUG(llvm::dbgs() << " Created Copy (Global -> Local).\n"); + + // 4. Local Buffer -> Tensor + auto tensorType = + RankedTensorType::get({size}, memrefType.getElementType()); + auto toTensor = rewriter.create( + op.getLoc(), tensorType, localAlloc, /*restrict=*/true); + LLVM_DEBUG(llvm::dbgs() << " Created ToTensorOp (Result: " + << toTensor.getResult() << ").\n"); + + // 5. 注册映射:原 Load 的标量结果 -> 新的 Tensor 结果 + scalarToTensorMap[loadOp.getResult()] = toTensor.getResult(); + continue; + } + + // --- Case C: 计算逻辑 (Generic Binary Operations -> Vector Binary + // Operations) --- 检查是否为二元运算操作 + bool isBinaryOp = + inst.getNumOperands() == 2 && + (isa(inst)); + + if (isBinaryOp) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Binary ArithOp: " + << inst.getName() << "\n"); + + Value lhs = inst.getOperand(0); + Value rhs = inst.getOperand(1); + + // 检查操作数是否已向量化 + Value vecLhs = + scalarToTensorMap.count(lhs) ? scalarToTensorMap[lhs] : nullptr; + Value vecRhs = + scalarToTensorMap.count(rhs) ? scalarToTensorMap[rhs] : nullptr; + + if (vecLhs) + LLVM_DEBUG(llvm::dbgs() << " LHS is vectorized.\n"); + if (vecRhs) + LLVM_DEBUG(llvm::dbgs() << " RHS is vectorized.\n"); + + // 如果两个输入都是向量,生成向量运算 + if (vecLhs && vecRhs) { + // 创建一个新的OperationState,使用与原操作相同的操作码 + OperationState state(op.getLoc(), inst.getName().getStringRef()); + + // 添加向量化的操作数 + state.addOperands({vecLhs, vecRhs}); + + // 从原操作复制结果类型,但转换为向量类型 + llvm::SmallVector resultTypes; + for (auto result : inst.getResults()) { + Type scalarType = result.getType(); + ShapedType vectorType; + + if (auto shapedType = dyn_cast(scalarType)) { + // 如果已经是shaped type,则保持形状但可能更新为tensor类型 + vectorType = RankedTensorType::get(shapedType.getShape(), + shapedType.getElementType()); + } else { + // 如果是标量类型,转换为对应元素类型的向量 + vectorType = RankedTensorType::get({size}, scalarType); + } + + resultTypes.push_back(vectorType); + } + state.addTypes(resultTypes); + + // 创建新的向量化操作 + auto newOp = rewriter.create(state); + + // 将新操作的结果映射到scalarToTensorMap + for (size_t i = 0; i < inst.getNumResults(); ++i) { + scalarToTensorMap[inst.getResult(i)] = newOp->getResult(i); + } + + LLVM_DEBUG({ + llvm::dbgs() << " Created Vector Operation: " << inst.getName() + << "\n"; + llvm::dbgs() << " Result Type: " + << newOp->getResult(0).getType() << "\n"; + }); + } else { + // 如果不是向量操作(可能是索引计算的一部分),则回退到普通 clone + LLVM_DEBUG( + llvm::dbgs() + << " WARNING: Operands not vectorized, cloning scalar op.\n"); + rewriter.clone(inst, mapper); + } + continue; + } + + // --- Case D: 写回逻辑 (Materialize) --- + if (auto matOp = + dyn_cast(inst)) { + LLVM_DEBUG(llvm::dbgs() + << " [Action] Processing MaterializeInDestinationOp.\n"); + Value source = matOp.getSource(); + Value destMemref = matOp.getDest(); + + Value vectorResult = nullptr; + + // 追踪数据来源 + if (auto insertOp = source.getDefiningOp()) { + LLVM_DEBUG( + llvm::dbgs() + << " Source is tensor.insert, tracing scalar input...\n"); + Value scalarInput = insertOp.getScalar(); + if (scalarToTensorMap.count(scalarInput)) { + vectorResult = scalarToTensorMap[scalarInput]; + LLVM_DEBUG(llvm::dbgs() << " Found vectorized source.\n"); + } + } else if (scalarToTensorMap.count(source)) { + vectorResult = scalarToTensorMap[source]; + LLVM_DEBUG(llvm::dbgs() + << " Found vectorized source directly.\n"); + } + + if (vectorResult) { + // 1. Alloc Output Buffer + auto tensorType = dyn_cast(vectorResult.getType()); + if (!tensorType) { + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] ERROR: Expected " + "RankedTensorType for vector result.\n"); + continue; + } + auto elemType = tensorType.getElementType(); + auto localOutType = MemRefType::get({size}, elemType); + Value localOut = + rewriter.create(op.getLoc(), localOutType); + LLVM_DEBUG(llvm::dbgs() << " Created Local Output Alloc: " + << localOutType << "\n"); + + // 2. Materialize Tensor -> Local Buffer + // Fix: capture operation and set writable to true + auto newMatOp = + rewriter.create( + op.getLoc(), vectorResult, localOut); + newMatOp.setWritable(true); + LLVM_DEBUG( + llvm::dbgs() + << " Created Vectorized Materialize (writable=true).\n"); + + // 3. 处理输出地址 (ReinterpretCast -> Subview) + Value baseMemref = destMemref; + Value writeOffset = nullptr; + + if (auto castOp = + destMemref.getDefiningOp()) { + LLVM_DEBUG( + llvm::dbgs() + << " Dest is ReinterpretCast, resolving offset...\n"); + baseMemref = castOp.getSource(); + if (!castOp.getOffsets().empty()) { + // Fix: Directly use the Value, do not use dyn_cast + Value loopOffset = castOp.getOffsets()[0]; + writeOffset = mapper.lookup(loopOffset); + LLVM_DEBUG(llvm::dbgs() << " Resolved write offset: " + << writeOffset << "\n"); + } + } else { + LLVM_DEBUG(llvm::dbgs() + << " Dest is not ReinterpretCast. Handling logic " + "might be incomplete for simple memrefs.\n"); + } + + // 如果找到了写入位置,执行 Copy Local -> Global + if (baseMemref && writeOffset) { + SmallVector offsets = {writeOffset}; + SmallVector sizes = {rewriter.getIndexAttr(size)}; + SmallVector strides = {rewriter.getIndexAttr(1)}; + + Value outSubview = rewriter.create( + op.getLoc(), baseMemref, offsets, sizes, strides); + + rewriter.create(op.getLoc(), localOut, outSubview); + LLVM_DEBUG(llvm::dbgs() + << " Created Copy (Local -> Global).\n"); + } + } else { + LLVM_DEBUG(llvm::dbgs() + << " WARNING: Could not find vectorized source for " + "materialize.\n"); + } + continue; + } + + // 忽略不需要的操作 + if (isa(inst)) { + LLVM_DEBUG( + llvm::dbgs() + << " Skipping tensor.insert (handled in materialize).\n"); + continue; + } + if (isa(inst)) { + LLVM_DEBUG(llvm::dbgs() << " Skipping tensor.empty.\n"); + continue; + } + + LLVM_DEBUG(llvm::dbgs() + << " [Unhandled] Operation not handled specifically: " + << inst.getName() << "\n"); + } + + // 打印当前op + LLVM_DEBUG({ + llvm::dbgs() << "[VectorizeParallelLoop] Current Op: "; + op.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + // 打印映射表 + LLVM_DEBUG({ + llvm::dbgs() << "[VectorizeParallelLoop] Scalar to Tensor Map:\n"; + for (auto &[scalar, tensor] : scalarToTensorMap) { + llvm::dbgs() << " " << scalar << " -> " << tensor << "\n"; + } + }); + + // 4. 删除原循环 + LLVM_DEBUG( + llvm::dbgs() + << "[VectorizeParallelLoop] Erasing original scf.parallel op.\n"); + rewriter.eraseOp(op); + + LLVM_DEBUG(llvm::dbgs() + << "[VectorizeParallelLoop] <<< MatchAndRewrite Done.\n\n"); + return success(); + } +}; + +struct VectorizeParallelLoopPass + : public PassWrapper> { + StringRef getArgument() const final { return "vectorize-parallel-loop"; } + StringRef getDescription() const final { + return "Vectorize scf.parallel loops by unrolling and using bulk memory " + "ops."; + } + + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() + << "[Pass] Starting VectorizeParallelLoopPass on function...\n"); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "[Pass] Pattern application failed.\n"); + signalPassFailure(); + } else { + LLVM_DEBUG(llvm::dbgs() << "[Pass] Pattern application succeeded.\n"); + } + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizeParallelLoopPass) +}; + +} // namespace + +namespace mlir::dicp::LinalgExt { +std::unique_ptr> createVectorizeParallelLoopPass() { + return std::make_unique(); +} +} // namespace mlir::dicp::LinalgExt \ No newline at end of file diff --git a/test/ascend/mlir/vectorize_parallel_loop.mlir b/test/ascend/mlir/vectorize_parallel_loop.mlir new file mode 100644 index 00000000..fca31593 --- /dev/null +++ b/test/ascend/mlir/vectorize_parallel_loop.mlir @@ -0,0 +1,45 @@ +// RUN: %dicp_opt %s --vectorize-parallel-loop | %FileCheck %s +// /opt/conda/envs/commonir/lib/python3.10/site-packages/triton/_C/dicp_opt + +module { + func.func @main(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c1024_i32 = arith.constant 1024 : i32 + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.empty() : tensor<1xf32> + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1048576], strides: [1] : memref to memref<1048576xf32, strided<[1]>> + %reinterpret_cast_0 = memref.reinterpret_cast %arg3 to offset: [0], sizes: [1048576], strides: [1] : memref to memref<1048576xf32, strided<[1]>> + %1 = arith.muli %arg8, %c1024_i32 : i32 + scf.parallel (%arg11) = (%c0) to (%c1024) step (%c1) { + %2 = arith.index_cast %arg11 : index to i32 + %3 = arith.addi %1, %2 : i32 + %4 = arith.index_cast %3 : i32 to index + %5 = memref.load %reinterpret_cast[%4] : memref<1048576xf32, strided<[1]>> + %6 = memref.load %reinterpret_cast_0[%4] : memref<1048576xf32, strided<[1]>> + %7 = arith.addf %5, %6 : f32 + %inserted = tensor.insert %7 into %0[%c0] : tensor<1xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg4 to offset: [%4], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> + bufferization.materialize_in_destination %inserted in writable %reinterpret_cast_1 : (tensor<1xf32>, memref<1xf32, strided<[1], offset: ?>>) -> () + scf.reduce + } + return + } +} + +// CHECK-LABEL: func.func @main +// CHECK-NOT: scf.parallel +// CHECK: %[[ALLOC0:.+]] = memref.alloc() : memref<1024xf32> +// CHECK: %[[SUBVIEW0:.+]] = memref.subview %{{.+}}[%{{.+}}] [1024] [1] : memref<1048576xf32, strided<[1]>> to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: memref.copy %[[SUBVIEW0]], %[[ALLOC0]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> +// CHECK: %[[TENSOR0:.+]] = bufferization.to_tensor %[[ALLOC0]] restrict : memref<1024xf32> to tensor<1024xf32> +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<1024xf32> +// CHECK: %[[SUBVIEW1:.+]] = memref.subview %{{.+}}[%{{.+}}] [1024] [1] : memref<1048576xf32, strided<[1]>> to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: memref.copy %[[SUBVIEW1]], %[[ALLOC1]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] restrict : memref<1024xf32> to tensor<1024xf32> +// CHECK: %[[RESULT:.+]] = arith.addf %[[TENSOR0]], %[[TENSOR1]] : tensor<1024xf32> +// CHECK: %[[ALLOC_RESULT:.+]] = memref.alloc() : memref<1024xf32> +// CHECK: bufferization.materialize_in_destination %[[RESULT]] in writable %[[ALLOC_RESULT]] : (tensor<1024xf32>, memref<1024xf32>) -> () +// CHECK: %[[SUBVIEW_OUT:.+]] = memref.subview %{{.+}}[%{{.+}}] [1024] [1] : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: memref.copy %[[ALLOC_RESULT]], %[[SUBVIEW_OUT]] : memref<1024xf32> to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: return diff --git a/test/commonir/ascend/vector_add.py b/test/commonir/ascend/vector_add.py new file mode 100644 index 00000000..5157b618 --- /dev/null +++ b/test/commonir/ascend/vector_add.py @@ -0,0 +1,244 @@ +import os +import time + +import tilelang +import tilelang.language as T + +import triton +import triton.language as tl + +import torch + +dtype = "float32" +seq_len = 1024 * 1024 # 增大到1M个元素,更能体现kernel优化效果,1048576 +block = 1024 + + +def vec_add(N, block_N, dtype="float32"): + n_num = ( + N // block_N + ) # n_num是块的数量 32,block_N是每个块处理的元素数量 32k,N是总元素数量 1M + # print(f"zmz debug : n_num={n_num}, block_N={block_N}, N={N}") + + @T.prim_func + def main( + A: T.Tensor((N), dtype), + B: T.Tensor((N), dtype), + C: T.Tensor((N), dtype), + ): + with T.Kernel(n_num, 1) as (tid, _): + start_idx = tid * block_N + for local_y in T.Parallel(block_N): + y = start_idx + local_y + C[y] = A[y] + B[y] + + return main + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def create_test_data(): + """创建测试数据 - 使用更大的张量""" + v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu() + return v1, v2, v3 + + +def test_tilelang_add(): + """测试 TileLang 实现""" + print("Testing TileLang implementation...") + + # 创建测试数据 + v1, v2, v3 = create_test_data() + y_ref = v1 + v2 + + # 编译 TileLang kernel + # func = vec_add(seq_len, seq_len // 32) # 使用更合适的块大小 1M, 32K + func = vec_add(seq_len, seq_len // block) + compiled_kernel = tilelang.compile(func, target="commonir") + + # 执行 TileLang kernel + compiled_kernel(v1, v2, v3) + + # 验证结果 + max_diff = torch.max(torch.abs(y_ref - v3)) + print(f"The maximum difference between torch and TileLang is {max_diff}") + + torch.testing.assert_close(v3, y_ref, atol=1e-2, rtol=0) + print("TileLang test passed!\n") + + return v1, v2, v3, y_ref + + +def test_triton_add(): + """测试 Triton 实现""" + print("Testing Triton implementation...") + + # 创建测试数据 + v1, v2, v3 = create_test_data() + y_ref = v1 + v2 + + # 设置块大小和网格 - 适应更大的数据集 + block_size = block # Triton常用的块大小 + grid = (triton.cdiv(seq_len, block_size),) # 修正网格定义 + + # 执行 Triton kernel + add_kernel[grid](v1, v2, v3, seq_len, BLOCK_SIZE=block_size) + + # 验证结果 + max_diff = torch.max(torch.abs(y_ref - v3)) + print(f"The maximum difference between torch and Triton is {max_diff}") + + torch.testing.assert_close(v3, y_ref, atol=1e-2, rtol=0) + print("Triton test passed!\n") + + return v1, v2, v3, y_ref + + +def benchmark_function(func, *args, num_runs=100, warmup_runs=10): + """性能测试函数""" + # 预热运行 + for _ in range(warmup_runs): + func(*args) + + # 同步NPU + if torch.npu.is_available(): + torch.npu.synchronize() + + # 正式测试 + start_time = time.time() + for _ in range(num_runs): + func(*args) + + # 同步NPU + if torch.npu.is_available(): + torch.npu.synchronize() + + end_time = time.time() + + avg_time = (end_time - start_time) / num_runs + return avg_time * 1000 # 转换为毫秒 + + +def run_performance_tests(): + """运行性能测试""" + print("=" * 60) + print( + f"PERFORMANCE TESTS - Vector size: {seq_len:,} elements ({seq_len * 4 / 1e6:.2f} MB)" + ) + print("=" * 60) + + # 创建测试数据 + v1, v2, v3 = create_test_data() + + # TileLang kernel + func = vec_add(seq_len, seq_len // block) # 使用更合适的块大小 + compiled_tilelang_kernel = tilelang.compile(func, target="commonir") + + def tilelang_benchmark(): + temp_v3 = torch.zeros_like(v3) + compiled_tilelang_kernel(v1, v2, temp_v3) + + def triton_benchmark(): + temp_v3 = torch.zeros_like(v3) + block_size = block # Triton常用块大小 + grid = (triton.cdiv(seq_len, block_size),) + add_kernel[grid](v1, v2, temp_v3, seq_len, BLOCK_SIZE=block_size) + + def torch_benchmark(): + temp_result = v1 + v2 + + # 运行基准测试 + print("Running benchmarks... (this may take a moment)") + + tilelang_time = benchmark_function(tilelang_benchmark, num_runs=100, warmup_runs=10) + triton_time = benchmark_function(triton_benchmark, num_runs=100, warmup_runs=10) + torch_time = benchmark_function(torch_benchmark, num_runs=100, warmup_runs=10) + + print(f"\nAverage execution time over 100 runs:") + print(f" TileLang: {tilelang_time:.4f} ms") + print(f" Triton: {triton_time:.4f} ms") + print(f" PyTorch: {torch_time:.4f} ms") + + # 计算吞吐量 + total_elements = seq_len + triton_throughput = (total_elements * 4 * 3 / 1e9) / ( + triton_time / 1000 + ) # GB/s (read two + write one) + tilelang_throughput = (total_elements * 4 * 3 / 1e9) / ( + tilelang_time / 1000 + ) # GB/s + torch_throughput = (total_elements * 4 * 3 / 1e9) / (torch_time / 1000) # GB/s + + print(f"\nThroughput Analysis:") + print(f" Triton: {triton_throughput:.2f} GB/s") + print(f" TileLang: {tilelang_throughput:.2f} GB/s") + print(f" PyTorch: {torch_throughput:.2f} GB/s") + + print(f"\nPerformance comparison relative to PyTorch:") + if triton_time <= torch_time: + speedup = torch_time / triton_time + print(f" Triton is {speedup:.2f}x FASTER than PyTorch") + else: + slowdown = triton_time / torch_time + print(f" Triton is {slowdown:.2f}x SLOWER than PyTorch") + + if tilelang_time <= torch_time: + speedup = torch_time / tilelang_time + print(f" TileLang is {speedup:.2f}x FASTER than PyTorch") + else: + slowdown = tilelang_time / torch_time + print(f" TileLang is {slowdown:.2f}x SLOWER than PyTorch") + + if triton_time <= tilelang_time: + speedup = tilelang_time / triton_time + print(f" Triton is {speedup:.2f}x FASTER than TileLang") + else: + slowdown = triton_time / tilelang_time + print(f" Triton is {slowdown:.2f}x SLOWER than TileLang") + + +def main(): + """主函数""" + print("Vector Addition Comparison: TileLang vs Triton") + print("=" * 60) + + # 运行功能测试 + print("FUNCTIONALITY TESTS") + print("-" * 20) + tilelang_data = test_tilelang_add() + triton_data = test_triton_add() + + # 运行性能测试 + run_performance_tests() + + +if __name__ == "__main__": + main() diff --git a/test/commonir/ascend/vector_sub.py b/test/commonir/ascend/vector_sub.py new file mode 100644 index 00000000..742f9a0e --- /dev/null +++ b/test/commonir/ascend/vector_sub.py @@ -0,0 +1,243 @@ +import os +import time + +import tilelang +import tilelang.language as T + +import triton +import triton.language as tl + +import torch + +dtype = "float32" +seq_len = 1024 * 1024 # 增大到1M个元素,更能体现kernel优化效果,1048576 +block = 1024 + + +def vec_sub(N, block_N, dtype="float32"): + n_num = ( + N // block_N + ) # n_num是块的数量 32,block_N是每个块处理的元素数量 32k,N是总元素数量 1M + # print(f"zmz debug : n_num={n_num}, block_N={block_N}, N={N}") + + @T.prim_func + def main( + A: T.Tensor((N), dtype), + B: T.Tensor((N), dtype), + C: T.Tensor((N), dtype), + ): + with T.Kernel(n_num, 1) as (tid, _): + start_idx = tid * block_N + for local_y in T.Parallel(block_N): + y = start_idx + local_y + C[y] = A[y] - B[y] + + return main + + +@triton.jit +def sub_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x - y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def create_test_data(): + """创建测试数据 - 使用更大的张量""" + v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu() + return v1, v2, v3 + + +def test_tilelang_sub(): + """测试 TileLang 实现""" + print("Testing TileLang implementation...") + + # 创建测试数据 + v1, v2, v3 = create_test_data() + y_ref = v1 - v2 + + # 编译 TileLang kernel + func = vec_sub(seq_len, seq_len // block) + compiled_kernel = tilelang.compile(func, target="commonir") + + # 执行 TileLang kernel + compiled_kernel(v1, v2, v3) + + # 验证结果 + max_diff = torch.max(torch.abs(y_ref - v3)) + print(f"The maximum difference between torch and TileLang is {max_diff}") + + torch.testing.assert_close(v3, y_ref, atol=1e-2, rtol=0) + print("TileLang test passed!\n") + + return v1, v2, v3, y_ref + + +def test_triton_sub(): + """测试 Triton 实现""" + print("Testing Triton implementation...") + + # 创建测试数据 + v1, v2, v3 = create_test_data() + y_ref = v1 - v2 + + # 设置块大小和网格 - 适应更大的数据集 + block_size = block # Triton常用的块大小 + grid = (triton.cdiv(seq_len, block_size),) # 修正网格定义 + + # 执行 Triton kernel + sub_kernel[grid](v1, v2, v3, seq_len, BLOCK_SIZE=block_size) + + # 验证结果 + max_diff = torch.max(torch.abs(y_ref - v3)) + print(f"The maximum difference between torch and Triton is {max_diff}") + + torch.testing.assert_close(v3, y_ref, atol=1e-2, rtol=0) + print("Triton test passed!\n") + + return v1, v2, v3, y_ref + + +def benchmark_function(func, *args, num_runs=100, warmup_runs=10): + """性能测试函数""" + # 预热运行 + for _ in range(warmup_runs): + func(*args) + + # 同步NPU + if torch.npu.is_available(): + torch.npu.synchronize() + + # 正式测试 + start_time = time.time() + for _ in range(num_runs): + func(*args) + + # 同步NPU + if torch.npu.is_available(): + torch.npu.synchronize() + + end_time = time.time() + + avg_time = (end_time - start_time) / num_runs + return avg_time * 1000 # 转换为毫秒 + + +def run_performance_tests(): + """运行性能测试""" + print("=" * 60) + print( + f"PERFORMANCE TESTS - Vector size: {seq_len:,} elements ({seq_len * 4 / 1e6:.2f} MB)" + ) + print("=" * 60) + + # 创建测试数据 + v1, v2, v3 = create_test_data() + + # TileLang kernel + func = vec_sub(seq_len, seq_len // block) # 使用更合适的块大小 + compiled_tilelang_kernel = tilelang.compile(func, target="commonir") + + def tilelang_benchmark(): + temp_v3 = torch.zeros_like(v3) + compiled_tilelang_kernel(v1, v2, temp_v3) + + def triton_benchmark(): + temp_v3 = torch.zeros_like(v3) + block_size = block # Triton常用块大小 + grid = (triton.cdiv(seq_len, block_size),) + sub_kernel[grid](v1, v2, temp_v3, seq_len, BLOCK_SIZE=block_size) + + def torch_benchmark(): + temp_result = v1 - v2 + + # 运行基准测试 + print("Running benchmarks... (this may take a moment)") + + tilelang_time = benchmark_function(tilelang_benchmark, num_runs=100, warmup_runs=10) + triton_time = benchmark_function(triton_benchmark, num_runs=100, warmup_runs=10) + torch_time = benchmark_function(torch_benchmark, num_runs=100, warmup_runs=10) + + print(f"\nAverage execution time over 100 runs:") + print(f" TileLang: {tilelang_time:.4f} ms") + print(f" Triton: {triton_time:.4f} ms") + print(f" PyTorch: {torch_time:.4f} ms") + + # 计算吞吐量 + total_elements = seq_len + triton_throughput = (total_elements * 4 * 3 / 1e9) / ( + triton_time / 1000 + ) # GB/s (read two + write one) + tilelang_throughput = (total_elements * 4 * 3 / 1e9) / ( + tilelang_time / 1000 + ) # GB/s + torch_throughput = (total_elements * 4 * 3 / 1e9) / (torch_time / 1000) # GB/s + + print(f"\nThroughput Analysis:") + print(f" Triton: {triton_throughput:.2f} GB/s") + print(f" TileLang: {tilelang_throughput:.2f} GB/s") + print(f" PyTorch: {torch_throughput:.2f} GB/s") + + print(f"\nPerformance comparison relative to PyTorch:") + if triton_time <= torch_time: + speedup = torch_time / triton_time + print(f" Triton is {speedup:.2f}x FASTER than PyTorch") + else: + slowdown = triton_time / torch_time + print(f" Triton is {slowdown:.2f}x SLOWER than PyTorch") + + if tilelang_time <= torch_time: + speedup = torch_time / tilelang_time + print(f" TileLang is {speedup:.2f}x FASTER than PyTorch") + else: + slowdown = tilelang_time / torch_time + print(f" TileLang is {slowdown:.2f}x SLOWER than PyTorch") + + if triton_time <= tilelang_time: + speedup = tilelang_time / triton_time + print(f" Triton is {speedup:.2f}x FASTER than TileLang") + else: + slowdown = triton_time / tilelang_time + print(f" Triton is {slowdown:.2f}x SLOWER than TileLang") + + +def main(): + """主函数""" + print("Vector Sub Comparison: TileLang vs Triton") + print("=" * 60) + + # 运行功能测试 + print("FUNCTIONALITY TESTS") + print("-" * 20) + tilelang_data = test_tilelang_sub() + triton_data = test_triton_sub() + + # 运行性能测试 + run_performance_tests() + + +if __name__ == "__main__": + main() diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index dccd884e..2a38c56b 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -105,6 +105,7 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerVectorizeParallelLoopPass(); registry.insert( + dicp::LinalgExt::createVectorizeParallelLoopPass()); + }); } void init_triton_dicp_triton(py::module &&m) { @@ -107,6 +111,7 @@ void init_triton_dicp_triton(py::module &&m) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerVectorizeParallelLoopPass(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects();