Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ std::unique_ptr<Pass> createPTOLowerFrontendPipeOpsPass();
std::unique_ptr<Pass> createPTOResolveReservedBuffersPass();
std::unique_ptr<Pass> createPTOWrapFunctionsInSectionsPass();
std::unique_ptr<Pass> createPTOVerifyTFreePass();
std::unique_ptr<Pass> createPTORemoveIdentityTMovPass();

// Creates a pass for ...
std::unique_ptr<Pass> createPTOInsertSyncPass();
Expand Down
14 changes: 14 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def PTOInsertSync : Pass<"pto-insert-sync", "func::FuncOp"> {
];
}

def PTORemoveIdentityTMov : Pass<"pto-remove-identity-tmov", "func::FuncOp"> {
let summary = "Remove identity pto.tmov before auto-sync on A5";
let description = [{
Erases `pto.tmov` operations where source and destination are the same SSA
value. The pass is gated by `pto.target_arch = "a5"` and is intended to run
before `pto-insert-sync` to avoid generating synchronization edges for a
no-op move.
}];
let constructor = "mlir::pto::createPTORemoveIdentityTMovPass()";
let dependentDialects = [
"mlir::pto::PTODialect"
];
}

def ConvertToPTOOp : Pass<"convert-to-pto-op"> {
let summary = "Convert Ops from other dialects to PTO Ops";
let constructor = "mlir::pto::createConvertToPTOOpPass()";
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(PTOTransforms
PTOPlanMemory.cpp
PTORemoveRedundantBarrier.cpp
InferPTOLayout.cpp
PTORemoveIdentityTMovPass.cpp
BufferizableOpInterfaceImpl.cpp
ConvertToPTOOp.cpp
PTOLowerFrontendPipeOpsPass.cpp
Expand Down
75 changes: 75 additions & 0 deletions lib/PTO/Transforms/PTORemoveIdentityTMovPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace pto {
namespace func = ::mlir::func;
#define GEN_PASS_DEF_PTOREMOVEIDENTITYTMOV
#include "PTO/Transforms/Passes.h.inc"
} // namespace pto
} // namespace mlir

using namespace mlir;
using namespace mlir::pto;

namespace {

static bool isA5Target(func::FuncOp funcOp) {
ModuleOp module = funcOp->getParentOfType<ModuleOp>();
if (!module)
return false;
auto arch = module->getAttrOfType<StringAttr>("pto.target_arch");
return arch && arch.getValue() == "a5";
}

static bool canEraseIdentityTMov(TMovOp op) {
if (op.getSrc() != op.getDst())
return false;

Value result = op.getResult();
if (!result || result.use_empty())
return true;

return result.getType() == op.getDst().getType();
}

struct PTORemoveIdentityTMovPass
: public mlir::pto::impl::PTORemoveIdentityTMovBase<
PTORemoveIdentityTMovPass> {
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
if (!isA5Target(funcOp))
return;

SmallVector<TMovOp> toErase;
funcOp.walk([&](TMovOp op) {
if (canEraseIdentityTMov(op))
toErase.push_back(op);
});

for (TMovOp op : toErase) {
Value result = op.getResult();
if (result && !result.use_empty())
result.replaceAllUsesWith(op.getDst());
op.erase();
}
Comment on lines +56 to +67
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation iterates over the function twice: once to collect the TMovOps to erase, and a second time to perform the erasure. This can be simplified into a single walk. MLIR's walk function is safe to use with in-place erasure of the visited operation, making the intermediate toErase vector unnecessary. This improves both efficiency and readability.

    funcOp.walk([&](TMovOp op) {
      if (canEraseIdentityTMov(op)) {
        Value result = op.getResult();
        if (result && !result.use_empty())
          result.replaceAllUsesWith(op.getDst());
        op.erase();
      }
    });

}
};

} // namespace

std::unique_ptr<Pass> mlir::pto::createPTORemoveIdentityTMovPass() {
return std::make_unique<PTORemoveIdentityTMovPass>();
}
33 changes: 33 additions & 0 deletions test/basic/identity_tmov_autosync_a5_only.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s | FileCheck %s --check-prefix=A5
// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s --check-prefix=A3

module attributes {"pto.device-spec" = "Ascend950"} {
func.func @identity_tmov_autosync_a5_only(
%src: memref<1x64xf16, #pto.address_space<gm>>,
%dst: memref<1x64xf16, #pto.address_space<gm>>) {
%ub = memref.alloc() : memref<1x64xf16, #pto.address_space<vec>>

pto.tload ins(%src : memref<1x64xf16, #pto.address_space<gm>>)
outs(%ub : memref<1x64xf16, #pto.address_space<vec>>)

// Identity move: should be removed on A5 before sync insertion.
pto.tmov ins(%ub : memref<1x64xf16, #pto.address_space<vec>>)
outs(%ub : memref<1x64xf16, #pto.address_space<vec>>)

pto.tstore ins(%ub : memref<1x64xf16, #pto.address_space<vec>>)
outs(%dst : memref<1x64xf16, #pto.address_space<gm>>)
return
}
}

// A5-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only(
// A5: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
// A5-NOT: set_flag(PIPE_MTE2, PIPE_V
// A5-NOT: wait_flag(PIPE_MTE2, PIPE_V
// A5-NOT: set_flag(PIPE_V, PIPE_MTE3
// A5-NOT: wait_flag(PIPE_V, PIPE_MTE3
// A5-NOT: TMOV(
// A5: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);

// A3-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only(
// A3: TMOV(
5 changes: 4 additions & 1 deletion tools/ptoas/ptoas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,11 @@ int main(int argc, char **argv) {
pm.addPass(pto::createPTOResolveReservedBuffersPass());

// Conditionally add Sync pass based on flag.
if (enableInsertSync)
if (enableInsertSync) {
pm.addNestedPass<mlir::func::FuncOp>(
pto::createPTORemoveIdentityTMovPass());
pm.addNestedPass<mlir::func::FuncOp>(pto::createPTOInsertSyncPass());
}

pm.addPass(createCSEPass());
if (arch == "a3") {
Expand Down
Loading