Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,19 @@ jobs:
export PTOAS_OUT_DIR="${PAYLOAD_DIR}/test/samples"
bash test/samples/runop.sh --enablebc all

- name: Run issue828 branch-fixed basic pto regressions
shell: bash
env:
PTOAS_BIN: ${{ github.workspace }}/build/tools/ptoas/ptoas
run: |
set -euo pipefail
for case in \
test/basic/issue828_softmax_rescale_incore_1_a5_if.pto \
test/basic/issue828_softmax_rescale_incore_1_a5_else.pto; do
"${PTOAS_BIN}" --pto-level=level3 --pto-arch=a5 --enable-insert-sync "${case}" >/dev/null
"${PTOAS_BIN}" --pto-level=level3 --pto-arch=a5 --enable-insert-sync --disable-identity-tmov-cleanup "${case}" >/dev/null
done

- name: Build payload artifact
if: >-
${{
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ ptoas tests/input.pto
# 运行 AutoSyncInsert Pass
ptoas tests/input.pto --enable-insert-sync -o outputfile.cpp

# 调试开关:关闭 A5 identity tmov cleanup(用于 A/B 验证)
ptoas tests/input.pto --enable-insert-sync --disable-identity-tmov-cleanup -o outputfile.cpp

# 指定目标硬件架构(A3 / A5)
ptoas tests/input.pto --pto-arch=a3 -o outputfile.cpp

Expand Down
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ std::unique_ptr<Pass> createPTOLowerFrontendPipeOpsPass();
std::unique_ptr<Pass> createPTOResolveReservedBuffersPass();
std::unique_ptr<Pass> createPTOWrapFunctionsInSectionsPass();
std::unique_ptr<Pass> createPTOVerifyTFreePass();
std::unique_ptr<Pass> createPTORemoveIdentityTMovPass();

// Creates a pass for ...
std::unique_ptr<Pass> createPTOInsertSyncPass();
Expand Down
17 changes: 17 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ def PTOInsertSync : Pass<"pto-insert-sync", "func::FuncOp"> {
];
}

def PTORemoveIdentityTMov : Pass<"pto-remove-identity-tmov", "func::FuncOp"> {
let summary = "Remove identity pto.tmov before auto-sync on A5";
let description = [{
Erases provably-identity `pto.tmov` operations on A5 before
`pto-insert-sync`. Besides `src == dst`, this also handles distinct SSA
values when alias analysis proves source and destination have the exact same
address range and type/layout. `memref.reinterpret_cast` views are treated
conservatively (not elided), and when roots differ the pass requires
concrete-equal root addresses. Dynamic/unknown address cases are not
removed.
}];
let constructor = "mlir::pto::createPTORemoveIdentityTMovPass()";
let dependentDialects = [
"mlir::pto::PTODialect"
];
}

def ConvertToPTOOp : Pass<"convert-to-pto-op"> {
let summary = "Convert Ops from other dialects to PTO Ops";
let constructor = "mlir::pto::createConvertToPTOOpPass()";
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(PTOTransforms
PTOPlanMemory.cpp
PTORemoveRedundantBarrier.cpp
InferPTOLayout.cpp
PTORemoveIdentityTMovPass.cpp
BufferizableOpInterfaceImpl.cpp
ConvertToPTOOp.cpp
PTOLowerFrontendPipeOpsPass.cpp
Expand Down
245 changes: 245 additions & 0 deletions lib/PTO/Transforms/PTORemoveIdentityTMovPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h"
#include "PTO/Transforms/InsertSync/PTOIRTranslator.h"
#include "PTO/Transforms/InsertSync/SyncCommon.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <optional>

namespace mlir {
namespace pto {
namespace func = ::mlir::func;
#define GEN_PASS_DEF_PTOREMOVEIDENTITYTMOV
#include "PTO/Transforms/Passes.h.inc"
} // namespace pto
} // namespace mlir

using namespace mlir;
using namespace mlir::pto;

namespace {

static bool isA5Target(func::FuncOp funcOp) {
ModuleOp module = funcOp->getParentOfType<ModuleOp>();
if (!module)
return false;
auto arch = module->getAttrOfType<StringAttr>("pto.target_arch");
return arch && arch.getValue().equals_insensitive("a5");
}

static const BaseMemInfo *
getSingleMemInfo(const Buffer2MemInfoMap &buffer2MemInfoMap, Value value) {
auto it = buffer2MemInfoMap.find(value);
if (it == buffer2MemInfoMap.end())
return nullptr;
if (it->second.size() != 1)
return nullptr;
return it->second.front().get();
}

static std::optional<int64_t> tryEvalI64Constant(Value value) {
if (!value)
return std::nullopt;

APInt apInt;
if (matchPattern(value, m_ConstantInt(&apInt)))
return apInt.getSExtValue();

Operation *defOp = value.getDefiningOp();
if (!defOp)
return std::nullopt;

if (auto castOp = dyn_cast<arith::IndexCastOp>(defOp))
return tryEvalI64Constant(castOp.getIn());
if (auto castOp = dyn_cast<arith::ExtSIOp>(defOp))
return tryEvalI64Constant(castOp.getIn());
if (auto castOp = dyn_cast<arith::ExtUIOp>(defOp))
return tryEvalI64Constant(castOp.getIn());
if (auto castOp = dyn_cast<arith::TruncIOp>(defOp))
return tryEvalI64Constant(castOp.getIn());

return std::nullopt;
}

static std::optional<int64_t>
tryGetConcreteRootAddress(const BaseMemInfo *info) {
if (!info)
return std::nullopt;

if (auto direct = tryEvalI64Constant(info->rootBuffer))
return direct;

Operation *defOp = info->rootBuffer.getDefiningOp();
if (!defOp)
return std::nullopt;

if (auto alloc = dyn_cast<pto::AllocTileOp>(defOp))
return tryEvalI64Constant(alloc.getAddr());

if (auto cast = dyn_cast<pto::PointerCastOp>(defOp)) {
if (!cast.getAddrs().empty())
return tryEvalI64Constant(cast.getAddrs().front());
}

return std::nullopt;
}

static bool hasDynamicStaticList(ArrayRef<int64_t> values) {
return llvm::any_of(values, [](int64_t value) {
return value == ShapedType::kDynamic;
});
}

static bool isStaticallyAddressableValue(Value value) {
int depth = 0;
constexpr int kMaxDepth = 32;
while (value && depth++ < kMaxDepth) {
Operation *defOp = value.getDefiningOp();
if (!defOp)
return false;

if (auto subView = dyn_cast<memref::SubViewOp>(defOp)) {
if (hasDynamicStaticList(subView.getStaticOffsets()) ||
hasDynamicStaticList(subView.getStaticSizes()) ||
hasDynamicStaticList(subView.getStaticStrides())) {
return false;
}
value = subView.getSource();
continue;
}

if (isa<memref::ReinterpretCastOp>(defOp))
return false;

if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
value = expand.getSrc();
continue;
}
if (auto view = dyn_cast<memref::ViewOp>(defOp)) {
if (view.getByteShift())
return false;
value = view.getSource();
continue;
}

return true;
}

return false;
}

static bool hasExactSameAddressRange(const BaseMemInfo *srcInfo,
const BaseMemInfo *dstInfo) {
if (!srcInfo || !dstInfo)
return false;

if (srcInfo->scope != dstInfo->scope)
return false;
if (srcInfo->allocateSize == 0 || dstInfo->allocateSize == 0)
return false;
if (srcInfo->allocateSize != dstInfo->allocateSize)
return false;
if (srcInfo->baseAddresses.empty() || dstInfo->baseAddresses.empty())
return false;
if (srcInfo->baseAddresses != dstInfo->baseAddresses)
return false;

return true;
}

static bool canEraseIdentityTMov(
TMovOp op, const Buffer2MemInfoMap &buffer2MemInfoMap) {
Value src = op.getSrc();
Value dst = op.getDst();

if (src == dst)
return true;

if (src.getType() != dst.getType())
return false;
if (!isStaticallyAddressableValue(src) || !isStaticallyAddressableValue(dst))
return false;

const BaseMemInfo *srcInfo = getSingleMemInfo(buffer2MemInfoMap, src);
const BaseMemInfo *dstInfo = getSingleMemInfo(buffer2MemInfoMap, dst);
if (!srcInfo || !dstInfo)
return false;

if (!hasExactSameAddressRange(srcInfo, dstInfo))
return false;

if (srcInfo->rootBuffer == dstInfo->rootBuffer)
return true;

auto srcRootAddr = tryGetConcreteRootAddress(srcInfo);
auto dstRootAddr = tryGetConcreteRootAddress(dstInfo);
if (!srcRootAddr || !dstRootAddr)
return false;
return *srcRootAddr == *dstRootAddr;
}

struct PTORemoveIdentityTMovPass
: public mlir::pto::impl::PTORemoveIdentityTMovBase<
PTORemoveIdentityTMovPass> {
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
if (!isA5Target(funcOp))
return;

bool hasTMov = false;
funcOp.walk([&](TMovOp) {
hasTMov = true;
return WalkResult::interrupt();
});
if (!hasTMov)
return;

MemoryDependentAnalyzer memAnalyzer;
SyncIRs syncIR;
Buffer2MemInfoMap buffer2MemInfoMap;
PTOIRTranslator translator(syncIR, memAnalyzer, buffer2MemInfoMap, funcOp,
SyncAnalysisMode::NORMALSYNC);
translator.Build();

SmallVector<TMovOp> toErase;
funcOp.walk([&](TMovOp op) {
if (canEraseIdentityTMov(op, buffer2MemInfoMap))
toErase.push_back(op);
});

for (TMovOp op : toErase) {
Value result = op.getResult();
if (result && !result.use_empty())
result.replaceAllUsesWith(op.getDst());
op.erase();
}
Comment on lines +226 to +237
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation first collects all TMovOps to be erased into a SmallVector and then iterates over this vector to perform the erasure. This two-step process can be simplified and made more efficient. You can perform the erasure directly within the walk callback. Since TMovOp has no regions, it's safe to erase it during the walk, which avoids the need for intermediate storage and a second loop.

    funcOp.walk([&](TMovOp op) {
      if (canEraseIdentityTMov(op)) {
        Value result = op.getResult();
        if (result && !result.use_empty()) {
          result.replaceAllUsesWith(op.getDst());
        }
        op.erase();
      }
    });

}
};

} // namespace

std::unique_ptr<Pass> mlir::pto::createPTORemoveIdentityTMovPass() {
return std::make_unique<PTORemoveIdentityTMovPass>();
}
40 changes: 40 additions & 0 deletions test/basic/identity_tmov_alias_deadlock_signature_a5.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s | FileCheck %s --check-prefix=SAFE
// RUN: ptoas --pto-arch=a5 --enable-insert-sync --disable-identity-tmov-cleanup %s | FileCheck %s --check-prefix=UNSAFE

module attributes {"pto.device-spec" = "Ascend950"} {
func.func @identity_tmov_alias_deadlock_signature_a5(
%src: memref<1x64xf16, #pto.address_space<gm>>,
%dst: memref<1x64xf16, #pto.address_space<gm>>) {
%ub = memref.alloc() : memref<1x64xf16, #pto.address_space<vec>>
%src_alias = memref.subview %ub[0, 0] [1, 64] [1, 1]
: memref<1x64xf16, #pto.address_space<vec>>
to memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>
%dst_alias = memref.subview %ub[0, 0] [1, 64] [1, 1]
: memref<1x64xf16, #pto.address_space<vec>>
to memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>

pto.tload ins(%src : memref<1x64xf16, #pto.address_space<gm>>)
outs(%src_alias : memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>)
pto.tmov ins(%src_alias : memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>)
outs(%dst_alias : memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>)
pto.tstore ins(%dst_alias : memref<1x64xf16, strided<[64, 1]>, #pto.address_space<vec>>)
outs(%dst : memref<1x64xf16, #pto.address_space<gm>>)
return
}
}

// SAFE-LABEL: __global__ AICORE void identity_tmov_alias_deadlock_signature_a5(
// SAFE-NOT: TMOV(
// SAFE: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
// SAFE: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
// SAFE-NOT: set_flag(PIPE_MTE2, PIPE_V
// SAFE-NOT: wait_flag(PIPE_MTE2, PIPE_V
// SAFE-NOT: set_flag(PIPE_V, PIPE_MTE3
// SAFE-NOT: wait_flag(PIPE_V, PIPE_MTE3

// UNSAFE-LABEL: __global__ AICORE void identity_tmov_alias_deadlock_signature_a5(
// UNSAFE: TMOV(
// UNSAFE: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
// UNSAFE: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
// UNSAFE: set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
// UNSAFE: wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
33 changes: 33 additions & 0 deletions test/basic/identity_tmov_autosync_a5_only.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s | FileCheck %s --check-prefix=A5
// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s --check-prefix=A3

module attributes {"pto.device-spec" = "Ascend950"} {
func.func @identity_tmov_autosync_a5_only(
%src: memref<1x64xf16, #pto.address_space<gm>>,
%dst: memref<1x64xf16, #pto.address_space<gm>>) {
%ub = memref.alloc() : memref<1x64xf16, #pto.address_space<vec>>

pto.tload ins(%src : memref<1x64xf16, #pto.address_space<gm>>)
outs(%ub : memref<1x64xf16, #pto.address_space<vec>>)

// Identity move: should be removed on A5 before sync insertion.
pto.tmov ins(%ub : memref<1x64xf16, #pto.address_space<vec>>)
outs(%ub : memref<1x64xf16, #pto.address_space<vec>>)

pto.tstore ins(%ub : memref<1x64xf16, #pto.address_space<vec>>)
outs(%dst : memref<1x64xf16, #pto.address_space<gm>>)
return
}
}

// A5-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only(
// A5: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
// A5-NOT: set_flag(PIPE_MTE2, PIPE_V
// A5-NOT: wait_flag(PIPE_MTE2, PIPE_V
// A5-NOT: set_flag(PIPE_V, PIPE_MTE3
// A5-NOT: wait_flag(PIPE_V, PIPE_MTE3
// A5-NOT: TMOV(
// A5: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);

// A3-LABEL: __global__ AICORE void identity_tmov_autosync_a5_only(
// A3: TMOV(
Loading
Loading