|
12 | 12 | // |
13 | 13 | //===----------------------------------------------------------------------===// |
14 | 14 |
|
| 15 | +#include "cuda_shim/CudaShimBuilder.hpp" |
| 16 | +#include "mlir/IR/Builders.h" |
15 | 17 | #include "mlir/IR/BuiltinAttributes.h" |
16 | 18 | #include "mlir/IR/BuiltinDialect.h" |
17 | 19 | #include "mlir/IR/BuiltinOps.h" |
18 | 20 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
19 | 21 | #include "mlir/IR/BuiltinTypes.h" |
20 | 22 | #include "mlir/IR/Diagnostics.h" |
21 | 23 | #include "mlir/IR/DialectRegistry.h" |
| 24 | +#include "mlir/IR/Operation.h" |
22 | 25 | #include "mlir/IR/PatternMatch.h" |
23 | 26 | #include "mlir/IR/ValueRange.h" |
24 | 27 | #include "mlir/Support/LLVM.h" |
25 | 28 | #include "mlir/Support/TypeID.h" |
26 | 29 | #include "toy/Dialect.h" |
27 | 30 | #include "toy/Passes.h" |
| 31 | +#include "llvm/ADT/StringRef.h" |
| 32 | +#include "llvm/Support/DebugLog.h" |
28 | 33 |
|
29 | 34 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
30 | 35 | #include "mlir/Dialect/Arith/IR/Arith.h" |
|
36 | 41 | #include "llvm/ADT/ArrayRef.h" |
37 | 42 | #include "llvm/ADT/STLExtras.h" |
38 | 43 | #include "llvm/ADT/Sequence.h" |
39 | | -#include "llvm/ADT/StringExtras.h" |
40 | 44 | #include "llvm/Support/Casting.h" |
41 | | -#include <algorithm> |
| 45 | +#include "llvm/Support/Debug.h" |
42 | 46 | #include <cstdint> |
43 | 47 | #include <functional> |
44 | 48 | #include <memory> |
@@ -390,6 +394,92 @@ struct MatMulOpLowering : public ConversionPattern { |
390 | 394 | } |
391 | 395 | }; |
392 | 396 |
|
| 397 | +memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter, |
| 398 | + Operation *op, |
| 399 | + llvm::StringRef sym_name, |
| 400 | + StringAttr attr) { |
| 401 | + auto loc = op->getLoc(); |
| 402 | + auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 403 | + |
| 404 | + if (auto global = moduleOp.lookupSymbol<memref::GlobalOp>(sym_name); global) { |
| 405 | + return global; |
| 406 | + } |
| 407 | + |
| 408 | + OpBuilder::InsertionGuard guard(rewriter); |
| 409 | + rewriter.setInsertionPointToStart(moduleOp.getBody()); |
| 410 | + |
| 411 | + auto str = attr.getValue(); |
| 412 | + std::vector<uint8_t> bytes(str.begin(), str.end()); |
| 413 | + bytes.push_back(0); |
| 414 | + |
| 415 | + auto memrefType = |
| 416 | + MemRefType::get({(int64_t)bytes.size()}, rewriter.getIntegerType(8)); |
| 417 | + |
| 418 | + auto denseAttr = |
| 419 | + DenseElementsAttr::get(memrefType, llvm::ArrayRef<uint8_t>(bytes)); |
| 420 | + |
| 421 | + auto global = memref::GlobalOp::create( |
| 422 | + rewriter, loc, sym_name, |
| 423 | + /*sym_visibility=*/rewriter.getStringAttr("private"), memrefType, |
| 424 | + denseAttr, |
| 425 | + /*constant=*/true, |
| 426 | + /*alignment=*/nullptr); |
| 427 | + |
| 428 | + return global; |
| 429 | +} |
| 430 | + |
| 431 | +struct LanchGpuLowering : public ConversionPattern { |
| 432 | + LanchGpuLowering(MLIRContext *ctx) |
| 433 | + : ConversionPattern(toy::LaunchGpuOp::getOperationName(), 1, ctx) {} |
| 434 | + |
| 435 | + LogicalResult |
| 436 | + matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 437 | + ConversionPatternRewriter &rewriter) const final { |
| 438 | + auto loc = op->getLoc(); |
| 439 | + CudaShimRegistry registry(op->getParentOfType<ModuleOp>()); |
| 440 | + |
| 441 | + toy::LaunchGpuOp launchGpuOp = llvm::cast<toy::LaunchGpuOp>(op); |
| 442 | + for (auto ranked_tensor_type : launchGpuOp->getOperands()) { |
| 443 | + if (!llvm::isa<RankedTensorType>(ranked_tensor_type.getType())) { |
| 444 | + return rewriter.notifyMatchFailure(op, "expected operand to be a " |
| 445 | + "ranked tensor type"); |
| 446 | + } |
| 447 | + } |
| 448 | + |
| 449 | + auto cudaBinaryPathAttr = |
| 450 | + launchGpuOp->getDiscardableAttr("cuda_binary_path"); |
| 451 | + if (!cudaBinaryPathAttr) { |
| 452 | + return rewriter.notifyMatchFailure( |
| 453 | + op, "expected 'cuda_binary_path' attribute to be present"); |
| 454 | + } |
| 455 | + |
| 456 | + auto cudaBinaryPathStr = llvm::dyn_cast<StringAttr>(cudaBinaryPathAttr); |
| 457 | + if (!cudaBinaryPathStr) { |
| 458 | + return rewriter.notifyMatchFailure( |
| 459 | + op, "expected 'cuda_binary_path' attribute to be a string"); |
| 460 | + } |
| 461 | + |
| 462 | + auto cuda_blob_memref = createGlobalForStringAttr( |
| 463 | + rewriter, launchGpuOp, "cuda_blob", cudaBinaryPathStr); |
| 464 | + |
| 465 | + auto kernelName = launchGpuOp.getCallee(); |
| 466 | + |
| 467 | + auto kernel_name_memref = createGlobalForStringAttr( |
| 468 | + rewriter, launchGpuOp, "kname", rewriter.getStringAttr(kernelName)); |
| 469 | + |
| 470 | + auto nbytesVal = arith::ConstantIndexOp::create(rewriter, loc, 1); |
| 471 | + auto streamVal = arith::ConstantIndexOp::create(rewriter, loc, 0); |
| 472 | + auto isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 0, 1); |
| 473 | + |
| 474 | + auto callee = |
| 475 | + registry.call(rewriter, launchGpuOp, CudaShimFn::Malloc, |
| 476 | + ValueRange{nbytesVal, streamVal, isHostSharedVal}); |
| 477 | + |
| 478 | + rewriter.replaceOp(op, callee); |
| 479 | + return success(); |
| 480 | + } |
| 481 | +}; |
| 482 | + |
393 | 483 | } // namespace |
394 | 484 |
|
395 | 485 | //===----------------------------------------------------------------------===// |
@@ -442,7 +532,7 @@ void ToyToAffineLoweringPass::runOnOperation() { |
442 | 532 | RewritePatternSet patterns(&getContext()); |
443 | 533 | patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering, |
444 | 534 | PrintOpLowering, ReturnOpLowering, TransposeOpLowering, |
445 | | - MatMulOpLowering>(&getContext()); |
| 535 | + MatMulOpLowering, LanchGpuLowering>(&getContext()); |
446 | 536 |
|
447 | 537 | // With the target and rewrite patterns defined, we can now attempt the |
448 | 538 | // conversion. The conversion will signal failure if any of our `illegal` |
|
0 commit comments