Skip to content

Commit bb633c5

Browse files
committed
Add implementation and tests related to fusion feasibility analysis.
1 parent 3e6d6e2 commit bb633c5

File tree

9 files changed

+711
-10
lines changed

9 files changed

+711
-10
lines changed

mlir/optimization/scheduler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ add_executable(
1515
lib/BufferAnalysisPass.cpp
1616
lib/LivenessAnalysisPass.cpp
1717
lib/MemrefLifetime.cpp
18+
lib/FusionFeasibility.cpp
19+
lib/LivenessAdapter.cpp
1820
)
1921

2022
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#pragma once
2+
3+
#include "mlir/Analysis/Liveness.h"
4+
#include "mlir/IR/Operation.h"
5+
#include "mlir/IR/Value.h"
6+
#include "llvm/ADT/SmallVector.h"
7+
#include <cstdint>
8+
9+
#include "mlir/Analysis/Liveness.h"
10+
#include "mlir/IR/Operation.h"
11+
#include "mlir/IR/Value.h"
12+
#include "llvm/ADT/DenseMap.h"
13+
#include <cstdint>
14+
#include <optional>
15+
16+
namespace mlir {
17+
namespace lab {
18+
19+
enum class FusionReasonKind {
20+
// positive
21+
DirectUse,
22+
SingleUse,
23+
SupportedProducer,
24+
SupportedConsumer,
25+
ElementwiseConsumer,
26+
StaticShapeCompatible,
27+
NoSideEffect,
28+
IntermediateEliminable,
29+
PeakMemoryAcceptable,
30+
TrafficReductionExpected,
31+
32+
// negative
33+
NullProducer,
34+
MultiUseProducer,
35+
UnsupportedProducer,
36+
UnsupportedConsumer,
37+
NonElementwiseConsumer,
38+
DynamicShapeUnsupported,
39+
ShapeMismatch,
40+
SideEffectingProducer,
41+
SideEffectingConsumer,
42+
IntermediateNotEliminable,
43+
PeakMemoryTooHigh
44+
};
45+
46+
struct FusionFeasibilityResult {
47+
Operation *producer = nullptr;
48+
Operation *consumer = nullptr;
49+
OpOperand *fusedOperand = nullptr;
50+
51+
bool isFusable = false;
52+
bool isProfitable = false;
53+
54+
int64_t eliminatedIntermediateBytes = 0;
55+
int64_t estimatedTrafficSavedBytes = 0;
56+
int64_t estimatedPeakBeforeBytes = 0;
57+
int64_t estimatedPeakAfterBytes = 0;
58+
int64_t extraLivenessGrowthBytes = 0;
59+
60+
double score = 0.0;
61+
62+
llvm::SmallVector<FusionReasonKind> reasons;
63+
};
64+
65+
const char *toString(FusionReasonKind kind);
66+
67+
} // namespace lab
68+
} // namespace mlir

mlir/optimization/scheduler/include/lab/LabPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ std::unique_ptr<Pass> createLabMatmulTilePass();
1414
std::unique_ptr<Pass> createLabPipelinePlanPass();
1515
std::unique_ptr<Pass> createLabLivenessPass();
1616
std::unique_ptr<Pass> createLabMemrefLifetimePass();
17+
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
1718

1819
} // namespace mlir
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#pragma once
2+
3+
#include "mlir/Analysis/Liveness.h"
4+
#include "mlir/IR/Operation.h"
5+
#include "mlir/IR/Value.h"
6+
#include "llvm/ADT/DenseMap.h"
7+
#include <cstdint>
8+
#include <optional>
9+
10+
namespace mlir {
11+
namespace lab {
12+
13+
struct ValueLifetimeInfo {
14+
int64_t start = -1;
15+
int64_t end = -1;
16+
int64_t sizeBytes = 0;
17+
};
18+
19+
class LivenessAdapter {
20+
public:
21+
virtual ~LivenessAdapter() = default;
22+
virtual std::optional<ValueLifetimeInfo> lookup(Value v) const = 0;
23+
virtual int64_t getPeakLiveBytes(Operation *scope) const = 0;
24+
};
25+
26+
class MlirLivenessAdapter final : public LivenessAdapter {
27+
public:
28+
explicit MlirLivenessAdapter(Operation *scope);
29+
30+
std::optional<ValueLifetimeInfo> lookup(Value v) const override;
31+
int64_t getPeakLiveBytes(Operation *scope) const override;
32+
33+
private:
34+
Operation *topScope;
35+
mlir::Liveness liveness;
36+
37+
llvm::DenseMap<Operation *, int64_t> opOrder;
38+
llvm::DenseMap<Value, ValueLifetimeInfo> valueInfo;
39+
int64_t cachedPeakLiveBytes = 0;
40+
41+
void buildOperationOrder(Operation *scope);
42+
void buildValueLifetimeInfo(Operation *scope);
43+
void buildPeakLiveBytes(Operation *scope);
44+
45+
static int64_t getValueSizeBytes(Value v);
46+
};
47+
48+
} // namespace lab
49+
} // namespace mlir

mlir/optimization/scheduler/lab-opt.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,31 @@ int main(int argc, char **argv) {
1919
registry.insert<mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
2020
mlir::arith::ArithDialect, mlir::tensor::TensorDialect,
2121
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
22-
mlir::affine::AffineDialect,
23-
mlir::cf::ControlFlowDialect>();
22+
mlir::affine::AffineDialect, mlir::cf::ControlFlowDialect>();
2423

2524
mlir::registerAllPasses();
2625
mlir::PassPipelineRegistration<>("lab-op-stats", "Lab Op Stats Pass",
2726
[](mlir::OpPassManager &pm) {
2827
pm.addPass(mlir::createLabOpStatsPass());
2928
});
30-
mlir::PassPipelineRegistration<>("lab-buffer-stats", "Lab Buffer Stats Pass",
31-
[](mlir::OpPassManager &pm) {
32-
pm.addPass(mlir::createLabBufferStatsPass());
33-
});
29+
mlir::PassPipelineRegistration<>(
30+
"lab-buffer-stats", "Lab Buffer Stats Pass", [](mlir::OpPassManager &pm) {
31+
pm.addPass(mlir::createLabBufferStatsPass());
32+
});
3433
mlir::PassPipelineRegistration<>("lab-liveness", "Lab Liveness Pass",
3534
[](mlir::OpPassManager &pm) {
3635
pm.addPass(mlir::createLabLivenessPass());
3736
});
38-
mlir::PassPipelineRegistration<>("lab-memref-lifetime", "Lab Memref Lifetime Pass",
39-
[](mlir::OpPassManager &pm) {
40-
pm.addPass(mlir::createLabMemrefLifetimePass());
41-
});
37+
mlir::PassPipelineRegistration<>(
38+
"lab-memref-lifetime", "Lab Memref Lifetime Pass",
39+
[](mlir::OpPassManager &pm) {
40+
pm.addPass(mlir::createLabMemrefLifetimePass());
41+
});
42+
mlir::PassPipelineRegistration<>(
43+
"lab-fusion-feasibility", "Lab Fusion Feasibility Pass",
44+
[](mlir::OpPassManager &pm) {
45+
pm.addPass(mlir::createLabFusionFeasibilityPass());
46+
});
4247

4348
return mlir::asMainReturnCode(
4449
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));

0 commit comments

Comments
 (0)