Skip to content

Commit 6709060

Browse files
committed
[PyTorchSim] Fix indirect access lowering logic
1 parent 58618d3 commit 6709060

2 files changed

Lines changed: 31 additions & 22 deletions

File tree

mlir/test/lib/Analysis/TestTileOperationGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ void TestTileOperationGraph::processDramIndices(mlir::Value value,
577577
if (auto applyOp = value.getDefiningOp<mlir::affine::AffineApplyOp>()) {
578578
mlir::AffineMap map = applyOp.getAffineMap();
579579
mlir::OperandRange applyOperands = applyOp.getOperands();
580-
indirect_mode = map.getNumSymbols() !=0 && !indirect_mode ? true : indirect_mode;
580+
Attribute indirectAccessAttr = applyOp->getAttr("indirect_access");
581+
indirect_mode = indirectAccessAttr ? true : false;
581582
for (unsigned i = 0; i < applyOperands.size(); ++i) {
582583
auto operand = applyOperands[i];
583584
if (auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(operand)) {

mlir/test/lib/Conversion/MemRefToGemmini/TestMemRefToGemminiConversion.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
10+
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
11+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1012
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1113
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1214
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -138,38 +140,42 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
138140
for (auto index : indices) {
139141
if (auto applyOp = index.getDefiningOp<affine::AffineApplyOp>()) {
140142
index_map = applyOp.getAffineMap();
141-
indirect_access = index_map.getNumSymbols()!=0;
143+
Attribute indirectAccessAttr = applyOp->getAttr("indirect_access");
144+
indirect_access = indirectAccessAttr ? 1 : 0;
142145
parentIndices = applyOp.getOperands();
143146
if (indirect_access) {
144147
// FIXME. How to get converted type?
148+
bool found = false;
145149
for (mlir::Value operand : applyOp.getMapOperands()) {
150+
// Found index cast
146151
auto indexCastOp = operand.getDefiningOp<arith::IndexCastOp>();
147152
if (!indexCastOp)
148153
continue;
149-
// Found index cast
154+
150155
auto loadOp = indexCastOp.getIn().getDefiningOp<mlir::affine::AffineLoadOp>();
151156
if (!loadOp)
152157
continue;
158+
153159
// Found index spad memref
154-
bool found = false;
155-
Value indirectMemref = loadOp.getMemRef();
156-
auto indirectMemRefType = dyn_cast<MemRefType>(indirectMemref.getType());
157-
for (auto &use : indirectMemref.getUses()) {
158-
mlir::Operation* userOp = use.getOwner();
159-
if (auto castOp = llvm::dyn_cast<mlir::UnrealizedConversionCastOp>(userOp)) {
160-
indirectMemref = castOp->getResult(0);
161-
found = true;
162-
}
163-
}
164-
if (!found) {
165-
op.emitError("Failed to find converted memref...");
160+
Value rawMemref = loadOp.getMemRef();
161+
Type convertedTy = typeConverter->convertType(rawMemref.getType());
162+
if (!convertedTy)
166163
return failure();
167-
}
168-
indirect_spad_addr = getStridedElementPtr(loc, indirectMemRefType, indirectMemref, {}, rewriter);
164+
Value llvmMemref = rewriter.create<UnrealizedConversionCastOp>(op.getLoc(), convertedTy, rawMemref).getResult(0);
165+
166+
found = true;
167+
MemRefType originalType = loadOp.getMemRefType();
168+
MemRefDescriptor descriptor(llvmMemref);
169+
Value base_ptr = descriptor.allocatedPtr(rewriter, loc);
170+
indirect_spad_addr = getStridedElementPtr(loc, originalType, llvmMemref, {}, rewriter);
169171
indirect_spad_addr = rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), indirect_spad_addr);
170-
indirect_element_size = getElementBitWidth(indirectMemRefType.getElementType()) / 8;
172+
indirect_element_size = getElementBitWidth(originalType.getElementType()) / 8;
171173
break;
172174
}
175+
if (!found) {
176+
op.emitError("Failed to find indirect access for affine apply operation.");
177+
return failure();
178+
}
173179
}
174180
}
175181
}
@@ -386,21 +392,23 @@ struct TestMemRefToGemmini
386392
MLIRContext *ctx = &getContext();
387393
LowerToLLVMOptions options(ctx);
388394
LLVMTypeConverter typeConverter(ctx, options);
389-
395+
// vectorlane is passed to the pattern as an argument
390396
VECTOR_LANE = vectorlane;
397+
398+
// Lower dma_start and dma_wait operations to gemmini instructions
399+
LLVMConversionTarget target(getContext());
391400
RewritePatternSet patterns(ctx);
392-
// vectorlane is passed to the pattern as an argument
401+
target.addIllegalOp<memref::DmaStartOp>();
402+
target.addIllegalOp<memref::DmaWaitOp>();
393403
if (timing_mode)
394404
patterns.add<TimingDmaStartOpLowering>(typeConverter);
395405
else
396406
patterns.add<DmaStartOpLowering>(typeConverter);
397407
patterns.add<DmaWaitOpLowering>(typeConverter);
398-
LLVMConversionTarget target(getContext());
399408
if (failed(applyPartialConversion(getOperation(), target,
400409
std::move(patterns))))
401410
signalPassFailure();
402411
}
403-
404412
};
405413

406414
} // namespace

0 commit comments

Comments
 (0)