|
6 | 6 | #include "mlir/IR/Value.h" |
7 | 7 | #include "llvm/ADT/DenseMap.h" |
8 | 8 |
|
| 9 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 10 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 11 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 12 | +#include "mlir/IR/Builders.h" |
| 13 | +#include "mlir/IR/BuiltinAttributes.h" |
| 14 | +#include "mlir/IR/BuiltinOps.h" |
| 15 | +#include "mlir/IR/BuiltinTypes.h" |
| 16 | +#include "mlir/IR/Diagnostics.h" |
| 17 | +#include "mlir/IR/Operation.h" |
| 18 | +#include "mlir/IR/PatternMatch.h" |
| 19 | +#include "mlir/IR/Types.h" |
| 20 | +#include "mlir/IR/Value.h" |
| 21 | +#include "mlir/IR/ValueRange.h" |
| 22 | +#include "mlir/Support/LLVM.h" |
| 23 | +#include "mlir/Transforms/DialectConversion.h" |
| 24 | + |
9 | 25 | enum class CudaShimFn { |
10 | 26 | // ----- Module ----- |
11 | 27 | LoadModuleFromImage, |
@@ -199,3 +215,84 @@ class CudaShimRegistry { |
199 | 215 | mlir::ModuleOp module; |
200 | 216 | llvm::DenseMap<unsigned, mlir::func::FuncOp> cache; |
201 | 217 | }; |
| 218 | + |
| 219 | +inline mlir::memref::GlobalOp |
| 220 | +createGlobalForStringAttr(mlir::PatternRewriter &rewriter, mlir::Operation *op, |
| 221 | + llvm::StringRef sym_name, mlir::StringAttr attr) { |
| 222 | + auto loc = op->getLoc(); |
| 223 | + auto moduleOp = op->getParentOfType<mlir::ModuleOp>(); |
| 224 | + |
| 225 | + if (auto global = moduleOp.lookupSymbol<mlir::memref::GlobalOp>(sym_name); |
| 226 | + global) { |
| 227 | + return global; |
| 228 | + } |
| 229 | + |
| 230 | + mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 231 | + rewriter.setInsertionPointToStart(moduleOp.getBody()); |
| 232 | + |
| 233 | + auto str = attr.getValue(); |
| 234 | + std::vector<uint8_t> bytes(str.begin(), str.end()); |
| 235 | + bytes.push_back(0); |
| 236 | + |
| 237 | + auto type = mlir::RankedTensorType::get({(int64_t)bytes.size()}, |
| 238 | + rewriter.getIntegerType(8)); |
| 239 | + |
| 240 | + auto memrefType = mlir::MemRefType::get({(int64_t)bytes.size()}, |
| 241 | + rewriter.getIntegerType(8)); |
| 242 | + |
| 243 | + auto denseAttr = |
| 244 | + mlir::DenseElementsAttr::get(type, llvm::ArrayRef<uint8_t>(bytes)); |
| 245 | + |
| 246 | + auto global = mlir::memref::GlobalOp::create( |
| 247 | + rewriter, loc, sym_name, |
| 248 | + /*sym_visibility=*/rewriter.getStringAttr("private"), memrefType, |
| 249 | + denseAttr, |
| 250 | + /*constant=*/true, |
| 251 | + /*alignment=*/nullptr); |
| 252 | + |
| 253 | + return global; |
| 254 | +} |
| 255 | + |
| 256 | +inline mlir::arith::IndexCastOp |
| 257 | +getIndexFromValue(mlir::PatternRewriter &rewriter, mlir::Location loc, |
| 258 | + mlir::Value value) { |
| 259 | + auto extractOp = mlir::memref::ExtractAlignedPointerAsIndexOp::create( |
| 260 | + rewriter, loc, rewriter.getIndexType(), value); |
| 261 | + auto indexCastOp = mlir::arith::IndexCastOp::create( |
| 262 | + rewriter, loc, rewriter.getI64Type(), extractOp.getResult()); |
| 263 | + return indexCastOp; |
| 264 | +} |
| 265 | + |
| 266 | +inline mlir::arith::IndexCastOp |
| 267 | +getIndexFromGlobalMemref(mlir::PatternRewriter &rewriter, mlir::Location loc, |
| 268 | + mlir::memref::GlobalOp global) { |
| 269 | + |
| 270 | + auto getGlobalOp = mlir::memref::GetGlobalOp::create( |
| 271 | + rewriter, loc, global.getType(), global.getName()); |
| 272 | + |
| 273 | + return getIndexFromValue(rewriter, loc, getGlobalOp.getResult()); |
| 274 | +} |
| 275 | + |
| 276 | +inline mlir::func::CallOp createCallToCudaShimMalloc( |
| 277 | + mlir::PatternRewriter &rewriter, mlir::Location loc, |
| 278 | + CudaShimRegistry ®istry, mlir::func::CallOp stream, |
| 279 | + mlir::arith::ConstantIntOp nbytesVal, bool isHostShared) { |
| 280 | + mlir::arith::ConstantIntOp isHostSharedVal; |
| 281 | + if (isHostShared) { |
| 282 | + isHostSharedVal = mlir::arith::ConstantIntOp::create(rewriter, loc, 1, 1); |
| 283 | + } else { |
| 284 | + isHostSharedVal = mlir::arith::ConstantIntOp::create(rewriter, loc, 0, 1); |
| 285 | + } |
| 286 | + auto sreamVal = stream.getResult(0); |
| 287 | + auto callee = |
| 288 | + registry.call(rewriter, stream, CudaShimFn::Malloc, |
| 289 | + mlir::ValueRange{nbytesVal, sreamVal, isHostSharedVal}); |
| 290 | + return callee; |
| 291 | +} |
| 292 | + |
| 293 | +inline unsigned long getNbytes(mlir::Type tensorType) { |
| 294 | + auto ranked_tensor_type = llvm::cast<mlir::MemRefType>(tensorType); |
| 295 | + return llvm::divideCeil(ranked_tensor_type.getNumElements() * |
| 296 | + ranked_tensor_type.getElementTypeBitWidth(), |
| 297 | + 8); |
| 298 | +} |
0 commit comments