Skip to content

Commit b9cfa7e

Browse files
committed
Sync: added the getglobal memref
1 parent 5007d48 commit b9cfa7e

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,16 @@ memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter,
433433
arith::IndexCastOp getIndexFromGlobalMemref(mlir::PatternRewriter &rewriter,
434434
Location loc,
435435
memref::GlobalOp global) {
436+
436437
auto getGlobalOp = memref::GetGlobalOp::create(
437-
rewriter, loc, global->getResult(0).getType(), global.getName());
438-
memref::ExtractAlignedPointerAsIndexOp extractOp =
439-
memref::ExtractAlignedPointerAsIndexOp::create(
440-
rewriter, loc, rewriter.getIndexType(), getGlobalOp.getResult());
441-
442-
auto globalType = llvm::cast<MemRefType>(global.getType());
443-
auto size = globalType.getShape()[0];
444-
auto sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
445-
return rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(),
446-
sizeValue);
438+
rewriter, loc, global.getType(), global.getName());
439+
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
440+
rewriter, loc, rewriter.getIndexType(), getGlobalOp.getResult());
441+
442+
auto indexCastOp = arith::IndexCastOp::create(
443+
rewriter, loc, rewriter.getI64Type(), extractOp.getResult());
444+
445+
return indexCastOp;
447446
}
448447

449448
struct LanchGpuLowering : public ConversionPattern {
@@ -513,11 +512,16 @@ struct LanchGpuLowering : public ConversionPattern {
513512
rewriter, launchGpuOp, "kname", rewriter.getStringAttr(kernelName));
514513

515514
// load the cuda binary path from the global memref.
516-
auto cuda_blob_loaded = memref::GetGlobalOp::create(
517-
rewriter, loc, cuda_blob_memref->getResult(0).getType(), "cuda_blob");
518-
519-
auto kname_loaded = memref::GetGlobalOp::create(
520-
rewriter, loc, kernel_name_memref->getResult(0).getType(), "kname");
515+
auto cuda_blob_index =
516+
getIndexFromGlobalMemref(rewriter, loc, cuda_blob_memref);
517+
auto kname_loaded_index =
518+
getIndexFromGlobalMemref(rewriter, loc, kernel_name_memref);
519+
520+
// Added blob size.
521+
auto blob_size =
522+
llvm::cast<MemRefType>(cuda_blob_memref.getType()).getShape()[0];
523+
arith::ConstantIndexOp blob_size_index =
524+
arith::ConstantIndexOp::create(rewriter, loc, blob_size);
521525

522526
// handle the input of the launch op, we will create a cuda allocation for
523527
// each input tensor.

0 commit comments

Comments
 (0)