Skip to content

Commit cd56fa2

Browse files
committed
添加简单循环交换实现及相关测试
1 parent 676f9b5 commit cd56fa2

5 files changed

Lines changed: 135 additions & 1 deletion

File tree

mlir/optimization/scheduler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_executable(
1818
lib/FusionFeasibility.cpp
1919
lib/LivenessAdapter.cpp
2020
lib/LocalListScheduling.cpp
21+
lib/SimpleLoopInterchange.cpp
2122
)
2223

2324
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen

mlir/optimization/scheduler/include/lab/LabPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ std::unique_ptr<Pass> createLabLivenessPass();
1616
std::unique_ptr<Pass> createLabMemrefLifetimePass();
1717
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
1818
std::unique_ptr<Pass> createAsyncLocalSchedulePass();
19+
std::unique_ptr<Pass> createSimpleLoopInterchangePass();
1920

2021
} // namespace mlir

mlir/optimization/scheduler/lab-opt.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ int main(int argc, char **argv) {
5252
[](mlir::OpPassManager &pm) {
5353
pm.addPass(mlir::createAsyncLocalSchedulePass());
5454
});
55-
55+
mlir::PassPipelineRegistration<>(
56+
"simple-loop-interchange", "Simple Loop Interchange Pass",
57+
[](mlir::OpPassManager &pm) {
58+
pm.addPass(mlir::createSimpleLoopInterchangePass());
59+
});
5660
return mlir::asMainReturnCode(
5761
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));
5862
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2+
#include "mlir/Dialect/Affine/LoopUtils.h"
3+
#include "mlir/Dialect/Func/IR/FuncOps.h"
4+
#include "mlir/IR/Visitors.h"
5+
#include "mlir/Pass/Pass.h"
6+
7+
using namespace mlir;
8+
using namespace mlir::affine;
9+
10+
static bool isPerfectTwoLevelNest(AffineForOp outer, AffineForOp &inner) {
11+
Block &body = outer.getRegion().front();
12+
13+
Operation *firstNonTerminator = nullptr;
14+
for (Operation &op : body.without_terminator()) {
15+
if (firstNonTerminator)
16+
return false; // 外层 body 里不止一个非 terminator op
17+
firstNonTerminator = &op;
18+
}
19+
20+
if (!firstNonTerminator)
21+
return false;
22+
23+
inner = dyn_cast<AffineForOp>(firstNonTerminator);
24+
return inner != nullptr;
25+
}
26+
27+
static bool shouldInterchangeByLastIndexHeuristic(AffineForOp outer,
28+
AffineForOp inner) {
29+
Value outerIV = outer.getInductionVar();
30+
Value innerIV = inner.getInductionVar();
31+
32+
bool outerUsedAsLastIndex = false;
33+
bool innerUsedAsLastIndex = false;
34+
35+
inner.walk([&](Operation *op) {
36+
if (auto load = dyn_cast<AffineLoadOp>(op)) {
37+
auto indices = load.getIndices();
38+
if (!indices.empty()) {
39+
if (indices.back() == outerIV)
40+
outerUsedAsLastIndex = true;
41+
if (indices.back() == innerIV)
42+
innerUsedAsLastIndex = true;
43+
}
44+
}
45+
if (auto store = dyn_cast<AffineStoreOp>(op)) {
46+
auto indices = store.getIndices();
47+
if (!indices.empty()) {
48+
if (indices.back() == outerIV)
49+
outerUsedAsLastIndex = true;
50+
if (indices.back() == innerIV)
51+
innerUsedAsLastIndex = true;
52+
}
53+
}
54+
});
55+
56+
// 如果外层 iv 作为最右索引更常见,而内层不是,则值得尝试交换
57+
return outerUsedAsLastIndex && !innerUsedAsLastIndex;
58+
}
59+
60+
namespace {
61+
struct SimpleLoopInterchangePass
62+
: public PassWrapper<SimpleLoopInterchangePass,
63+
OperationPass<func::FuncOp>> {
64+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleLoopInterchangePass)
65+
66+
StringRef getArgument() const final { return "lab-simple-loop-interchange"; }
67+
StringRef getDescription() const final {
68+
return "A simple affine loop interchange pass for perfect 2-level nests";
69+
}
70+
71+
void runOnOperation() override;
72+
};
73+
74+
void SimpleLoopInterchangePass::runOnOperation() {
75+
func::FuncOp func = getOperation();
76+
77+
SmallVector<AffineForOp> candidates;
78+
func.walk([&](AffineForOp forOp) { candidates.push_back(forOp); });
79+
80+
for (AffineForOp outer : candidates) {
81+
AffineForOp inner;
82+
if (!isPerfectTwoLevelNest(outer, inner))
83+
continue;
84+
85+
if (!shouldInterchangeByLastIndexHeuristic(outer, inner))
86+
continue;
87+
88+
SmallVector<AffineForOp> loops = {outer, inner};
89+
SmallVector<unsigned> perm = {1, 0}; // 交换两层
90+
91+
if (!isValidLoopInterchangePermutation(loops, perm))
92+
continue;
93+
94+
interchangeLoops(outer, inner);
95+
}
96+
}
97+
} // namespace
98+
99+
namespace mlir {
100+
std::unique_ptr<Pass> createSimpleLoopInterchangePass() {
101+
return std::make_unique<SimpleLoopInterchangePass>();
102+
}
103+
} // namespace mlir
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
func.func @interchange_me(%A: memref<64x64xf32>, %B: memref<64x64xf32>) {
2+
affine.for %i = 0 to 64 {
3+
affine.for %j = 0 to 64 {
4+
%v = affine.load %A[%j, %i] : memref<64x64xf32>
5+
%c = arith.constant 1.0 : f32
6+
%r = arith.addf %v, %c : f32
7+
affine.store %r, %B[%j, %i] : memref<64x64xf32>
8+
}
9+
}
10+
return
11+
}
12+
13+
// we assume the tensor is stored in row-major order, so the original loop order is i-j.
14+
// expected to be transformed to:
15+
// func.func @interchange_me(%A: memref<64x64xf32>, %B: memref<64x64xf32>) {
16+
// affine.for %j = 0 to 64 {
17+
// affine.for %i = 0 to 64 {
18+
// %v = affine.load %A[%j, %i] : memref<64x64xf32>
19+
// %c = arith.constant 1.0 : f32
20+
// %r = arith.addf %v, %c : f32
21+
// affine.store %r, %B[%j, %i] : memref<64x64xf32>
22+
// }
23+
// }
24+
// return
25+
// }

0 commit comments

Comments
 (0)