|
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
9 | 9 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 10 | +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
| 11 | +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
10 | 12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
11 | 13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
12 | 14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
@@ -138,38 +140,42 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> { |
138 | 140 | for (auto index : indices) { |
139 | 141 | if (auto applyOp = index.getDefiningOp<affine::AffineApplyOp>()) { |
140 | 142 | index_map = applyOp.getAffineMap(); |
141 | | - indirect_access = index_map.getNumSymbols()!=0; |
| 143 | + Attribute indirectAccessAttr = applyOp->getAttr("indirect_access"); |
| 144 | + indirect_access = indirectAccessAttr ? 1 : 0; |
142 | 145 | parentIndices = applyOp.getOperands(); |
143 | 146 | if (indirect_access) { |
144 | 147 | // FIXME. How to get converted type? |
| 148 | + bool found = false; |
145 | 149 | for (mlir::Value operand : applyOp.getMapOperands()) { |
| 150 | + // Found index cast |
146 | 151 | auto indexCastOp = operand.getDefiningOp<arith::IndexCastOp>(); |
147 | 152 | if (!indexCastOp) |
148 | 153 | continue; |
149 | | - // Found index cast |
| 154 | + |
150 | 155 | auto loadOp = indexCastOp.getIn().getDefiningOp<mlir::affine::AffineLoadOp>(); |
151 | 156 | if (!loadOp) |
152 | 157 | continue; |
| 158 | + |
153 | 159 | // Found index spad memref |
154 | | - bool found = false; |
155 | | - Value indirectMemref = loadOp.getMemRef(); |
156 | | - auto indirectMemRefType = dyn_cast<MemRefType>(indirectMemref.getType()); |
157 | | - for (auto &use : indirectMemref.getUses()) { |
158 | | - mlir::Operation* userOp = use.getOwner(); |
159 | | - if (auto castOp = llvm::dyn_cast<mlir::UnrealizedConversionCastOp>(userOp)) { |
160 | | - indirectMemref = castOp->getResult(0); |
161 | | - found = true; |
162 | | - } |
163 | | - } |
164 | | - if (!found) { |
165 | | - op.emitError("Failed to find converted memref..."); |
| 160 | + Value rawMemref = loadOp.getMemRef(); |
| 161 | + Type convertedTy = typeConverter->convertType(rawMemref.getType()); |
| 162 | + if (!convertedTy) |
166 | 163 | return failure(); |
167 | | - } |
168 | | - indirect_spad_addr = getStridedElementPtr(loc, indirectMemRefType, indirectMemref, {}, rewriter); |
| 164 | + Value llvmMemref = rewriter.create<UnrealizedConversionCastOp>(op.getLoc(), convertedTy, rawMemref).getResult(0); |
| 165 | + |
| 166 | + found = true; |
| 167 | + MemRefType originalType = loadOp.getMemRefType(); |
| 168 | + MemRefDescriptor descriptor(llvmMemref); |
| 169 | + Value base_ptr = descriptor.allocatedPtr(rewriter, loc); |
| 170 | + indirect_spad_addr = getStridedElementPtr(loc, originalType, llvmMemref, {}, rewriter); |
169 | 171 | indirect_spad_addr = rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), indirect_spad_addr); |
170 | | - indirect_element_size = getElementBitWidth(indirectMemRefType.getElementType()) / 8; |
| 172 | + indirect_element_size = getElementBitWidth(originalType.getElementType()) / 8; |
171 | 173 | break; |
172 | 174 | } |
| 175 | + if (!found) { |
| 176 | + op.emitError("Failed to find indirect access for affine apply operation."); |
| 177 | + return failure(); |
| 178 | + } |
173 | 179 | } |
174 | 180 | } |
175 | 181 | } |
@@ -386,21 +392,23 @@ struct TestMemRefToGemmini |
386 | 392 | MLIRContext *ctx = &getContext(); |
387 | 393 | LowerToLLVMOptions options(ctx); |
388 | 394 | LLVMTypeConverter typeConverter(ctx, options); |
389 | | - |
| 395 | + // vectorlane is passed to the pattern as an argument |
390 | 396 | VECTOR_LANE = vectorlane; |
| 397 | + |
| 398 | + // Lower dma_start and dma_wait operations to gemmini instructions |
| 399 | + LLVMConversionTarget target(getContext()); |
391 | 400 | RewritePatternSet patterns(ctx); |
392 | | - // vectorlane is passed to the pattern as an argument |
| 401 | + target.addIllegalOp<memref::DmaStartOp>(); |
| 402 | + target.addIllegalOp<memref::DmaWaitOp>(); |
393 | 403 | if (timing_mode) |
394 | 404 | patterns.add<TimingDmaStartOpLowering>(typeConverter); |
395 | 405 | else |
396 | 406 | patterns.add<DmaStartOpLowering>(typeConverter); |
397 | 407 | patterns.add<DmaWaitOpLowering>(typeConverter); |
398 | | - LLVMConversionTarget target(getContext()); |
399 | 408 | if (failed(applyPartialConversion(getOperation(), target, |
400 | 409 | std::move(patterns)))) |
401 | 410 | signalPassFailure(); |
402 | 411 | } |
403 | | - |
404 | 412 | }; |
405 | 413 |
|
406 | 414 | } // namespace |
|
0 commit comments