diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index cafdb784c..4861b155d 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -38,6 +38,7 @@ std::unique_ptr createPTOLowerFrontendPipeOpsPass(); std::unique_ptr createPTOResolveReservedBuffersPass(); std::unique_ptr createPTOWrapFunctionsInSectionsPass(); std::unique_ptr createPTOVerifyTFreePass(); +std::unique_ptr createPTORemoveIdentityTMovPass(); // Creates a pass for ... std::unique_ptr createPTOInsertSyncPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 37979bf21..96521685b 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -38,6 +38,20 @@ def PTOInsertSync : Pass<"pto-insert-sync", "func::FuncOp"> { ]; } +def PTORemoveIdentityTMov : Pass<"pto-remove-identity-tmov", "func::FuncOp"> { + let summary = "Remove identity pto.tmov before auto-sync on A5"; + let description = [{ + Erases `pto.tmov` operations where source and destination are the same SSA + value. The pass is gated by `pto.target_arch = "a5"` and is intended to run + before `pto-insert-sync` to avoid generating synchronization edges for a + no-op move. + }]; + let constructor = "mlir::pto::createPTORemoveIdentityTMovPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect" + ]; +} + def ConvertToPTOOp : Pass<"convert-to-pto-op"> { let summary = "Convert Ops from other dialects to PTO Ops"; let constructor = "mlir::pto::createConvertToPTOOpPass()"; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b82d227fe..755c2a1bb 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ add_mlir_dialect_library(PTOTransforms PTOPlanMemory.cpp PTORemoveRedundantBarrier.cpp InferPTOLayout.cpp + PTORemoveIdentityTMovPass.cpp BufferizableOpInterfaceImpl.cpp ConvertToPTOOp.cpp PTOLowerFrontendPipeOpsPass.cpp diff --git a/lib/PTO/Transforms/PTORemoveIdentityTMovPass.cpp b/lib/PTO/Transforms/PTORemoveIdentityTMovPass.cpp new file mode 100644 index 000000000..88f359ba7 --- /dev/null +++ b/lib/PTO/Transforms/PTORemoveIdentityTMovPass.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +namespace func = ::mlir::func; +#define GEN_PASS_DEF_PTOREMOVEIDENTITYTMOV +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isA5Target(func::FuncOp funcOp) { + ModuleOp module = funcOp->getParentOfType(); + if (!module) + return false; + auto arch = module->getAttrOfType("pto.target_arch"); + return arch && arch.getValue() == "a5"; +} + +static bool canEraseIdentityTMov(TMovOp op) { + if (op.getSrc() != op.getDst()) + return false; + + Value result = op.getResult(); + if (!result || result.use_empty()) + return true; + + return result.getType() == op.getDst().getType(); +} + +struct PTORemoveIdentityTMovPass + : public mlir::pto::impl::PTORemoveIdentityTMovBase< + PTORemoveIdentityTMovPass> { + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + if (!isA5Target(funcOp)) + return; + + SmallVector toErase; + funcOp.walk([&](TMovOp op) { + if (canEraseIdentityTMov(op)) + toErase.push_back(op); + }); + + for (TMovOp op : toErase) { + Value result = op.getResult(); + if (result && !result.use_empty()) + result.replaceAllUsesWith(op.getDst()); + op.erase(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTORemoveIdentityTMovPass() { + return std::make_unique(); +} diff --git a/test/basic/identity_tmov_autosync_a5_only.pto b/test/basic/identity_tmov_autosync_a5_only.pto new file mode 100644 index 000000000..e02a33014 --- /dev/null +++ b/test/basic/identity_tmov_autosync_a5_only.pto @@ -0,0 +1,33 @@ +// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s | FileCheck %s --check-prefix=A5 +// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s --check-prefix=A3 + +module attributes {"pto.device-spec" = "Ascend950"} { + func.func @identity_tmov_autosync_a5_only( + %src: memref<1x64xf16, #pto.address_space>, + %dst: memref<1x64xf16, #pto.address_space>) { + %ub = memref.alloc() : memref<1x64xf16, #pto.address_space> + + pto.tload ins(%src : memref<1x64xf16, #pto.address_space>) + outs(%ub : memref<1x64xf16, #pto.address_space>) + + // Identity move: should be removed on A5 before sync insertion. + pto.tmov ins(%ub : memref<1x64xf16, #pto.address_space>) + outs(%ub : memref<1x64xf16, #pto.address_space>) + + pto.tstore ins(%ub : memref<1x64xf16, #pto.address_space>) + outs(%dst : memref<1x64xf16, #pto.address_space>) + return + } +} + +// A5-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only( +// A5: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); +// A5-NOT: set_flag(PIPE_MTE2, PIPE_V +// A5-NOT: wait_flag(PIPE_MTE2, PIPE_V +// A5-NOT: set_flag(PIPE_V, PIPE_MTE3 +// A5-NOT: wait_flag(PIPE_V, PIPE_MTE3 +// A5-NOT: TMOV( +// A5: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + +// A3-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only( +// A3: TMOV( diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e7034ecab..1d71acd87 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1111,8 +1111,11 @@ int main(int argc, char **argv) { pm.addPass(pto::createPTOResolveReservedBuffersPass()); // Conditionally add Sync pass based on flag. - if (enableInsertSync) + if (enableInsertSync) { + pm.addNestedPass( + pto::createPTORemoveIdentityTMovPass()); pm.addNestedPass(pto::createPTOInsertSyncPass()); + } pm.addPass(createCSEPass()); if (arch == "a3") {