Skip to content

Commit f88f2f2

Browse files
committed
Create the gpu outline pass
1 parent 7e1f45c commit f88f2f2

6 files changed

Lines changed: 377 additions & 10 deletions

File tree

mlir/cuda-tile/Toy/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ add_executable(
2828
mlir/LowerToAffineLoops.cpp
2929
mlir/LowerToLLVM.cpp
3030
mlir/ShapeInferencePass.cpp
31-
mlir/ToyCombine.cpp)
31+
mlir/ToyCombine.cpp
32+
mlir/LowerToGpu.cpp
33+
)
3234

3335
add_dependencies(toy-cuda
3436
ToyCudaShapeInferenceInterfaceIncGen

mlir/cuda-tile/Toy/include/toy/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ std::unique_ptr<mlir::Pass> createLowerToAffinePass();
2929
/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
3030
std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
3131

32+
std::unique_ptr<mlir::Pass> createGpuOutlinePass();
33+
3234
} // namespace toy
3335
} // namespace mlir
3436

mlir/cuda-tile/Toy/mlir/Dialect.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ llvm::LogicalResult ReturnOp::verify() {
379379
if (!function)
380380
return emitOpError() << "must be enclosed in a function-like op";
381381

382-
383382
/// ReturnOps can only have a single optional operand.
384383
if (getNumOperands() > 1)
385384
return emitOpError() << "expects at most 1 return operand";
@@ -498,7 +497,7 @@ llvm::LogicalResult MatMulOp::verify() {
498497
//===----------------------------------------------------------------------===//
499498

500499
void LaunchGpuOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
501-
StringRef callee, ArrayRef<mlir::Value> arguments) {
500+
StringRef callee, ArrayRef<mlir::Value> arguments) {
502501
// Generic call always returns an unranked Tensor initially.
503502
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
504503
state.addOperands(arguments);
@@ -529,21 +528,20 @@ MutableOperandRange LaunchGpuOp::getArgOperandsMutable() {
529528
return getInputsMutable();
530529
}
531530

532-
533531
//===----------------------------------------------------------------------===//
534532
// GPUFuncOp
535533
//===----------------------------------------------------------------------===//
536534

537535
void GPUFuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
538-
llvm::StringRef name, mlir::FunctionType type,
539-
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
536+
llvm::StringRef name, mlir::FunctionType type,
537+
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
540538
// FunctionOpInterface provides a convenient `build` method that will populate
541539
// the state of our GPUFuncOp, and create an entry block.
542540
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
543541
}
544542

545543
mlir::ParseResult GPUFuncOp::parse(mlir::OpAsmParser &parser,
546-
mlir::OperationState &result) {
544+
mlir::OperationState &result) {
547545
// Dispatch to the FunctionOpInterface provided utility method that parses the
548546
// function operation.
549547
auto buildFuncType =
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
#include "mlir/IR/Attributes.h"
2+
#include "mlir/IR/Block.h"
3+
#include "mlir/IR/Builders.h"
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "mlir/IR/BuiltinTypes.h"
6+
#include "mlir/IR/IRMapping.h"
7+
#include "mlir/IR/Operation.h"
8+
#include "mlir/IR/SymbolTable.h"
9+
#include "mlir/IR/Types.h"
10+
#include "mlir/IR/Value.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Support/LLVM.h"
13+
#include "mlir/Support/TypeID.h"
14+
#include "toy/Dialect.h"
15+
#include "toy/Passes.h"
16+
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/SmallPtrSet.h"
18+
#include "llvm/ADT/SmallSet.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/ADT/StringExtras.h"
21+
#include "llvm/ADT/StringRef.h"
22+
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/DebugLog.h"
24+
25+
#include <memory>
26+
#include <string>
27+
28+
#define DEBUG_TYPE "toy-gpu-outline"
29+
30+
namespace {
31+
32+
static bool isGpuOperation(mlir::Operation *op,
33+
const llvm::SmallSet<llvm::StringRef, 4> &gpuOps) {
34+
llvm::StringRef opName = op->getName().getStringRef().split('.').second;
35+
return gpuOps.contains(opName);
36+
}
37+
38+
static llvm::SmallVector<int64_t, 3> parseGrid(llvm::StringRef gridStr) {
39+
llvm::SmallVector<int64_t, 3> dims;
40+
llvm::SmallVector<llvm::StringRef, 4> pieces;
41+
gridStr.split(pieces, ',');
42+
for (llvm::StringRef piece : pieces) {
43+
int64_t value = 0;
44+
if (!piece.empty() && llvm::to_integer(piece.trim(), value))
45+
dims.push_back(value);
46+
}
47+
if (dims.size() != 3)
48+
dims = {1, 1, 1};
49+
return dims;
50+
}
51+
52+
struct GpuOutlinePass
53+
: public mlir::PassWrapper<GpuOutlinePass,
54+
mlir::OperationPass<mlir::toy::FuncOp>> {
55+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GpuOutlinePass)
56+
57+
llvm::StringRef getArgument() const override { return "toy-gpu-outline"; }
58+
59+
void runOnOperation() override {
60+
auto func = getOperation();
61+
if (func.getName() != "main")
62+
return;
63+
64+
llvm::SmallSet<llvm::StringRef, 4> gpuOperations = {"matmul", "add", "mul",
65+
"transpose"};
66+
67+
// // Collect GPU-eligible ops in block order for deterministic cloning.
68+
// llvm::SmallDenseSet<mlir::Operation *, 8> gpuOpSet;
69+
// llvm::SmallVector<mlir::Operation *> gpuOps;
70+
71+
// for (mlir::Operation &op : func.front()) {
72+
// if (isGpuOperation(&op, gpuOperations)) {
73+
// gpuOpSet.insert(&op);
74+
// gpuOps.push_back(&op);
75+
// }
76+
// }
77+
78+
// if (gpuOps.empty())
79+
// return;
80+
81+
llvm::SmallVector<int64_t, 3> gridDims = parseGrid("4,4,1");
82+
83+
llvm::SmallVector<llvm::SmallVector<mlir::Operation *>> gpuSubgraphs;
84+
85+
// Find a gpu subgraph like
86+
// [[gpuOps, ...], [gpuOps, ...], ...]
87+
// original sequence:
88+
// [..., non-gpu-op, [gpu-op, gpu-op], non-gpu-op, [gpu-op, ...]]
89+
func.walk([&](mlir::Operation *op) {
90+
if (isGpuOperation(op, gpuOperations)) {
91+
if (gpuSubgraphs.empty()) {
92+
gpuSubgraphs.push_back({op});
93+
} else {
94+
gpuSubgraphs.back().push_back(op);
95+
}
96+
} else {
97+
if (gpuSubgraphs.empty()) {
98+
gpuSubgraphs.push_back({});
99+
} else if (!gpuSubgraphs.back().empty()) {
100+
gpuSubgraphs.push_back({});
101+
}
102+
}
103+
});
104+
105+
if (gpuSubgraphs.empty())
106+
return;
107+
108+
bool allEmpty = llvm::all_of(
109+
gpuSubgraphs, [](const llvm::SmallVector<mlir::Operation *> &sg) {
110+
return sg.empty();
111+
});
112+
113+
if (allEmpty)
114+
return;
115+
116+
if (gpuSubgraphs.back().empty()) {
117+
gpuSubgraphs.pop_back();
118+
}
119+
120+
for (const auto &gpuSubgraph : gpuSubgraphs) {
121+
LDBG() << "----GPU subgraph----\n";
122+
for (const auto &op : gpuSubgraph) {
123+
LDBG() << *op << "\n";
124+
}
125+
LDBG() << "--------------------\n";
126+
}
127+
128+
llvm::SmallVector<std::string> outlinedFuncNames;
129+
llvm::SmallVector<mlir::Operation *> insertPoints;
130+
131+
// the logic to outline each gpu subgraph
132+
// 1. find operands or input for the subgraph (exclude the input inside
133+
// subgraph).
134+
// 2. find results or output for the subgraph (exclude the output inside
135+
// subgraph).
136+
// 3. create a new function with operands as input and results as output.
137+
// 4. insert a LaunchGpuOp to call the outlined function at the insert point
138+
139+
for (const auto &[index, gpuSubgraph] : llvm::enumerate(gpuSubgraphs)) {
140+
if (!gpuSubgraph.empty()) {
141+
LDBG() << "----GPU subgraph----\n";
142+
for (const auto &op : gpuSubgraph) {
143+
LDBG() << *op << "\n";
144+
}
145+
146+
// Identify its operands.
147+
llvm::SmallVector<mlir::Value, 8> Operands;
148+
llvm::SmallPtrSet<mlir::Value, 8> OperandSet;
149+
for (mlir::Operation *op : gpuSubgraph) {
150+
for (mlir::Value operand : op->getOperands()) {
151+
auto *def = operand.getDefiningOp();
152+
if (!def || !isGpuOperation(def, gpuOperations)) {
153+
if (OperandSet.insert(operand).second)
154+
Operands.push_back(operand);
155+
}
156+
}
157+
}
158+
159+
LDBG() << "Operands:\n";
160+
for (mlir::Value &operand : Operands) {
161+
LDBG() << " " << operand << "\n";
162+
}
163+
164+
llvm::SmallVector<mlir::Value, 2> Results;
165+
llvm::SmallPtrSet<mlir::Value, 2> ResultSet;
166+
167+
for (mlir::Operation *op : gpuSubgraph) {
168+
for (mlir::Value result : op->getResults()) {
169+
bool escapes =
170+
llvm::any_of(result.getUsers(), [&](mlir::Operation *user) {
171+
return !isGpuOperation(user, gpuOperations);
172+
});
173+
if (escapes && ResultSet.insert(result).second)
174+
Results.push_back(result);
175+
}
176+
}
177+
178+
LDBG() << "Results:\n";
179+
for (mlir::Value &result : Results) {
180+
LDBG() << " " << result << "\n";
181+
}
182+
183+
if (Results.size() != 1) {
184+
llvm::errs()
185+
<< "Currently only support single result GPU kernel "
186+
<< "Since the toy return op only supports single return value "
187+
<< "Found " << Results.size() << " results\n";
188+
return signalPassFailure();
189+
}
190+
191+
// buid the kernel for each subgraph
192+
llvm::SmallVector<mlir::Type, 8> argTypes;
193+
argTypes.reserve(Operands.size());
194+
for (mlir::Value v : Operands)
195+
argTypes.push_back(v.getType());
196+
197+
llvm::SmallVector<mlir::Type> resultTypes;
198+
resultTypes.reserve(Results.size());
199+
for (mlir::Value v : Results)
200+
resultTypes.push_back(v.getType());
201+
202+
mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
203+
mlir::SymbolTable symbolTable(module);
204+
std::string outline_func_name =
205+
"outlined_gpu_kernel_" + std::to_string(index);
206+
207+
unsigned suffix = 0;
208+
while (symbolTable.lookup(outline_func_name))
209+
outline_func_name =
210+
outline_func_name + "_" + std::to_string(++suffix);
211+
212+
insertPoints.push_back(gpuSubgraph.front());
213+
214+
{
215+
mlir::OpBuilder moduleBuilder(module.getContext());
216+
mlir::OpBuilder::InsertionGuard guard(moduleBuilder);
217+
moduleBuilder.setInsertionPointToEnd(module.getBody());
218+
auto funcType = moduleBuilder.getFunctionType(argTypes, resultTypes);
219+
auto gpuFunc = mlir::toy::GPUFuncOp::create(
220+
moduleBuilder, func.getLoc(), outline_func_name, funcType);
221+
222+
mlir::Block &kernelEntry = gpuFunc.getBody().front();
223+
mlir::OpBuilder kernelBuilder =
224+
mlir::OpBuilder::atBlockEnd(&kernelEntry);
225+
226+
mlir::IRMapping mapping;
227+
for (auto [blockArg, captured] :
228+
llvm::zip(kernelEntry.getArguments(), Operands))
229+
mapping.map(captured, blockArg);
230+
231+
for (mlir::Operation *op : gpuSubgraph) {
232+
kernelBuilder.clone(*op, mapping);
233+
}
234+
llvm::SmallVector<mlir::Value> mappedResults;
235+
mappedResults.reserve(Results.size());
236+
for (mlir::Value res : Results)
237+
mappedResults.push_back(mapping.lookup(res));
238+
mlir::toy::ReturnOp::create(kernelBuilder, func.getLoc(),
239+
mappedResults);
240+
241+
LDBG() << "Created GPU kernel: " << gpuFunc << "\n";
242+
}
243+
244+
outlinedFuncNames.push_back(outline_func_name);
245+
246+
{
247+
mlir::OpBuilder hostBuilder(func.getContext());
248+
mlir::OpBuilder::InsertionGuard guard(hostBuilder);
249+
// Insert the host launch in place of the first outlined op.
250+
hostBuilder.setInsertionPoint(gpuSubgraph.back()->getNextNode());
251+
252+
auto calleeAttr = mlir::SymbolRefAttr::get(
253+
func.getContext(), llvm::StringRef(outline_func_name));
254+
255+
auto gridAttr = hostBuilder.getDenseI64ArrayAttr(gridDims);
256+
257+
auto launch = mlir::toy::LaunchGpuOp::create(
258+
hostBuilder, func.getLoc(), resultTypes, Operands,
259+
{{"callee", calleeAttr}, {"grid", gridAttr}});
260+
261+
for (auto [idx, res] : llvm::enumerate(Results))
262+
res.replaceAllUsesWith(launch.getResult(idx));
263+
264+
for (mlir::Operation *op : llvm::reverse(gpuSubgraph))
265+
op->erase();
266+
LDBG() << "Inserted LaunchGpuOp: " << launch << "\n";
267+
}
268+
LDBG() << "--------------------\n";
269+
}
270+
}
271+
};
272+
};
273+
}; // namespace
274+
275+
namespace mlir::toy {
276+
277+
std::unique_ptr<mlir::Pass> createGpuOutlinePass() {
278+
return std::make_unique<GpuOutlinePass>();
279+
};
280+
281+
}; // namespace mlir::toy

0 commit comments

Comments
 (0)