Skip to content

Commit 80e2354

Browse files
committed
sync the progress not finish
1 parent 99596b0 commit 80e2354

4 files changed

Lines changed: 106 additions & 15 deletions

File tree

mlir/cuda-tile/Toy/cuda_wrapper/cuda_shim.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,11 @@ cuda_shim_launch_block_packed(uint64_t module_handle, uint64_t kernel_name_ptr,
518518
extern "C" void cuda_shim_ctx_synchronize(void) { mgpuCtxSynchronize(); }
519519

520520
// only for debugging
521-
// extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522-
// auto *p = reinterpret_cast<const float *>(static_cast<uintptr_t>(dptr));
523-
// for (uint32_t i = 0; i < n; ++i) {
524-
// fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525-
// }
526-
// }
521+
extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522+
auto *p = reinterpret_cast<const float*>(static_cast<uintptr_t>(dptr));
523+
for (uint32_t i = 0; i < n; ++i) {
524+
fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525+
}
526+
}
527527

528528
#endif

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,24 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "cuda_shim/CudaShimBuilder.hpp"
16+
#include "mlir/IR/Builders.h"
1517
#include "mlir/IR/BuiltinAttributes.h"
1618
#include "mlir/IR/BuiltinDialect.h"
1719
#include "mlir/IR/BuiltinOps.h"
1820
#include "mlir/IR/BuiltinTypeInterfaces.h"
1921
#include "mlir/IR/BuiltinTypes.h"
2022
#include "mlir/IR/Diagnostics.h"
2123
#include "mlir/IR/DialectRegistry.h"
24+
#include "mlir/IR/Operation.h"
2225
#include "mlir/IR/PatternMatch.h"
2326
#include "mlir/IR/ValueRange.h"
2427
#include "mlir/Support/LLVM.h"
2528
#include "mlir/Support/TypeID.h"
2629
#include "toy/Dialect.h"
2730
#include "toy/Passes.h"
31+
#include "llvm/ADT/StringRef.h"
32+
#include "llvm/Support/DebugLog.h"
2833

2934
#include "mlir/Dialect/Affine/IR/AffineOps.h"
3035
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -36,9 +41,8 @@
3641
#include "llvm/ADT/ArrayRef.h"
3742
#include "llvm/ADT/STLExtras.h"
3843
#include "llvm/ADT/Sequence.h"
39-
#include "llvm/ADT/StringExtras.h"
4044
#include "llvm/Support/Casting.h"
41-
#include <algorithm>
45+
#include "llvm/Support/Debug.h"
4246
#include <cstdint>
4347
#include <functional>
4448
#include <memory>
@@ -390,6 +394,92 @@ struct MatMulOpLowering : public ConversionPattern {
390394
}
391395
};
392396

397+
memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter,
398+
Operation *op,
399+
llvm::StringRef sym_name,
400+
StringAttr attr) {
401+
auto loc = op->getLoc();
402+
auto moduleOp = op->getParentOfType<ModuleOp>();
403+
404+
if (auto global = moduleOp.lookupSymbol<memref::GlobalOp>(sym_name); global) {
405+
return global;
406+
}
407+
408+
OpBuilder::InsertionGuard guard(rewriter);
409+
rewriter.setInsertionPointToStart(moduleOp.getBody());
410+
411+
auto str = attr.getValue();
412+
std::vector<uint8_t> bytes(str.begin(), str.end());
413+
bytes.push_back(0);
414+
415+
auto memrefType =
416+
MemRefType::get({(int64_t)bytes.size()}, rewriter.getIntegerType(8));
417+
418+
auto denseAttr =
419+
DenseElementsAttr::get(memrefType, llvm::ArrayRef<uint8_t>(bytes));
420+
421+
auto global = memref::GlobalOp::create(
422+
rewriter, loc, sym_name,
423+
/*sym_visibility=*/rewriter.getStringAttr("private"), memrefType,
424+
denseAttr,
425+
/*constant=*/true,
426+
/*alignment=*/nullptr);
427+
428+
return global;
429+
}
430+
431+
struct LanchGpuLowering : public ConversionPattern {
432+
LanchGpuLowering(MLIRContext *ctx)
433+
: ConversionPattern(toy::LaunchGpuOp::getOperationName(), 1, ctx) {}
434+
435+
LogicalResult
436+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
437+
ConversionPatternRewriter &rewriter) const final {
438+
auto loc = op->getLoc();
439+
CudaShimRegistry registry(op->getParentOfType<ModuleOp>());
440+
441+
toy::LaunchGpuOp launchGpuOp = llvm::cast<toy::LaunchGpuOp>(op);
442+
for (auto ranked_tensor_type : launchGpuOp->getOperands()) {
443+
if (!llvm::isa<RankedTensorType>(ranked_tensor_type.getType())) {
444+
return rewriter.notifyMatchFailure(op, "expected operand to be a "
445+
"ranked tensor type");
446+
}
447+
}
448+
449+
auto cudaBinaryPathAttr =
450+
launchGpuOp->getDiscardableAttr("cuda_binary_path");
451+
if (!cudaBinaryPathAttr) {
452+
return rewriter.notifyMatchFailure(
453+
op, "expected 'cuda_binary_path' attribute to be present");
454+
}
455+
456+
auto cudaBinaryPathStr = llvm::dyn_cast<StringAttr>(cudaBinaryPathAttr);
457+
if (!cudaBinaryPathStr) {
458+
return rewriter.notifyMatchFailure(
459+
op, "expected 'cuda_binary_path' attribute to be a string");
460+
}
461+
462+
auto cuda_blob_memref = createGlobalForStringAttr(
463+
rewriter, launchGpuOp, "cuda_blob", cudaBinaryPathStr);
464+
465+
auto kernelName = launchGpuOp.getCallee();
466+
467+
auto kernel_name_memref = createGlobalForStringAttr(
468+
rewriter, launchGpuOp, "kname", rewriter.getStringAttr(kernelName));
469+
470+
auto nbytesVal = arith::ConstantIndexOp::create(rewriter, loc, 1);
471+
auto streamVal = arith::ConstantIndexOp::create(rewriter, loc, 0);
472+
auto isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 0, 1);
473+
474+
auto callee =
475+
registry.call(rewriter, launchGpuOp, CudaShimFn::Malloc,
476+
ValueRange{nbytesVal, streamVal, isHostSharedVal});
477+
478+
rewriter.replaceOp(op, callee);
479+
return success();
480+
}
481+
};
482+
393483
} // namespace
394484

395485
//===----------------------------------------------------------------------===//
@@ -442,7 +532,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
442532
RewritePatternSet patterns(&getContext());
443533
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
444534
PrintOpLowering, ReturnOpLowering, TransposeOpLowering,
445-
MatMulOpLowering>(&getContext());
535+
MatMulOpLowering, LanchGpuLowering>(&getContext());
446536

447537
// With the target and rewrite patterns defined, we can now attempt the
448538
// conversion. The conversion will signal failure if any of our `illegal`

mlir/cuda-tile/Toy/toyc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ static int loadAndProcessMLIRGPU(mlir::MLIRContext &context,
345345

346346
// mlir::OpPassManager &gpuOptPM = pm.nest<mlir::toy::FuncOp>();
347347
// // Partially lower the toy dialect.
348-
// pm.addPass(mlir::toy::createLowerToAffinePass());
348+
pm.addPass(mlir::toy::createLowerToAffinePass());
349349

350350
// // Add a few cleanups post lowering.
351351
// mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();

mlir/cuda-tile/sample/matmul.toy

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
def main() {
22
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
33
# The shape is inferred from the supplied literal.
4-
var a = [[1, 2, 3], [4, 5, 6]];
4+
var a = [[1, 2, 3, 9], [4, 5, 6, 10]];
55

66
# b is identical to a, the literal tensor is implicitly reshaped: defining new
77
# variables is the way to reshape tensors (element count must match).
8-
var b<2, 3> = [11, 12, 13, 14, 15, 16];
8+
var b<2, 4> = [11, 12, 13, 14, 15, 16, 17, 18];
99

1010
# transpose() and print() are the only builtin, the following will transpose
1111
# a and b and perform an element-wise multiplication before printing the result.
1212
# print(a * b + b);
13-
print(matmul(a, transpose(b)));
14-
var c<2, 3> = [[7, 8, 9], [10, 11, 12]];
15-
print(a * c + b);
13+
# print(matmul(a, transpose(b)));
14+
var c<2, 4> = [[7, 8, 9, 13], [10, 11, 12, 14]];
15+
var d<2, 4> = [[7, 8, 9, 13], [10, 11, 12, 14]];
16+
print(a * c + b * d);
1617
}

0 commit comments

Comments
 (0)