Skip to content

Commit b2e34d0

Browse files
committed
Move the helper function into cuda shim builder
1 parent 96a122c commit b2e34d0

File tree

3 files changed

+103
-84
lines changed

3 files changed

+103
-84
lines changed

mlir/cuda-tile/Toy/include/cuda_shim/CudaShimBuilder.hpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,22 @@
66
#include "mlir/IR/Value.h"
77
#include "llvm/ADT/DenseMap.h"
88

9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12+
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/Diagnostics.h"
17+
#include "mlir/IR/Operation.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/IR/Types.h"
20+
#include "mlir/IR/Value.h"
21+
#include "mlir/IR/ValueRange.h"
22+
#include "mlir/Support/LLVM.h"
23+
#include "mlir/Transforms/DialectConversion.h"
24+
925
enum class CudaShimFn {
1026
// ----- Module -----
1127
LoadModuleFromImage,
@@ -199,3 +215,84 @@ class CudaShimRegistry {
199215
mlir::ModuleOp module;
200216
llvm::DenseMap<unsigned, mlir::func::FuncOp> cache;
201217
};
218+
219+
inline mlir::memref::GlobalOp
220+
createGlobalForStringAttr(mlir::PatternRewriter &rewriter, mlir::Operation *op,
221+
llvm::StringRef sym_name, mlir::StringAttr attr) {
222+
auto loc = op->getLoc();
223+
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
224+
225+
if (auto global = moduleOp.lookupSymbol<mlir::memref::GlobalOp>(sym_name);
226+
global) {
227+
return global;
228+
}
229+
230+
mlir::OpBuilder::InsertionGuard guard(rewriter);
231+
rewriter.setInsertionPointToStart(moduleOp.getBody());
232+
233+
auto str = attr.getValue();
234+
std::vector<uint8_t> bytes(str.begin(), str.end());
235+
bytes.push_back(0);
236+
237+
auto type = mlir::RankedTensorType::get({(int64_t)bytes.size()},
238+
rewriter.getIntegerType(8));
239+
240+
auto memrefType = mlir::MemRefType::get({(int64_t)bytes.size()},
241+
rewriter.getIntegerType(8));
242+
243+
auto denseAttr =
244+
mlir::DenseElementsAttr::get(type, llvm::ArrayRef<uint8_t>(bytes));
245+
246+
auto global = mlir::memref::GlobalOp::create(
247+
rewriter, loc, sym_name,
248+
/*sym_visibility=*/rewriter.getStringAttr("private"), memrefType,
249+
denseAttr,
250+
/*constant=*/true,
251+
/*alignment=*/nullptr);
252+
253+
return global;
254+
}
255+
256+
inline mlir::arith::IndexCastOp
257+
getIndexFromValue(mlir::PatternRewriter &rewriter, mlir::Location loc,
258+
mlir::Value value) {
259+
auto extractOp = mlir::memref::ExtractAlignedPointerAsIndexOp::create(
260+
rewriter, loc, rewriter.getIndexType(), value);
261+
auto indexCastOp = mlir::arith::IndexCastOp::create(
262+
rewriter, loc, rewriter.getI64Type(), extractOp.getResult());
263+
return indexCastOp;
264+
}
265+
266+
inline mlir::arith::IndexCastOp
267+
getIndexFromGlobalMemref(mlir::PatternRewriter &rewriter, mlir::Location loc,
268+
mlir::memref::GlobalOp global) {
269+
270+
auto getGlobalOp = mlir::memref::GetGlobalOp::create(
271+
rewriter, loc, global.getType(), global.getName());
272+
273+
return getIndexFromValue(rewriter, loc, getGlobalOp.getResult());
274+
}
275+
276+
inline mlir::func::CallOp createCallToCudaShimMalloc(
277+
mlir::PatternRewriter &rewriter, mlir::Location loc,
278+
CudaShimRegistry &registry, mlir::func::CallOp stream,
279+
mlir::arith::ConstantIntOp nbytesVal, bool isHostShared) {
280+
mlir::arith::ConstantIntOp isHostSharedVal;
281+
if (isHostShared) {
282+
isHostSharedVal = mlir::arith::ConstantIntOp::create(rewriter, loc, 1, 1);
283+
} else {
284+
isHostSharedVal = mlir::arith::ConstantIntOp::create(rewriter, loc, 0, 1);
285+
}
286+
auto sreamVal = stream.getResult(0);
287+
auto callee =
288+
registry.call(rewriter, stream, CudaShimFn::Malloc,
289+
mlir::ValueRange{nbytesVal, sreamVal, isHostSharedVal});
290+
return callee;
291+
}
292+
293+
inline unsigned long getNbytes(mlir::Type tensorType) {
294+
auto ranked_tensor_type = llvm::cast<mlir::MemRefType>(tensorType);
295+
return llvm::divideCeil(ranked_tensor_type.getNumElements() *
296+
ranked_tensor_type.getElementTypeBitWidth(),
297+
8);
298+
}

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -397,84 +397,6 @@ struct MatMulOpLowering : public ConversionPattern {
397397
}
398398
};
399399

400-
memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter,
401-
Operation *op,
402-
llvm::StringRef sym_name,
403-
StringAttr attr) {
404-
auto loc = op->getLoc();
405-
auto moduleOp = op->getParentOfType<ModuleOp>();
406-
407-
if (auto global = moduleOp.lookupSymbol<memref::GlobalOp>(sym_name); global) {
408-
return global;
409-
}
410-
411-
OpBuilder::InsertionGuard guard(rewriter);
412-
rewriter.setInsertionPointToStart(moduleOp.getBody());
413-
414-
auto str = attr.getValue();
415-
std::vector<uint8_t> bytes(str.begin(), str.end());
416-
bytes.push_back(0);
417-
418-
auto type = RankedTensorType::get({(int64_t)bytes.size()},
419-
rewriter.getIntegerType(8));
420-
421-
auto memrefType =
422-
MemRefType::get({(int64_t)bytes.size()}, rewriter.getIntegerType(8));
423-
424-
auto denseAttr = DenseElementsAttr::get(type, llvm::ArrayRef<uint8_t>(bytes));
425-
426-
auto global = memref::GlobalOp::create(
427-
rewriter, loc, sym_name,
428-
/*sym_visibility=*/rewriter.getStringAttr("private"), memrefType,
429-
denseAttr,
430-
/*constant=*/true,
431-
/*alignment=*/nullptr);
432-
433-
return global;
434-
}
435-
436-
arith::IndexCastOp getIndexFromValue(mlir::PatternRewriter &rewriter,
437-
Location loc, Value value) {
438-
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
439-
rewriter, loc, rewriter.getIndexType(), value);
440-
auto indexCastOp = arith::IndexCastOp::create(
441-
rewriter, loc, rewriter.getI64Type(), extractOp.getResult());
442-
return indexCastOp;
443-
}
444-
445-
arith::IndexCastOp getIndexFromGlobalMemref(mlir::PatternRewriter &rewriter,
446-
Location loc,
447-
memref::GlobalOp global) {
448-
449-
auto getGlobalOp = memref::GetGlobalOp::create(
450-
rewriter, loc, global.getType(), global.getName());
451-
452-
return getIndexFromValue(rewriter, loc, getGlobalOp.getResult());
453-
}
454-
455-
func::CallOp
456-
createCallToCudaShimMalloc(mlir::PatternRewriter &rewriter, Location loc,
457-
CudaShimRegistry &registry, func::CallOp stream,
458-
arith::ConstantIntOp nbytesVal, bool isHostShared) {
459-
arith::ConstantIntOp isHostSharedVal;
460-
if (isHostShared) {
461-
isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 1, 1);
462-
} else {
463-
isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 0, 1);
464-
}
465-
auto sreamVal = stream.getResult(0);
466-
auto callee = registry.call(rewriter, stream, CudaShimFn::Malloc,
467-
ValueRange{nbytesVal, sreamVal, isHostSharedVal});
468-
return callee;
469-
}
470-
471-
unsigned long getNbytes(Type tensorType) {
472-
auto ranked_tensor_type = llvm::cast<MemRefType>(tensorType);
473-
return llvm::divideCeil(ranked_tensor_type.getNumElements() *
474-
ranked_tensor_type.getElementTypeBitWidth(),
475-
8);
476-
}
477-
478400
struct LanchGpuLowering : public OpConversionPattern<toy::LaunchGpuOp> {
479401
using OpConversionPattern<toy::LaunchGpuOp>::OpConversionPattern;
480402

mlir/cuda-tile/sample/test.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,15 @@ module {
127127
memref.store %12, %alloc_35[%c3_52] : memref<4xi64>
128128
%c8_i64_53 = arith.constant 8 : i64
129129
memref.store %c8_i64_53, %alloc_36[%c3_52] : memref<4xi64>
130+
%c8_i32 = arith.constant 8 : i32
130131
%c1_i32 = arith.constant 1 : i32
131132
%c1_i32_54 = arith.constant 1 : i32
132-
%c1_i32_55 = arith.constant 1 : i32
133133
%c4_i32 = arith.constant 4 : i32
134-
%intptr_56 = memref.extract_aligned_pointer_as_index %alloc_35 : memref<4xi64> -> index
135-
%14 = arith.index_cast %intptr_56 : index to i64
136-
%intptr_57 = memref.extract_aligned_pointer_as_index %alloc_36 : memref<4xi64> -> index
137-
%15 = arith.index_cast %intptr_57 : index to i64
138-
call @cuda_shim_launch_block_packed(%4, %3, %c1_i32, %c1_i32_54, %c1_i32_55, %5, %14, %15, %c4_i32) : (i64, i64, i32, i32, i32, i64, i64, i64, i32) -> ()
134+
%intptr_55 = memref.extract_aligned_pointer_as_index %alloc_35 : memref<4xi64> -> index
135+
%14 = arith.index_cast %intptr_55 : index to i64
136+
%intptr_56 = memref.extract_aligned_pointer_as_index %alloc_36 : memref<4xi64> -> index
137+
%15 = arith.index_cast %intptr_56 : index to i64
138+
call @cuda_shim_launch_block_packed(%4, %3, %c8_i32, %c1_i32, %c1_i32_54, %5, %14, %15, %c4_i32) : (i64, i64, i32, i32, i32, i64, i64, i64, i32) -> ()
139139
call @cuda_shim_stream_synchronize(%5) : (i64) -> ()
140140
call @cuda_shim_memcpy_d2h(%13, %12, %c32_i64_49) : (i64, i64, i64) -> ()
141141
memref.dealloc %alloc_35 : memref<4xi64>

0 commit comments

Comments
 (0)