diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index 75f15c54d..cdbd4f6a6 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - //===- DialectPTO.cpp -----------------------------------------------------===// // // Python bindings for the PTO dialect types (pybind11 version). @@ -52,7 +47,8 @@ void populatePTODialectSubmodule(pybind11::module &m); void populatePTODialectSubmodule(pybind11::module &m) { (void)m; } -PYBIND11_MODULE(_pto, m) { + +static void bindPTOModule(pybind11::module &m) { m.doc() = "PTO dialect Python bindings (pybind11)."; // -------------------------------------------------------------------------- @@ -744,3 +740,7 @@ PYBIND11_MODULE(_pto, m) { populatePTODialectSubmodule(m); } + +PYBIND11_MODULE(_pto, m) { + bindPTOModule(m); +} diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 651daa101..6a46604b7 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -35,6 +35,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/Support/ErrorHandling.h" #include #include @@ -1205,9 +1206,11 @@ void PTODialect::initialize() { AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { auto memRefType = dyn_cast(type); - assert(memRefType && "input type must be a memref type"); + if (!memRefType) + return {}; auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); - assert(scopeAttr && "memory scope should be a pto address scope"); + if (!scopeAttr) + return {}; return scopeAttr; } @@ -5151,14 +5154,15 @@ void mlir::pto::TMrgSortOp::print(OpAsmPrinter &p) { p << " ins(" << getSrc() << ", " << getBlockLen() << " : " << getSrc().getType() << ", " << getBlockLen().getType() << ") outs(" << getDst() << " : " << getDst().getType() << ")"; - } else { - assert(isFormat2()); + } else if (isFormat2()) { p << " ins(" << getSrcs()[0] << ", " << getSrcs()[1] << ", " << getSrcs()[2] << ", " << getSrcs()[3] << " {exhausted = " << (getExhausted() ? "true" : "false") << "} : " << getSrcs()[0].getType() << ", " << getSrcs()[1].getType() << ", " << getSrcs()[2].getType() << ", " << getSrcs()[3].getType() << ") outs(" << getDst() << ", " << getTmp() << ", " << getExcuted() << " : " << getDst().getType() << ", " << getTmp().getType() << ", " << getExcuted().getType() << ")"; + } else { + llvm::report_fatal_error("TMrgSortOp print expects format1 or format2"); } p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes", "exhausted"}); } diff --git a/lib/PTO/IR/PTOAttrs.cpp b/lib/PTO/IR/PTOAttrs.cpp index 04a55fef3..be89fa648 100644 --- a/lib/PTO/IR/PTOAttrs.cpp +++ b/lib/PTO/IR/PTOAttrs.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - //===- PTOAttrs.cpp ------------------------------------------------*- C++ -*-===// #include "PTO/IR/PTO.h" #include "mlir/IR/Builders.h" @@ -132,7 +127,8 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { if (succeeded(p.parseOptionalGreater())) return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); - while (true) { + bool parsedGreater = false; + while (!parsedGreater) { StringRef key; if (p.parseKeyword(&key)) return {}; if (p.parseEqual()) return {}; @@ -166,7 +162,8 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { return {}; } - if (succeeded(p.parseOptionalGreater())) + parsedGreater = succeeded(p.parseOptionalGreater()); + if (parsedGreater) break; if (p.parseComma()) return {}; } diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index 503774abc..295f01f4d 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -6,14 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - //===- PTOTypeDefs.cpp --------------------------------------------*- C++ -*-===// #include "PTO/IR/PTO.h" #include "mlir/IR/DialectImplementation.h" +#include using namespace mlir; using namespace mlir::pto; @@ -49,6 +45,119 @@ static SmallVector canonicalizeTileBufValidShape(ArrayRef v return canonical; } +static LogicalResult parseTileBufKeyEq(AsmParser &parser, + StringRef expectedKey) { + if (failed(parser.parseKeyword(expectedKey))) + return failure(); + return parser.parseEqual(); +} + +static LogicalResult parseTileBufComma(AsmParser &parser) { + return parser.parseComma(); +} + +static LogicalResult parseTileBufKeywordField(AsmParser &parser, StringRef key, + std::string &value) { + if (failed(parseTileBufKeyEq(parser, key))) + return failure(); + if (failed(parser.parseKeywordOrString(&value))) + return failure(); + return parseTileBufComma(parser); +} + +static LogicalResult parseTileBufTypeField(AsmParser &parser, StringRef key, + Type &value) { + if (failed(parseTileBufKeyEq(parser, key))) + return failure(); + if (failed(parser.parseType(value))) + return failure(); + return parseTileBufComma(parser); +} + +static LogicalResult parseTileBufIntegerField(AsmParser &parser, StringRef key, + int64_t &value) { + if (failed(parseTileBufKeyEq(parser, key))) + return failure(); + if (failed(parser.parseInteger(value))) + return failure(); + return parseTileBufComma(parser); +} + +static LogicalResult parseTileBufValidDim(AsmParser &parser, StringRef key, + int64_t &value) { + if (failed(parseTileBufKeyEq(parser, key))) + return failure(); + + if (succeeded(parser.parseOptionalQuestion())) { + value = -1; + return success(); + } + + if (failed(parser.parseInteger(value))) + return failure(); + if (value < -1) { + parser.emitError(parser.getCurrentLocation(), + key + " must be '?', -1, or a non-negative integer"); + return failure(); + } + return success(); +} + +static LogicalResult parseTileBufValidShapeFields(AsmParser &parser, + int64_t &vrow, + int64_t &vcol) { + if (failed(parseTileBufValidDim(parser, "v_row", vrow))) + return failure(); + if (failed(parseTileBufComma(parser))) + return failure(); + if (failed(parseTileBufValidDim(parser, "v_col", vcol))) + return failure(); + return parseTileBufComma(parser); +} + +static LogicalResult parseTileBufPadField(AsmParser &parser, uint32_t &padInt) { + int64_t parsedPad = 0; + if (failed(parseTileBufKeyEq(parser, "pad"))) + return failure(); + if (failed(parser.parseInteger(parsedPad))) + return failure(); + if (parsedPad < 0 || parsedPad > std::numeric_limits::max()) { + parser.emitError(parser.getCurrentLocation(), + "pad must be a non-negative 32-bit integer"); + return failure(); + } + padInt = static_cast(parsedPad); + return success(); +} + +static std::optional resolveTileBufMemorySpace(StringRef locStr) { + return ::llvm::StringSwitch<::std::optional>(locStr) + .Case("mat", AddressSpace::MAT) + .Case("left", AddressSpace::LEFT) + .Case("right", AddressSpace::RIGHT) + .Case("acc", AddressSpace::ACC) + .Case("vec", AddressSpace::VEC) + .Case("bias", AddressSpace::BIAS) + .Case("scaling", AddressSpace::SCALING) + .Default(::std::nullopt); +} + +static BLayout resolveTileBufBLayout(AddressSpace memorySpace, + BLayout parsedLayout) { + if (memorySpace != AddressSpace::LEFT) + return parsedLayout; + + switch (getPTOParserTargetArch()) { + case PTOParserTargetArch::A3: + return BLayout::RowMajor; + case PTOParserTargetArch::A5: + return BLayout::ColMajor; + case PTOParserTargetArch::Unspecified: + return parsedLayout; + } + return parsedLayout; +} + TileBufConfigAttr TileBufType::getConfigAttr() const { // 情况 A:getConfig() 已经是 TileBufConfigAttr if constexpr (std::is_same_v) { @@ -119,108 +228,31 @@ Type TileBufType::parse(AsmParser &parser) { uint32_t padInt; uint32_t compactInt = 0; - auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult { - if (failed(parser.parseKeyword(expectedKey))) - return failure(); - if (failed(parser.parseEqual())) - return failure(); - return success(); - }; - - // loc=Vec - { - if (failed(parseKeyEq("loc"))) return Type(); - // Vec/Mat/Acc 不是类型/属性,直接当 keyword/string 读 - if (failed(parser.parseKeywordOrString(&locStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - // dtype=f16 - { - if (failed(parseKeyEq("dtype"))) return Type(); - if (failed(parser.parseType(dtype))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - // rows=16 - { - if (failed(parseKeyEq("rows"))) return Type(); - if (failed(parser.parseInteger(rows))) return Type(); - if (failed(parser.parseComma())) return Type(); + if (failed(parseTileBufKeywordField(parser, "loc", locStr)) || + failed(parseTileBufTypeField(parser, "dtype", dtype)) || + failed(parseTileBufIntegerField(parser, "rows", rows)) || + failed(parseTileBufIntegerField(parser, "cols", cols)) || + failed(parseTileBufValidShapeFields(parser, vrow, vcol)) || + failed(parseTileBufKeywordField(parser, "blayout", blayoutStr)) || + failed(parseTileBufKeywordField(parser, "slayout", slayoutStr)) || + failed(parseTileBufIntegerField(parser, "fractal", fractal)) || + failed(parseTileBufPadField(parser, padInt))) { + return Type(); } - // cols=16 - { - if (failed(parseKeyEq("cols"))) return Type(); - if (failed(parser.parseInteger(cols))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - { - // v_row=?/-1/16 , v_col=?/-1/8 (支持半动态) - if (failed(parseKeyEq("v_row"))) return Type(); - - // 解析 v_row:'?' -> -1,否则整数(允许 -1 兼容) - if (succeeded(parser.parseOptionalQuestion())) { - vrow = -1; - } else { - if (failed(parser.parseInteger(vrow))) return Type(); - if (vrow < -1) { - parser.emitError(parser.getCurrentLocation(), - "v_row must be '?', -1, or a non-negative integer"); - return Type(); - } + if (succeeded(parser.parseOptionalComma())) { + int64_t parsedCompact = 0; + if (failed(parseTileBufKeyEq(parser, "compact")) || + failed(parser.parseInteger(parsedCompact))) { + return Type(); } - - if (failed(parser.parseComma())) return Type(); - - if (failed(parseKeyEq("v_col"))) return Type(); - - // 解析 v_col:'?' -> -1,否则整数(允许 -1 兼容) - if (succeeded(parser.parseOptionalQuestion())) { - vcol = -1; - } else { - if (failed(parser.parseInteger(vcol))) return Type(); - if (vcol < -1) { - parser.emitError(parser.getCurrentLocation(), - "v_col must be '?', -1, or a non-negative integer"); - return Type(); - } + if (parsedCompact < 0 || + parsedCompact > std::numeric_limits::max()) { + parser.emitError(parser.getCurrentLocation(), + "compact must be a non-negative 32-bit integer"); + return Type(); } - if (failed(parser.parseComma())) return Type(); - } - - // blayout=RowMajor - { - if (failed(parseKeyEq("blayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&blayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - - // slayout=NoneBox - { - if (failed(parseKeyEq("slayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&slayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - // fractal=512 - { - if (failed(parseKeyEq("fractal"))) return Type(); - if (failed(parser.parseInteger(fractal))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - // pad=0 - { - if (failed(parseKeyEq("pad"))) return Type(); - if (failed(parser.parseInteger(padInt))) return Type(); - } - - if (succeeded(parser.parseOptionalComma())) { - if (failed(parseKeyEq("compact"))) return Type(); - if (failed(parser.parseInteger(compactInt))) return Type(); + compactInt = static_cast(parsedCompact); } if (failed(parser.parseGreater())) @@ -232,15 +264,7 @@ Type TileBufType::parse(AsmParser &parser) { return Type(); } - auto memorySpace = ::llvm::StringSwitch<::std::optional>(locStr) - .Case("mat", AddressSpace::MAT) - .Case("left", AddressSpace::LEFT) - .Case("right", AddressSpace::RIGHT) - .Case("acc", AddressSpace::ACC) - .Case("vec", AddressSpace::VEC) - .Case("bias", AddressSpace::BIAS) - .Case("scaling", AddressSpace::SCALING) - .Default(::std::nullopt); + auto memorySpace = resolveTileBufMemorySpace(locStr); if (!memorySpace.has_value()) { parser.emitError(parser.getNameLoc(), "unknown loc: ") << locStr; return Type(); @@ -267,19 +291,8 @@ Type TileBufType::parse(AsmParser &parser) { return Type(); } - BLayout effectiveBLayout = bl.value(); - if (memorySpace.value() == AddressSpace::LEFT) { - switch (getPTOParserTargetArch()) { - case PTOParserTargetArch::A3: - effectiveBLayout = BLayout::RowMajor; - break; - case PTOParserTargetArch::A5: - effectiveBLayout = BLayout::ColMajor; - break; - case PTOParserTargetArch::Unspecified: - break; - } - } + BLayout effectiveBLayout = + resolveTileBufBLayout(memorySpace.value(), bl.value()); auto blAttr = BLayoutAttr::get(ctx, effectiveBLayout); auto slAttr = SLayoutAttr::get(ctx, sl.value()); diff --git a/lib/PTO/Transforms/AllocToPointerCast.cpp b/lib/PTO/Transforms/AllocToPointerCast.cpp index b7dc44bd9..ba9be008a 100644 --- a/lib/PTO/Transforms/AllocToPointerCast.cpp +++ b/lib/PTO/Transforms/AllocToPointerCast.cpp @@ -25,13 +25,8 @@ using namespace mlir::pto; namespace {} // namespace -LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( - memref::AllocOp op, PatternRewriter &rewriter) const { - const auto ¤tMemRefType = cast(op.getType()); - - // Preserve tile config carried by the downstream bind_tile user. Losing this - // metadata here makes PointerCast lowering fall back to RowMajor defaults, - // which can generate illegal intermediate TRESHAPE sequences. +namespace { +static TileBufConfigAttr inferBindTileConfig(memref::AllocOp op) { TileBufConfigAttr configAttr; for (Operation *user : op.getResult().getUsers()) { auto bind = dyn_cast(user); @@ -47,12 +42,15 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( break; } } - + return configAttr; +} + +static SmallVector getAllocatedOffsets(memref::AllocOp op, + BaseMemRefType memRefType, + const DenseMap> &buffer2Offsets, + uint64_t &fallbackNextOffset) { constexpr uint64_t kAlign = 4096; auto iter = buffer2Offsets.find(op.getResult()); - - // If MemPlan didn't assign an address, synthesize a unique, aligned offset so - // downstream PointerCast lowering won't crash on empty addrs. SmallVector offsets; if (iter != buffer2Offsets.end()) offsets = iter->second; @@ -61,12 +59,15 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( // Estimate buffer size (best-effort). Most PTO tile buffers are 32x32 and // naturally align to 4096 bytes. uint64_t bytes = kAlign; - if (auto memrefTy = dyn_cast(currentMemRefType)) { + if (auto memrefTy = dyn_cast(memRefType)) { uint64_t elemBytes = 0; Type elemTy = memrefTy.getElementType(); - if (elemTy.isF16()) elemBytes = 2; - else if (elemTy.isF32()) elemBytes = 4; - else if (auto it = dyn_cast(elemTy)) elemBytes = it.getWidth() / 8; + if (elemTy.isF16()) + elemBytes = 2; + else if (elemTy.isF32()) + elemBytes = 4; + else if (auto it = dyn_cast(elemTy)) + elemBytes = it.getWidth() / 8; if (elemBytes != 0) { uint64_t numel = 1; @@ -87,7 +88,29 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( fallbackNextOffset += std::max(stride, kAlign); offsets.push_back(off); } + return offsets; +} + +static std::pair getDynamicValidShapeValues(memref::AllocOp op) { + Value vRow; + Value vCol; + auto dynSizes = op.getDynamicSizes(); + if (dynSizes.size() >= 2) { + vRow = dynSizes[0]; + vCol = dynSizes[1]; + } else if (dynSizes.size() == 1) { + vCol = dynSizes[0]; + } + return {vRow, vCol}; +} +} // namespace +LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( + memref::AllocOp op, PatternRewriter &rewriter) const { + const auto ¤tMemRefType = cast(op.getType()); + TileBufConfigAttr configAttr = inferBindTileConfig(op); + SmallVector offsets = getAllocatedOffsets( + op, currentMemRefType, buffer2Offsets, fallbackNextOffset); SmallVector addrs; addrs.reserve(offsets.size()); for (uint64_t offset : offsets) { @@ -96,35 +119,10 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( addrs.push_back(constantIntOffsetOp); } - // [修改 1] 从 ValueRange 中拆解出 row 和 col - // memref.alloc 的 getDynamicSizes() 返回的是变长列表。 - // 既然我们只支持 2D Tile,且如果是动态 shape 通常两个维度都是动态的 (?x?), - // 我们直接按顺序提取。 - Value vRow, vCol; - auto dynSizes = op.getDynamicSizes(); - - if (dynSizes.size() >= 2) { - vRow = dynSizes[0]; - vCol = dynSizes[1]; - } else if (dynSizes.size() == 1) { - // 极其罕见的混合情况 (例如 32x?),视具体需求处理,这里默认取第一个 - // 或者根据维度索引判断是 row 还是 col,这里暂时从简 - vCol = dynSizes[0]; - } - - // [修改 2] 调用新的 Builder 签名 - // 1. ValueRange(addrs) -> 传递物理地址列表 - // 2. vRow ? vRow : Value() -> 传递 Value 对象(如果为空则传空 Value) - // 3. TileBufConfigAttr() -> 传递空 Attribute 对象 (不能传 nullptr) - + auto [vRow, vCol] = getDynamicValidShapeValues(op); auto ptoPointerCastOp = rewriter.create( - op.getLoc(), - currentMemRefType, - ValueRange(addrs), // addrs - vRow ? vRow : Value(), // valid_row - vCol ? vCol : Value(), // valid_col - configAttr // preserve bind_tile config when available - ); + op.getLoc(), currentMemRefType, ValueRange(addrs), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); rewriter.replaceOp(op, ptoPointerCastOp->getResults()); return success(); diff --git a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp index 217735d8e..8f2001dbc 100644 --- a/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/lib/PTO/Transforms/BufferizableOpInterfaceImpl.cpp @@ -23,6 +23,22 @@ using namespace mlir::bufferization; namespace { +template +struct PTOReadWriteDpsOpInterfaceBase + : public DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto dpsOp = cast(op); + return dpsOp.isDpsInput(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto dpsOp = cast(op); + return dpsOp.isDpsInit(&opOperand); + } +}; + /// Generic conversion for any DestinationStyleOpInterface on tensors. static LogicalResult bufferizeDestinationStyleOpInterface( RewriterBase &rewriter, DestinationStyleOpInterface op, @@ -146,22 +162,7 @@ struct PTOMrgSortDpsOpInterface }; struct PTOAddOpInterface - : public DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // Operand is read if it is used in the computation. - auto dpsOp = cast(op); - return dpsOp.isDpsInput(&opOperand); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // Operand is written to if it is not an input/init. - auto dpsOp = cast(op); - return dpsOp.isDpsInit(&opOperand); - } - + : public PTOReadWriteDpsOpInterfaceBase { bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state, ArrayRef opOperands) const { return true; @@ -175,20 +176,8 @@ struct PTOAddOpInterface }; struct PTOMatmulOpInterface - : public DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - auto dpsOp = cast(op); - return dpsOp.isDpsInput(&opOperand); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - auto dpsOp = cast(op); - return dpsOp.isDpsInit(&opOperand); - } - + : public PTOReadWriteDpsOpInterfaceBase { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { return bufferizeDestinationStyleOpInterface( diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index 9e5bd7b7d..20b7216c1 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - //===- InferPTOLayout.cpp - Infer layout for global tensor views -----------===// // // The pto-isa GlobalTensor ABI expects shape/stride to be represented in a 5D @@ -116,6 +111,43 @@ static std::optional rightAlignTo5D(ArrayRef shape, return out; } +static bool matchesNDMinor2D(int64_t rows, int64_t cols, int64_t rowStride, + int64_t colStride) { + if (cols != 1 && colStride != 1) + return false; + if (rows == 1) + return true; + return cols == 1 ? rowStride == 1 : rowStride == cols; +} + +static bool matchesDNMinor2D(int64_t rows, int64_t cols, int64_t rowStride, + int64_t colStride) { + if (rows != 1 && rowStride != 1) + return false; + if (cols == 1) + return true; + return rows == 1 ? colStride == 1 : colStride == rows; +} + +static std::optional inferMinor2DLayout( + int64_t rows, int64_t cols, int64_t rowStride, int64_t colStride, + std::optional preferredMinor2D, bool *isMinor2DAmbiguous) { + const bool nd = matchesNDMinor2D(rows, cols, rowStride, colStride); + const bool dn = matchesDNMinor2D(rows, cols, rowStride, colStride); + if (!nd && !dn) + return Layout::ND; + if (nd && dn) { + if (isMinor2DAmbiguous) + *isMinor2DAmbiguous = true; + if (preferredMinor2D && + (*preferredMinor2D == Layout::ND || *preferredMinor2D == Layout::DN)) { + return *preferredMinor2D; + } + return (cols == 1 && rows != 1) ? Layout::DN : Layout::ND; + } + return dn ? Layout::DN : Layout::ND; +} + static std::optional inferLayout5D(ArrayRef shape, ArrayRef strides, unsigned elemBytes, @@ -149,45 +181,8 @@ static std::optional inferLayout5D(ArrayRef shape, const int64_t cols = sh[4]; const int64_t rowStride = st[3]; const int64_t colStride = st[4]; - - bool nd = true; - if (cols != 1 && colStride != 1) - nd = false; - if (rows != 1) { - if (cols == 1) { - nd &= (rowStride == 1); - } else { - nd &= (rowStride == cols); - } - } - - bool dn = true; - if (rows != 1 && rowStride != 1) - dn = false; - if (cols != 1) { - if (rows == 1) { - dn &= (colStride == 1); - } else { - dn &= (colStride == rows); - } - } - - if (nd && dn) { - if (isMinor2DAmbiguous) - *isMinor2DAmbiguous = true; - if (preferredMinor2D && - (*preferredMinor2D == Layout::ND || *preferredMinor2D == Layout::DN)) - return *preferredMinor2D; - if (cols == 1 && rows != 1) - return Layout::DN; - return Layout::ND; - } - if (dn) - return Layout::DN; - if (nd) - return Layout::ND; - - return Layout::ND; // fallback + return inferMinor2DLayout(rows, cols, rowStride, colStride, + preferredMinor2D, isMinor2DAmbiguous); } return std::nullopt; } @@ -220,6 +215,115 @@ static bool isMinorColsOne(ArrayRef shape) { return !shape.empty() && shape.back() == 1; } +struct ResolvedLayoutInfo { + Operation *owner = nullptr; + std::optional layout; + bool inferred = false; +}; + +static bool getStaticShapeAndStride(MakeTensorViewOp op, + SmallVectorImpl &shape, + SmallVectorImpl &strides); +static ResolvedLayoutInfo resolveLayoutFromViewValue(Value v); + +static void setLayoutAttr(Operation *op, Layout layout, bool inferred) { + op->setAttr(kLayoutAttrName, LayoutAttr::get(op->getContext(), layout)); + if (inferred) + op->setAttr(kInferredLayoutAttrName, BoolAttr::get(op->getContext(), true)); + else + op->removeAttr(kInferredLayoutAttrName); +} + +template +static void verifyOrSetLayoutAttr(Operation *op, + std::optional inferred, + SignalFailureFn signalFailure, + bool isMinor2DAmbiguous = false) { + auto existing = op->getAttrOfType(kLayoutAttrName); + if (existing) { + if (inferred && existing.getLayout() != *inferred) { + if (isMinor2DAmbiguous && isMinor2DLayout(existing.getLayout()) && + isMinor2DLayout(*inferred)) { + return; + } + op->emitError() << "layout mismatch: user-specified layout=" + << stringifyLayout(existing.getLayout()) + << " but inferred=" << stringifyLayout(*inferred); + signalFailure(); + } + return; + } + setLayoutAttr(op, inferred.value_or(Layout::ND), /*inferred=*/true); +} + +static std::optional inferFromStaticMemRefTy(MemRefType mrTy) { + if (!mrTy.hasStaticShape() || mrTy.getRank() == 0 || mrTy.getRank() > 5) + return std::nullopt; + SmallVector strideInts; + int64_t offset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(mrTy, strideInts, offset))) + return std::nullopt; + if (offset == ShapedType::kDynamic || + llvm::any_of(strideInts, + [](int64_t s) { return s == ShapedType::kDynamic; })) { + return std::nullopt; + } + return inferLayout5D(mrTy.getShape(), strideInts, + elemByteSize(mrTy.getElementType())); +} + +template +static void maybeRepairMinor2DLoadStoreLayout(LoadStoreOp op, ViewGetter getView, + TileGetter getTile) { + auto tilePref = isVectorTileType(getTile(op).getType()) + ? tileBLayoutToGlobalLayout(getTile(op).getType()) + : std::nullopt; + if (!tilePref || (*tilePref != Layout::ND && *tilePref != Layout::DN)) + return; + + auto viewInfo = resolveLayoutFromViewValue(getView(op)); + if (!viewInfo.owner || !viewInfo.layout || !viewInfo.inferred || + *viewInfo.layout == *tilePref) { + return; + } + auto tv = dyn_cast(viewInfo.owner); + if (!tv) + return; + + SmallVector shape, strides; + bool ambiguous = false; + if (!getStaticShapeAndStride(tv, shape, strides)) + return; + (void)inferLayout5D( + shape, strides, + elemByteSize(cast(tv.getResult().getType()).getElementType()), + std::nullopt, &ambiguous); + if (ambiguous && isMinorColsOne(shape)) { + setLayoutAttr(viewInfo.owner, *tilePref, /*inferred=*/true); + setLayoutAttr(op.getOperation(), *tilePref, /*inferred=*/true); + } +} + +template +static void attachLoadStoreLayout(LoadStoreOp op, ViewGetter getView, + TileGetter getTile) { + if (op->template getAttrOfType(kLayoutAttrName)) { + maybeRepairMinor2DLoadStoreLayout(op, getView, getTile); + return; + } + + auto viewInfo = resolveLayoutFromViewValue(getView(op)); + if (viewInfo.layout) { + setLayoutAttr(op.getOperation(), *viewInfo.layout, viewInfo.inferred); + } else if (auto memTy = dyn_cast(getView(op).getType()); + memTy && isGlobalMemRef(memTy)) { + setLayoutAttr(op.getOperation(), inferFromStaticMemRefTy(memTy).value_or(Layout::ND), + /*inferred=*/true); + } + + maybeRepairMinor2DLoadStoreLayout(op, getView, getTile); +} + struct LayoutPreference { std::optional preferred; bool conflict = false; @@ -304,12 +408,6 @@ static bool getStaticShapeAndStride(MakeTensorViewOp op, return true; } -struct ResolvedLayoutInfo { - Operation *owner = nullptr; - std::optional layout; - bool inferred = false; -}; - static ResolvedLayoutInfo resolveLayoutFromViewValue(Value v) { ResolvedLayoutInfo info; Operation *def = v.getDefiningOp(); @@ -338,43 +436,14 @@ struct InferPTOLayoutPass void runOnOperation() override { func::FuncOp func = getOperation(); - auto setLayout = [&](Operation *op, Layout layout, bool inferred) { - op->setAttr(kLayoutAttrName, LayoutAttr::get(op->getContext(), layout)); - if (inferred) { - op->setAttr(kInferredLayoutAttrName, - BoolAttr::get(op->getContext(), true)); - } else { - op->removeAttr(kInferredLayoutAttrName); - } - }; - - auto verifyOrSetLayout = [&](Operation *op, std::optional inferred, - bool isMinor2DAmbiguous = false) -> void { - auto existing = op->getAttrOfType(kLayoutAttrName); - if (existing) { - if (inferred && existing.getLayout() != *inferred) { - // For minor-2D ambiguous cases, ND/DN are both legal and should be - // treated as equivalent ABI hints. Keep user-specified layout. - if (isMinor2DAmbiguous && isMinor2DLayout(existing.getLayout()) && - isMinor2DLayout(*inferred)) - return; - op->emitError() << "layout mismatch: user-specified layout=" - << stringifyLayout(existing.getLayout()) - << " but inferred=" << stringifyLayout(*inferred); - signalPassFailure(); - } - return; - } - setLayout(op, inferred.value_or(Layout::ND), /*inferred=*/true); - }; - // ------------------------------------------------------------------ // 1) pto.make_tensor_view (only if it still exists in the pipeline) // ------------------------------------------------------------------ func.walk([&](MakeTensorViewOp op) { SmallVector shape, strides; if (!getStaticShapeAndStride(op, shape, strides)) { - verifyOrSetLayout(op.getOperation(), std::nullopt); + verifyOrSetLayoutAttr(op.getOperation(), std::nullopt, + [this] { signalPassFailure(); }); return; } @@ -392,7 +461,8 @@ struct InferPTOLayoutPass elemByteSize(cast(op.getResult().getType()) .getElementType()), preferredForAmbiguous, &isAmbiguous); - verifyOrSetLayout(op.getOperation(), inferred, isAmbiguous); + verifyOrSetLayoutAttr(op.getOperation(), inferred, + [this] { signalPassFailure(); }, isAmbiguous); // If this make_tensor_view layout was inferred in an ambiguous ND/DN // shape and a downstream tile has a clear BLayout preference, force-align @@ -401,7 +471,7 @@ struct InferPTOLayoutPass op->getAttrOfType(kInferredLayoutAttrName)) { auto cur = op->getAttrOfType(kLayoutAttrName); if (cur && pref.preferred && *pref.preferred != cur.getLayout()) - setLayout(op.getOperation(), *pref.preferred, /*inferred=*/true); + setLayoutAttr(op.getOperation(), *pref.preferred, /*inferred=*/true); } }); @@ -415,7 +485,8 @@ struct InferPTOLayoutPass const size_t rank = op.getMixedSizes().size(); if (rank == 0 || rank > 5) { - verifyOrSetLayout(op.getOperation(), std::nullopt); + verifyOrSetLayoutAttr(op.getOperation(), std::nullopt, + [this] { signalPassFailure(); }); return; } @@ -424,7 +495,8 @@ struct InferPTOLayoutPass for (OpFoldResult s : op.getMixedSizes()) { auto v = getConstInt(s); if (!v) { - verifyOrSetLayout(op.getOperation(), std::nullopt); + verifyOrSetLayoutAttr(op.getOperation(), std::nullopt, + [this] { signalPassFailure(); }); return; } shape.push_back(*v); @@ -435,7 +507,8 @@ struct InferPTOLayoutPass for (OpFoldResult s : op.getMixedStrides()) { auto v = getConstInt(s); if (!v) { - verifyOrSetLayout(op.getOperation(), std::nullopt); + verifyOrSetLayoutAttr(op.getOperation(), std::nullopt, + [this] { signalPassFailure(); }); return; } strides.push_back(*v); @@ -445,7 +518,9 @@ struct InferPTOLayoutPass auto inferred = inferLayout5D(shape, strides, elemByteSize(mrTy.getElementType()), std::nullopt, &isMinor2DAmbiguous); - verifyOrSetLayout(op.getOperation(), inferred, isMinor2DAmbiguous); + verifyOrSetLayoutAttr(op.getOperation(), inferred, + [this] { signalPassFailure(); }, + isMinor2DAmbiguous); }); // ------------------------------------------------------------------ @@ -473,7 +548,7 @@ struct InferPTOLayoutPass // Fallback: if source memref type is fully static, infer from it. auto srcTy = dyn_cast(op.getSource().getType()); if (!srcTy || !srcTy.hasStaticShape()) { - setLayout(op.getOperation(), Layout::ND, /*inferred=*/true); + setLayoutAttr(op.getOperation(), Layout::ND, /*inferred=*/true); return; } @@ -483,126 +558,28 @@ struct InferPTOLayoutPass offset == ShapedType::kDynamic || llvm::any_of(strideInts, [](int64_t s) { return s == ShapedType::kDynamic; })) { - setLayout(op.getOperation(), Layout::ND, /*inferred=*/true); + setLayoutAttr(op.getOperation(), Layout::ND, /*inferred=*/true); return; } auto inferred = inferLayout5D(srcTy.getShape(), strideInts, elemByteSize(srcTy.getElementType())); - setLayout(op.getOperation(), inferred.value_or(Layout::ND), - /*inferred=*/true); + setLayoutAttr(op.getOperation(), inferred.value_or(Layout::ND), + /*inferred=*/true); }); // ------------------------------------------------------------------ // 4) pto.tload / pto.tstore: attach layout for static GM memrefs so EmitC // doesn't need to infer again in buildGlobalTensorFromMemref(). // ------------------------------------------------------------------ - auto inferFromStaticMemRefTy = [&](MemRefType mrTy) -> std::optional { - if (!mrTy.hasStaticShape() || mrTy.getRank() == 0 || mrTy.getRank() > 5) - return std::nullopt; - SmallVector strideInts; - int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(mrTy, strideInts, offset))) - return std::nullopt; - if (offset == ShapedType::kDynamic || - llvm::any_of(strideInts, - [](int64_t s) { return s == ShapedType::kDynamic; })) - return std::nullopt; - return inferLayout5D(mrTy.getShape(), strideInts, - elemByteSize(mrTy.getElementType())); - }; - func.walk([&](pto::TLoadOp op) { - bool hasLayout = - static_cast(op->getAttrOfType(kLayoutAttrName)); - if (!hasLayout) { - auto viewInfo = resolveLayoutFromViewValue(op.getSrc()); - if (viewInfo.layout) { - setLayout(op.getOperation(), *viewInfo.layout, viewInfo.inferred); - hasLayout = true; - } - } - if (!hasLayout) { - auto srcTy = dyn_cast(op.getSrc().getType()); - if (srcTy && isGlobalMemRef(srcTy)) { - setLayout(op.getOperation(), - inferFromStaticMemRefTy(srcTy).value_or(Layout::ND), - /*inferred=*/true); - } - } - - // Consistency check and repair (inferred + ambiguous only): if source view - // layout conflicts with the consumer tile BLayout, retarget to tile - // preference to keep emitted GlobalTensor/Tile compatible. - auto tilePref = isVectorTileType(op.getDst().getType()) - ? tileBLayoutToGlobalLayout(op.getDst().getType()) - : std::nullopt; - if (tilePref && (*tilePref == Layout::ND || *tilePref == Layout::DN)) { - auto viewInfo = resolveLayoutFromViewValue(op.getSrc()); - if (viewInfo.owner && viewInfo.layout && - *viewInfo.layout != *tilePref && viewInfo.inferred) { - if (auto tv = dyn_cast(viewInfo.owner)) { - SmallVector shape, strides; - bool ambiguous = false; - if (getStaticShapeAndStride(tv, shape, strides)) { - (void)inferLayout5D( - shape, strides, - elemByteSize(cast(tv.getResult().getType()) - .getElementType()), - std::nullopt, &ambiguous); - if (ambiguous && isMinorColsOne(shape)) { - setLayout(viewInfo.owner, *tilePref, /*inferred=*/true); - setLayout(op.getOperation(), *tilePref, /*inferred=*/true); - } - } - } - } - } + attachLoadStoreLayout(op, [](auto load) { return load.getSrc(); }, + [](auto load) { return load.getDst(); }); }); func.walk([&](pto::TStoreOp op) { - bool hasLayout = - static_cast(op->getAttrOfType(kLayoutAttrName)); - if (!hasLayout) { - auto viewInfo = resolveLayoutFromViewValue(op.getDst()); - if (viewInfo.layout) { - setLayout(op.getOperation(), *viewInfo.layout, viewInfo.inferred); - hasLayout = true; - } - } - if (!hasLayout) { - auto dstTy = dyn_cast(op.getDst().getType()); - if (dstTy && isGlobalMemRef(dstTy)) { - setLayout(op.getOperation(), - inferFromStaticMemRefTy(dstTy).value_or(Layout::ND), - /*inferred=*/true); - } - } - - auto tilePref = isVectorTileType(op.getSrc().getType()) - ? tileBLayoutToGlobalLayout(op.getSrc().getType()) - : std::nullopt; - if (tilePref && (*tilePref == Layout::ND || *tilePref == Layout::DN)) { - auto viewInfo = resolveLayoutFromViewValue(op.getDst()); - if (viewInfo.owner && viewInfo.layout && - *viewInfo.layout != *tilePref && viewInfo.inferred) { - if (auto tv = dyn_cast(viewInfo.owner)) { - SmallVector shape, strides; - bool ambiguous = false; - if (getStaticShapeAndStride(tv, shape, strides)) { - (void)inferLayout5D( - shape, strides, - elemByteSize(cast(tv.getResult().getType()) - .getElementType()), - std::nullopt, &ambiguous); - if (ambiguous && isMinorColsOne(shape)) { - setLayout(viewInfo.owner, *tilePref, /*inferred=*/true); - setLayout(op.getOperation(), *tilePref, /*inferred=*/true); - } - } - } - } - } + attachLoadStoreLayout(op, [](auto store) { return store.getDst(); }, + [](auto store) { return store.getSrc(); }); }); } }; diff --git a/lib/PTO/Transforms/InferPTOMemScope.cpp b/lib/PTO/Transforms/InferPTOMemScope.cpp index e707e586b..7e6f65b7a 100644 --- a/lib/PTO/Transforms/InferPTOMemScope.cpp +++ b/lib/PTO/Transforms/InferPTOMemScope.cpp @@ -33,13 +33,33 @@ namespace mlir { using namespace mlir; using namespace pto; +namespace { +static std::optional requireRootAlloc(Operation *op, Value value, + StringRef valueName) { + auto alloc = tracebackMemRefToAlloc(value); + if (!alloc.has_value()) + emitError(op->getLoc()) << "Cannot find root memref.alloc for " << valueName + << " of this op."; + return alloc; +} + +static LogicalResult propagateAllocScope(Operation *op, Value value, + StringRef valueName, + const AddressSpaceAttr &targetScope, + MemScopeInferAndPropagateHelper &helper) { + auto alloc = requireRootAlloc(op, value, valueName); + if (!alloc.has_value()) + return failure(); + if (failed(helper.Run(*alloc, targetScope))) + return op->emitOpError() + << "Failed to infer/propagate memory scope for " << valueName; + return success(); +} +} // namespace + LogicalResult MemScopeInferAndPropagateHelper::propagateMemScopeToUsers(Value val) { - // Get new memory scope from result. auto memrefScope = getPTOAddressSpaceAttr(val.getType()); - // This function propagates the type change of an SSA result to the operation - // that uses it. The result type of the updated operation might be affected, - // so we need to cascade the change. auto propagateFn = [&](OpOperand &user) -> LogicalResult { Operation *userDefiningOp = user.getOwner(); return TypeSwitch(userDefiningOp) @@ -47,22 +67,13 @@ MemScopeInferAndPropagateHelper::propagateMemScopeToUsers(Value val) { Operation *parentOp = op->getParentOp(); auto yieldResult = op.getOperand(user.getOperandNumber()); auto parentResult = parentOp->getResult(user.getOperandNumber()); - - Type yieldType = yieldResult.getType(); - Type valType = val.getType(); - if (!isa(yieldType)) - return success(); - if (!isa(valType)) - return success(); - auto mtype = dyn_cast(yieldType); - auto vtype = dyn_cast(valType); - if (mtype.getElementType() != vtype.getElementType()) + auto yieldType = dyn_cast(yieldResult.getType()); + auto valType = dyn_cast(val.getType()); + if (!yieldType || !valType || + yieldType.getElementType() != valType.getElementType()) return success(); setBaseMemRefTypeScope(parentResult, memrefScope); - if (failed(propagateMemScopeToUsers(parentResult))) { - return failure(); - } - return success(); + return propagateMemScopeToUsers(parentResult); }) .Case([&](scf::ForOp op) { auto result = op.getTiedLoopResult(&user); @@ -151,279 +162,114 @@ struct InferPTOMemScopePass } // namespace LogicalResult pto::inferAndPropagateMemScopeForMovDps(pto::TMovOp op) { - // 替换 hasPureBufferSemantics() - // 在 PTO 的语义中,如果 Op 没有返回值 (Result),就意味着它是 Buffer 语义(操作的是 TileBuf 或 MemRef) - if (op.getNumResults() != 0) { - return op->emitOpError("Run infer memory scope after bufferization (Op must have 0 results)."); - } - - Value mA = op.getSrc(); - Value mB = op.getDst(); - - // 直接使用 Value,不需要再调 ->get() - // mA, mB, mC 现在已经是 Value 类型了 - auto allocA = tracebackMemRefToAlloc(mA); - auto allocB = tracebackMemRefToAlloc(mB); + if (op.getNumResults() != 0) + return op->emitOpError( + "Run infer memory scope after bufferization (Op must have 0 results)."); - if (!allocA.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mA of this op."; + auto dstAlloc = requireRootAlloc(op, op.getDst(), "mB"); + if (!dstAlloc.has_value()) return failure(); - } - if (!allocB.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mB of this op."; - return failure(); - } - auto memRefType = dyn_cast(allocB.value().getType()); - if (!memRefType) { + + auto memRefType = dyn_cast(dstAlloc->getType()); + if (!memRefType) return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } auto memSpace = memRefType.getMemorySpace(); - if (!memSpace) { + if (!memSpace) return success(); - } - auto l0aSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); - auto l0bSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); - auto l0cSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); - auto l1SpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::MAT); - auto ubSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::VEC); - auto biasSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::BIAS); + auto l0aSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); + auto l0bSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); + auto l0cSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); + auto l1SpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::MAT); + auto ubSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::VEC); + auto biasSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::BIAS); MemScopeInferAndPropagateHelper helper; - - if (memSpace == ubSpaceAttr) { - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, ubSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - return success(); - } - - if (memSpace == l1SpaceAttr) { - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, l0cSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - return success(); - } - - if (memSpace == l0aSpaceAttr || - memSpace == l0bSpaceAttr || + if (memSpace == ubSpaceAttr) + return propagateAllocScope(op, op.getSrc(), "mA", ubSpaceAttr, helper); + if (memSpace == l1SpaceAttr) + return propagateAllocScope(op, op.getSrc(), "mA", l0cSpaceAttr, helper); + if (memSpace == l0aSpaceAttr || memSpace == l0bSpaceAttr || memSpace == biasSpaceAttr) { - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, l1SpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - return success(); + return propagateAllocScope(op, op.getSrc(), "mA", l1SpaceAttr, helper); } - return success(); } LogicalResult pto::inferAndPropagateMemScopeForMatmulAccDps(pto::TMatmulAccOp op) { - // 替换 hasPureBufferSemantics() - // 在 PTO 的语义中,如果 Op 没有返回值 (Result),就意味着它是 Buffer 语义(操作的是 TileBuf 或 MemRef) - if (op.getNumResults() != 0) { - return op->emitOpError("Run infer memory scope after bufferization (Op must have 0 results)."); - } + if (op.getNumResults() != 0) + return op->emitOpError( + "Run infer memory scope after bufferization (Op must have 0 results)."); - // 替换 getDpsInputOperand/getDpsInitOperand - // 直接使用 ODS 生成的命名函数,更直观且安全 - // 原逻辑: Input(0)->LHS, Input(1)->RHS, Init(0)->DST - Value mAcc = op.getAccIn(); - Value mA = op.getLhs(); - Value mB = op.getRhs(); - Value mC = op.getDst(); - - // 直接使用 Value,不需要再调 ->get() - // mA, mB, mC 现在已经是 Value 类型了 - auto allocAcc = tracebackMemRefToAlloc(mAcc); - auto allocA = tracebackMemRefToAlloc(mA); - auto allocB = tracebackMemRefToAlloc(mB); - auto allocC = tracebackMemRefToAlloc(mC); - - - if (!allocAcc.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mAcc of this op."; - return failure(); - } - if (!allocA.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mA of this op."; - return failure(); - } - if (!allocB.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mB of this op."; - return failure(); - } - if (!allocC.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mC of this op."; - return failure(); - } - - auto l0aSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); - auto l0bSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); - auto l0cSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); + auto l0aSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); + auto l0bSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); + auto l0cSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); MemScopeInferAndPropagateHelper helper; - - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocAcc, l0cSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mAcc"); - } - - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, l0aSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - LDBG("IR after setting mem scope for mA:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mB should be in L1. - if (failed(helper.Run(*allocB, l0bSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mB"); - } - LDBG("IR after setting mem scope for mB:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mC should be in L0C. - if (failed(helper.Run(*allocC, l0cSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mC"); + if (failed(propagateAllocScope(op, op.getAccIn(), "mAcc", l0cSpaceAttr, + helper)) || + failed(propagateAllocScope(op, op.getLhs(), "mA", l0aSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getRhs(), "mB", l0bSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getDst(), "mC", l0cSpaceAttr, helper))) { + return failure(); } - LDBG("IR after setting mem scope for mC:\n" << *(op->getParentOfType())); - return success(); } LogicalResult pto::inferAndPropagateMemScopeForMatmulBiasDps(pto::TMatmulBiasOp op) { - // 替换 hasPureBufferSemantics() - // 在 PTO 的语义中,如果 Op 没有返回值 (Result),就意味着它是 Buffer 语义(操作的是 TileBuf 或 MemRef) - if (op.getNumResults() != 0) { - return op->emitOpError("Run infer memory scope after bufferization (Op must have 0 results)."); - } - - // 替换 getDpsInputOperand/getDpsInitOperand - // 直接使用 ODS 生成的命名函数,更直观且安全 - // 原逻辑: Input(0)->LHS, Input(1)->RHS, Init(0)->DST - Value mA = op.getA(); - Value mB = op.getB(); - Value mC = op.getDst(); - Value mD = op.getBias(); - - // 直接使用 Value,不需要再调 ->get() - // mA, mB, mC 现在已经是 Value 类型了 - auto allocA = tracebackMemRefToAlloc(mA); - auto allocB = tracebackMemRefToAlloc(mB); - auto allocC = tracebackMemRefToAlloc(mC); - auto allocD = tracebackMemRefToAlloc(mD); - - if (!allocA.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mA of this op."; - return failure(); - } - if (!allocB.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mB of this op."; - return failure(); - } - if (!allocC.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mC of this op."; - return failure(); - } - if (!allocD.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mD of this op."; - return failure(); - } - - auto l0aSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); - auto l0bSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); - auto l0cSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); - auto l0dSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::BIAS); + if (op.getNumResults() != 0) + return op->emitOpError( + "Run infer memory scope after bufferization (Op must have 0 results)."); + + auto l0aSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); + auto l0bSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); + auto l0cSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); + auto biasSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::BIAS); MemScopeInferAndPropagateHelper helper; - - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, l0aSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - LDBG("IR after setting mem scope for mA:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mB should be in L1. - if (failed(helper.Run(*allocB, l0bSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mB"); - } - LDBG("IR after setting mem scope for mB:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mC should be in L0C. - if (failed(helper.Run(*allocC, l0cSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mC"); - } - LDBG("IR after setting mem scope for mC:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mD should be in BIAS. - if (failed(helper.Run(*allocD, l0dSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mC"); + if (failed(propagateAllocScope(op, op.getA(), "mA", l0aSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getB(), "mB", l0bSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getDst(), "mC", l0cSpaceAttr, helper)) || + failed( + propagateAllocScope(op, op.getBias(), "mD", biasSpaceAttr, helper))) { + return failure(); } - LDBG("IR after setting mem scope for mC:\n" << *(op->getParentOfType())); - return success(); } LogicalResult pto::inferAndPropagateMemScopeForMatmulDps(pto::TMatmulOp op) { - // 替换 hasPureBufferSemantics() - // 在 PTO 的语义中,如果 Op 没有返回值 (Result),就意味着它是 Buffer 语义(操作的是 TileBuf 或 MemRef) - if (op.getNumResults() != 0) { - return op->emitOpError("Run infer memory scope after bufferization (Op must have 0 results)."); - } - - // 替换 getDpsInputOperand/getDpsInitOperand - // 直接使用 ODS 生成的命名函数,更直观且安全 - // 原逻辑: Input(0)->LHS, Input(1)->RHS, Init(0)->DST - Value mA = op.getLhs(); - Value mB = op.getRhs(); - Value mC = op.getDst(); - - // 直接使用 Value,不需要再调 ->get() - // mA, mB, mC 现在已经是 Value 类型了 - auto allocA = tracebackMemRefToAlloc(mA); - auto allocB = tracebackMemRefToAlloc(mB); - auto allocC = tracebackMemRefToAlloc(mC); - - if (!allocA.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mA of this op."; - return failure(); - } - if (!allocB.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mB of this op."; - return failure(); - } - if (!allocC.has_value()) { - emitError(op.getLoc()) << "Cannot find root memref.alloc for mC of this op."; - return failure(); - } + if (op.getNumResults() != 0) + return op->emitOpError( + "Run infer memory scope after bufferization (Op must have 0 results)."); - auto l0aSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); - auto l0bSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); - auto l0cSpaceAttr = AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); + auto l0aSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::LEFT); + auto l0bSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::RIGHT); + auto l0cSpaceAttr = + AddressSpaceAttr::get(op->getContext(), pto::AddressSpace::ACC); MemScopeInferAndPropagateHelper helper; - - // For MmadL1Op, operand mA should be in L1. - if (failed(helper.Run(*allocA, l0aSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mA"); - } - LDBG("IR after setting mem scope for mA:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mB should be in L1. - if (failed(helper.Run(*allocB, l0bSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mB"); - } - LDBG("IR after setting mem scope for mB:\n" << *(op->getParentOfType())); - - // For MmadL1Op, operand mC should be in L0C. - if (failed(helper.Run(*allocC, l0cSpaceAttr))) { - return op->emitOpError("Failed to infer/propagate memory scope for mC"); + if (failed(propagateAllocScope(op, op.getLhs(), "mA", l0aSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getRhs(), "mB", l0bSpaceAttr, helper)) || + failed(propagateAllocScope(op, op.getDst(), "mC", l0cSpaceAttr, helper))) { + return failure(); } - LDBG("IR after setting mem scope for mC:\n" << *(op->getParentOfType())); - return success(); } @@ -610,16 +456,9 @@ LogicalResult pto::inferAndPropagateUbufMemScope(memref::AllocOp op) { } void InferPTOMemScopePass::runOnOperation() { - llvm::errs() << "Hello PTO Infer Mem Scope!\n"; - auto op = getOperation(); - op->dump(); - SmallVector deviceFuncList; - SetVector deviceFuncNames; - SmallVector hostFuncList; getOperation()->walk([&](func::FuncOp func) { deviceFuncList.push_back(func); - deviceFuncNames.insert(func.getSymName()); return; }); @@ -673,15 +512,6 @@ void InferPTOMemScopePass::runOnOperation() { if (failed(fixDeviceCallSite(func))) signalPassFailure(); } - - for (auto func : hostFuncList) { - if (failed(fixHostFuncSignature(func))) - signalPassFailure(); - } - - llvm::errs() << "end PTO Infer Mem Scope!\n"; - op = getOperation(); - op->dump(); } std::unique_ptr mlir::pto::createInferPTOMemScopePass() { diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index a66e40878..e9164320c 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/InsertSyncAnalysis.h" #include "PTO/Transforms/InsertSync/SyncCommon.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -81,7 +76,8 @@ void InsertSyncAnalysis::DealWithLoopSync(LoopInstanceElement *nowElement) { } SyncIRs backSyncIr; - assert(syncIR_.size() >= nowElement->endId); + if (syncIR_.size() < nowElement->endId) + return; for (unsigned i = nowElement->beginId; i < nowElement->endId; i++) { if (auto *compound = dyn_cast(syncIR_[i].get())) { InsertBackForSync(compound, backSyncIr, nowElement); @@ -165,8 +161,8 @@ void InsertSyncAnalysis::InsertSeqSync( for (int i = end - 1; i >= begin; i--) { auto &frontPtr = syncElement[i]; unsigned frontIndex = frontPtr->GetIndex(); - assert(frontIndex < syncIR_.size()); - assert(syncIR_[frontIndex] != nullptr); + if (frontIndex >= syncIR_.size() || syncIR_[frontIndex] == nullptr) + continue; if (auto *frontCompound = dyn_cast(frontPtr.get())) { @@ -243,7 +239,8 @@ unsigned InsertSyncAnalysis::InsertBranchSync( return (branchElement->endId - branchElement->beginId); } else if (branchElement->getBranchKind() == KindOfBranch::ELSE_BEGIN && index != begin) { - assert(nowCompound->GetIndex() > branchElement->branchId); + if (nowCompound->GetIndex() <= branchElement->branchId) + return 0; return (branchElement->branchId - branchElement->beginId); } return 0; @@ -384,7 +381,8 @@ void InsertSyncAnalysis::InsertSyncOperation( } syncIndex_++; - assert(syncOperations_.size() == syncIndex_); + if (syncOperations_.size() != syncIndex_) + syncIndex_ = syncOperations_.size(); } // ============================================================================== @@ -464,9 +462,11 @@ void InsertSyncAnalysis::UpdateSyncRecord(const SyncOperation *sync, void InsertSyncAnalysis::UpdateSyncRecordInfo( CompoundInstanceElement *frontCompound, SyncRecordList &syncRecordList) { (void)frontCompound; - assert(!syncOperations_.empty()); + if (syncOperations_.empty()) + return; auto &syncPair = syncOperations_.back(); - assert(!syncPair.empty()); + if (syncPair.empty()) + return; auto *newSync = syncPair[0].get(); for (size_t bufferIdx = 0; bufferIdx < syncRecordList.size(); bufferIdx++) { diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncDebug.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncDebug.cpp index 63bf84491..24bb246d2 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncDebug.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncDebug.cpp @@ -12,6 +12,7 @@ #include "PTO/Transforms/InsertSync/InsertSyncDebug.h" #include "mlir/IR/AsmState.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" @@ -180,6 +181,92 @@ static void dumpMemInfoList(llvm::raw_ostream &os, llvm::StringRef tag, os << "]"; } +static void adjustIndentBeforeElement(const std::unique_ptr &element, + int &indent) { + if (auto *loop = dyn_cast(element.get())) { + if (loop->getLoopKind() == KindOfLoop::LOOP_END) + indent = std::max(0, indent - 1); + } + if (auto *branch = dyn_cast(element.get())) { + if (branch->getBranchKind() == KindOfBranch::IF_END || + branch->getBranchKind() == KindOfBranch::ELSE_BEGIN) { + indent = std::max(0, indent - 1); + } + } +} + +static void adjustIndentAfterElement(const std::unique_ptr &element, + int &indent) { + if (auto *loop = dyn_cast(element.get())) { + if (loop->getLoopKind() == KindOfLoop::LOOP_BEGIN) + ++indent; + } + if (auto *branch = dyn_cast(element.get())) { + if (branch->getBranchKind() == KindOfBranch::IF_BEGIN || + branch->getBranchKind() == KindOfBranch::ELSE_BEGIN) { + ++indent; + } + } +} + +static void dumpInstanceElementSummary(llvm::raw_ostream &os, + const InstanceElement *element, + llvm::function_ref indentBy, + bool showMemInfo, + mlir::AsmState *state) { + switch (element->GetKind()) { + case InstanceElement::KindTy::COMPOUND: { + auto *comp = cast(element); + os << "COMPOUND " << comp->opName.getStringRef() << " [" + << getPipelineName(comp->kPipeValue) << "]\n"; + if (!showMemInfo) + return; + os.indent(indentBy(2)); + dumpMemInfoList(os, "def", comp->defVec, state); + os << "\n"; + os.indent(indentBy(2)); + dumpMemInfoList(os, "use", comp->useVec, state); + os << "\n"; + return; + } + case InstanceElement::KindTy::LOOP: { + auto *loop = cast(element); + os << "LOOP " << getLoopKindName(loop->getLoopKind()) << " (begin=" + << loop->beginId << ", end=" << loop->endId << ")\n"; + return; + } + case InstanceElement::KindTy::BRANCH: { + auto *branch = cast(element); + os << "BRANCH " << getBranchKindName(branch->getBranchKind()) << " (begin=" + << branch->beginId << ", branch=" << branch->branchId + << ", end=" << branch->endId << ")\n"; + return; + } + case InstanceElement::KindTy::PLACE_HOLDER: { + auto *placeHolder = cast(element); + os << "PLACE_HOLDER (parentScopeId=" << placeHolder->parentScopeId; + if (placeHolder->isVirtualElse) + os << ", virtualElse"; + os << ")\n"; + return; + } + } +} + +static void dumpSyncOpsWithPrefix(llvm::raw_ostream &os, llvm::StringRef prefix, + const SyncOps &ops, + llvm::function_ref indentBy, + InsertSyncDumpOptions options) { + for (const auto *syncOp : ops) { + if (!syncOp || (syncOp->uselessSync && !options.showUselessSync)) + continue; + os.indent(indentBy(2)); + os << prefix << ": "; + dumpSyncOp(os, syncOp, options.showUselessSync); + os << "\n"; + } +} + static void dumpSyncIR(llvm::raw_ostream &os, const SyncIRs &syncIR, Operation *opForPrinting, InsertSyncDumpOptions options, bool showMemInfo) { @@ -196,136 +283,81 @@ static void dumpSyncIR(llvm::raw_ostream &os, const SyncIRs &syncIR, if (!e) continue; - if (auto *loop = dyn_cast(e.get())) { - if (loop->getLoopKind() == KindOfLoop::LOOP_END) - indent = std::max(0, indent - 1); - } - if (auto *branch = dyn_cast(e.get())) { - if (branch->getBranchKind() == KindOfBranch::IF_END || - branch->getBranchKind() == KindOfBranch::ELSE_BEGIN) - indent = std::max(0, indent - 1); - } + adjustIndentBeforeElement(e, indent); os.indent(indentBy()); os << llvm::formatv("[{0,4}] ", e->GetIndex()); - - switch (e->GetKind()) { - case InstanceElement::KindTy::COMPOUND: { - auto *comp = cast(e.get()); - os << "COMPOUND " << comp->opName.getStringRef() << " [" - << getPipelineName(comp->kPipeValue) << "]"; - os << "\n"; - if (showMemInfo) { - os.indent(indentBy(2)); - dumpMemInfoList(os, "def", comp->defVec, state ? &*state : nullptr); - os << "\n"; - os.indent(indentBy(2)); - dumpMemInfoList(os, "use", comp->useVec, state ? &*state : nullptr); - os << "\n"; - } - break; - } - case InstanceElement::KindTy::LOOP: { - auto *loop = cast(e.get()); - os << "LOOP " << getLoopKindName(loop->getLoopKind()) - << " (begin=" << loop->beginId << ", end=" << loop->endId << ")\n"; - break; - } - case InstanceElement::KindTy::BRANCH: { - auto *branch = cast(e.get()); - os << "BRANCH " << getBranchKindName(branch->getBranchKind()) - << " (begin=" << branch->beginId << ", branch=" << branch->branchId - << ", end=" << branch->endId << ")\n"; - break; - } - case InstanceElement::KindTy::PLACE_HOLDER: { - auto *ph = cast(e.get()); - os << "PLACE_HOLDER (parentScopeId=" << ph->parentScopeId; - if (ph->isVirtualElse) - os << ", virtualElse"; - os << ")\n"; - break; - } - } - - auto dumpOps = [&](llvm::StringRef prefix, const SyncOps &ops) { - for (const auto *op : ops) { - if (!op) - continue; - if (op->uselessSync && !options.showUselessSync) - continue; - os.indent(indentBy(2)); - os << prefix << ": "; - dumpSyncOp(os, op, options.showUselessSync); - os << "\n"; - } - }; - - dumpOps("PRE ", e->pipeBefore); - dumpOps("POST", e->pipeAfter); - - if (auto *loop = dyn_cast(e.get())) { - if (loop->getLoopKind() == KindOfLoop::LOOP_BEGIN) - indent += 1; - } - if (auto *branch = dyn_cast(e.get())) { - if (branch->getBranchKind() == KindOfBranch::IF_BEGIN || - branch->getBranchKind() == KindOfBranch::ELSE_BEGIN) - indent += 1; - } + dumpInstanceElementSummary(os, e.get(), indentBy, showMemInfo, + state ? &*state : nullptr); + dumpSyncOpsWithPrefix(os, "PRE ", e->pipeBefore, indentBy, options); + dumpSyncOpsWithPrefix(os, "POST", e->pipeAfter, indentBy, options); + adjustIndentAfterElement(e, indent); } } -void mlir::pto::dumpInsertSyncPhase(llvm::StringRef phase, const SyncIRs &syncIR, - const SyncOperations &syncOperations, - Operation *opForPrinting, - llvm::raw_ostream &os) { - const unsigned level = getInsertSyncDebugLevel(); - if (level < static_cast(InsertSyncDebugLevel::Phase)) - return; - +struct InsertSyncPhaseStats { unsigned activeOps = 0; - unsigned setCnt = 0, waitCnt = 0, barrierCnt = 0; - unsigned blockSetCnt = 0, blockWaitCnt = 0, blockAllCnt = 0; + unsigned setCnt = 0; + unsigned waitCnt = 0; + unsigned barrierCnt = 0; + unsigned blockSetCnt = 0; + unsigned blockWaitCnt = 0; + unsigned blockAllCnt = 0; +}; + +static InsertSyncPhaseStats collectPhaseStats( + const SyncOperations &syncOperations) { + InsertSyncPhaseStats stats; for (const auto &group : syncOperations) { for (const auto &op : group) { - if (!op) - continue; - if (op->uselessSync) + if (!op || op->uselessSync) continue; - activeOps++; + ++stats.activeOps; switch (op->GetType()) { case SyncOperation::TYPE::SET_EVENT: - setCnt++; + ++stats.setCnt; break; case SyncOperation::TYPE::WAIT_EVENT: - waitCnt++; + ++stats.waitCnt; break; case SyncOperation::TYPE::PIPE_BARRIER: case SyncOperation::TYPE::PIPE_BARRIER_CUBE: case SyncOperation::TYPE::PIPE_BARRIER_VECTOR: - barrierCnt++; + ++stats.barrierCnt; break; case SyncOperation::TYPE::SYNC_BLOCK_SET: - blockSetCnt++; + ++stats.blockSetCnt; break; case SyncOperation::TYPE::SYNC_BLOCK_WAIT: - blockWaitCnt++; + ++stats.blockWaitCnt; break; case SyncOperation::TYPE::SYNC_BLOCK_ALL: - blockAllCnt++; + ++stats.blockAllCnt; break; } } } + return stats; +} + +void mlir::pto::dumpInsertSyncPhase(llvm::StringRef phase, const SyncIRs &syncIR, + const SyncOperations &syncOperations, + Operation *opForPrinting, + llvm::raw_ostream &os) { + const unsigned level = getInsertSyncDebugLevel(); + if (level < static_cast(InsertSyncDebugLevel::Phase)) + return; + + InsertSyncPhaseStats stats = collectPhaseStats(syncOperations); os << "\n// === [PTOInsertSync Debug] " << phase << " === //\n"; os << llvm::formatv("// nodes={0}, syncGroups={1}, activeOps={2} " "(set={3}, wait={4}, barrier={5}, blockSet={6}, " "blockWait={7}, blockAll={8})\n", - syncIR.size(), syncOperations.size(), activeOps, setCnt, - waitCnt, barrierCnt, blockSetCnt, blockWaitCnt, - blockAllCnt); + syncIR.size(), syncOperations.size(), stats.activeOps, + stats.setCnt, stats.waitCnt, stats.barrierCnt, + stats.blockSetCnt, stats.blockWaitCnt, + stats.blockAllCnt); if (level < static_cast(InsertSyncDebugLevel::SyncIR)) { os << "// ========================================= //\n"; diff --git a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp index 08c529246..33d859eb9 100644 --- a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp +++ b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h" #include "PTO/Transforms/InsertSync/InsertSyncDebug.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -22,12 +17,9 @@ using namespace mlir; using namespace mlir::pto; -static bool isTraceEnabled() { - return isInsertSyncDebugEnabled(InsertSyncDebugLevel::Trace); -} - -// [Debug] 打印 Value 详细信息 -static void printValueDebug(const char* tag, Value v) { +static bool isTraceEnabled() { return isInsertSyncDebugEnabled(InsertSyncDebugLevel::Trace); } + +static void printValueDebug(const char *tag, Value v) { if (!isTraceEnabled()) return; llvm::errs() << tag << ": "; @@ -43,64 +35,44 @@ static void printValueDebug(const char* tag, Value v) { } llvm::errs() << " | Type: " << v.getType() << "\n"; } - -// [Fix & Debug] 增强版 GetRealRoot + +static Value peelRootValue(Value value, bool trace) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + if (trace) + llvm::errs() << " -> Reached BlockArgument. Stop.\n"; + return Value(); + } + if (auto op = dyn_cast(defOp)) + return op.getSrc(); + if (auto op = dyn_cast(defOp)) + return op.getSrc(); + if (auto op = dyn_cast(defOp)) + return op.getSource(); + if (auto view = dyn_cast(defOp)) + return view.getViewSource(); + if (auto cast = dyn_cast(defOp)) + return cast.getSource(); + if (auto reCast = dyn_cast(defOp)) + return reCast.getSource(); + if (trace) + llvm::errs() << " -> Hit Alloc/Other [" << defOp->getName() << "]. Stop.\n"; + return Value(); +} + static Value GetRealRoot(Value v) { const bool trace = isTraceEnabled(); if (trace) { llvm::errs() << " [Trace] GetRealRoot Start:\n"; printValueDebug(" Current", v); } - - int depth = 0; - const int maxDepth = 20; - - while (v && depth++ < maxDepth) { - Operation *defOp = v.getDefiningOp(); - if (!defOp) { - if (trace) - llvm::errs() << " -> Reached BlockArgument. Stop.\n"; - break; - } - - if (auto op = dyn_cast(defOp)) { - if (trace) - llvm::errs() << " -> Hit CollapseShapeOp. Peel off.\n"; - v = op.getSrc(); - continue; - } - if (auto op = dyn_cast(defOp)) { - if (trace) - llvm::errs() << " -> Hit ExpandShapeOp. Peel off.\n"; - v = op.getSrc(); - continue; - } - if (auto op = dyn_cast(defOp)) { - if (trace) - llvm::errs() << " -> Hit ViewOp. Peel off.\n"; - v = op.getSource(); - continue; - } - if (auto view = dyn_cast(defOp)) { - if (trace) - llvm::errs() << " -> Hit ViewLikeInterface. Peel off.\n"; - v = view.getViewSource(); - continue; - } - if (auto cast = dyn_cast(defOp)) { - v = cast.getSource(); - continue; - } - if (auto reCast = dyn_cast(defOp)) { - v = reCast.getSource(); - continue; - } - - if (trace) { - llvm::errs() << " -> Hit Alloc/Other [" << defOp->getName() - << "]. Stop.\n"; - } - break; + + constexpr int kMaxDepth = 20; + for (int depth = 0; v && depth < kMaxDepth; ++depth) { + Value peeled = peelRootValue(v, trace); + if (!peeled) + break; + v = peeled; } return v; } diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 8ba4f265b..4eea476cc 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -6,16 +6,12 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/PTOIRTranslator.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/AsmState.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -428,10 +424,12 @@ void PTOIRTranslator::UpdateForOpInfo(scf::ForOp forOp) { index++; auto *forBeginPtr = dyn_cast(forElement.get()); - assert(forBeginPtr != nullptr && "Sync IR Construction failed."); + if (!forBeginPtr) + llvm::report_fatal_error("failed to build loop sync IR node"); if (!forOp.getInitArgs().empty()) { - assert(forOp.getInitArgs().size() == forOp.getRegionIterArgs().size()); + if (forOp.getInitArgs().size() != forOp.getRegionIterArgs().size()) + return; for (auto [i, arg] : llvm::enumerate(forOp.getInitArgs())) { UpdateAliasBufferInfo(forOp.getRegionIterArgs()[i], arg); } @@ -539,7 +537,8 @@ void PTOIRTranslator::UpdateYieldOpInfo(scf::YieldOp yieldOp) { auto *parentOp = yieldOp->getParentOp(); if (!parentOp || isa(parentOp)) return; - assert(parentOp->getResults().size() == yieldOp->getOpOperands().size()); + if (parentOp->getResults().size() != yieldOp->getOpOperands().size()) + return; for (auto [yieldVal, resultVal] : llvm::zip(yieldOp->getOpOperands(), parentOp->getResults())) { UpdateAliasBufferInfo(resultVal, yieldVal.get()); } diff --git a/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp b/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp index 863dbd6f6..2dfb40fb7 100644 --- a/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp +++ b/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp @@ -6,23 +6,20 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/RemoveRedundantSync.h" #include "llvm/ADT/STLExtras.h" #include #include - + #define DEBUG_TYPE "pto-inject-sync" - + using namespace mlir; using namespace mlir::pto; namespace { +using SyncPair = std::pair; + SmallVector canonicalizeDepRoots(const SmallVector &roots) { SmallVector result; result.reserve(roots.size()); @@ -50,85 +47,69 @@ bool hasSameDepRoots(const SyncOperation *lhs, const SyncOperation *rhs) { return lhsRoots == rhsRoots; } +std::vector collectPairedSyncOps(SyncOperations &syncOperations) { + std::vector syncOps; + for (auto &syncPair : syncOperations) { + if (syncPair.size() != 2) + continue; + syncOps.emplace_back(syncPair[0].get(), syncPair[1].get()); + } + return syncOps; +} + +bool compareSyncPairPriority(const SyncPair &lhs, const SyncPair &rhs) { + auto *syncOp1 = lhs.first; + auto *syncOp2 = rhs.first; + + bool hasLoop1 = syncOp1->GetForEndIndex().has_value(); + bool hasLoop2 = syncOp2->GetForEndIndex().has_value(); + if (hasLoop1 && hasLoop2) { + if (syncOp1->GetForEndIndex().value() != syncOp2->GetForEndIndex().value()) + return syncOp1->GetForEndIndex().value() > syncOp2->GetForEndIndex().value(); + return syncOp1->GetSyncIndex() > syncOp2->GetSyncIndex(); + } + if (hasLoop1 != hasLoop2) + return hasLoop1; + return syncOp1->GetSyncIndex() > syncOp2->GetSyncIndex(); +} + +bool shouldKeepSyncPair(SyncOperation *setFlag, SyncOperation *waitFlag) { + if (setFlag->eventIdNum != 1 || waitFlag->eventIdNum != 1) + return true; + if (setFlag->isCompensation || waitFlag->isCompensation) + return true; + return !hasSameDepRoots(setFlag, waitFlag); +} + +template +void eraseSyncFromList(SyncContainer &syncs, SyncOperation *target) { + auto it = std::find(syncs.begin(), syncs.end(), target); + if (it != syncs.end()) + syncs.erase(it); +} + +void markSyncPairUseless(SyncIRs &syncIR, SyncOperation *setFlag, + SyncOperation *waitFlag) { + eraseSyncFromList(syncIR[setFlag->GetSyncIRIndex()]->pipeAfter, setFlag); + eraseSyncFromList(syncIR[waitFlag->GetSyncIRIndex()]->pipeBefore, waitFlag); + setFlag->uselessSync = true; + waitFlag->uselessSync = true; +} + } // namespace - + void RemoveRedundantSync::Run() { - // 1. 收集所有成对的同步指令 (Set/Wait) - std::vector> syncOps; - for (auto &syncPair : syncOperations_) { - // 只有成对的 (Set, Wait) 才能进行此类消除,Barrier 不适用 - if (syncPair.size() == 2) { - auto *setFlag = syncPair[0].get(); - auto *waitFlag = syncPair[1].get(); - syncOps.push_back(std::make_pair(setFlag, waitFlag)); - } - } - - // 2. 排序:优先处理范围较小的或者是 Loop 内部的, - // 这样如果它们被保留,可以用来消除外部更大的。 - // (这里采用简单且稳定的排序策略,确保处理顺序可预测) - std::sort(syncOps.begin(), syncOps.end(), - [](std::pair syncPair1, - std::pair syncPair2) { - auto *syncOp1 = syncPair1.first; - auto *syncOp2 = syncPair2.first; - - bool hasLoop1 = syncOp1->GetForEndIndex().has_value(); - bool hasLoop2 = syncOp2->GetForEndIndex().has_value(); - - if (hasLoop1 && hasLoop2) { - if (syncOp1->GetForEndIndex().value() != syncOp2->GetForEndIndex().value()) { - return syncOp1->GetForEndIndex().value() > syncOp2->GetForEndIndex().value(); - } else { - return syncOp1->GetSyncIndex() > syncOp2->GetSyncIndex(); - } - } - if (hasLoop1 || hasLoop2) { - return hasLoop1 > hasLoop2; - } - return syncOp1->GetSyncIndex() > syncOp2->GetSyncIndex(); - }); - - // 3. 逐个检查并移除冗余 + std::vector syncOps = collectPairedSyncOps(syncOperations_); + std::sort(syncOps.begin(), syncOps.end(), compareSyncPairPriority); + for (auto [setFlag, waitFlag] : syncOps) { - // Conservative mode: - // 1) keep multi-buffer and compensation syncs - // 2) only prune syncs that carry concrete dependency signatures - if (setFlag->eventIdNum != 1 || waitFlag->eventIdNum != 1) { - continue; - } - if (setFlag->isCompensation || waitFlag->isCompensation) { + if (shouldKeepSyncPair(setFlag, waitFlag)) continue; - } - if (!hasSameDepRoots(setFlag, waitFlag)) { - continue; - } - - bool useless = CheckAllSync(setFlag, waitFlag); - if (useless) { - // 标记为冗余 (虽然这里是物理移除) - - // 从 SyncIR 中移除 Set - auto &pipeAfter = syncIR_[setFlag->GetSyncIRIndex()]->pipeAfter; - auto it0 = std::find(pipeAfter.begin(), pipeAfter.end(), setFlag); - if (it0 != pipeAfter.end()) { - pipeAfter.erase(it0); - } - - // 从 SyncIR 中移除 Wait - auto &pipeBefore = syncIR_[waitFlag->GetSyncIRIndex()]->pipeBefore; - auto it1 = std::find(pipeBefore.begin(), pipeBefore.end(), waitFlag); - if (it1 != pipeBefore.end()) { - pipeBefore.erase(it1); - } - - // 标记对象本身,避免 EventID 分配时分配给它 - setFlag->uselessSync = true; - waitFlag->uselessSync = true; - } + if (CheckAllSync(setFlag, waitFlag)) + markSyncPairUseless(syncIR_, setFlag, waitFlag); } } - + bool RemoveRedundantSync::CheckAllSync(SyncOperation *setFlag, SyncOperation *waitFlag) { // syncFinder 用于跟踪在当前范围内,哪些 SyncIndex 的 Set 已经被看到了。 @@ -138,7 +119,7 @@ bool RemoveRedundantSync::CheckAllSync(SyncOperation *setFlag, unsigned int begin = setFlag->GetSyncIRIndex(); unsigned int end = waitFlag->GetSyncIRIndex(); auto forEndIndex = setFlag->GetForEndIndex(); - + if (begin < end) { // 普通的前向依赖 return CheckRepeatSync(begin, end, syncFinder, setFlag); @@ -148,7 +129,7 @@ bool RemoveRedundantSync::CheckAllSync(SyncOperation *setFlag, return false; } } - + bool RemoveRedundantSync::CheckRepeatSync(unsigned int begin, unsigned int end, SmallVector &syncFinder, SyncOperation *setFlag) { @@ -198,27 +179,27 @@ bool RemoveRedundantSync::CheckRepeatSync(unsigned int begin, unsigned int end, return res; } - + bool RemoveRedundantSync::CheckBranchBetween( BranchInstanceElement *branchElement, SmallVector syncFinder, SyncOperation *setFlag, unsigned endId, unsigned &i) { - // 只处理 IF_BEGIN if (branchElement->getBranchKind() != KindOfBranch::IF_BEGIN) { i = branchElement->endId; return false; } - + bool hasElseBranch = branchElement->branchId < branchElement->endId; - + // 检查 waitFlag (endId) 是否在分支内部。如果是,我们不能简单跳过分支。 // 这里逻辑是:如果当前的冗余检查范围跨越了整个分支(即 begin 在 if 前,end 在 if 后), // 那么我们需要检查是否在 THEN 和 ELSE 两个路径上都找到了内部同步。 bool endIsInsideThenBranch = (!hasElseBranch && endId < branchElement->endId) || (hasElseBranch && endId < branchElement->branchId); - if (endIsInsideThenBranch) return false; - + if (endIsInsideThenBranch) + return false; + bool endIsInsideElseBranch = hasElseBranch && endId >= branchElement->branchId && endId < branchElement->endId; @@ -229,20 +210,21 @@ bool RemoveRedundantSync::CheckBranchBetween( // 核心:如果两个分支都存在内部覆盖,则整体覆盖 if (hasElseBranch) { - bool coveredInThen = CheckRepeatSync(branchElement->beginId, branchElement->branchId, syncFinder, setFlag); - bool coveredInElse = CheckRepeatSync(branchElement->branchId, branchElement->endId, syncFinder, setFlag); - - if (coveredInThen && coveredInElse) { + bool coveredInThen = CheckRepeatSync( + branchElement->beginId, branchElement->branchId, syncFinder, setFlag); + bool coveredInElse = CheckRepeatSync( + branchElement->branchId, branchElement->endId, syncFinder, setFlag); + + if (coveredInThen && coveredInElse) return true; - } } // 如果只有 Then 分支 (Implicit Else),除非我们在 Else (空路径) 上也能找到同步(不可能), // 否则无法断定冗余。所以单 If 分支通常无法帮助消除跨越它的外部同步。 - + i = branchElement->endId; // 跳过整个分支块 return false; } - + bool RemoveRedundantSync::CheckLoopBetween(LoopInstanceElement *loopElement, SyncOperation *setFlag, unsigned &i) { @@ -252,7 +234,7 @@ bool RemoveRedundantSync::CheckLoopBetween(LoopInstanceElement *loopElement, i = loopElement->endId; return false; } - + bool RemoveRedundantSync::CanMatchedSync(SmallVector &syncFinder, SyncOperation *relatedSync, SyncOperation *setFlag) { diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index c638cfab2..09f971f2a 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/SyncCodegen.h" #include "PTO/IR/PTO.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -58,6 +53,25 @@ static void MergeSyncList(SyncOps &dstList, const SyncOps &srcList) { } } } + +static void setSyncInsertionPoint(IRRewriter &rewriter, Operation *op, + bool beforeInsert) { + if (beforeInsert || op->hasTrait()) + rewriter.setInsertionPoint(op); + else + rewriter.setInsertionPointAfter(op); +} + +static void emitSetOrWaitOp(IRRewriter &rewriter, Operation *op, + SyncOperation *sync) { + auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); + auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); + auto eventId = getEventAttr(rewriter, sync->eventIds[0]); + if (sync->isSyncWaitType()) + rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + else + rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); +} // ============================================================================== // 2. SyncCodegen Implementation @@ -315,56 +329,19 @@ void SyncCodegen::CreateSetWaitOpForSingleBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - // [Fix] Terminator 强制前置插入 - if (beforeInsert || op->hasTrait()) { - rewriter.setInsertionPoint(op); - } else { - rewriter.setInsertionPointAfter(op); - } - - auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); - auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); - - if (sync->isSyncWaitType()) { - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); - } else { - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); - } + setSyncInsertionPoint(rewriter, op, beforeInsert); + emitSetOrWaitOp(rewriter, op, sync); } void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - // 注意:GetBufferSelected 可能需要在插入 Set/Wait 之前调用,以确保 SSA 顺序 - // 但这里只是获取 Value,不影响 InsertionPoint 的设定 Value bufferSelected = GetBufferSelected(rewriter, op, sync); (void)bufferSelected; - - // [Fix] Terminator 强制前置插入 - if (beforeInsert || op->hasTrait()) { - rewriter.setInsertionPoint(op); - } else { - rewriter.setInsertionPointAfter(op); - } - - auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); - auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); // 注意:MultiBuffer可能需要特殊处理Attr - - // 这里假设 SetFlagOp/WaitFlagOp 支持动态 Value 作为 EventID,或者您有特殊的 Op - // 如果 PTO 定义只支持 Attribute,那么上面的 GetBufferSelected 逻辑需要配合修改 Op 定义 - // 假设目前的 Op 定义如下: - if (sync->isSyncWaitType()) { - // 假设 WaitFlagOp 有支持 Value eventId 的重载或变体 - // 如果没有,这行代码可能需要调整。但在您之前的 Double Buffer 测试中,看起来它是工作的? - // 或者您是否使用了 UpdateFlagOp (带 Value)? - // 这里保持原样,只修改 InsertionPoint - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); - } else { - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); - } + + setSyncInsertionPoint(rewriter, op, beforeInsert); + emitSetOrWaitOp(rewriter, op, sync); } Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, diff --git a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp index e877b6eab..dcbb736b8 100644 --- a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp @@ -6,13 +6,9 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/InsertSync/SyncEventIdAllocation.h" #include "PTO/Transforms/InsertSync/SyncCommon.h" +#include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "pto-inject-sync" @@ -137,8 +133,10 @@ void SyncEventIdAllocation::SetEventId(SyncOperation *sync) { SmallVector eventIdLifetimeAvailableStatus = GetEventPool(sync, poolSize); SmallVector eventIdIdleStatus = GetEventIdIdleStatus(sync, poolSize); - assert(eventIdLifetimeAvailableStatus.size() == poolSize); - assert(eventIdIdleStatus.size() == poolSize); + if (eventIdLifetimeAvailableStatus.size() != poolSize || + eventIdIdleStatus.size() != poolSize) { + llvm::report_fatal_error("invalid event-id pool state"); + } // Apply per-(src,dst) reservations by marking the "reserved tail" as // unavailable. Historically this pass treated reserved IDs as being at the @@ -161,7 +159,6 @@ void SyncEventIdAllocation::SetEventId(SyncOperation *sync) { } else if (reallocatedPipePair.count(ScopePair(sync)) && (canAllocaEventId.size() < idSize)) { // Reallocate strategy: reduce usage to 1 - assert(canAllocaEventId.size() > 0); SetEventPool(sync, canAllocaEventId[0]); sync->eventIdNum = 1; } @@ -232,7 +229,8 @@ SmallVector SyncEventIdAllocation::GetEventIdIdleStatus(SyncOperation *syn SmallVector SyncEventIdAllocation::GetEventPool(const SyncOperation *sync, size_t eventIdNum) { SmallVector eventIdPool(eventIdNum, true); - assert(sync->GetSyncIndex() < syncOperations_.size()); + if (sync->GetSyncIndex() >= syncOperations_.size()) + return eventIdPool; auto &syncPair = syncOperations_[sync->GetSyncIndex()]; auto *setFlag = syncPair[0].get(); auto *waitFlag = syncPair[1].get(); @@ -241,8 +239,8 @@ SmallVector SyncEventIdAllocation::GetEventPool(const SyncOperation *sync, if (reallocatedPipePair.count(ScopePair(sync))) { auto *ptr = dyn_cast( syncIR_[setFlag->GetForEndIndex().value()].get()); - assert(ptr != nullptr); - FindUseEventID(ptr->beginId, ptr->endId, setFlag, eventIdPool); + if (ptr) + FindUseEventID(ptr->beginId, ptr->endId, setFlag, eventIdPool); } else { FindUseEventID(0, syncIR_.size() - 1, setFlag, eventIdPool); } @@ -276,7 +274,8 @@ void SyncEventIdAllocation::FindUseEventID(unsigned int begin, unsigned int end, const SyncOperation *s, SmallVector &eventId) { const auto eventIdSize = eventId.size(); - assert(begin < end); + if (begin >= end) + return; int scopePair = ScopePair(s); eventCyclePool.try_emplace(scopePair, EventCyclePool(eventIdSize)); EventCyclePool &seqPool = eventCyclePool[scopePair]; @@ -297,7 +296,8 @@ void SyncEventIdAllocation::FindUseEventID(unsigned int begin, unsigned int end, bool SyncEventIdAllocation::CheckSyncLifeCycleConflict( SmallVector &syncLifeCycle, unsigned int begin, unsigned int end, SmallVector &eventId, unsigned i) const { - assert((syncLifeCycle.size() & 0x1) == 0 && "sync_life_cycle error."); + if ((syncLifeCycle.size() & 0x1) != 0) + return true; if (syncLifeCycle[0] <= begin) { return true; // Conflict! } @@ -317,7 +317,10 @@ void SyncEventIdAllocation::UpdateEventId( eventId[index] = false; // Conflict } } else if (j == syncLifeCycle.size() - 1) { - assert((j & 0x1) == 1); + if ((j & 0x1) == 0) { + eventId[index] = false; + break; + } if (syncLifeCycle[j] >= end) { break; // Safe } else { @@ -329,7 +332,8 @@ void SyncEventIdAllocation::UpdateEventId( void SyncEventIdAllocation::SetEventPool(const SyncOperation *sync, unsigned eventId) { - assert(sync->GetSyncIndex() < syncOperations_.size()); + if (sync->GetSyncIndex() >= syncOperations_.size()) + return; auto &syncPair = syncOperations_[sync->GetSyncIndex()]; // [Fix] 遍历组内所有 SyncOperation,为它们统一分配 Event ID @@ -348,8 +352,8 @@ void SyncEventIdAllocation::SetEventPool(const SyncOperation *sync, if (reallocatedPipePair.count(ScopePair(sync))) { auto *ptr = dyn_cast( syncIR_[setFlag->GetForEndIndex().value()].get()); - assert(ptr != nullptr); - SetUseEventID(ptr->beginId, ptr->endId, setFlag.get(), eventId); + if (ptr) + SetUseEventID(ptr->beginId, ptr->endId, setFlag.get(), eventId); } else { SetUseEventID(0, syncIR_.size(), setFlag.get(), eventId); } @@ -389,13 +393,19 @@ void SyncEventIdAllocation::UpdateBackwardMatchSync( if (reallocatedPipePair.count(ScopePair(setFlag))) { auto *ptr = dyn_cast( syncIR_[setFlag->GetForEndIndex().value()].get()); - assert(ptr != nullptr); - syncFront->SetSyncIRIndex(ptr->beginId); - syncEnd->SetSyncIRIndex(ptr->endId); - syncFront->reallocatedLoopHeadTailSync = true; - syncEnd->reallocatedLoopHeadTailSync = true; - syncIR_[ptr->beginId]->pipeBefore.push_back(syncFront.get()); - syncIR_[ptr->endId]->pipeAfter.push_back(syncEnd.get()); + if (ptr) { + syncFront->SetSyncIRIndex(ptr->beginId); + syncEnd->SetSyncIRIndex(ptr->endId); + syncFront->reallocatedLoopHeadTailSync = true; + syncEnd->reallocatedLoopHeadTailSync = true; + syncIR_[ptr->beginId]->pipeBefore.push_back(syncFront.get()); + syncIR_[ptr->endId]->pipeAfter.push_back(syncEnd.get()); + } else { + syncFront->SetSyncIRIndex(0); + syncEnd->SetSyncIRIndex(syncIR_.size() - 1); + syncIR_[0]->pipeBefore.push_back(syncFront.get()); + syncIR_[syncIR_.size() - 1]->pipeAfter.push_back(syncEnd.get()); + } } else { syncFront->SetSyncIRIndex(0); syncEnd->SetSyncIRIndex(syncIR_.size() - 1); @@ -415,7 +425,8 @@ void SyncEventIdAllocation::UpdateBackwardMatchSync( void SyncEventIdAllocation::SetUseEventID(unsigned int begin, unsigned int end, const SyncOperation *setFlag, unsigned int eventId) { - assert(begin < end); + if (begin >= end) + return; int scopePair = ScopePair(setFlag); const size_t poolSize = getEventIdPoolSize(setFlag, reservedBlockSyncEventIdNum); @@ -650,7 +661,8 @@ void SyncEventIdAllocation::IgnoreBackHeadAndTailSync() { } bool SyncEventIdAllocation::TryWidenByOtherSync(const SyncOperation *sync) { - assert(!sync->isBarrierType()); + if (sync->isBarrierType()) + return false; auto &syncPair = syncOperations_[sync->GetSyncIndex()]; SyncOperation *setSync = syncPair[0].get(); SyncOperation *waitSync = syncPair[1].get(); @@ -680,7 +692,8 @@ bool SyncEventIdAllocation::TryWidenByOtherSync(const SyncOperation *sync) { } } widenSetSyncIR->pipeAfter = newPipeAfter; - if (!removeSync) llvm_unreachable("in widen fun, remove sync failed"); + if (!removeSync) + llvm::report_fatal_error("failed to remove widened sync from original position"); } return true; } diff --git a/lib/PTO/Transforms/LoweringSyncToPipe.cpp b/lib/PTO/Transforms/LoweringSyncToPipe.cpp index dc2bb04d6..bc2672a0b 100644 --- a/lib/PTO/Transforms/LoweringSyncToPipe.cpp +++ b/lib/PTO/Transforms/LoweringSyncToPipe.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/Transforms/Passes.h" #include "PTO/IR/PTO.h" #include "PTO/IR/PTOSyncUtils.h" @@ -43,29 +38,43 @@ static FailureOr getSyncOpTypeFromAttr(Attribute attr, Operation *op return failure(); } +static FailureOr> getConcretePipePair(Operation *op, + Attribute srcAttr, + Attribute dstAttr) { + auto srcTypeOr = getSyncOpTypeFromAttr(srcAttr, op, "src_op"); + if (failed(srcTypeOr)) + return failure(); + auto dstTypeOr = getSyncOpTypeFromAttr(dstAttr, op, "dst_op"); + if (failed(dstTypeOr)) + return failure(); + + PIPE srcPipe = mapSyncOpTypeToPipe(*srcTypeOr); + PIPE dstPipe = mapSyncOpTypeToPipe(*dstTypeOr); + if (!isConcreteSyncPipe(srcPipe) || !isConcreteSyncPipe(dstPipe)) { + op->emitError("Failed to map SyncOpType to hardware pipe during lowering."); + return failure(); + } + return std::make_pair(srcPipe, dstPipe); +} + +template +static LogicalResult lowerEventSyncOp(HighLevelOp op, PatternRewriter &rewriter) { + auto pipes = getConcretePipePair(op.getOperation(), op.getSrcOpAttr(), + op.getDstOpAttr()); + if (failed(pipes)) + return failure(); + rewriter.replaceOpWithNewOp( + op, PipeAttr::get(op.getContext(), pipes->first), + PipeAttr::get(op.getContext(), pipes->second), op.getEventIdAttr()); + return success(); +} + struct RecordEventLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(RecordEventOp op, PatternRewriter &rewriter) const override { - auto srcTypeOr = getSyncOpTypeFromAttr(op.getSrcOpAttr(), op, "src_op"); - if (failed(srcTypeOr)) - return failure(); - auto dstTypeOr = getSyncOpTypeFromAttr(op.getDstOpAttr(), op, "dst_op"); - if (failed(dstTypeOr)) - return failure(); - SyncOpType srcType = *srcTypeOr; - SyncOpType dstType = *dstTypeOr; - - PIPE srcPipe = mapSyncOpTypeToPipe(srcType); - PIPE dstPipe = mapSyncOpTypeToPipe(dstType); - if (!isConcreteSyncPipe(srcPipe) || !isConcreteSyncPipe(dstPipe)) - return op.emitError("Failed to map SyncOpType to hardware pipe during lowering."); - - rewriter.replaceOpWithNewOp( - op, PipeAttr::get(op.getContext(), srcPipe), - PipeAttr::get(op.getContext(), dstPipe), op.getEventIdAttr()); - return success(); + return lowerEventSyncOp(op, rewriter); } }; @@ -74,24 +83,7 @@ struct WaitEventLowering : public OpRewritePattern { LogicalResult matchAndRewrite(WaitEventOp op, PatternRewriter &rewriter) const override { - auto srcTypeOr = getSyncOpTypeFromAttr(op.getSrcOpAttr(), op, "src_op"); - if (failed(srcTypeOr)) - return failure(); - auto dstTypeOr = getSyncOpTypeFromAttr(op.getDstOpAttr(), op, "dst_op"); - if (failed(dstTypeOr)) - return failure(); - SyncOpType srcType = *srcTypeOr; - SyncOpType dstType = *dstTypeOr; - - PIPE srcPipe = mapSyncOpTypeToPipe(srcType); - PIPE dstPipe = mapSyncOpTypeToPipe(dstType); - if (!isConcreteSyncPipe(srcPipe) || !isConcreteSyncPipe(dstPipe)) - return op.emitError("Failed to map SyncOpType to hardware pipe during lowering."); - - rewriter.replaceOpWithNewOp( - op, PipeAttr::get(op.getContext(), srcPipe), - PipeAttr::get(op.getContext(), dstPipe), op.getEventIdAttr()); - return success(); + return lowerEventSyncOp(op, rewriter); } }; diff --git a/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp b/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp index d0ab9f34c..647478aae 100644 --- a/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp +++ b/lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -38,155 +33,161 @@ struct FrontendPipeHandles { }; template -static FailureOr lowerFrontendInitOp(InitOpT initOp, - IRRewriter &rewriter) { - FrontendPipeHandles handles; +static LogicalResult requireFrontendGmSlotBuffer(InitOpT initOp) { + if (initOp.getGmSlotBuffer()) + return success(); + return initOp.emitOpError("requires 'gm_slot_buffer' when lowering to a2/a3"); +} + +template +static FailureOr createFrontendPipe(InitOpT initOp, IRRewriter &rewriter, + PTOArch arch, Type pipeTy, + int8_t dirMask, int32_t slotNum, + Value localAddr, + Value peerLocalAddr = Value{}) { Location loc = initOp.getLoc(); - MLIRContext *ctx = initOp.getContext(); - auto pipeTy = PipeType::get(ctx); - PTOArch arch = getTargetArch(initOp.getOperation()); + auto dirAttr = rewriter.getI8IntegerAttr(dirMask); + auto slotSizeAttr = rewriter.getI32IntegerAttr(initOp.getSlotSize()); + auto slotNumAttr = rewriter.getI32IntegerAttr(slotNum); + + if (arch == PTOArch::A5) { + auto pipe = rewriter.create( + loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, IntegerAttr{}, + localAddr, peerLocalAddr); + return pipe.getPipe(); + } - auto createPipe = [&](int8_t dirMask, int32_t slotNum, - Value localAddr) -> FailureOr { - auto dirAttr = rewriter.getI8IntegerAttr(dirMask); - auto slotSizeAttr = rewriter.getI32IntegerAttr(initOp.getSlotSize()); - auto slotNumAttr = rewriter.getI32IntegerAttr(slotNum); - - if (arch == PTOArch::A5) { - auto pipe = rewriter.create( - loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, IntegerAttr{}, - localAddr, /*peer_local_addr=*/Value{}); - return pipe.getPipe(); - } + if (failed(requireFrontendGmSlotBuffer(initOp))) + return failure(); - if (!initOp.getGmSlotBuffer()) { - initOp.emitOpError("requires 'gm_slot_buffer' when lowering to a2/a3"); - return failure(); - } + auto localSlotNumAttr = rewriter.getI32IntegerAttr(slotNum); + auto pipe = rewriter.create( + loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, localSlotNumAttr, + IntegerAttr{}, initOp.getGmSlotBuffer(), localAddr, peerLocalAddr); + return pipe.getPipe(); +} - auto localSlotNumAttr = rewriter.getI32IntegerAttr(slotNum); - auto pipe = rewriter.create( - loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, localSlotNumAttr, - IntegerAttr{}, initOp.getGmSlotBuffer(), localAddr, - /*peer_local_addr=*/Value{}); - return pipe.getPipe(); - }; +template +static FailureOr +lowerSingleDirectionFrontendInit(InitOpT initOp, IRRewriter &rewriter, + PTOArch arch, Type pipeTy, int8_t dirMask, + Value localAddr) { + auto pipeOr = + createFrontendPipe(initOp, rewriter, arch, pipeTy, dirMask, /*slotNum=*/8, + localAddr); + if (failed(pipeOr)) + return failure(); - switch (initOp.getDirMask()) { - case 1: { - auto pipeOr = - createPipe(/*dirMask=*/1, /*slotNum=*/8, initOp.getC2vConsumerBuf()); - if (failed(pipeOr)) - return failure(); + FrontendPipeHandles handles; + if (dirMask == 1) handles.c2vPipe = *pipeOr; - handles.anchorOp = handles.c2vPipe.getDefiningOp(); - break; - } - case 2: { - auto pipeOr = - createPipe(/*dirMask=*/2, /*slotNum=*/8, initOp.getV2cConsumerBuf()); - if (failed(pipeOr)) - return failure(); + else handles.v2cPipe = *pipeOr; - handles.anchorOp = handles.v2cPipe.getDefiningOp(); - break; - } - case 3: { - auto dirAttr = rewriter.getI8IntegerAttr(3); - auto slotSizeAttr = rewriter.getI32IntegerAttr(initOp.getSlotSize()); - auto slotNumAttr = rewriter.getI32IntegerAttr(4); - Value c2vAddr = initOp.getC2vConsumerBuf(); - Value v2cAddr = initOp.getV2cConsumerBuf(); - - if (arch == PTOArch::A5) { - auto pipe = rewriter.create( - loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, IntegerAttr{}, - c2vAddr, v2cAddr); - handles.c2vPipe = pipe.getPipe(); - handles.v2cPipe = pipe.getPipe(); - handles.anchorOp = pipe.getOperation(); - } else { - if (!initOp.getGmSlotBuffer()) { - initOp.emitOpError("requires 'gm_slot_buffer' when lowering to a2/a3"); - return failure(); - } - auto localSlotNumAttr = rewriter.getI32IntegerAttr(4); - auto pipe = rewriter.create( - loc, pipeTy, dirAttr, slotSizeAttr, slotNumAttr, localSlotNumAttr, - IntegerAttr{}, initOp.getGmSlotBuffer(), c2vAddr, v2cAddr); - handles.c2vPipe = pipe.getPipe(); - handles.v2cPipe = pipe.getPipe(); - handles.anchorOp = pipe.getOperation(); - } - break; - } - default: - break; - } - + handles.anchorOp = pipeOr->getDefiningOp(); return handles; } -static FailureOr lowerInitIfPresent(func::FuncOp funcOp, - IRRewriter &rewriter) { +template +static FailureOr +lowerBidirectionalFrontendInit(InitOpT initOp, IRRewriter &rewriter, + PTOArch arch, Type pipeTy) { + auto pipeOr = createFrontendPipe(initOp, rewriter, arch, pipeTy, + /*dirMask=*/3, /*slotNum=*/4, + initOp.getC2vConsumerBuf(), + initOp.getV2cConsumerBuf()); + if (failed(pipeOr)) + return failure(); + FrontendPipeHandles handles; + handles.c2vPipe = *pipeOr; + handles.v2cPipe = *pipeOr; + handles.anchorOp = pipeOr->getDefiningOp(); + return handles; +} + +template +static FailureOr lowerFrontendInitOp(InitOpT initOp, + IRRewriter &rewriter) { + MLIRContext *ctx = initOp.getContext(); + auto pipeTy = PipeType::get(ctx); + PTOArch arch = getTargetArch(initOp.getOperation()); + + switch (initOp.getDirMask()) { + case 1: + return lowerSingleDirectionFrontendInit(initOp, rewriter, arch, pipeTy, + /*dirMask=*/1, + initOp.getC2vConsumerBuf()); + case 2: + return lowerSingleDirectionFrontendInit(initOp, rewriter, arch, pipeTy, + /*dirMask=*/2, + initOp.getV2cConsumerBuf()); + case 3: + return lowerBidirectionalFrontendInit(initOp, rewriter, arch, pipeTy); + default: + return FrontendPipeHandles{}; + } +} + +struct FrontendInitOps { AicInitializePipeOp aicInit; AivInitializePipeOp aivInit; unsigned aicInitCount = 0; unsigned aivInitCount = 0; +}; +static FrontendInitOps collectFrontendInitOps(func::FuncOp funcOp) { + FrontendInitOps initOps; funcOp.walk([&](Operation *op) { if (auto init = dyn_cast(op)) { - ++aicInitCount; - if (!aicInit) - aicInit = init; + ++initOps.aicInitCount; + if (!initOps.aicInit) + initOps.aicInit = init; return WalkResult::advance(); } if (auto init = dyn_cast(op)) { - ++aivInitCount; - if (!aivInit) - aivInit = init; - return WalkResult::advance(); + ++initOps.aivInitCount; + if (!initOps.aivInit) + initOps.aivInit = init; } return WalkResult::advance(); }); + return initOps; +} - if (aicInitCount > 1) { - funcOp.emitOpError("requires at most one pto.aic_initialize_pipe"); - return failure(); +static LogicalResult validateFrontendInitOps(func::FuncOp funcOp, + const FrontendInitOps &initOps) { + if (initOps.aicInitCount > 1) + return funcOp.emitOpError("requires at most one pto.aic_initialize_pipe"); + if (initOps.aivInitCount > 1) + return funcOp.emitOpError("requires at most one pto.aiv_initialize_pipe"); + if (initOps.aicInit && initOps.aivInit) { + return funcOp.emitOpError("cannot mix pto.aic_initialize_pipe and " + "pto.aiv_initialize_pipe in one function"); } + return success(); +} - if (aivInitCount > 1) { - funcOp.emitOpError("requires at most one pto.aiv_initialize_pipe"); +template +static FailureOr lowerAndEraseFrontendInit(InitOpT initOp, + IRRewriter &rewriter) { + rewriter.setInsertionPoint(initOp); + auto loweredOr = lowerFrontendInitOp(initOp, rewriter); + if (failed(loweredOr)) return failure(); - } + rewriter.eraseOp(initOp); + return *loweredOr; +} - if (aicInit && aivInit) { - funcOp.emitOpError( - "cannot mix pto.aic_initialize_pipe and pto.aiv_initialize_pipe in one function"); +static FailureOr lowerInitIfPresent(func::FuncOp funcOp, + IRRewriter &rewriter) { + FrontendInitOps initOps = collectFrontendInitOps(funcOp); + if (failed(validateFrontendInitOps(funcOp, initOps))) return failure(); - } - - if (!aicInit && !aivInit) - return handles; - - if (aicInit) { - rewriter.setInsertionPoint(aicInit); - auto loweredOr = lowerFrontendInitOp(aicInit, rewriter); - if (failed(loweredOr)) - return failure(); - handles = *loweredOr; - rewriter.eraseOp(aicInit); - } else { - rewriter.setInsertionPoint(aivInit); - auto loweredOr = lowerFrontendInitOp(aivInit, rewriter); - if (failed(loweredOr)) - return failure(); - handles = *loweredOr; - rewriter.eraseOp(aivInit); - } - - return handles; + if (initOps.aicInit) + return lowerAndEraseFrontendInit(initOps.aicInit, rewriter); + if (initOps.aivInit) + return lowerAndEraseFrontendInit(initOps.aivInit, rewriter); + return FrontendPipeHandles{}; } static bool hasFrontendPipeOps(func::FuncOp funcOp) { @@ -213,88 +214,82 @@ static LogicalResult lowerFrontendDataOps(func::FuncOp funcOp, frontendOps.push_back(op); }); - for (Operation *op : frontendOps) { - if (!handles.anchorOp) { - op->emitOpError("requires a frontend initialize_pipe op in the same function"); - return failure(); - } + auto requireDominatingFrontendInit = [&](Operation *op) -> LogicalResult { + if (!handles.anchorOp) + return op->emitOpError( + "requires a frontend initialize_pipe op in the same function"); if (!dom.dominates(handles.anchorOp, op)) { - op->emitOpError( + return op->emitOpError( "requires a dominating frontend initialize_pipe op"); - return failure(); } + return success(); + }; - rewriter.setInsertionPoint(op); - - if (auto push = dyn_cast(op)) { + auto getRequiredPipe = [&](Operation *op) -> FailureOr { + if (isa(op)) { if (!handles.c2vPipe) { op->emitOpError( "requires the dominating initialize_pipe op to enable C2V"); return failure(); } - rewriter.replaceOpWithNewOp(push, push.getTile(), handles.c2vPipe, - push.getSplitAttr()); - continue; + return handles.c2vPipe; + } + + if (!handles.v2cPipe) { + op->emitOpError( + "requires the dominating initialize_pipe op to enable V2C"); + return failure(); } + return handles.v2cPipe; + }; + auto lowerFrontendDataOp = [&](Operation *op) -> LogicalResult { + if (failed(requireDominatingFrontendInit(op))) + return failure(); + auto pipeOr = getRequiredPipe(op); + if (failed(pipeOr)) + return failure(); + + Value pipe = *pipeOr; + rewriter.setInsertionPoint(op); + if (auto push = dyn_cast(op)) { + rewriter.replaceOpWithNewOp(push, push.getTile(), pipe, + push.getSplitAttr()); + return success(); + } if (auto push = dyn_cast(op)) { - if (!handles.v2cPipe) { - op->emitOpError( - "requires the dominating initialize_pipe op to enable V2C"); - return failure(); - } - rewriter.replaceOpWithNewOp(push, push.getTile(), handles.v2cPipe, + rewriter.replaceOpWithNewOp(push, push.getTile(), pipe, push.getSplitAttr()); - continue; + return success(); } - if (auto pop = dyn_cast(op)) { - if (!handles.c2vPipe) { - op->emitOpError( - "requires the dominating initialize_pipe op to enable C2V"); - return failure(); - } auto decl = rewriter.create(pop.getLoc(), pop.getTile().getType()); - rewriter.create(pop.getLoc(), decl.getTile(), handles.c2vPipe, + rewriter.create(pop.getLoc(), decl.getTile(), pipe, pop.getSplitAttr()); rewriter.replaceOp(pop, decl.getTile()); - continue; + return success(); } - if (auto pop = dyn_cast(op)) { - if (!handles.v2cPipe) { - op->emitOpError( - "requires the dominating initialize_pipe op to enable V2C"); - return failure(); - } auto decl = rewriter.create(pop.getLoc(), pop.getTile().getType()); - rewriter.create(pop.getLoc(), decl.getTile(), handles.v2cPipe, + rewriter.create(pop.getLoc(), decl.getTile(), pipe, pop.getSplitAttr()); rewriter.replaceOp(pop, decl.getTile()); - continue; + return success(); } - if (auto free = dyn_cast(op)) { - if (!handles.c2vPipe) { - op->emitOpError( - "requires the dominating initialize_pipe op to enable C2V"); - return failure(); - } - rewriter.replaceOpWithNewOp(free, handles.c2vPipe, - free.getSplitAttr()); - continue; + rewriter.replaceOpWithNewOp(free, pipe, free.getSplitAttr()); + return success(); } - auto free = cast(op); - if (!handles.v2cPipe) { - op->emitOpError( - "requires the dominating initialize_pipe op to enable V2C"); + rewriter.replaceOpWithNewOp(free, pipe, free.getSplitAttr()); + return success(); + }; + + for (Operation *op : frontendOps) { + if (failed(lowerFrontendDataOp(op))) return failure(); - } - rewriter.replaceOpWithNewOp(free, handles.v2cPipe, - free.getSplitAttr()); } return success(); diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index c0389f8bf..a5a2df897 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -19,6 +19,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include #include @@ -47,12 +48,13 @@ struct LocalMemSpec { static int64_t ceilDivBitsToBytes(int64_t bits) { return (bits + 7) / 8; } static int64_t alignUpBytes(int64_t value, int64_t align) { - if (align <= 1) + int64_t safeAlign = std::max(align, 1); + if (safeAlign == 1) return value; - int64_t rem = value % align; + int64_t rem = value % safeAlign; if (rem == 0) return value; - return value + (align - rem); + return value + (safeAlign - rem); } static LocalMemSpec getLocalMemSpec(Operation *op, AddressSpace as) { @@ -354,14 +356,17 @@ void MemLivenessAnalysis::UpdateForOpBufferAlias(scf::ForOp forOp) { return; } if (!forOp.getRegionIterArgs().empty()) { - assert(forOp.getYieldedValues().size() == forOp.getRegionIterArgs().size()); - assert(forOp.getInitArgs().size() == forOp.getRegionIterArgs().size()); + if (forOp.getYieldedValues().size() != forOp.getRegionIterArgs().size() || + forOp.getInitArgs().size() != forOp.getRegionIterArgs().size()) { + llvm::report_fatal_error("scf.for alias sizes are inconsistent"); + } for (auto [i, arg] : llvm::enumerate(forOp.getRegionIterArgs())) { // yielded values alias region iter args. UpdateBufferAlias(forOp.getYieldedValues()[i], arg); } } - assert(forOp->getResults().size() == forOp.getYieldedValues().size()); + if (forOp->getResults().size() != forOp.getYieldedValues().size()) + llvm::report_fatal_error("scf.for result/yield sizes are inconsistent"); for (auto [i, arg] : llvm::enumerate(forOp.getYieldedValues())) { // forOp result values alias region iter yielded values. UpdateBufferAlias(forOp->getResult(i), arg); @@ -390,7 +395,8 @@ void MemLivenessAnalysis::UpdateForOpInitArgsAlias(scf::ForOp forOp) { if (forOp.getInitArgs().empty()) { return; } - assert(forOp.getInitArgs().size() == forOp.getRegionIterArgs().size()); + if (forOp.getInitArgs().size() != forOp.getRegionIterArgs().size()) + llvm::report_fatal_error("scf.for init/iter-arg sizes are inconsistent"); for (auto [i, arg] : llvm::enumerate(forOp.getInitArgs())) { // init args alias region iter args. UpdateBufferAlias(forOp.getRegionIterArgs()[i], arg); @@ -402,7 +408,8 @@ void MemLivenessAnalysis::UpdateIfOpBufferAlias(scf::IfOp ifOp, if (ifOp.getResults().empty()) { return; } - assert(ifOp->getResults().size() == yieldOp->getOperands().size()); + if (ifOp->getResults().size() != yieldOp->getOperands().size()) + llvm::report_fatal_error("scf.if result/yield sizes are inconsistent"); for (auto [i, arg] : llvm::enumerate(yieldOp->getOperands())) { // Multiple buffers involved, requiring one-to-one correspondence. UpdateBufferAlias(ifOp->getResult(i), arg); @@ -458,7 +465,8 @@ SmallVector MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp, LogicalResult MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const { auto allocOp = dyn_cast(op); - assert(allocOp && "must be alloc op"); + if (!allocOp) + return op->emitError("must be alloc op"), failure(); auto memorySpaceAttr = GetBufferSpaceAttr(allocOp.getResult()); if (isLocalBuffer(memorySpaceAttr)) { return success(); @@ -630,7 +638,8 @@ BufferInfo MemLivenessAnalysis::GenerateBufferInfo(Operation *op, Value operand) { auto memorySpaceAttr = GetBufferSpaceAttr(operand); if (isLocalMemPlan() && isLocalBuffer(memorySpaceAttr)) { - assert(memorySpaceAttr.has_value() && "buffer must has space!"); + if (!memorySpaceAttr.has_value()) + llvm::report_fatal_error("local buffer must have memory space"); return GetBufferInfo(op, operand, memorySpaceAttr.value().getAddressSpace()); } @@ -648,8 +657,8 @@ BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, bufferInfo.bufferType = memRefType.getElementType(); std::optional totalStaticSize = getStaticTotalSize(memRefType.getShape()); - assert(totalStaticSize.has_value() && - "Failed to obtain op buffer shape size!"); + if (!totalStaticSize.has_value()) + llvm::report_fatal_error("failed to obtain buffer static shape size"); bufferInfo.constBits = totalStaticSize.value() * static_cast(memRefType.getElementTypeBitWidth()); @@ -674,8 +683,8 @@ void MemLivenessAnalysis::GenerateBufferLife() { // Time given to buffer end. for (const Value &killBuffer : it->second.kill) { auto iter = buffer2Life.find(killBuffer); - assert(iter != buffer2Life.end() && - "buffer has not been generated before! "); + if (iter == buffer2Life.end()) + llvm::report_fatal_error("buffer lifetime killed before generation"); iter->second->freeTime = scopeTime; } scopeTime++; @@ -714,21 +723,21 @@ SmallVector MemPlan::GenerateInplaceList() { continue; if (hasTouchOp[operationSeq->operation]) { continue; - } - for (const Value &genBuffer : it->second.gen) { - auto genBufferIter = bufferInfos.find(genBuffer); - assert(genBufferIter != bufferInfos.end() && - "genBuffer should be find in bufferInfos"); - if (genBufferIter->second.ignoreInplace) { - continue; } - for (const Value &killBuffer : it->second.kill) { - auto killBufferIter = bufferInfos.find(killBuffer); - assert(killBufferIter != bufferInfos.end() && - "killBuffer should be find in bufferInfos"); - if (killBufferIter->second.ignoreInplace) { + for (const Value &genBuffer : it->second.gen) { + auto genBufferIter = bufferInfos.find(genBuffer); + if (genBufferIter == bufferInfos.end()) + llvm::report_fatal_error("gen buffer missing from buffer info map"); + if (genBufferIter->second.ignoreInplace) { continue; } + for (const Value &killBuffer : it->second.kill) { + auto killBufferIter = bufferInfos.find(killBuffer); + if (killBufferIter == bufferInfos.end()) + llvm::report_fatal_error("kill buffer missing from buffer info map"); + if (killBufferIter->second.ignoreInplace) { + continue; + } bool bufferSizeMatch = killBufferIter->second.constBits >= genBufferIter->second.constBits; @@ -843,7 +852,8 @@ void MemPlan::GenerateStorageEntry() { void MemPlan::PrintSuccessfulAllocatedMaxBits() { auto it = memscope2rootStorageEntry.find(pto::AddressSpace::VEC); if (it != memscope2rootStorageEntry.end()) { - assert(it->second != nullptr); + if (!it->second) + llvm::report_fatal_error("missing root storage entry for VEC scope"); uint64_t ubAllocBits = it->second->alignedConstBits + it->second->bitsOffset; for (auto& child : it->second->mergedChildren) { ubAllocBits = std::max(ubAllocBits, child->bitsOffset + child->alignedConstBits); @@ -854,9 +864,12 @@ void MemPlan::PrintSuccessfulAllocatedMaxBits() { } void MemPlan::ValidateParameters(std::unique_ptr &e) const { - assert(e->bufInfo->operation && "Unrecognized legal define operation !"); - assert(e->bufInfo->constBits >= 0U && "recognized illegal memory sizes !"); - assert(!e->bufferLifeVec.empty() && "Unrecognized buffer's life time !"); + if (!e->bufInfo->operation) + llvm::report_fatal_error("storage entry missing defining operation"); + if (e->bufInfo->constBits < 0U) + llvm::report_fatal_error("storage entry has invalid memory size"); + if (e->bufferLifeVec.empty()) + llvm::report_fatal_error("storage entry missing lifetime information"); } void MemPlan::UpdateBuffer2Offsets() { @@ -898,8 +911,8 @@ void MemPlan::MergeInplaceSE() { // already same storageEntry, no need to inplace. continue; } - assert(genSE != nullptr && killSE != nullptr && - " genSE and killSE should be valid"); + if (genSE == nullptr || killSE == nullptr) + llvm::report_fatal_error("invalid storage entry during inplace merge"); BufferLifeVec mergedBufferLifeVec; mergedBufferLifeVec.insert(mergedBufferLifeVec.end(), genSE->bufferLifeVec.begin(), @@ -1015,7 +1028,8 @@ bool MemPlan::IsEnoughForBuffersNoReuse(StorageEntry *rootStorageEntry, size_t alignUnit) { auto iter = bufferScope2RequiredSize.find(rootStorageEntry->bufInfo->bufferScope); - assert(iter != bufferScope2RequiredSize.end()); + if (iter == bufferScope2RequiredSize.end()) + llvm::report_fatal_error("missing required-size entry for buffer scope"); if (iter->second < restBufferSize) { PlanBuffersWithoutReuse(rootStorageEntry, alignUnit); return true; @@ -1394,7 +1408,8 @@ MemPlan::GetBufferParentLoop(const SmallVector &buffers) { llvm::SmallSet parentLoopVec; for (auto buffer : buffers) { if (!buffer.getDefiningOp()) { - assert(isa(buffer.getParentBlock()->getParentOp())); + if (!isa(buffer.getParentBlock()->getParentOp())) + llvm::report_fatal_error("expected loop-carried block argument"); // Init args and region iter arg are inplace, ignore Region Iter Arg // without DefineOp. continue; @@ -1496,7 +1511,8 @@ void MemPlan::SpecAllocRelationPongEntry(MemBoundList &outline, PlanRecHis &his, if (e->multiBufferNum == 2 && e->relationPongEntry) { pongStorageEntry = e->relationPongEntry; } - assert(pongStorageEntry && "PongStorage Entry not found!"); + if (!pongStorageEntry) + llvm::report_fatal_error("pong storage entry not found"); UpdateOutline(outline, his, pongStorageEntry, OutlineSectionInfo(start, end, size, true), SPEC_LEVEL_1); return; @@ -1817,7 +1833,8 @@ void MemPlan::ReportAllocatedEntryDebugInfo(StorageEntry *rootStorageEntry) { LDBG("\n"); } size_t num = allocatedEntry.size() - 1; - assert(rootStorageEntry->mergedChildren.size() > num); + if (rootStorageEntry->mergedChildren.size() <= num) + llvm::report_fatal_error("missing failed storage entry"); const StorageEntry *failedSe = rootStorageEntry->mergedChildren[num]; printRecord(failedSe); LDBG("alloc fail,because exceed bound of memory \n" diff --git a/lib/PTO/Transforms/PTORemoveRedundantBarrier.cpp b/lib/PTO/Transforms/PTORemoveRedundantBarrier.cpp index a10fc0690..f349e42fd 100644 --- a/lib/PTO/Transforms/PTORemoveRedundantBarrier.cpp +++ b/lib/PTO/Transforms/PTORemoveRedundantBarrier.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/IR/PTO.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -62,6 +57,21 @@ bool isPipeUsedInRegion(Region ®ion, Attribute targetPipe) { // 向后扫描:检查 targetPipe 在当前 Block 后续是否"真正"活跃 // WaitOp 不再被视为活跃标志。 // 如果一个 Pipe 后面只剩 Wait,说明它已经完成了工作,发给它的信号是多余的。 +static bool hasPipelineActivityAfterOp(Operation *parentOp, Attribute targetPipe) { + Block *parentBlock = parentOp ? parentOp->getBlock() : nullptr; + if (!parentBlock) + return false; + for (auto it = std::next(parentOp->getIterator()); it != parentBlock->end(); ++it) { + if (isResourceOp(&*it, targetPipe)) + return true; + if (it->getNumRegions() > 0) + return true; + if (isa(&*it)) + return false; + } + return false; +} + bool isPipelineActiveFuture(Block *block, Block::iterator startIt, Attribute targetPipe) { for (auto it = startIt; it != block->end(); ++it) { Operation *op = &*it; @@ -81,25 +91,7 @@ bool isPipelineActiveFuture(Block *block, Block::iterator startIt, Attribute tar if (op->hasTrait()) { // 如果是 Return,肯定死了 if (isa(op)) return false; - - // 如果是 Yield (scf.if / scf.for),我们需要看 Parent Block 的后续 - // 这是一个简单的单层 Lookahead,防止 Set 被误删 - if (auto parentOp = block->getParentOp()) { - Block *parentBlock = parentOp->getBlock(); - if (parentBlock) { - // 从 Parent Op 的下一条指令开始查 - for (auto pIt = std::next(parentOp->getIterator()); pIt != parentBlock->end(); ++pIt) { - if (isResourceOp(&*pIt, targetPipe)) return true; - - // 如果外面还有嵌套,理论上要继续递归,这里保守返回 true - if (pIt->getNumRegions() > 0) return true; - - // 如果遇到 Return,说明后面真没了 - if (isa(&*pIt)) return false; - } - } - } - return false; // 如果 Parent Block 后面也没东西,那就是死了 + return hasPipelineActivityAfterOp(block->getParentOp(), targetPipe); } } return false; diff --git a/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp b/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp index 95fc17936..39a8a7a1d 100644 --- a/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp +++ b/lib/PTO/Transforms/PTOResolveReservedBuffersPass.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" @@ -57,6 +52,9 @@ struct PipeInitInfo { int8_t dirMask = 0; }; +using PipeInitGroups = std::map>; +using PipeParticipants = std::map>; + template static Value getLocalAddrOperand(InitOpT op) { // Hide the concrete init-op type and expose the local address operand // through one helper used by the shared peer-grouping logic. @@ -123,54 +121,105 @@ static bool hasCompletePeerInitPair(const SmallVector &inits, return initFuncs.size() == 2; } +template +static LogicalResult collectPeerAwareInit(InitOpT initOp, + PipeInitGroups &keyedInits, + PipeParticipants &keyedParticipants) { + PipeInitInfo info; + info.op = initOp.getOperation(); + info.funcOp = initOp->template getParentOfType(); + info.dirMask = initOp.getDirMask(); + + auto recordAddr = [&](Value addr, int8_t effectiveDirMask) { + auto key = getPipePeerKey(addr, info.funcOp); + if (!key) + return false; + key->dirMask = effectiveDirMask; + keyedInits[*key].push_back(info); + keyedParticipants[*key].insert(getFuncSymbol(info.funcOp)); + keyedParticipants[*key].insert(key->ownerFunc); + return true; + }; + + bool recorded = false; + if (info.dirMask == 3) { + Value peerAddr = initOp.getPeerLocalAddr(); + recorded = recordAddr(getLocalAddrOperand(initOp), /*c2v=*/1); + recorded = (peerAddr && recordAddr(peerAddr, /*v2c=*/2)) || recorded; + } else { + recorded = recordAddr(getLocalAddrOperand(initOp), info.dirMask); + } + + if (recorded || getFlagBaseAttr(initOp)) + return success(); + + return initOp.emitOpError( + "requires local_addr to come from pto.reserve_buffer or " + "pto.import_reserved_buffer when 'flag_base' is not explicit"); +} + +static LogicalResult validatePeerInitGroups(const PipeInitGroups &keyedInits, + const PipeParticipants &keyedParticipants) { + for (const auto &it : keyedInits) { + if (hasCompletePeerInitPair(it.second, keyedParticipants.at(it.first))) + continue; + return it.second.front().op->emitOpError( + "requires a complete peer init pair when local_addr comes from " + "pto.reserve_buffer or pto.import_reserved_buffer"); + } + return success(); +} + +static FailureOr chooseFlagBaseForPeerGroup( + const SmallVector &inits) { + std::optional chosenBase; + for (const PipeInitInfo &info : inits) { + IntegerAttr flagBaseAttr; + if (auto initOp = dyn_cast(info.op)) + flagBaseAttr = getFlagBaseAttr(initOp); + else + flagBaseAttr = getFlagBaseAttr(cast(info.op)); + + if (!flagBaseAttr) + continue; + if (chosenBase && *chosenBase != flagBaseAttr.getInt()) { + return info.op->emitOpError( + "conflicting explicit flag_base across peer pipe inits"); + } + chosenBase = flagBaseAttr.getInt(); + } + return chosenBase.value_or(0); +} + +static void assignMissingFlagBases(const SmallVector &inits, + IntegerAttr flagBaseAttr) { + for (const PipeInitInfo &info : inits) { + if (auto initOp = dyn_cast(info.op)) { + if (!getFlagBaseAttr(initOp)) + setFlagBaseAttr(initOp, flagBaseAttr); + continue; + } + + auto initOp = cast(info.op); + if (!getFlagBaseAttr(initOp)) + setFlagBaseAttr(initOp, flagBaseAttr); + } +} + struct PTOResolveReservedBuffersPass : public mlir::pto::impl::PTOResolveReservedBuffersBase< PTOResolveReservedBuffersPass> { LogicalResult assignPeerAwareFlagBases(ModuleOp moduleOp) { // Group internal pipe init ops by their logical pipe identity, then fill // missing flag_base attrs so both sides of the same logical pipe agree. - std::map> keyedInits; - std::map> keyedParticipants; + PipeInitGroups keyedInits; + PipeParticipants keyedParticipants; LogicalResult status = success(); auto collectInit = [&](auto initOp) { if (failed(status)) return; - PipeInitInfo info; - info.op = initOp.getOperation(); - info.funcOp = initOp->template getParentOfType(); - info.dirMask = initOp.getDirMask(); - - // Record one address into the keyed maps. Returns true when the - // address comes from reserve_buffer / import_reserved_buffer. - auto recordAddr = [&](Value addr, int8_t effectiveDirMask) -> bool { - auto key = getPipePeerKey(addr, info.funcOp); - if (!key) - return false; - key->dirMask = effectiveDirMask; - keyedInits[*key].push_back(info); - keyedParticipants[*key].insert(getFuncSymbol(info.funcOp)); - keyedParticipants[*key].insert(key->ownerFunc); - return true; - }; - - if (info.dirMask == 3) { - // DIR_BOTH: treat as two logical pipes keyed by direction. - bool c2vOk = recordAddr(getLocalAddrOperand(initOp), /*c2v=*/1); - Value peerAddr = initOp.getPeerLocalAddr(); - bool v2cOk = peerAddr && recordAddr(peerAddr, /*v2c=*/2); - if (c2vOk || v2cOk) - return; - } else { - Value localAddr = getLocalAddrOperand(initOp); - if (recordAddr(localAddr, info.dirMask)) - return; - } - if (getFlagBaseAttr(initOp)) - return; - status = initOp.emitOpError( - "requires local_addr to come from pto.reserve_buffer or " - "pto.import_reserved_buffer when 'flag_base' is not explicit"); + status = collectPeerAwareInit(initOp, keyedInits, keyedParticipants); }; moduleOp.walk([&](InitializeL2LPipeOp initOp) { collectInit(initOp); }); @@ -178,59 +227,17 @@ struct PTOResolveReservedBuffersPass if (failed(status)) return failure(); - for (const auto &it : keyedInits) { - if (hasCompletePeerInitPair(it.second, keyedParticipants[it.first])) - continue; - return it.second.front().op->emitOpError( - "requires a complete peer init pair when local_addr comes from " - "pto.reserve_buffer or pto.import_reserved_buffer"); - } + if (failed(validatePeerInitGroups(keyedInits, keyedParticipants))) + return failure(); OpBuilder builder(moduleOp.getContext()); for (const auto &it : keyedInits) { const auto &inits = it.second; - // flag_base is always 0: single-direction pipes use flag pair 0/1; - // DIR_BOTH pipes internally manage 0/1 for C2V and 2/3 for V2C. - int32_t desiredBase = 0; - - std::optional chosenBase; - for (const PipeInitInfo &info : inits) { - // Respect any explicit flag_base already present on one side, but make - // sure all peers resolve to the same value. - if (auto initOp = dyn_cast(info.op)) { - if (auto flagBaseAttr = getFlagBaseAttr(initOp)) { - if (chosenBase && *chosenBase != flagBaseAttr.getInt()) { - return info.op->emitOpError( - "conflicting explicit flag_base across peer pipe inits"); - } - chosenBase = flagBaseAttr.getInt(); - } - continue; - } - - auto initOp = cast(info.op); - if (auto flagBaseAttr = getFlagBaseAttr(initOp)) { - if (chosenBase && *chosenBase != flagBaseAttr.getInt()) { - return info.op->emitOpError( - "conflicting explicit flag_base across peer pipe inits"); - } - chosenBase = flagBaseAttr.getInt(); - } - } - if (!chosenBase) - chosenBase = desiredBase; - - auto flagBaseAttr = builder.getI32IntegerAttr(*chosenBase); - for (const PipeInitInfo &info : inits) { - if (auto initOp = dyn_cast(info.op)) { - if (!getFlagBaseAttr(initOp)) - setFlagBaseAttr(initOp, flagBaseAttr); - continue; - } - auto initOp = cast(info.op); - if (!getFlagBaseAttr(initOp)) - setFlagBaseAttr(initOp, flagBaseAttr); - } + auto chosenBaseOr = chooseFlagBaseForPeerGroup(inits); + if (failed(chosenBaseOr)) + return failure(); + auto flagBaseAttr = builder.getI32IntegerAttr(*chosenBaseOr); + assignMissingFlagBases(inits, flagBaseAttr); } return success(); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index ae9756216..c203ae46b 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -265,7 +265,7 @@ class PTOToEmitCTypeConverter : public TypeConverter { return emitc::OpaqueType::get(Ctx, "pto::MrgSortExecutedNumList"); return Type{}; }); - + // --------------------------------------------------------- // 2. PTO 特殊类型 (透传或转换) // --------------------------------------------------------- @@ -306,7 +306,7 @@ class PTOToEmitCTypeConverter : public TypeConverter { std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; return emitc::OpaqueType::get(Ctx, tok); }); - + addConversion([Ctx](pto::AsyncSessionType type) -> Type { (void)type; return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); @@ -1403,6 +1403,11 @@ struct ArithBitcastToEmitC : public OpConversionPattern { struct ArithCmpFToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + struct CmpFConfig { + bool unordered = false; + emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; + }; + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, Value v) { return rewriter @@ -1419,114 +1424,99 @@ struct ArithCmpFToEmitC : public OpConversionPattern { .getResult(); } - LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getLhs().getType())) - return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); - - auto loc = op.getLoc(); - auto i1Ty = rewriter.getI1Type(); - - bool unordered = false; - emitc::CmpPredicate pred = emitc::CmpPredicate::eq; - - switch (op.getPredicate()) { - case arith::CmpFPredicate::AlwaysFalse: { - auto cst = makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); - rewriter.replaceOp(op, cst); - return success(); - } - case arith::CmpFPredicate::AlwaysTrue: { - auto cst = makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); - rewriter.replaceOp(op, cst); - return success(); + static std::optional buildSpecialCmpFResult( + arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); + case arith::CmpFPredicate::AlwaysTrue: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); + case arith::CmpFPredicate::ORD: + return rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), + isNotNaN(rewriter, loc, rhs)) + .getResult(); + case arith::CmpFPredicate::UNO: + return rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), + isNaN(rewriter, loc, rhs)) + .getResult(); + default: + return std::nullopt; } + } + + static std::optional + getCmpFConfig(arith::CmpFPredicate predicate) { + switch (predicate) { case arith::CmpFPredicate::OEQ: - unordered = false; - pred = emitc::CmpPredicate::eq; - break; + return CmpFConfig{false, emitc::CmpPredicate::eq}; case arith::CmpFPredicate::OGT: - unordered = false; - pred = emitc::CmpPredicate::gt; - break; + return CmpFConfig{false, emitc::CmpPredicate::gt}; case arith::CmpFPredicate::OGE: - unordered = false; - pred = emitc::CmpPredicate::ge; - break; + return CmpFConfig{false, emitc::CmpPredicate::ge}; case arith::CmpFPredicate::OLT: - unordered = false; - pred = emitc::CmpPredicate::lt; - break; + return CmpFConfig{false, emitc::CmpPredicate::lt}; case arith::CmpFPredicate::OLE: - unordered = false; - pred = emitc::CmpPredicate::le; - break; + return CmpFConfig{false, emitc::CmpPredicate::le}; case arith::CmpFPredicate::ONE: - unordered = false; - pred = emitc::CmpPredicate::ne; - break; - case arith::CmpFPredicate::ORD: { - Value ordered = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, adaptor.getLhs()), - isNotNaN(rewriter, loc, adaptor.getRhs())); - rewriter.replaceOp(op, ordered); - return success(); - } + return CmpFConfig{false, emitc::CmpPredicate::ne}; case arith::CmpFPredicate::UEQ: - unordered = true; - pred = emitc::CmpPredicate::eq; - break; + return CmpFConfig{true, emitc::CmpPredicate::eq}; case arith::CmpFPredicate::UGT: - unordered = true; - pred = emitc::CmpPredicate::gt; - break; + return CmpFConfig{true, emitc::CmpPredicate::gt}; case arith::CmpFPredicate::UGE: - unordered = true; - pred = emitc::CmpPredicate::ge; - break; + return CmpFConfig{true, emitc::CmpPredicate::ge}; case arith::CmpFPredicate::ULT: - unordered = true; - pred = emitc::CmpPredicate::lt; - break; + return CmpFConfig{true, emitc::CmpPredicate::lt}; case arith::CmpFPredicate::ULE: - unordered = true; - pred = emitc::CmpPredicate::le; - break; + return CmpFConfig{true, emitc::CmpPredicate::le}; case arith::CmpFPredicate::UNE: - unordered = true; - pred = emitc::CmpPredicate::ne; - break; - case arith::CmpFPredicate::UNO: { - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, adaptor.getLhs()), - isNaN(rewriter, loc, adaptor.getRhs())); - rewriter.replaceOp(op, unord); - return success(); - } + return CmpFConfig{true, emitc::CmpPredicate::ne}; + default: + return std::nullopt; } + } + static Value buildCmpFResult(const CmpFConfig &config, + ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { Value cmp = rewriter - .create(loc, i1Ty, pred, adaptor.getLhs(), - adaptor.getRhs()) + .create(loc, i1Ty, config.predicate, lhs, rhs) .getResult(); - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, adaptor.getLhs()), - isNaN(rewriter, loc, adaptor.getRhs())); + loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); + if (config.unordered) + return rewriter + .create(loc, i1Ty, unord, cmp) + .getResult(); Value ord = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, adaptor.getLhs()), - isNotNaN(rewriter, loc, adaptor.getRhs())); + loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); + return rewriter + .create(loc, i1Ty, ord, cmp) + .getResult(); + } - if (unordered) { - Value res = - rewriter.create(loc, i1Ty, unord, cmp).getResult(); - rewriter.replaceOp(op, res); + LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); + + auto loc = op.getLoc(); + auto i1Ty = rewriter.getI1Type(); + if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, + i1Ty, adaptor.getLhs(), + adaptor.getRhs())) { + rewriter.replaceOp(op, *special); return success(); } - Value res = - rewriter.create(loc, i1Ty, ord, cmp).getResult(); - rewriter.replaceOp(op, res); + auto config = getCmpFConfig(op.getPredicate()); + if (!config) + return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); + rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, + adaptor.getLhs(), adaptor.getRhs())); return success(); } }; @@ -1881,102 +1871,114 @@ struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, ArithFloatMinMaxToEmitCBase { using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return rewriter.notifyMatchFailure(op, "expected scalar float type"); - - auto loc = op.getLoc(); - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - // Basic compare-based min/max. - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value candidate = rewriter - .create( - loc, dstTy, cmpLt, - isMaximum ? adaptor.getRhs() : adaptor.getLhs(), - isMaximum ? adaptor.getLhs() : adaptor.getRhs()) - .getResult(); - - // Fix signed zero tie-breaking for equal zeros. - Value zero = makeFZero(rewriter, loc, dstTy); - Value eq = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value lhsZero = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, - adaptor.getLhs(), zero) - .getResult(); - Value bothZero = rewriter - .create(loc, rewriter.getI1Type(), - eq, lhsZero) - .getResult(); + static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs) { + Value cmpLt = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhs, rhs) + .getResult(); + return rewriter + .create( + loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) + .getResult(); + } - auto floatTy = cast(op.getType()); - auto bitsTy = getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - cast(bitsTy).getValue())}); + static Value buildSignBitValue(ConversionPatternRewriter &rewriter, + Location loc, Value lhs, FloatType floatTy) { + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + rewriter.getContext(), cast(bitsTy).getValue())}); Value lhsBits = rewriter .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", - ValueRange{adaptor.getLhs()}, - /*args=*/ArrayAttr{}, - /*template_args=*/templateArgs) + ValueRange{lhs}, ArrayAttr{}, + templateArgs) .getResult(0); - Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); - Value shAmt = makeEmitCIntConstant(rewriter, loc, bitsTy, - floatTy.getWidth() - 1); + Value shiftAmount = + makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); Value signMask = rewriter .create(loc, bitsTy, oneBits, - shAmt) + shiftAmount) .getResult(); - Value signBit = rewriter - .create(loc, bitsTy, lhsBits, signMask) + return rewriter + .create(loc, bitsTy, lhsBits, signMask) + .getResult(); + } + + static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value zero = makeFZero(rewriter, loc, dstTy); + Value equal = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, rhs) + .getResult(); + Value lhsZero = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, + zero) .getResult(); + Value bothZero = rewriter + .create(loc, rewriter.getI1Type(), + equal, lhsZero) + .getResult(); + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); Value lhsIsNegZero = rewriter .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, signBit, zeroBits) + emitc::CmpPredicate::ne, + buildSignBitValue(rewriter, loc, lhs, floatTy), + zeroBits) .getResult(); + Value tie = rewriter + .create( + loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, + isMaximum ? lhs : rhs) + .getResult(); + return rewriter + .create(loc, dstTy, bothZero, tie, + buildPrimaryCandidate(rewriter, loc, dstTy, + lhs, rhs)) + .getResult(); + } - Value tie = - rewriter - .create( - loc, dstTy, lhsIsNegZero, - isMaximum ? adaptor.getRhs() : adaptor.getLhs(), - isMaximum ? adaptor.getLhs() : adaptor.getRhs()) - .getResult(); - Value noNaN = rewriter - .create(loc, dstTy, bothZero, tie, - candidate) - .getResult(); - - // Propagate NaN: if lhs is NaN return lhs, else if rhs is NaN return rhs. + static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value lhsNaN = isNaN(rewriter, loc, lhs); + Value rhsNaN = isNaN(rewriter, loc, rhs); + Value noNaN = + buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); Value rhsOrNoNaN = rewriter - .create(loc, dstTy, rhsNaN, - adaptor.getRhs(), noNaN) + .create(loc, dstTy, rhsNaN, rhs, + noNaN) .getResult(); - Value res = rewriter - .create(loc, dstTy, lhsNaN, - adaptor.getLhs(), rhsOrNoNaN) - .getResult(); - rewriter.replaceOp(op, res); + return rewriter + .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) + .getResult(); + } + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected scalar float type"); + + auto loc = op.getLoc(); + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto floatTy = cast(op.getType()); + rewriter.replaceOp(op, buildNaNPropagatingResult( + rewriter, loc, dstTy, adaptor.getLhs(), + adaptor.getRhs(), floatTy)); return success(); } }; @@ -2569,27 +2571,50 @@ struct FuncToEmitC : public OpConversionPattern { enum class Role { A, B, C, Unknown }; -static Role inferSubviewRole(memref::SubViewOp sv) { - for (Operation *u : sv.getResult().getUsers()) { - if (auto ld = dyn_cast(u)) { - Value ub = ld.getDst(); - if (!ub) continue; - for (Operation *uu : ub.getUsers()) { - if (auto mm = dyn_cast(uu)) { - if (mm.getLhs() == ub) return Role::A; - if (mm.getRhs() == ub) return Role::B; - } - if (auto mmacc = dyn_cast(uu)) { - if (mmacc.getLhs() == ub) return Role::A; - if (mmacc.getRhs() == ub) return Role::B; - } - } - } +template +static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, + Value buffer) { + if (op.getLhs() == buffer) + return Role::A; + if (op.getRhs() == buffer) + return Role::B; + return std::nullopt; +} - if (auto st = dyn_cast(u)) { - if (st.getDst() == sv.getResult()) return Role::C; +static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { + Value buffer = load.getDst(); + if (!buffer) + return std::nullopt; + for (Operation *user : buffer.getUsers()) { + if (auto matmul = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) + return role; + continue; + } + if (auto matmulAcc = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) + return role; } } + return std::nullopt; +} + +static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { + if (auto load = dyn_cast(user)) + return inferSubviewRoleFromLoadUser(load); + if (auto store = dyn_cast(user)) { + if (store.getDst() == result) + return Role::C; + } + return std::nullopt; +} + +static Role inferSubviewRole(memref::SubViewOp sv) { + Value result = sv.getResult(); + for (Operation *user : result.getUsers()) { + if (auto role = inferSubviewRoleFromUser(user, result)) + return *role; + } return Role::Unknown; } @@ -3094,159 +3119,212 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } }; -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, - Operation *anchor) { - auto *ctx = rewriter.getContext(); +//===----------------------------------------------------------------------===// +// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) +//===----------------------------------------------------------------------===// - // Only handle fully static shapes/strides for now. - auto shape = mrTy.getShape(); - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return Value(); +static std::string getElemTypeStringForGT(Type elemTy) { + if (elemTy.isF16()) return "half"; + if (elemTy.isBF16()) return "bfloat16_t"; + if (elemTy.isF32()) return "float"; + if (elemTy.isF64()) return "double"; + if (elemTy.isInteger(8)) { + if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) + return "int8_t"; + return "uint8_t"; } - - SmallVector strides; - int64_t offset = 0; - if (failed(getStridesAndOffset(mrTy, strides, offset))) { - // Fallback: compact row-major - strides.resize(shape.size()); - int64_t s = 1; - for (int i = (int)shape.size() - 1; i >= 0; --i) { - strides[i] = s; - s *= shape[i]; - } - offset = 0; + if (elemTy.isInteger(16)) { + if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + return "int16_t"; + return "uint16_t"; } - if (offset == ShapedType::kDynamic) - return Value(); - for (int64_t s : strides) { - if (s == ShapedType::kDynamic) - return Value(); + if (elemTy.isInteger(32)) { + if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + return "int32_t"; + return "uint32_t"; + } + if (elemTy.isInteger(64)) { + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; } + return "float"; +} - // Apply static base offset if needed. - Value ptr = basePtr; - if (offset != 0) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto offVal = rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); - ptr = rewriter.create(loc, basePtr.getType(), basePtr, - offVal); +static bool hasStaticShape(MemRefType mrTy) { + return llvm::none_of(mrTy.getShape(), [](int64_t dim) { + return dim == ShapedType::kDynamic; + }); +} + +static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, + int64_t &offset) { + if (failed(getStridesAndOffset(mrTy, strides, offset))) { + strides.clear(); + int64_t stride = 1; + ArrayRef shape = mrTy.getShape(); + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides.push_back(stride); + stride *= shape[i]; + } + std::reverse(strides.begin(), strides.end()); + offset = 0; } + return offset != ShapedType::kDynamic && + llvm::none_of(strides, [](int64_t strideValue) { + return strideValue == ShapedType::kDynamic; + }); +} - std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); - std::string shapeTypeName = "GTShape" + suffix; - std::string strideTypeName = "GTStride" + suffix; - std::string gtTypeName = "GT" + suffix; +static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + int64_t offset) { + if (offset == 0) + return basePtr; + auto *ctx = rewriter.getContext(); + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto offVal = rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); + return rewriter.create(loc, basePtr.getType(), basePtr, offVal); +} - std::string elemTypeStr = getEmitCScalarTypeToken(mrTy.getElementType()); +static int getGlobalTensorElementBytes(StringRef elemTypeStr) { + if (elemTypeStr.contains("half") || elemTypeStr.contains("bf16")) + return 2; + if (elemTypeStr.contains("double")) + return 8; + return 4; +} - SmallVector shapeParamsVec; - SmallVector strideParamsVec; - for (int i = 0, e = (int)shape.size(); i < e; ++i) { - shapeParamsVec.push_back(std::to_string(shape[i])); - strideParamsVec.push_back(std::to_string(strides[i])); - } +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { + if (lhs < 0 || rhs < 0) + return -1; + return lhs * rhs; +} - // Right-align to 5D (pad leading dims with 1). - SmallVector finalShape(5, "1"); - SmallVector finalStride(5, "1"); - int rank = (int)shape.size(); +static void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D) { + shape5D.assign(5, 1); + stride5D.assign(5, 1); + int rank = static_cast(shape.size()); int shift = 5 - rank; for (int i = 0; i < rank && i < 5; ++i) { - finalShape[shift + i] = shapeParamsVec[i]; - finalStride[shift + i] = strideParamsVec[i]; - } - auto mulOrDyn = [](const std::string &a, const std::string &b) -> std::string { - if (a == "-1" || b == "-1") - return "-1"; - int64_t va = 1, vb = 1; - (void)llvm::to_integer(a, va); - (void)llvm::to_integer(b, vb); - return std::to_string(va * vb); - }; + shape5D[shift + i] = shape[i]; + stride5D[shift + i] = strides[i]; + } for (int i = 3; i >= 0; --i) { if (i >= shift) continue; - finalStride[i] = mulOrDyn(finalShape[i + 1], finalStride[i + 1]); + stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); } +} - auto joinParams = [](llvm::ArrayRef vec) { - std::string out; - for (size_t i = 0; i < vec.size(); ++i) { - if (i > 0) out += ", "; - out += vec[i]; - } - return out; +static std::string joinIntTemplateParams(ArrayRef values) { + std::string result; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) + result += ", "; + result += std::to_string(values[i]); + } + return result; +} + +static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, + ArrayRef stride5D, + StringRef elemTypeStr) { + int elemBytes = getGlobalTensorElementBytes(elemTypeStr); + if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && + stride5D[4] == 1 && stride5D[3] == shape5D[4]) { + return "pto::Layout::NZ"; + } + + bool isRowMajor = stride5D[4] == 1; + for (int i = 3; i >= 0 && isRowMajor; --i) + isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); + + bool isColMajor = stride5D[0] == 1; + for (int i = 0; i < 4 && isColMajor; ++i) + isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); + + if (isColMajor) + return "pto::Layout::DN"; + return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; +} + +static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, + ArrayRef shape5D, + ArrayRef stride5D, + StringRef elemTypeStr) { + if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) + return layoutToEmitCString(*layout); + return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTypeStr); +} + +struct GlobalTensorTypeNames { + std::string shapeTypeName; + std::string strideTypeName; + std::string tensorTypeName; + std::string layoutConstName; +}; + +static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { + std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); + return { + "GTShape" + suffix, + "GTStride" + suffix, + "GT" + suffix, + "GT" + suffix + "_layout", }; +} +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, + Operation *anchor) { + auto *ctx = rewriter.getContext(); - std::string shapeParams = joinParams(finalShape); - std::string strideParams = joinParams(finalStride); + ArrayRef shape = mrTy.getShape(); + if (!hasStaticShape(mrTy)) + return Value(); + + SmallVector strides; + int64_t offset = 0; + if (!getStaticMemrefLayout(mrTy, strides, offset)) + return Value(); + + Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); + GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); + std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); rewriter.create( - loc, "using " + shapeTypeName + " = pto::Shape<" + shapeParams + ">;"); - rewriter.create( - loc, "using " + strideTypeName + " = pto::Stride<" + strideParams + ">;"); - - // Layout: prefer the attribute from InferPTOLayout; only fall back to local - // inference when the pass is disabled. - std::string layoutEnum = "pto::Layout::ND"; - bool hasLayoutAttr = false; - if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) { - layoutEnum = layoutToEmitCString(*layout); - hasLayoutAttr = true; - } - if (!hasLayoutAttr) { - SmallVector shapeInt(5, -1), strideInt(5, -1); - for (int i = 0; i < 5; ++i) { - (void)llvm::to_integer(finalShape[i], shapeInt[i]); - (void)llvm::to_integer(finalStride[i], strideInt[i]); - } - int layoutTag = 0; // ND - int elemBytes = 4; - if (elemTypeStr.find("half") != std::string::npos || - elemTypeStr.find("bf16") != std::string::npos) - elemBytes = 2; - else if (elemTypeStr.find("double") != std::string::npos) - elemBytes = 8; - if (shapeInt[2] == 16 && shapeInt[2] * shapeInt[3] * elemBytes == 512 && - strideInt[4] == 1 && strideInt[3] == shapeInt[4]) { - layoutTag = 2; // NZ - } else { - bool isRow = strideInt[4] == 1; - for (int i = 3; i >= 0; --i) - isRow &= (strideInt[i] == strideInt[i + 1] * shapeInt[i + 1]); - bool isCol = strideInt[0] == 1; - for (int i = 0; i < 4; ++i) - isCol &= (strideInt[i + 1] == strideInt[i] * shapeInt[i]); - if (isCol) layoutTag = 1; // DN - else layoutTag = isRow ? 0 : 0; // fallback ND - } - if (layoutTag == 1) - layoutEnum = "pto::Layout::DN"; - else if (layoutTag == 2) - layoutEnum = "pto::Layout::NZ"; - } - std::string layoutConstName = gtTypeName + "_layout"; + loc, "using " + names.shapeTypeName + " = pto::Shape<" + + joinIntTemplateParams(shape5D) + ">;"); rewriter.create( - loc, "constexpr pto::Layout " + layoutConstName + " = " + layoutEnum + ";"); + loc, "using " + names.strideTypeName + " = pto::Stride<" + + joinIntTemplateParams(stride5D) + ">;"); - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeTypeName); - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideTypeName); + std::string layoutEnum = resolveGlobalTensorLayout(anchor, basePtr, shape5D, + stride5D, elemTypeStr); + rewriter.create(loc, "constexpr pto::Layout " + + names.layoutConstName + " = " + + layoutEnum + ";"); + + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); auto shapeInstOp = rewriter.create( - loc, shapeTypeOpaque, shapeTypeName, ArrayAttr{}, ArrayAttr{}, + loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, ValueRange{}); auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, strideTypeName, ArrayAttr{}, ArrayAttr{}, + loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, ValueRange{}); rewriter.create( - loc, "using " + gtTypeName + " = GlobalTensor<" + elemTypeStr + ", " + - shapeTypeName + ", " + strideTypeName + ", " + - layoutConstName + ">;"); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeName); + loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + + ", " + names.shapeTypeName + ", " + names.strideTypeName + + ", " + names.layoutConstName + ">;"); + auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); SmallVector gtArgs; gtArgs.push_back(ptr); @@ -3254,7 +3332,8 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, gtArgs.push_back(strideInstOp.getResult(0)); auto gtInst = rewriter.create( - loc, gtType, gtTypeName, ArrayAttr{}, ArrayAttr{}, ValueRange(gtArgs)); + loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange(gtArgs)); return gtInst.getResult(0); } @@ -4154,98 +4233,122 @@ struct PTOBarrierToEmitC : public OpConversionPattern { // Replace your PTOSyncToRuntimeCall with the code below. //===----------------------------------------------------------------------===// -static LogicalResult extractSyncTripletTokens(Operation *op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - auto *ctx = rewriter.getContext(); - - auto pipeToTok = [](mlir::Attribute a, std::string &out) -> bool { - if (!a) return false; - if (auto p = dyn_cast(a)) { - out = mlir::pto::stringifyPIPE(p.getPipe()).str(); - return true; - } - if (auto s = dyn_cast(a)) { - out = s.getValue().str(); // expects already like "PIPE_MTE2" - return true; - } +static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { + if (!attr) return false; - }; + if (auto pipe = dyn_cast(attr)) { + token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} - auto evtToTok = [](mlir::Attribute a, std::string &out) -> bool { - if (!a) return false; - if (auto e = dyn_cast(a)) { - out = mlir::pto::stringifyEVENT(e.getEvent()).str(); - return true; - } - if (auto s = dyn_cast(a)) { - out = s.getValue().str(); // expects already like "EVENT_ID0" - return true; - } +static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { + if (!attr) return false; - }; - - auto tryNamed = [&](StringRef s0, StringRef s1, StringRef e0) -> bool { - std::string st, dt, et; - if (!pipeToTok(op->getAttr(s0), st)) return false; - if (!pipeToTok(op->getAttr(s1), dt)) return false; - if (!evtToTok(op->getAttr(e0), et)) return false; - srcTok = std::move(st); - dstTok = std::move(dt); - evtTok = std::move(et); + if (auto event = dyn_cast(attr)) { + token = mlir::pto::stringifyEVENT(event.getEvent()).str(); return true; - }; - - // 1) Most common named-attr encodings - if (tryNamed("src_pipe", "dst_pipe", "event_id")) return success(); - if (tryNamed("srcPipe", "dstPipe", "eventId")) return success(); - if (tryNamed("src", "dst", "event")) return success(); - - // 2) Bracket-form / custom-asm often packs them into an ArrayAttr under some key - auto tryArrayKey = [&](StringRef key) -> bool { - auto arr = op->getAttrOfType(key); - if (!arr || arr.size() < 3) return false; - - std::string st, dt, et; - if (!pipeToTok(arr[0], st)) return false; - if (!pipeToTok(arr[1], dt)) return false; - if (!evtToTok(arr[2], et)) return false; - srcTok = std::move(st); - dstTok = std::move(dt); - evtTok = std::move(et); + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); return true; - }; + } + return false; +} - if (tryArrayKey("args") || tryArrayKey("pipes") || tryArrayKey("sync") || - tryArrayKey("triplet") || tryArrayKey("attrs")) - return success(); +static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, + Attribute evtAttr, std::string &srcTok, + std::string &dstTok, std::string &evtTok) { + std::string localSrc; + std::string localDst; + std::string localEvt; + if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || + !tryConvertPipeAttrToToken(dstAttr, localDst) || + !tryConvertEventAttrToToken(evtAttr, localEvt)) { + return false; + } + srcTok = std::move(localSrc); + dstTok = std::move(localDst); + evtTok = std::move(localEvt); + return true; +} - // 3) Last resort: scan everything and pick 2 Pipe + 1 Event in encounter order. - std::vector pipes; +static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, + StringRef srcName, + StringRef dstName, + StringRef evtName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), + op->getAttr(evtName), srcTok, dstTok, evtTok); +} + +static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + auto arrayAttr = op->getAttrOfType(attrName); + if (!arrayAttr || arrayAttr.size() < 3) + return false; + return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, + dstTok, evtTok); +} + +static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + SmallVector pipes; std::string event; - for (auto &na : op->getAttrs()) { - Attribute a = na.getValue(); - std::string tok; - if (pipeToTok(a, tok)) { - pipes.push_back(std::move(tok)); + for (NamedAttribute namedAttr : op->getAttrs()) { + std::string token; + if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { + pipes.push_back(std::move(token)); continue; } - if (evtToTok(a, tok)) { - event = std::move(tok); - continue; + if (event.empty() && + tryConvertEventAttrToToken(namedAttr.getValue(), token)) { + event = std::move(token); } } + if (pipes.size() < 2 || event.empty()) + return false; + srcTok = pipes[0]; + dstTok = pipes[1]; + evtTok = event; + return true; +} - if (pipes.size() >= 2 && !event.empty()) { - srcTok = pipes[0]; - dstTok = pipes[1]; - evtTok = event; +static LogicalResult extractSyncTripletTokens(Operation *op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, + dstTok, evtTok)) { return success(); } - return rewriter.notifyMatchFailure(op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); + for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { + if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, + evtTok)) { + return success(); + } + } + + if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) + return success(); + return rewriter.notifyMatchFailure( + op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); } static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { return mlir::pto::stringifyPIPE(p).str(); @@ -9044,6 +9147,55 @@ struct SCFIndexSwitchToCF : public OpRewritePattern { return success(); } + static Block *splitBlockForContinuation(PatternRewriter &rewriter, + scf::IndexSwitchOp op) { + auto switchIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + } + + static void addContinuationArguments(PatternRewriter &rewriter, + scf::IndexSwitchOp op, Location loc, + Block *continueBlock) { + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(contArgs[result.index()]); + } + + static void createIndexSwitchBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Region::iterator insertPt, + unsigned numCases, + SmallVectorImpl &checkBlocks, + Block *&defaultBlock, + SmallVectorImpl &caseBlocks) { + checkBlocks.reserve(numCases); + caseBlocks.reserve(numCases); + for (unsigned i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + defaultBlock = rewriter.createBlock(parentRegion, insertPt); + for (unsigned i = 0; i < numCases; ++i) + caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + } + + static void populateIndexSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value selector, + ArrayRef cases, ArrayRef checkBlocks, + ArrayRef caseBlocks, Block *defaultBlock) { + for (unsigned i = 0; i < checkBlocks.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + Value caseVal = rewriter.create(loc, cases[i]); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, selector, caseVal); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; + rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, + falseDest, ValueRange{}); + } + } + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -9055,50 +9207,22 @@ struct SCFIndexSwitchToCF : public OpRewritePattern { Block *curBlock = op->getBlock(); Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the switch results. - auto switchIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(switchIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); + Block *continueBlock = splitBlockForContinuation(rewriter, op); + addContinuationArguments(rewriter, op, loc, continueBlock); unsigned numCases = op.getCases().size(); auto insertPt = continueBlock->getIterator(); SmallVector checkBlocks; SmallVector caseBlocks; - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - - // Create check blocks for each case: check_i compares selector to case_i. - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - - // Create one block for default and one block per case to execute the body. - Block *defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + Block *defaultBlock = nullptr; + createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, + checkBlocks, defaultBlock, caseBlocks); Value selector = op.getArg(); auto cases = op.getCases(); - - // Fill check blocks with chained comparisons. - for (unsigned i = 0; i < numCases; ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = (i + 1 < numCases) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } + populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, + caseBlocks, defaultBlock); // Fill case blocks and default block with cloned bodies + branch to cont. for (unsigned i = 0; i < numCases; ++i) { @@ -9127,85 +9251,100 @@ struct SCFIndexSwitchToCF : public OpRewritePattern { struct SCFWhileToCF : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - - // Only support the common structured form where the while results are used - // in the same block after the op. - for (Value res : op.getResults()) { - for (auto &use : res.getUses()) { - if (use.getOwner()->getBlock() != curBlock) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); + static LogicalResult validateWhileResultUses(scf::WhileOp op) { + Block *parentBlock = op->getBlock(); + for (Value result : op.getResults()) { + for (OpOperand &use : result.getUses()) { + if (use.getOwner()->getBlock() != parentBlock) + return failure(); } } + return success(); + } - auto loc = op.getLoc(); + static Block *splitAfterWhileBlock(PatternRewriter &rewriter, + scf::WhileOp op) { auto whileIt = Block::iterator(op.getOperation()); - Block *afterWhileBlock = rewriter.splitBlock(curBlock, std::next(whileIt)); + return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); + } - // Add block args to carry while results into the continuation block. + static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { SmallVector exitArgs; exitArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(exitArgs[it.index()]); - - // Create the CFG blocks before the continuation block. - Region *parentRegion = curBlock->getParent(); - auto insertPt = afterWhileBlock->getIterator(); + for (Type type : op.getResultTypes()) + exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(exitArgs[result.index()]); + } - // Header block arguments match the while init operands. + static Block *createWhileHeaderBlock(PatternRewriter &rewriter, + scf::WhileOp op, Location loc, + Block *afterWhileBlock) { SmallVector headerArgTypes; - for (Value v : op.getInits()) - headerArgTypes.push_back(v.getType()); + for (Value init : op.getInits()) + headerArgTypes.push_back(init.getType()); SmallVector headerArgLocs(headerArgTypes.size(), loc); - Block *headerBlock = - rewriter.createBlock(parentRegion, insertPt, headerArgTypes, - headerArgLocs); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), headerArgTypes, + headerArgLocs); + } - // Body block arguments match the "after" region arguments. + static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { Block &afterRegionBlock = op.getAfter().front(); SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); + afterRegionBlock.getArgumentTypes().end()); SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - insertPt = afterWhileBlock->getIterator(); - Block *bodyBlock = - rewriter.createBlock(parentRegion, insertPt, bodyArgTypes, bodyArgLocs); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), bodyArgTypes, + bodyArgLocs); + } + + static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, + Block *headerBlock, Block *bodyBlock, + Block *afterWhileBlock) { + auto condOp = cast(headerBlock->getTerminator()); + rewriter.setInsertionPoint(condOp); + rewriter.create(loc, condOp.getCondition(), + /*trueDest=*/bodyBlock, + /*trueOperands=*/condOp.getArgs(), + /*falseDest=*/afterWhileBlock, + /*falseOperands=*/condOp.getArgs()); + rewriter.eraseOp(condOp); + + auto yieldOp = cast(bodyBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(loc, headerBlock, yieldOp.getOperands()); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.while inside a single-block parent region"); + } + + if (failed(validateWhileResultUses(op))) + return rewriter.notifyMatchFailure( + op, "unsupported: while results used outside the parent block"); + + auto loc = op.getLoc(); + Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); + addWhileExitArguments(rewriter, op, loc, afterWhileBlock); + Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, + afterWhileBlock); + Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); // Move the before/after region bodies into the new CFG blocks. + Block &afterRegionBlock = op.getAfter().front(); rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, headerBlock->getArguments()); rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - - // Replace scf.condition in the header with cf.cond_br. - { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - } - - // Replace scf.yield in the body with cf.br back to the header. - { - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); - } + rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, + afterWhileBlock); // Replace scf.while itself with a branch to the header. rewriter.setInsertionPoint(op); @@ -9221,6 +9360,65 @@ struct SCFWhileToCF : public OpRewritePattern { struct CFSwitchToCondBr : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + static SmallVector> + collectSwitchCaseOperands(cf::SwitchOp op) { + SmallVector> caseOperands; + caseOperands.reserve(op.getCaseDestinations().size()); + for (auto range : op.getCaseOperands()) + caseOperands.emplace_back(range.begin(), range.end()); + return caseOperands; + } + + static SmallVector getSwitchCaseValues(cf::SwitchOp op) { + SmallVector caseValues; + if (auto caseValuesAttr = op.getCaseValues()) { + for (APInt value : caseValuesAttr->getValues()) + caseValues.push_back(value); + } + return caseValues; + } + + static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Block *curBlock, + size_t numCases) { + auto insertPt = std::next(curBlock->getIterator()); + SmallVector checkBlocks; + checkBlocks.reserve(numCases); + for (size_t i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + return checkBlocks; + } + + static LogicalResult populateSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, + ArrayRef caseValues, ArrayRef caseDests, + ArrayRef> caseOperands, Block *defaultDest, + ValueRange defaultOperands, ArrayRef checkBlocks, + cf::SwitchOp op) { + for (size_t i = 0; i < caseDests.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + APInt caseVal = caseValues[i]; + if (caseVal.getBitWidth() != flagTy.getWidth()) { + return rewriter.notifyMatchFailure( + op, "case value bitwidth doesn't match flag type"); + } + + Value caseConst = rewriter.create( + loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, caseConst); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; + ValueRange falseOperands = + (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; + rewriter.create(loc, cond, caseDests[i], + caseOperands[i], falseDest, + falseOperands); + } + return success(); + } + LogicalResult matchAndRewrite(cf::SwitchOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -9244,62 +9442,30 @@ struct CFSwitchToCondBr : public OpRewritePattern { SmallVector caseDests(op.getCaseDestinations().begin(), op.getCaseDestinations().end()); - SmallVector> caseOperands; - caseOperands.reserve(caseDests.size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); + SmallVector> caseOperands = collectSwitchCaseOperands(op); if (caseDests.empty()) { rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); return success(); } - std::optional caseValuesAttr = op.getCaseValues(); - if (!caseValuesAttr) + if (!op.getCaseValues()) return rewriter.notifyMatchFailure(op, "missing case_values"); - - SmallVector caseValues; - for (APInt v : caseValuesAttr->getValues()) - caseValues.push_back(v); + SmallVector caseValues = getSwitchCaseValues(op); if (caseValues.size() != caseDests.size()) return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); if (caseOperands.size() != caseDests.size()) return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - // Insert check blocks right after the current block. - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(caseDests.size()); - for (size_t i = 0; i < caseDests.size(); ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - - // Fill each check block with: - // if (flag == caseVal_i) goto caseDest_i else goto nextCheck/default. - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } - - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : ValueRange(defaultOperands); - - rewriter.create(loc, cond, - /*trueDest=*/caseDests[i], - /*trueOperands=*/caseOperands[i], - /*falseDest=*/falseDest, - /*falseOperands=*/falseOperands); + SmallVector checkBlocks = + createSwitchCheckBlocks(rewriter, parentRegion, curBlock, + caseDests.size()); + if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, + caseValues, caseDests, caseOperands, + defaultDest, defaultOperands, + checkBlocks, op))) { + return failure(); } // Replace the switch terminator with a branch into the first check block. diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index de31262ba..f9d77d0a1 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -130,6 +130,13 @@ struct TileLayoutInfo { bool boxed = false; // slayout != NoneBox }; +struct TileLayoutConfig { + int32_t bLayout = 0; + int32_t sLayout = 0; + int32_t fractalSize = 512; + int32_t compactMode = 0; +}; + static int64_t getElemBytes(Type elemTy) { if (auto ft = elemTy.dyn_cast()) { if (ft.isF16() || ft.isBF16()) return 2; @@ -179,152 +186,232 @@ static bool readCompactModeI32(Attribute attr, int32_t &out) { return false; } -static bool getConstIndexValue(Value v, int64_t &out) { - if (auto cOp = v.getDefiningOp()) { - out = cOp.value(); +static Value peelIndexLikeCast(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto truncOp = value.getDefiningOp()) { + value = truncOp.getIn(); + continue; + } + return value; + } +} + +static bool getConstIndexValue(Value value, int64_t &out) { + value = peelIndexLikeCast(value); + if (auto constIndex = value.getDefiningOp()) { + out = constIndex.value(); + return true; + } + if (auto constInt = value.getDefiningOp()) { + out = constInt.value(); return true; } - if (auto cInt = v.getDefiningOp()) { - out = cInt.value(); + auto constOp = value.getDefiningOp(); + auto intAttr = + constOp ? dyn_cast(constOp.getValue()) : IntegerAttr(); + if (!intAttr) + return false; + out = intAttr.getInt(); + return true; +} + +static TileLayoutConfig getTileLayoutConfig(mlir::pto::TileBufConfigAttr cfg) { + TileLayoutConfig config; + (void)readBLayoutI32(cfg.getBLayout(), config.bLayout); + (void)readSLayoutI32(cfg.getSLayout(), config.sLayout); + if (auto attr = dyn_cast(cfg.getSFractalSize())) + config.fractalSize = static_cast(attr.getInt()); + (void)readCompactModeI32(cfg.getCompactMode(), config.compactMode); + return config; +} + +static bool computeBoxInnerShape(const TileLayoutConfig &config, Type elemTy, + TileLayoutInfo &info) { + info.boxed = config.sLayout != 0; + if (!info.boxed) { + info.innerRows = 1; + info.innerCols = 1; return true; } - if (auto cOp = v.getDefiningOp()) { - if (auto ia = dyn_cast(cOp.getValue())) { - out = ia.getInt(); + + int64_t elemBytes = getElemBytes(elemTy); + if (elemBytes <= 0) + return false; + + switch (config.fractalSize) { + case 1024: + info.innerRows = 16; + info.innerCols = 16; + return true; + case 32: + info.innerRows = 16; + info.innerCols = 2; + return true; + case 512: + if (config.sLayout == 1) { + info.innerRows = 16; + info.innerCols = 32 / elemBytes; + return true; + } + if (config.sLayout == 2) { + info.innerRows = 32 / elemBytes; + info.innerCols = 16; return true; } + return false; + default: + return false; } - if (auto castOp = v.getDefiningOp()) - return getConstIndexValue(castOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndexValue(extOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndexValue(extOp.getIn(), out); - if (auto truncOp = v.getDefiningOp()) - return getConstIndexValue(truncOp.getIn(), out); - return false; } -static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, - ArrayRef shape, - TileLayoutInfo &info) { - if (shape.size() != 2) return false; - if (shape[0] == ShapedType::kDynamic || shape[1] == ShapedType::kDynamic) - return false; - +static bool computeTilePointerStrides(const TileLayoutConfig &config, + ArrayRef shape, + TileLayoutInfo &info) { int64_t rows = shape[0]; int64_t cols = shape[1]; - - int32_t bl = 0; // RowMajor - int32_t sl = 0; // NoneBox - int32_t fr = 512; - int32_t compact = 0; // Null - (void)readBLayoutI32(cfg.getBLayout(), bl); - (void)readSLayoutI32(cfg.getSLayout(), sl); - if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); - (void)readCompactModeI32(cfg.getCompactMode(), compact); - - // CompactMode::RowPlusOne means adding one padded element in the major-stride - // dimension (the physically contiguous "row pitch" in row-major, or "column - // pitch" in col-major) to reduce bank conflicts on some vector paths. auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { - if (compact == 2) // CompactMode::RowPlusOne + if (config.compactMode == 2) return majorStride + 1; return majorStride; }; - - // Inner shape - if (sl == 0) { - info.innerRows = 1; - info.innerCols = 1; - info.boxed = false; - } else { - info.boxed = true; - int64_t elemBytes = getElemBytes(elemTy); - if (elemBytes <= 0) return false; - if (fr == 1024) { - info.innerRows = 16; - info.innerCols = 16; - } else if (fr == 32) { - info.innerRows = 16; - info.innerCols = 2; - } else if (fr == 512) { - if (sl == 1) { - info.innerRows = 16; - info.innerCols = 32 / elemBytes; - } else if (sl == 2) { - info.innerRows = 32 / elemBytes; - info.innerCols = 16; - } else { - return false; - } - } else { - return false; - } - } - - // Strides for pointer offset (block-aligned for boxed layouts). - if (sl == 0) { - if (bl == 1) { + if (!info.boxed) { + if (config.bLayout == 1) { info.rowStride = 1; info.colStride = applyCompactToMajorStride(rows); - } else { - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = 1; - } - } else { - if (bl == 1) { - // ColMajor + InnerRowMajor (NZ) is supported. InnerColMajor is unsupported. - if (sl != 1) return false; - info.rowStride = info.innerCols; - info.colStride = applyCompactToMajorStride(rows); - } else { - // RowMajor (ZZ/ZN) - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = info.innerRows; + return true; } + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = 1; + return true; + } + + if (config.bLayout == 1) { + if (config.sLayout != 1) + return false; + info.rowStride = info.innerCols; + info.colStride = applyCompactToMajorStride(rows); + return true; } + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = info.innerRows; return true; } -// Helper: 递归拆解 AffineExpr -static void flattenAddExpr(AffineExpr expr, SmallVectorImpl &terms) { - if (auto add = expr.dyn_cast()) { - if (add.getKind() == AffineExprKind::Add) { - flattenAddExpr(add.getLHS(), terms); - flattenAddExpr(add.getRHS(), terms); - return; +static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, + ArrayRef shape, + TileLayoutInfo &info) { + if (shape.size() != 2 || llvm::is_contained(shape, ShapedType::kDynamic)) + return false; + + TileLayoutConfig config = getTileLayoutConfig(cfg); + return computeBoxInnerShape(config, elemTy, info) && + computeTilePointerStrides(config, shape, info); +} + +static void collectAffineAddTerms(AffineExpr root, + SmallVectorImpl &terms) { + SmallVector pending{root}; + while (!pending.empty()) { + AffineExpr current = pending.pop_back_val(); + auto addExpr = current.dyn_cast(); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { + terms.push_back(current); + continue; } + pending.push_back(addExpr.getRHS()); + pending.push_back(addExpr.getLHS()); + } +} + +static bool tryAssignAffineStride(AffineExpr expr, + MutableArrayRef strides) { + if (auto dim = expr.dyn_cast()) { + strides[dim.getPosition()] = 1; + return true; } - terms.push_back(expr); + + auto mulExpr = expr.dyn_cast(); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + auto assignStride = [&](AffineExpr dimExpr, + AffineExpr constantExpr) -> bool { + auto dim = dimExpr.dyn_cast(); + auto constant = constantExpr.dyn_cast(); + if (!dim || !constant) + return false; + strides[dim.getPosition()] = constant.getValue(); + return true; + }; + return assignStride(mulExpr.getLHS(), mulExpr.getRHS()) || + assignStride(mulExpr.getRHS(), mulExpr.getLHS()); } -// Helper: 从 AffineMap 提取 Strides -static void decomposeStridedLayout(AffineMap map, SmallVectorImpl &strides) { +static void decomposeStridedLayout(AffineMap map, + SmallVectorImpl &strides) { strides.assign(map.getNumDims(), 0); - if (map.getNumResults() != 1) return; - + if (map.getNumResults() != 1) + return; + SmallVector terms; - flattenAddExpr(map.getResult(0), terms); - - for (auto term : terms) { - if (auto mul = term.dyn_cast()) { - if (mul.getKind() == AffineExprKind::Mul) { - AffineExpr lhs = mul.getLHS(); - AffineExpr rhs = mul.getRHS(); - if (auto dim = lhs.dyn_cast()) { - if (auto cst = rhs.dyn_cast()) - strides[dim.getPosition()] = cst.getValue(); - } else if (auto dim = rhs.dyn_cast()) { - if (auto cst = lhs.dyn_cast()) - strides[dim.getPosition()] = cst.getValue(); - } - } - } else if (auto dim = term.dyn_cast()) { - strides[dim.getPosition()] = 1; - } + collectAffineAddTerms(map.getResult(0), terms); + for (AffineExpr term : terms) + (void)tryAssignAffineStride(term, strides); +} + +static Value makeIndexConstant(IRRewriter &rewriter, Location loc, + int64_t value) { + return rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(value)); +} + +static SmallVector computeCompactStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t stride = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = stride; + if (shape[i] != ShapedType::kDynamic) + stride *= shape[i]; } + return strides; +} + +static void materializeStaticValidDims(IRRewriter &rewriter, Location loc, + mlir::pto::TileBufType tbTy, Value &vRow, + Value &vCol) { + ArrayRef validShape = tbTy.getValidShape(); + if (tbTy.hasDynamicValid()) + return; + if (validShape.size() >= 1 && validShape[0] >= 0) + vRow = makeIndexConstant(rewriter, loc, validShape[0]); + if (validShape.size() >= 2 && validShape[1] >= 0) + vCol = makeIndexConstant(rewriter, loc, validShape[1]); +} + +static bool checkMultipleOf(Operation *op, int64_t value, int64_t divisor, + StringRef label) { + if (divisor <= 0) { + op->emitError("boxed layout requires positive divisor for ") << label; + return false; + } + if (value % divisor == 0) + return true; + op->emitError("boxed layout requires ") + << label << " multiple of " << divisor << ", got " << value; + return false; } // 确保 Value 是 Index 类型 @@ -517,822 +604,610 @@ static void markForceDynamicValidShape(Operation *op, bool force, op->removeAttr(kForceDynamicValidShapeAttrName); } -// ============================================================================= -// The Pass Implementation -// ============================================================================= +static void rewriteFunctionSignature(func::FuncOp func, MLIRContext *ctx) { + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); -struct PTOViewToMemrefPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) + SmallVector newInputs; + for (Type type : fnTy.getInputs()) + newInputs.push_back(convertPTOTypeToMemRef(type)); - StringRef getArgument() const final { return "pto-view-to-memref"; } - StringRef getDescription() const final { - return "Lower PTO views to memref with Metadata Binding"; - } + SmallVector newResults; + for (Type type : fnTy.getResults()) + newResults.push_back(convertPTOTypeToMemRef(type)); - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newInputs[i]) + entry.getArgument(i).setType(newInputs[i]); } + func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); +} - void runOnOperation() override { - ModuleOp mod = getOperation(); - MLIRContext *ctx = &getContext(); - - for (auto func : mod.getOps()) { - if (func.isExternal()) continue; - - Block &entry = func.front(); - auto fnTy = func.getFunctionType(); - - // ------------------------------------------------------------------ - // Stage 0: Rewrite Function Signature - // ------------------------------------------------------------------ - SmallVector newInputs; - for (Type t : fnTy.getInputs()) newInputs.push_back(convertPTOTypeToMemRef(t)); - - SmallVector newResults; - for (Type t : fnTy.getResults()) newResults.push_back(convertPTOTypeToMemRef(t)); +static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector allocTiles; + func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); - // Update entry block arguments - for (unsigned i = 0; i < entry.getNumArguments(); ++i) { - if (entry.getArgument(i).getType() != newInputs[i]) { - entry.getArgument(i).setType(newInputs[i]); - } - } + for (auto op : allocTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); - // Update function type - func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) + continue; - // ------------------------------------------------------------------ - // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile - // ------------------------------------------------------------------ - SmallVector allocTiles; - func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); + SmallVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); - for (auto op : allocTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); + SmallVector strides; + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) + strides = {info.rowStride, info.colStride}; + else + strides = computeCompactStrides(shape); + + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + if (Value addr = op.getAddr()) { + auto pc = rewriter.create( + loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); + auto bindOp = rewriter.create( + loc, targetType, pc.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + continue; + } - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) continue; + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); + auto allocType = + MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + Value alloc = rewriter.create(loc, allocType); + auto bindOp = rewriter.create( + loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} - // 1. 获取 Shape 和 ElementType - SmallVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); - Type elemTy = tbTy.getElementType(); +static LogicalResult lowerDeclareTileOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector declaredTiles; + func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); - // 2. 计算 Strides (layout-aware when possible) - SmallVector strides; - TileLayoutInfo info; - if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { - strides = {info.rowStride, info.colStride}; - } else { - strides.resize(shape.size()); - int64_t s = 1; - for (int i = (int)shape.size() - 1; i >= 0; --i) { - strides[i] = s; - if (shape[i] != ShapedType::kDynamic) s *= shape[i]; - } - } + for (auto op : declaredTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); - // 3. 构造 [BindTile 输出] 的动态类型 (Offset: ?) - // 这必须与 convertPTOTypeToMemRef 返回的类型一致,以便与 Subview 兼容 - auto targetLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); // offset = ? - auto targetType = - MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); - - // 4. Preserve tile valid dims (v_row / v_col). - // - // `pto.alloc_tile` encodes the valid shape in the result TileBufType - // (e.g. acc tile may be rows=16 but v_row=1). The alloc op itself does - // not necessarily carry explicit operands for static valid dims, so we - // must materialize them from the type to keep them through - // tile_buf -> memref lowering. - // - // For dynamically valid tiles (validShape == [-1, -1]), preserve the - // runtime operands if present. - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - ArrayRef validShape = tbTy.getValidShape(); - if (!tbTy.hasDynamicValid()) { - // TileBuf valid dims use a negative sentinel (e.g. '?' / -1), which is - // distinct from MLIR's ShapedType::kDynamic (INT64_MIN). Treat any - // negative value as dynamic here. - if (validShape.size() >= 1 && validShape[0] >= 0) { - vRow = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[0])) - .getResult(); - } - if (validShape.size() >= 2 && validShape[1] >= 0) { - vCol = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[1])) - .getResult(); - } - } - - // 5. 获取 Config (保持不变) - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - // 6. If alloc_tile provides an explicit address, keep the original - // pointer_cast lowering intact and additionally rebind through - // pto.bind_tile. PointerCastOp continues to carry the tile metadata - // used by existing lowering paths, while BindTileOp provides the - // unified anchor EmitC uses to recover tile_buf information. - if (Value addr = op.getAddr()) { - auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); - auto bindOp = rewriter.create( - loc, targetType, pc.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - continue; - } + auto tbTy = dyn_cast(op.getTile().getType()); + if (!tbTy) { + op.emitError("declare_tile result must be tile_buf type"); + return failure(); + } - // 7. Otherwise allocate a concrete memref buffer and bind tile. - // memref.alloc 要求明确的 layout,不能是动态 offset。 - auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 - auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + op.emitError("failed to convert declare_tile result to memref type"); + return failure(); + } - // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 - auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow; + Value vCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto declaredMemRef = + rewriter.create(loc, targetType); + auto bindOp = rewriter.create( + loc, targetType, declaredMemRef.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} - rewriter.replaceOp(op, bindOp.getResult()); +static LogicalResult lowerMakeTensorViewOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector makeViews; + func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); + + for (auto op : makeViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value baseBuf = op.getOperand(0); + OpFoldResult off0 = rewriter.getIndexAttr(0); + bool foldedAddPtr = false; + { + Value cur = baseBuf; + Value totalOffset; + while (auto add = cur.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) + : off; + cur = add.getOperand(0); + } + if (cur != baseBuf) { + baseBuf = cur; + off0 = totalOffset ? OpFoldResult(totalOffset) : off0; } + } - // ------------------------------------------------------------------ - // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + - // pto.bind_tile - // ------------------------------------------------------------------ - SmallVector declaredTiles; - func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); + auto baseMr = dyn_cast(baseBuf.getType()); + if (!baseMr) { + op.emitError("make_tensor_view base must be memref"); + return failure(); + } - for (auto op : declaredTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); + size_t rank = op.getShape().size(); + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = + StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); + SmallVector dynShape(rank, dyn); + auto mrTy = MemRefType::get(dynShape, baseMr.getElementType(), layout, + baseMr.getMemorySpace()); + + SmallVector sizes; + for (Value value : op.getShape()) + sizes.push_back(ensureIndex(rewriter, loc, value, op)); + SmallVector strides; + for (Value value : op.getStrides()) + strides.push_back(ensureIndex(rewriter, loc, value, op)); + + auto rc = rewriter.create(loc, mrTy, baseBuf, off0, + sizes, strides); + if (foldedAddPtr) + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + if (auto layoutAttr = op.getLayoutAttr()) + rc->setAttr("layout", layoutAttr); + rewriter.replaceOp(op, rc.getResult()); + } + return success(); +} - auto tbTy = dyn_cast(op.getTile().getType()); - if (!tbTy) { - op.emitError("declare_tile result must be tile_buf type"); - signalPassFailure(); - return; - } +static LogicalResult lowerTensorViewDimOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - op.emitError("failed to convert declare_tile result to memref type"); - signalPassFailure(); - return; - } + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; + Value dim = rewriter.create(op.getLoc(), view, op.getDimIndex()); + rewriter.replaceOp(op, dim); + } + return success(); +} - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - Value vRow; - Value vCol; - ArrayRef validShape = tbTy.getValidShape(); - if (!tbTy.hasDynamicValid()) { - if (validShape.size() >= 1 && validShape[0] >= 0) { - vRow = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[0])) - .getResult(); - } - if (validShape.size() >= 2 && validShape[1] >= 0) { - vCol = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[1])) - .getResult(); - } - } +static LogicalResult foldAddPtrIntoScalarOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector loadScalars; + func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); + for (auto op : loadScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) + : off; + base = add.getOperand(0); + } + if (foldedAddPtr) { + auto newOp = + rewriter.create(loc, op.getValue().getType(), base, + totalOffset); + rewriter.replaceOp(op, newOp.getValue()); + } + } - auto declaredMemRef = - rewriter.create(loc, targetType); - auto bindOp = rewriter.create( - loc, targetType, declaredMemRef.getResult(), - vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + SmallVector storeScalars; + func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); + for (auto op : storeScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) + : off; + base = add.getOperand(0); + } + if (foldedAddPtr) { + rewriter.create(loc, base, totalOffset, op.getValue()); + rewriter.eraseOp(op); + } + } - rewriter.replaceOp(op, bindOp.getResult()); + SmallVector addPtrs; + func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); + bool changed = true; + while (changed) { + changed = false; + for (auto &op : addPtrs) { + if (!op) + continue; + if (op->use_empty()) { + op->erase(); + op = nullptr; + changed = true; } + } + } + for (Operation *op : addPtrs) { + if (!op) + continue; + op->emitError( + "addptr must feed make_tensor_view or load/store_scalar for lowering"); + return failure(); + } + return success(); +} - // ------------------------------------------------------------------ - // Stage 0.8: normalize pto.tassign result type to match tile operand - // after tile_buf -> memref lowering (required for verifier consistency). - // ------------------------------------------------------------------ - SmallVector tassignOps; - func.walk([&](mlir::pto::TAssignOp op) { tassignOps.push_back(op); }); - for (auto op : tassignOps) { - Type targetTy = op.getTile().getType(); - if (op.getResult().getType() == targetTy) - continue; - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - auto normalized = - rewriter.create(op.getLoc(), targetTy, op.getTile(), - op.getAddr()); - rewriter.replaceOp(op, normalized.getResult()); +static LogicalResult lowerPartitionViewOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector partitionViews; + func.walk([&](mlir::pto::PartitionViewOp op) { partitionViews.push_back(op); }); + + for (auto op : partitionViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + Value src = op.getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + int64_t rank = srcMrTy.getRank(); + + SmallVector staticSizes; + SmallVector mixedSizes; + for (Value size : op.getSizes()) { + IntegerAttr constAttr; + bool isStatic = false; + if (auto cOp = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + isStatic = true; + } else if (auto cInt = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + isStatic = true; + } + + if (isStatic) { + mixedSizes.push_back(constAttr); + staticSizes.push_back(constAttr.getInt()); + } else { + mixedSizes.push_back(ensureIndex(rewriter, loc, size, op)); + staticSizes.push_back(ShapedType::kDynamic); } + } - // ------------------------------------------------------------------ - // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast - // ------------------------------------------------------------------ - SmallVector makeViews; - func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); - - for (auto op : makeViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value baseBuf = op.getOperand(0); - OpFoldResult off0 = rewriter.getIndexAttr(0); - - // Fold pto.addptr chains into the view base to avoid nested reinterpret_cast. - bool foldedAddPtr = false; - { - Value cur = baseBuf; - Value totalOffset; - while (auto add = cur.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - cur = add.getOperand(0); - } - if (cur != baseBuf) { - baseBuf = cur; - off0 = totalOffset ? OpFoldResult(totalOffset) : off0; - } - } + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + IntegerAttr constAttr; + bool isStatic = false; + if (auto cOp = offset.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + isStatic = true; + } else if (auto cInt = offset.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + isStatic = true; + } + mixedOffsets.push_back(isStatic ? OpFoldResult(constAttr) + : OpFoldResult(ensureIndex(rewriter, loc, + offset, op))); + } - auto baseMr = dyn_cast(baseBuf.getType()); - if (!baseMr) { - op.emitError("make_tensor_view base must be memref"); signalPassFailure(); return; - } + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); + auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, + srcMrTy.getMemorySpace()); + + SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, resTy, src, mixedOffsets, + mixedSizes, mixedStrides); + if (Operation *srcDef = src.getDefiningOp()) { + if (auto layoutAttr = srcDef->getAttrOfType("layout")) + sv->setAttr("layout", layoutAttr); + } + rewriter.replaceOp(op, sv.getResult()); + } + return success(); +} - // [修复] 获取动态 Rank (根据 shape 输入的数量) - size_t rank = op.getShape().size(); +static LogicalResult lowerSubsetOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector subsets; + func.walk([&](mlir::pto::SubsetOp op) { subsets.push_back(op); }); + + for (auto op : subsets) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto resultTileTy = dyn_cast(op.getResult().getType()); + Value src = op->getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + op.emitError("pto.subset source must be lowered to memref first"); + return failure(); + } - // Construct target type with dynamic offset/strides - Type elemTy = baseMr.getElementType(); - int64_t dyn = ShapedType::kDynamic; - - // [修复] 构建 N 维 Strided Layout - // strides 数组长度必须等于 rank - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); - - // [修复] 构建 N 维 Shape - SmallVector dynShape(rank, dyn); - auto mrTy = MemRefType::get(dynShape, elemTy, layout, baseMr.getMemorySpace()); + ArrayAttr sizeAttr = op.getSizes(); + SmallVector staticSizes; + SmallVector mixedSizes; + for (Attribute attr : sizeAttr) { + int64_t size = cast(attr).getInt(); + staticSizes.push_back(size); + mixedSizes.push_back(rewriter.getIndexAttr(size)); + } - SmallVector sizes; - for (Value v : op.getShape()) sizes.push_back(ensureIndex(rewriter, loc, v, op)); + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + IntegerAttr constAttr; + bool isStatic = false; + if (auto cOp = offset.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + isStatic = true; + } else if (auto cInt = offset.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + isStatic = true; + } + mixedOffsets.push_back(isStatic ? OpFoldResult(constAttr) + : OpFoldResult(ensureIndex(rewriter, loc, + offset, op))); + } - SmallVector strides; - for (Value v : op.getStrides()) strides.push_back(ensureIndex(rewriter, loc, v, op)); + auto configAttr = lookupConfig(src); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); - auto rc = rewriter.create( - loc, mrTy, baseBuf, off0, sizes, strides); - if (foldedAddPtr) { - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - } - if (auto layoutAttr = op.getLayoutAttr()) { - rc->setAttr("layout", layoutAttr); - } + TileLayoutInfo layoutInfo; + if (!computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), + srcMrTy.getShape(), layoutInfo)) { + op.emitError("unsupported tile layout for pto.subset"); + return failure(); + } - rewriter.replaceOp(op, rc.getResult()); + if (layoutInfo.boxed) { + if (staticSizes.size() != 2 || op.getOffsets().size() != 2) { + op.emitError("boxed layout subset expects 2D sizes/offsets"); + return failure(); } - - // ------------------------------------------------------------------ - // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim - // ------------------------------------------------------------------ - SmallVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; // leave it to later passes if it hasn't been lowered yet - - Value dimIdx = op.getDimIndex(); - Value dim = rewriter.create(loc, view, dimIdx); - rewriter.replaceOp(op, dim); + if (!checkMultipleOf(op, staticSizes[0], layoutInfo.innerRows, "row size") || + !checkMultipleOf(op, staticSizes[1], layoutInfo.innerCols, "col size")) { + return failure(); } - // ------------------------------------------------------------------ - // Stage 1.5: Fold pto.addptr chains into load/store_scalar. - // ------------------------------------------------------------------ - SmallVector loadScalars; - func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); - - for (auto op : loadScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - auto newOp = rewriter.create( - loc, op.getValue().getType(), base, totalOffset); - rewriter.replaceOp(op, newOp.getValue()); - } + int64_t off0 = 0; + int64_t off1 = 0; + bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); + bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); + if (off0Const && + !checkMultipleOf(op, off0, layoutInfo.innerRows, "row offset")) { + return failure(); } - - SmallVector storeScalars; - func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); - - for (auto op : storeScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - rewriter.create( - loc, base, totalOffset, op.getValue()); - rewriter.eraseOp(op); - } + if (off1Const && + !checkMultipleOf(op, off1, layoutInfo.innerCols, "col offset")) { + return failure(); } - // Clean up: addptr should be folded into make_tensor_view. - SmallVector addPtrs; - func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); - bool changed = true; - while (changed) { - changed = false; - for (auto &op : addPtrs) { - if (!op) - continue; - if (op->use_empty()) { - op->erase(); - op = nullptr; - changed = true; + int32_t bl = 0; + (void)readBLayoutI32(configAttr.getBLayout(), bl); + auto srcShape = srcMrTy.getShape(); + if (srcShape.size() == 2) { + if (bl == 0) { + if (staticSizes[1] != srcShape[1]) { + op.emitError("boxed RowMajor subset must keep full cols"); + return failure(); } - } - } - for (auto *op : addPtrs) { - if (!op) - continue; - op->emitError("addptr must feed make_tensor_view or load/store_scalar for lowering"); - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 2: Lower pto.partition_tensor_view -> memref.subview - // ------------------------------------------------------------------ - SmallVector partitiontensorviews; - func.walk([&](mlir::pto::PartitionViewOp op) { partitiontensorviews.push_back(op); }); - - for (auto op : partitiontensorviews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - Value src = op.getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - int64_t rank = srcMrTy.getRank(); - - // ===================================================================== - // 1. 处理 Sizes (智能区分 Static/Dynamic) - // ===================================================================== - ValueRange sizeValues = op.getSizes(); - SmallVector staticSizes; // 用于构建 Result MemRefType - SmallVector mixedSizes; // 用于传给 memref.subview - - for (Value s : sizeValues) { - // [关键修改] 检查 Value 是否源自常量 Op - IntegerAttr constAttr; - bool isStatic = false; - - // 检查 arith.constant (index or int) - if (auto cOp = s.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - isStatic = true; - } else if (auto cInt = s.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - isStatic = true; - } - - if (isStatic) { - // Case A: 静态常量 -> 存 Attribute - mixedSizes.push_back(constAttr); - staticSizes.push_back(constAttr.getInt()); - } else { - // Case B: 动态变量 -> 存 Value - mixedSizes.push_back(ensureIndex(rewriter, loc, s, op)); - staticSizes.push_back(ShapedType::kDynamic); - } - } - - // ===================================================================== - // 2. 处理 Offsets (同样应用智能区分) - // ===================================================================== - // Offsets 也需要同样的逻辑,否则也会报类似的 mismatch - ValueRange offsValues = op.getOffsets(); - SmallVector mixedOffsets; - - for (Value o : offsValues) { - IntegerAttr constAttr; - bool isStatic = false; - - if (auto cOp = o.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - isStatic = true; - } else if (auto cInt = o.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - isStatic = true; - } - - if (isStatic) { - mixedOffsets.push_back(constAttr); - } else { - mixedOffsets.push_back(ensureIndex(rewriter, loc, o, op)); - } - } - - // ===================================================================== - // 3. 构建 Result MemRefType - // ===================================================================== - int64_t dyn = ShapedType::kDynamic; - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); - - auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, srcMrTy.getMemorySpace()); - - // ===================================================================== - // 4. 处理 Strides (默认全 1) - // ===================================================================== - SmallVector mixedStrides; - for (int i = 0; i < rank; ++i) { - mixedStrides.push_back(rewriter.getIndexAttr(1)); - } - - // ===================================================================== - // 5. 创建 memref.subview - // ===================================================================== - auto sv = rewriter.create( - loc, - resTy, - src, - mixedOffsets, - mixedSizes, - mixedStrides - ); - if (Operation *srcDef = src.getDefiningOp()) { - if (auto layoutAttr = srcDef->getAttrOfType("layout")) { - sv->setAttr("layout", layoutAttr); + if (!off1Const || off1 != 0) { + op.emitError("boxed RowMajor subset requires static col offset = 0"); + return failure(); } - } - - rewriter.replaceOp(op, sv.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 2.4: lower pto.subset -> memref.subview + bind_tile - // ------------------------------------------------------------------ - SmallVector subsets; - func.walk([&](mlir::pto::SubsetOp op) { subsets.push_back(op); }); - - for (auto op : subsets) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - auto resultTileTy = - dyn_cast(op.getResult().getType()); - - // 1. Source must be memref already - Value src = op->getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - op.emitError("pto.subset source must be lowered to memref first"); - signalPassFailure(); - return; - } - - // 2. Sizes (static) - ArrayAttr sizeAttr = op.getSizes(); - SmallVector staticSizes; - SmallVector mixedSizes; - staticSizes.reserve(sizeAttr.size()); - mixedSizes.reserve(sizeAttr.size()); - for (Attribute attr : sizeAttr) { - int64_t s = cast(attr).getInt(); - staticSizes.push_back(s); - mixedSizes.push_back(rewriter.getIndexAttr(s)); - } - - // 3. Offsets (mixed) - SmallVector mixedOffsets; - for (Value o : op.getOffsets()) { - IntegerAttr constAttr; - bool isStatic = false; - if (auto cOp = o.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - isStatic = true; - } else if (auto cInt = o.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - isStatic = true; + } else { + if (staticSizes[0] != srcShape[0]) { + op.emitError("boxed ColMajor subset must keep full rows"); + return failure(); } - if (isStatic) - mixedOffsets.push_back(constAttr); - else - mixedOffsets.push_back(ensureIndex(rewriter, loc, o, op)); - } - - // 3.1 Layout-aware checks for boxed tiles (SLayout != NoneBox) - auto configAttr = lookupConfig(src); - if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - TileLayoutInfo layoutInfo; - bool hasLayout = - computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), - srcMrTy.getShape(), layoutInfo); - if (!hasLayout) { - op.emitError("unsupported tile layout for pto.subset"); - signalPassFailure(); - return; - } - - if (layoutInfo.boxed) { - if (staticSizes.size() != 2 || op.getOffsets().size() != 2) { - op.emitError("boxed layout subset expects 2D sizes/offsets"); - signalPassFailure(); - return; + if (!off0Const || off0 != 0) { + op.emitError("boxed ColMajor subset requires static row offset = 0"); + return failure(); } + } + } + } - auto checkMul = [&](int64_t v, int64_t m, StringRef name) -> bool { - if (m <= 0) return false; - if (v % m != 0) { - op.emitError("boxed layout requires ") << name << " multiple of " - << m << ", got " << v; - return false; - } - return true; - }; - - if (!checkMul(staticSizes[0], layoutInfo.innerRows, "row size") || - !checkMul(staticSizes[1], layoutInfo.innerCols, "col size")) { - signalPassFailure(); - return; - } + SmallVector srcStrides; + int64_t srcOffset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) + srcStrides = computeCompactStrides(srcMrTy.getShape()); + + auto resultLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); + auto resultMemRefType = + MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, + srcMrTy.getMemorySpace()); + SmallVector mixedStrides(staticSizes.size(), + rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, resultMemRefType, src, + mixedOffsets, mixedSizes, + mixedStrides); + + Value parentVRow; + Value parentVCol; + lookupValidDims(src, parentVRow, parentVCol); + Value vRow; + Value vCol; + if (!staticSizes.empty()) + vRow = computeSubsetValidDim(rewriter, loc, parentVRow, op.getOffsets()[0], + staticSizes[0], op); + if (staticSizes.size() > 1) + vCol = computeSubsetValidDim(rewriter, loc, parentVCol, op.getOffsets()[1], + staticSizes[1], op); + + auto bindOp = rewriter.create( + loc, resultMemRefType, sv.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, + resultTileTy && resultTileTy.hasDynamicValid(), + ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} - int64_t off0 = 0, off1 = 0; - bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); - bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); - if (off0Const) { - if (!checkMul(off0, layoutInfo.innerRows, "row offset")) { - signalPassFailure(); - return; - } - } - if (off1Const) { - if (!checkMul(off1, layoutInfo.innerCols, "col offset")) { - signalPassFailure(); - return; - } - } +static Value buildTileBufViewLikeValue(Operation *anchorOp, Value src, + mlir::pto::TileBufType tbTy, + StringRef viewSemantics, + MLIRContext *ctx) { + Location loc = anchorOp->getLoc(); + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(anchorOp); + + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + anchorOp->emitError("tile_buf view op src must be lowered to memref first"); + return Value(); + } - int32_t bl = 0; - (void)readBLayoutI32(configAttr.getBLayout(), bl); - - auto srcShape = srcMrTy.getShape(); - if (srcShape.size() == 2) { - if (bl == 0) { - if (staticSizes[1] != srcShape[1]) { - op.emitError("boxed RowMajor subset must keep full cols"); - signalPassFailure(); - return; - } - if (!off1Const || off1 != 0) { - op.emitError("boxed RowMajor subset requires static col offset = 0"); - signalPassFailure(); - return; - } - } else { - if (staticSizes[0] != srcShape[0]) { - op.emitError("boxed ColMajor subset must keep full rows"); - signalPassFailure(); - return; - } - if (!off0Const || off0 != 0) { - op.emitError("boxed ColMajor subset requires static row offset = 0"); - signalPassFailure(); - return; - } - } - } - } + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + anchorOp->emitError("failed to convert tile_buf type to memref type"); + return Value(); + } + for (int64_t dim : targetType.getShape()) { + if (dim == ShapedType::kDynamic) { + anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); + return Value(); + } + } - // 4. Result layout inherits source strides (offset is dynamic) - SmallVector srcStrides; - int64_t srcOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) { - // Fallback: compact row-major - auto shape = srcMrTy.getShape(); - srcStrides.resize(shape.size()); - int64_t s = 1; - for (int i = shape.size() - 1; i >= 0; --i) { - srcStrides[i] = s; - if (shape[i] != ShapedType::kDynamic) s *= shape[i]; - } - } - (void)srcOffset; - - auto resultLayout = StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); - auto resultMemRefType = - MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, - srcMrTy.getMemorySpace()); - - // 5. Strides for subview: keep same stride (use 1) - SmallVector mixedStrides; - mixedStrides.reserve(staticSizes.size()); - for (size_t i = 0; i < staticSizes.size(); ++i) - mixedStrides.push_back(rewriter.getIndexAttr(1)); - - auto sv = rewriter.create( - loc, resultMemRefType, src, mixedOffsets, mixedSizes, mixedStrides); - - // 6. Re-bind tile metadata (config + valid dims) - Value parentVRow; - Value parentVCol; - lookupValidDims(src, parentVRow, parentVCol); - - Value vRow; - Value vCol; - if (!staticSizes.empty()) - vRow = computeSubsetValidDim(rewriter, loc, parentVRow, - op.getOffsets()[0], staticSizes[0], op); - if (staticSizes.size() > 1) - vCol = computeSubsetValidDim(rewriter, loc, parentVCol, - op.getOffsets()[1], staticSizes[1], op); - - auto bindOp = rewriter.create( - loc, resultMemRefType, sv.getResult(), - vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, - resultTileTy && resultTileTy.hasDynamicValid(), - ctx); - - rewriter.replaceOp(op, bindOp.getResult()); - } + Value parentVRow; + Value parentVCol; + lookupValidDims(src, parentVRow, parentVCol); + Value vRow = parentVRow; + Value vCol = parentVCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + auto bindOp = rewriter.create( + loc, targetType, src, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + if (!viewSemantics.empty()) + bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr(viewSemantics)); + return bindOp.getResult(); +} - // ------------------------------------------------------------------ - // Stage 2.75: Lower SSA tile_buf view ops (pto.treshape / pto.bitcast) - // ------------------------------------------------------------------ - auto lowerTileBufViewLike = [&](Operation *anchorOp, Value src, - mlir::pto::TileBufType tbTy, - StringRef viewSemantics) -> Value { - Location loc = anchorOp->getLoc(); - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(anchorOp); +static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx) { + SmallVector reshapes; + func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); + for (auto op : reshapes) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("treshape result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "treshape", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - anchorOp->emitError("tile_buf view op src must be lowered to memref first"); - signalPassFailure(); - return Value(); - } + SmallVector bitcasts; + func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); + for (auto op : bitcasts) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("bitcast result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "bitcast", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } + return success(); +} - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - anchorOp->emitError("failed to convert tile_buf type to memref type"); - signalPassFailure(); - return Value(); - } +// ============================================================================= +// The Pass Implementation +// ============================================================================= - // Require static shape for now (alloc_tile lowering also requires this). - for (int64_t d : targetType.getShape()) { - if (d == ShapedType::kDynamic) { - anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); - signalPassFailure(); - return Value(); - } - } +struct PTOViewToMemrefPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) - // Re-bind (possibly-updated) tile metadata. - Value parentVRow; - Value parentVCol; - lookupValidDims(src, parentVRow, parentVCol); - - Value vRow = parentVRow; - Value vCol = parentVCol; - ArrayRef validShape = tbTy.getValidShape(); - if (!tbTy.hasDynamicValid()) { - if (validShape.size() >= 1 && validShape[0] >= 0) { - vRow = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[0])) - .getResult(); - } - if (validShape.size() >= 2 && validShape[1] >= 0) { - vCol = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(validShape[1])) - .getResult(); - } - } + StringRef getArgument() const final { return "pto-view-to-memref"; } + StringRef getDescription() const final { + return "Lower PTO views to memref with Metadata Binding"; + } - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - auto bindOp = rewriter.create( - loc, targetType, src, - vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - if (!viewSemantics.empty()) - bindOp->setAttr("pto.view_semantics", - rewriter.getStringAttr(viewSemantics)); - return bindOp.getResult(); - }; - - SmallVector reshapes; - func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); - - for (auto op : reshapes) { - Value src = op->getOperand(0); - auto tbTy = dyn_cast(op->getResult(0).getType()); - if (!tbTy) { - op.emitError("treshape result must be tile_buf type"); - signalPassFailure(); - return; - } - Value lowered = lowerTileBufViewLike(op, src, tbTy, "treshape"); - if (!lowered) - return; - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); - } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } - SmallVector bitcasts; - func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); - for (auto op : bitcasts) { - Value src = op->getOperand(0); - auto tbTy = dyn_cast(op->getResult(0).getType()); - if (!tbTy) { - op.emitError("bitcast result must be tile_buf type"); - signalPassFailure(); - return; - } - Value lowered = lowerTileBufViewLike(op, src, tbTy, "bitcast"); - if (!lowered) - return; - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); + for (auto func : mod.getOps()) { + if (func.isExternal()) + continue; + rewriteFunctionSignature(func, ctx); + if (failed(lowerAllocTileOps(func, ctx)) || + failed(lowerDeclareTileOps(func, ctx)) || + failed(lowerMakeTensorViewOps(func, ctx)) || + failed(lowerTensorViewDimOps(func, ctx)) || + failed(foldAddPtrIntoScalarOps(func, ctx)) || + failed(lowerPartitionViewOps(func, ctx)) || + failed(lowerSubsetOps(func, ctx)) || + failed(lowerTileBufViewLikeOps(func, ctx))) { + signalPassFailure(); + return; } // ------------------------------------------------------------------ - // Stage 3: Rewrite Compute Ops + // Stage 3: Rewrite Compute Ops // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/Utils.cpp b/lib/PTO/Transforms/Utils.cpp index ec1dc7181..58e68c77e 100644 --- a/lib/PTO/Transforms/Utils.cpp +++ b/lib/PTO/Transforms/Utils.cpp @@ -6,13 +6,9 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "PTO/IR/PTO.h" #include "Utils.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -59,7 +55,8 @@ void setBaseMemRefTypeScope(Value val, AddressSpaceAttr targetMemScope) { if (auto curMemScope = dyn_cast_if_present( dyn_cast(type).getMemorySpace())) { - assert(curMemScope == targetMemScope); + if (curMemScope != targetMemScope) + llvm::report_fatal_error("memref scope mismatch while propagating PTO address space"); return; } @@ -207,7 +204,8 @@ std::optional getStaticTotalSize(const ArrayRef &shapes) { } uint64_t AlignUp(uint64_t lhs, uint64_t rhs) { - assert(rhs != 0); + if (rhs == 0) + return lhs; if (lhs % rhs != 0) { lhs += rhs - (lhs % rhs); } @@ -301,7 +299,8 @@ std::optional getYieldValueIdx(Value targetVal, ValueRange yieldedValues) { } LoopLikeOpInterface getParentLoop(Value val) { - assert(val.getDefiningOp() && "val should have defining op."); + if (!val.getDefiningOp()) + return nullptr; // Firstly, get parent loop LoopLikeOpInterface parentLoop = diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 474ad9f74..a5a5cd0c8 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -6,17 +6,22 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -from . import _pto_ops_gen as _pto_ops_gen -from ._pto_ops_gen import * +import importlib +import importlib.util +from pathlib import Path + from mlir import ir as _ods_ir +from . import _pto_ops_gen as _pto_ops_gen + + def _load_local_pto_ext(): - import importlib.util - from pathlib import Path lib_dir = Path(__file__).resolve().parent.parent / "_mlir_libs" for suffix in ("*.so", "*.pyd", "*.dll", "*.dylib"): for so_path in lib_dir.glob(f"_pto{suffix}"): - spec = importlib.util.spec_from_file_location("mlir._mlir_libs._pto", so_path) + spec = importlib.util.spec_from_file_location( + "mlir._mlir_libs._pto", so_path + ) if spec and spec.loader: mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -27,7 +32,21 @@ def _load_local_pto_ext(): try: _pto_mod = _load_local_pto_ext() except Exception: - from .._mlir_libs import _pto as _pto_mod + _pto_mod = importlib.import_module(".._mlir_libs._pto", __package__) + + +def _export_generated_symbols(): + for name, obj in _pto_ops_gen.__dict__.items(): + if name.startswith("_"): + continue + globals()[name] = obj + + +def get_op_result_or_value(value): + return getattr(_pto_ops_gen, "_get_op_result_or_value")(value) + + +_export_generated_symbols() register_dialect = _pto_mod.register_dialect PtrType = _pto_mod.PtrType @@ -68,7 +87,6 @@ def _load_local_pto_ext(): __all__ = [ # Dialect utilities "register_dialect", - # Types "PtrType", "AsyncSessionType", @@ -77,44 +95,83 @@ def _load_local_pto_ext(): "PartitionTensorViewType", "TileType", "TileBufType", - "AddressSpace", "AddressSpaceAttr", - "BLayout","BLayoutAttr", - "SLayout","SLayoutAttr", - "PadValue","PadValueAttr", - "CompactMode", "CompactModeAttr", - "RoundMode", "RoundModeAttr", - "CmpMode", "CmpModeAttr", - "PIPE", "PipeAttr", - "Layout", "LayoutAttr", - "SyncOpType", "SyncOpTypeAttr", - "EVENT", "EventAttr", - "MaskPattern", "MaskPatternAttr", - "QuantType", "QuantTypeAttr", + "AddressSpace", + "AddressSpaceAttr", + "BLayout", + "BLayoutAttr", + "SLayout", + "SLayoutAttr", + "PadValue", + "PadValueAttr", + "CompactMode", + "CompactModeAttr", + "RoundMode", + "RoundModeAttr", + "CmpMode", + "CmpModeAttr", + "PIPE", + "PipeAttr", + "Layout", + "LayoutAttr", + "SyncOpType", + "SyncOpTypeAttr", + "EVENT", + "EventAttr", + "MaskPattern", + "MaskPatternAttr", + "QuantType", + "QuantTypeAttr", "TileBufConfigAttr", "TileConfig", # High-level sync helpers - "record_event", "wait_event", "barrier", + "record_event", + "wait_event", + "barrier", # Low-level sync helpers (static/dynamic event id unified API) - "set_flag", "wait_flag", "set_flag_dyn", "wait_flag_dyn", + "set_flag", + "wait_flag", + "set_flag_dyn", + "wait_flag_dyn", # Inter-core sync helpers - "sync_set", "sync_wait", "sync_set_dyn", "sync_wait_dyn", "set_ffts", + "sync_set", + "sync_wait", + "sync_set_dyn", + "sync_wait_dyn", + "set_ffts", # A5 buffer-id sync helpers - "get_buf", "rls_buf", + "get_buf", + "rls_buf", # Scalar pointer helpers - "load_scalar", "store_scalar" - + "load_scalar", + "store_scalar", # Aliases for SyncOpType enums (for terse calls) - ,"TLOAD","TSTORE_ACC","TSTORE_VEC","TMOV_M2L","TMOV_M2S", - "TMOV_M2B","TMOV_M2V","TMOV_V2M","TMATMUL","TVEC","TVECWAIT_EVENT" + "TLOAD", + "TSTORE_ACC", + "TSTORE_VEC", + "TMOV_M2L", + "TMOV_M2S", + "TMOV_M2B", + "TMOV_M2V", + "TMOV_V2M", + "TMATMUL", + "TVEC", + "TVECWAIT_EVENT", # Aliases for EVENT enums - ,"EVENT_ID0","EVENT_ID1","EVENT_ID2","EVENT_ID3", - "EVENT_ID4","EVENT_ID5","EVENT_ID6","EVENT_ID7" + "EVENT_ID0", + "EVENT_ID1", + "EVENT_ID2", + "EVENT_ID3", + "EVENT_ID4", + "EVENT_ID5", + "EVENT_ID6", + "EVENT_ID7", ] # ----------------------------------------------------------------------------- # Convenience wrappers for high-level sync to allow passing enums directly # ----------------------------------------------------------------------------- + def _ensure_sync_attr(val, ctx): # Accept SyncOpType enum, SyncOpTypeAttr, or string name ("TMATMUL"/"tmatmul"). if isinstance(val, SyncOpType): @@ -123,11 +180,12 @@ def _ensure_sync_attr(val, ctx): name = val.upper() try: enum_val = getattr(SyncOpType, name) - except AttributeError: - raise ValueError(f"Unknown SyncOpType name: {val}") + except AttributeError as exc: + raise ValueError(f"Unknown SyncOpType name: {val}") from exc return SyncOpTypeAttr.get(enum_val, ctx) return val + def _ensure_event_attr(val, ctx): if isinstance(val, EVENT): return EventAttr.get(val, ctx) @@ -137,18 +195,19 @@ def _ensure_event_attr(val, ctx): enum_name = f"EVENT_ID{val}" try: enum_val = getattr(EVENT, enum_name) - except AttributeError: - raise ValueError(f"Unknown EVENT integer id: {val}") + except AttributeError as exc: + raise ValueError(f"Unknown EVENT integer id: {val}") from exc return EventAttr.get(enum_val, ctx) if isinstance(val, str): name = val.upper() try: enum_val = getattr(EVENT, name) - except AttributeError: - raise ValueError(f"Unknown EVENT name: {val}") + except AttributeError as exc: + raise ValueError(f"Unknown EVENT name: {val}") from exc return EventAttr.get(enum_val, ctx) return val + def _ensure_pipe_attr(val, ctx): if isinstance(val, PipeAttr): return val @@ -158,11 +217,12 @@ def _ensure_pipe_attr(val, ctx): name = val.upper() try: enum_val = getattr(PIPE, name) - except AttributeError: - raise ValueError(f"Unknown PIPE name: {val}") + except AttributeError as exc: + raise ValueError(f"Unknown PIPE name: {val}") from exc return PipeAttr.get(enum_val, ctx) return val + def _ensure_i32_attr(val, name, ctx): if isinstance(val, _ods_ir.IntegerAttr): return val @@ -171,13 +231,17 @@ def _ensure_i32_attr(val, name, ctx): return _ods_ir.IntegerAttr.get(i32, val) raise TypeError(f"{name} must be int or IntegerAttr, got {type(val).__name__}") + def record_event(src_op, dst_op, event_id, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current return _pto_ops_gen.record_event( _ensure_sync_attr(src_op, ctx), _ensure_sync_attr(dst_op, ctx), _ensure_event_attr(event_id, ctx), - loc=loc, ip=ip) + loc=loc, + ip=ip, + ) + def wait_event(src_op, dst_op, event_id, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current @@ -185,7 +249,10 @@ def wait_event(src_op, dst_op, event_id, *, loc=None, ip=None): _ensure_sync_attr(src_op, ctx), _ensure_sync_attr(dst_op, ctx), _ensure_event_attr(event_id, ctx), - loc=loc, ip=ip) + loc=loc, + ip=ip, + ) + def barrier(op, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current @@ -209,22 +276,31 @@ def _is_static_i32_event_id(event_id): return False +def _create_pipe_event_op(op_name, src_attr, dst_attr, event_id, *, loc=None, ip=None): + return _ods_ir.Operation.create( + op_name, + attributes={"src_pipe": src_attr, "dst_pipe": dst_attr}, + operands=[get_op_result_or_value(event_id)], + loc=loc, + ip=ip, + ) + + def set_flag_dyn(src_pipe, dst_pipe, event_id, *, loc=None, ip=None): """Low-level dynamic event-id set_flag helper.""" ctx = loc.context if loc else _ods_ir.Context.current src_attr = _ensure_pipe_attr(src_pipe, ctx) dst_attr = _ensure_pipe_attr(dst_pipe, ctx) - event_val = _pto_ops_gen._get_op_result_or_value(event_id) if hasattr(_pto_ops_gen, "set_flag_dyn"): return _pto_ops_gen.set_flag_dyn( - src_attr, dst_attr, event_val, loc=loc, ip=ip + src_attr, + dst_attr, + get_op_result_or_value(event_id), + loc=loc, + ip=ip, ) - return _ods_ir.Operation.create( - "pto.set_flag_dyn", - attributes={"src_pipe": src_attr, "dst_pipe": dst_attr}, - operands=[event_val], - loc=loc, - ip=ip, + return _create_pipe_event_op( + "pto.set_flag_dyn", src_attr, dst_attr, event_id, loc=loc, ip=ip ) @@ -233,17 +309,16 @@ def wait_flag_dyn(src_pipe, dst_pipe, event_id, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current src_attr = _ensure_pipe_attr(src_pipe, ctx) dst_attr = _ensure_pipe_attr(dst_pipe, ctx) - event_val = _pto_ops_gen._get_op_result_or_value(event_id) if hasattr(_pto_ops_gen, "wait_flag_dyn"): return _pto_ops_gen.wait_flag_dyn( - src_attr, dst_attr, event_val, loc=loc, ip=ip + src_attr, + dst_attr, + get_op_result_or_value(event_id), + loc=loc, + ip=ip, ) - return _ods_ir.Operation.create( - "pto.wait_flag_dyn", - attributes={"src_pipe": src_attr, "dst_pipe": dst_attr}, - operands=[event_val], - loc=loc, - ip=ip, + return _create_pipe_event_op( + "pto.wait_flag_dyn", src_attr, dst_attr, event_id, loc=loc, ip=ip ) @@ -278,13 +353,16 @@ def wait_flag(src_pipe, dst_pipe, event_id, *, loc=None, ip=None): ) return wait_flag_dyn(src_attr, dst_attr, event_id, loc=loc, ip=ip) + # ----------------------------------------------------------------------------- # Inter-core sync helpers (pto.sync.set / pto.sync.wait / pto.set_ffts) # ----------------------------------------------------------------------------- + + def sync_set_dyn(pipe, event_id, ffts_mode=2, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current pipe_attr = _ensure_pipe_attr(pipe, ctx) - event_val = _pto_ops_gen._get_op_result_or_value(event_id) + event_val = get_op_result_or_value(event_id) mode_attr = None if ffts_mode != 2: mode_attr = _ensure_i32_attr(ffts_mode, "ffts_mode", ctx) @@ -340,7 +418,7 @@ def sync_set(pipe, event_id, ffts_mode=2, *, loc=None, ip=None): def sync_wait_dyn(pipe, event_id, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current pipe_attr = _ensure_pipe_attr(pipe, ctx) - event_val = _pto_ops_gen._get_op_result_or_value(event_id) + event_val = get_op_result_or_value(event_id) try: return _pto_ops_gen.sync_wait( pipe_attr, event_id=None, event_id_dyn=event_val, loc=loc, ip=ip @@ -369,17 +447,21 @@ def sync_wait(pipe, event_id, *, loc=None, ip=None): ) return sync_wait_dyn(pipe_attr, event_id, loc=loc, ip=ip) + def set_ffts(ffts, *, loc=None, ip=None): return _ods_ir.Operation.create( "pto.set_ffts", - operands=[_pto_ops_gen._get_op_result_or_value(ffts)], + operands=[get_op_result_or_value(ffts)], loc=loc, ip=ip, ) + # ----------------------------------------------------------------------------- # A5 buffer-id sync helpers # ----------------------------------------------------------------------------- + + def get_buf(op_type, buf_id, mode=0, *, loc=None, ip=None): ctx = loc.context if loc else _ods_ir.Context.current if isinstance(op_type, (PipeAttr, PIPE)): @@ -413,13 +495,16 @@ def rls_buf(op_type, buf_id, mode=0, *, loc=None, ip=None): ip=ip, ) + # ----------------------------------------------------------------------------- # Scalar pointer helpers (manual wrappers until python ops are regenerated) # ----------------------------------------------------------------------------- + + def load_scalar(result_type, ptr, offset, *, loc=None, ip=None): operands = [ - _pto_ops_gen._get_op_result_or_value(ptr), - _pto_ops_gen._get_op_result_or_value(offset), + get_op_result_or_value(ptr), + get_op_result_or_value(offset), ] op = _ods_ir.Operation.create( "pto.load_scalar", @@ -433,9 +518,9 @@ def load_scalar(result_type, ptr, offset, *, loc=None, ip=None): def store_scalar(ptr, offset, value, *, loc=None, ip=None): operands = [ - _pto_ops_gen._get_op_result_or_value(ptr), - _pto_ops_gen._get_op_result_or_value(offset), - _pto_ops_gen._get_op_result_or_value(value), + get_op_result_or_value(ptr), + get_op_result_or_value(offset), + get_op_result_or_value(value), ] return _ods_ir.Operation.create( "pto.store_scalar", @@ -444,6 +529,7 @@ def store_scalar(ptr, offset, value, *, loc=None, ip=None): ip=ip, ) + # ----------------------------------------------------------------------------- # Export enum aliases for terse calls: pto.record_event(TLOAD, TLOAD, EVENT_ID0) # ----------------------------------------------------------------------------- @@ -478,9 +564,12 @@ class TileConfig: fractalCSize = 1024 fractalMxSize = 32 + # ----------------------------------------------------------------------------- # Op aliases without "Op" suffix (user-facing) # ----------------------------------------------------------------------------- + + def _install_op_aliases(): added = [] for name, obj in _pto_ops_gen.__dict__.items(): @@ -499,4 +588,5 @@ def _install_op_aliases(): added.append(alias) return added + __all__.extend(_install_op_aliases()) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e0c49c4cd..5bf085d67 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -270,87 +270,111 @@ static bool parseAutoSyncTailHint(llvm::StringRef hintStr, std::string &normaliz // PTOAS__EVENTID_ARRAY_LOAD(arr, idx) -> arr[idx] // PTOAS__EVENTID_ARRAY_STORE(arr, idx, v) -> arr[idx] = v // -------------------------------------------------------------------------- -static bool rewriteMarkerCallToMember(std::string &cpp, llvm::StringRef marker, - llvm::StringRef memberName, - unsigned expectedNumArgs) { - size_t searchPos = 0; - bool changed = false; - while (true) { - size_t markerPos = cpp.find(marker.str(), searchPos); - if (markerPos == std::string::npos) - break; +struct ParsedMarkerCall { + size_t markerPos = std::string::npos; + size_t rparenPos = std::string::npos; + llvm::SmallVector args; +}; - size_t lparenPos = markerPos + marker.size(); - if (lparenPos >= cpp.size() || cpp[lparenPos] != '(') { - searchPos = markerPos + marker.size(); +static bool parseMarkerArgs(llvm::StringRef argsRef, + llvm::SmallVectorImpl &args) { + size_t partBegin = 0; + int parenDepth = 0; + for (size_t i = 0; i < argsRef.size(); ++i) { + char c = argsRef[i]; + if (c == '(') { + ++parenDepth; continue; } - - // Find the matching ')' for this call, tracking nested parentheses. - size_t argsBegin = lparenPos + 1; - int parenDepth = 0; - size_t rparenPos = std::string::npos; - for (size_t i = argsBegin; i < cpp.size(); ++i) { - char c = cpp[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth == 0) { - rparenPos = i; - break; - } + if (c == ')') { + if (parenDepth > 0) --parenDepth; - } + continue; } - if (rparenPos == std::string::npos) { - // Unbalanced parentheses; stop trying to rewrite. - break; + if (c == ',' && parenDepth == 0) { + args.push_back(argsRef.slice(partBegin, i).trim()); + partBegin = i + 1; } + } + if (partBegin > argsRef.size()) + return false; + args.push_back(argsRef.drop_front(partBegin).trim()); + return true; +} - llvm::StringRef argsRef(cpp.data() + argsBegin, rparenPos - argsBegin); - llvm::SmallVector args; - size_t partBegin = 0; - parenDepth = 0; - for (size_t i = 0; i < argsRef.size(); ++i) { - char c = argsRef[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth > 0) - --parenDepth; - } else if (c == ',' && parenDepth == 0) { - args.push_back(argsRef.slice(partBegin, i).trim()); - partBegin = i + 1; - } +static std::optional +findNextMarkerCall(const std::string &cpp, llvm::StringRef marker, + size_t searchPos) { + ParsedMarkerCall call; + call.markerPos = cpp.find(marker.str(), searchPos); + if (call.markerPos == std::string::npos) + return std::nullopt; + + size_t lparenPos = call.markerPos + marker.size(); + if (lparenPos >= cpp.size() || cpp[lparenPos] != '(') + return ParsedMarkerCall{call.markerPos, std::string::npos, {}}; + + size_t argsBegin = lparenPos + 1; + int parenDepth = 0; + for (size_t i = argsBegin; i < cpp.size(); ++i) { + char c = cpp[i]; + if (c == '(') { + ++parenDepth; + continue; } - if (partBegin <= argsRef.size()) - args.push_back(argsRef.drop_front(partBegin).trim()); + if (c != ')') + continue; + if (parenDepth == 0) { + call.rparenPos = i; + break; + } + --parenDepth; + } + if (call.rparenPos == std::string::npos) + return call; - if (args.size() != expectedNumArgs) { - searchPos = rparenPos + 1; + llvm::StringRef argsRef(cpp.data() + argsBegin, call.rparenPos - argsBegin); + if (!parseMarkerArgs(argsRef, call.args)) + call.args.clear(); + return call; +} + +static bool rewriteMarkerCallToMember(std::string &cpp, llvm::StringRef marker, + llvm::StringRef memberName, + unsigned expectedNumArgs) { + size_t searchPos = 0; + bool changed = false; + for (auto call = findNextMarkerCall(cpp, marker, searchPos); call; + call = findNextMarkerCall(cpp, marker, searchPos)) { + if (call->rparenPos == std::string::npos) { + searchPos = call->markerPos + marker.size(); + continue; + } + if (call->args.size() != expectedNumArgs) { + searchPos = call->rparenPos + 1; continue; } std::string replacement; - replacement.reserve(marker.size() + argsRef.size() + 16); - replacement.append(args[0].str()); + replacement.reserve(marker.size() + 16); + replacement.append(call->args[0].str()); replacement.push_back('.'); replacement.append(memberName.str()); replacement.push_back('('); if (expectedNumArgs == 1) { - // no args } else if (expectedNumArgs == 2) { - replacement.append(args[1].str()); + replacement.append(call->args[1].str()); } else if (expectedNumArgs == 3) { - replacement.append(args[1].str()); + replacement.append(call->args[1].str()); replacement.append(", "); - replacement.append(args[2].str()); + replacement.append(call->args[2].str()); } replacement.push_back(')'); - cpp.replace(markerPos, (rparenPos - markerPos) + 1, replacement); + cpp.replace(call->markerPos, (call->rparenPos - call->markerPos) + 1, + replacement); changed = true; - searchPos = markerPos + replacement.size(); + searchPos = call->markerPos + replacement.size(); } return changed; } @@ -462,70 +486,29 @@ static bool rewriteMarkerCallToSubscript(std::string &cpp, llvm::StringRef marke bool isStore) { size_t searchPos = 0; bool changed = false; - while (true) { - size_t markerPos = cpp.find(marker.str(), searchPos); - if (markerPos == std::string::npos) - break; - - size_t lparenPos = markerPos + marker.size(); - if (lparenPos >= cpp.size() || cpp[lparenPos] != '(') { - searchPos = markerPos + marker.size(); + for (auto call = findNextMarkerCall(cpp, marker, searchPos); call; + call = findNextMarkerCall(cpp, marker, searchPos)) { + if (call->rparenPos == std::string::npos) { + searchPos = call->markerPos + marker.size(); continue; } - - size_t argsBegin = lparenPos + 1; - int parenDepth = 0; - size_t rparenPos = std::string::npos; - for (size_t i = argsBegin; i < cpp.size(); ++i) { - char c = cpp[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth == 0) { - rparenPos = i; - break; - } - --parenDepth; - } - } - if (rparenPos == std::string::npos) { - break; - } - - llvm::StringRef argsRef(cpp.data() + argsBegin, rparenPos - argsBegin); - llvm::SmallVector args; - size_t partBegin = 0; - parenDepth = 0; - for (size_t i = 0; i < argsRef.size(); ++i) { - char c = argsRef[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth > 0) - --parenDepth; - } else if (c == ',' && parenDepth == 0) { - args.push_back(argsRef.slice(partBegin, i).trim()); - partBegin = i + 1; - } - } - if (partBegin <= argsRef.size()) - args.push_back(argsRef.drop_front(partBegin).trim()); - - if (args.size() != expectedNumArgs) { - searchPos = rparenPos + 1; + if (call->args.size() != expectedNumArgs) { + searchPos = call->rparenPos + 1; continue; } std::string replacement; if (isStore) { - replacement = (args[0] + "[" + args[1] + "] = " + args[2]).str(); + replacement = + (call->args[0] + "[" + call->args[1] + "] = " + call->args[2]).str(); } else { - replacement = (args[0] + "[" + args[1] + "]").str(); + replacement = (call->args[0] + "[" + call->args[1] + "]").str(); } - cpp.replace(markerPos, (rparenPos - markerPos) + 1, replacement); + cpp.replace(call->markerPos, (call->rparenPos - call->markerPos) + 1, + replacement); changed = true; - searchPos = markerPos + replacement.size(); + searchPos = call->markerPos + replacement.size(); } return changed; } @@ -557,88 +540,65 @@ static void rewriteEventIdArrayMarkers(std::string &cpp) { static bool rewriteAddPtrTraceMarkers(std::string &cpp, bool showTrace) { size_t searchPos = 0; bool changed = false; - while (true) { - size_t markerPos = cpp.find("PTOAS__ADDPTR_TRACE", searchPos); - if (markerPos == std::string::npos) - break; - - size_t lparenPos = markerPos + (sizeof("PTOAS__ADDPTR_TRACE") - 1); - if (lparenPos >= cpp.size() || cpp[lparenPos] != '(') { - searchPos = markerPos + 1; + for (auto call = findNextMarkerCall(cpp, "PTOAS__ADDPTR_TRACE", searchPos); + call; call = findNextMarkerCall(cpp, "PTOAS__ADDPTR_TRACE", searchPos)) { + if (call->rparenPos == std::string::npos) { + searchPos = call->markerPos + 1; continue; } - - size_t argsBegin = lparenPos + 1; - int parenDepth = 0; - size_t rparenPos = std::string::npos; - for (size_t i = argsBegin; i < cpp.size(); ++i) { - char c = cpp[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth == 0) { - rparenPos = i; - break; - } - --parenDepth; - } - } - if (rparenPos == std::string::npos) { - break; - } - - llvm::StringRef argsRef(cpp.data() + argsBegin, rparenPos - argsBegin); - llvm::SmallVector args; - size_t partBegin = 0; - parenDepth = 0; - for (size_t i = 0; i < argsRef.size(); ++i) { - char c = argsRef[i]; - if (c == '(') { - ++parenDepth; - } else if (c == ')') { - if (parenDepth > 0) - --parenDepth; - } else if (c == ',' && parenDepth == 0) { - args.push_back(argsRef.slice(partBegin, i).trim()); - partBegin = i + 1; - } - } - if (partBegin <= argsRef.size()) - args.push_back(argsRef.drop_front(partBegin).trim()); - - if (args.size() != 3) { - searchPos = rparenPos + 1; + if (call->args.size() != 3) { + searchPos = call->rparenPos + 1; continue; } std::string replacement; if (showTrace) { - replacement.reserve(64 + argsRef.size()); + replacement.reserve(64); replacement.append("/* ADDPTR_TRACE: "); - replacement.append(args[0].str()); + replacement.append(call->args[0].str()); replacement.append(" = "); - replacement.append(args[1].str()); + replacement.append(call->args[1].str()); replacement.append(" + "); - replacement.append(args[2].str()); + replacement.append(call->args[2].str()); replacement.append(" */"); } - size_t replaceEnd = rparenPos; + size_t replaceEnd = call->rparenPos; if (!showTrace) { - size_t i = rparenPos + 1; + size_t i = call->rparenPos + 1; while (i < cpp.size() && std::isspace(static_cast(cpp[i]))) ++i; if (i < cpp.size() && cpp[i] == ';') replaceEnd = i; } - cpp.replace(markerPos, (replaceEnd - markerPos) + 1, replacement); + cpp.replace(call->markerPos, (replaceEnd - call->markerPos) + 1, + replacement); changed = true; - searchPos = markerPos + replacement.size(); + searchPos = call->markerPos + replacement.size(); } return changed; } +static bool isGeneratedGlobalTensorDecl(llvm::StringRef trimmed, + llvm::StringRef &decl, + llvm::StringRef &varName) { + if (!trimmed.starts_with("GlobalTensor<") || !trimmed.ends_with(";") || + trimmed.contains('=') || trimmed.contains('(')) { + return false; + } + + decl = trimmed.drop_back().rtrim(); + size_t lastWs = decl.find_last_of(" \t"); + if (lastWs == llvm::StringRef::npos) + return false; + varName = decl.drop_front(lastWs + 1); + if (!varName.starts_with("v") || varName.size() <= 1) + return false; + return llvm::all_of(varName.drop_front(1), + [](char c) { return std::isdigit(c); }); +} + static void rewriteHoistedGlobalTensorDecls(std::string &cpp) { // When `declareVariablesAtTop` is enabled, the C++ emitter hoists SSA value // declarations to the top of the function and emits assignments later. This @@ -663,33 +623,18 @@ static void rewriteHoistedGlobalTensorDecls(std::string &cpp) { llvm::StringRef trimmed = line.trim(); bool rewritten = false; - if (trimmed.starts_with("GlobalTensor<") && trimmed.ends_with(";") && - !trimmed.contains('=') && !trimmed.contains('(')) { - llvm::StringRef decl = trimmed.drop_back().rtrim(); - size_t lastWs = decl.find_last_of(" \t"); - if (lastWs != llvm::StringRef::npos) { - llvm::StringRef varName = decl.drop_front(lastWs + 1); - if (varName.starts_with("v") && varName.size() > 1) { - bool allDigits = true; - for (char c : varName.drop_front(1)) { - if (c < '0' || c > '9') { - allDigits = false; - break; - } - } - if (allDigits) { - size_t indentLen = line.find_first_not_of(" \t"); - if (indentLen == std::string::npos) - indentLen = 0; - llvm::StringRef indent = line.take_front(indentLen); - - out.append(indent.str()); - out.append(decl.str()); - out.append("(nullptr);"); - rewritten = true; - } - } - } + llvm::StringRef decl; + llvm::StringRef varName; + if (isGeneratedGlobalTensorDecl(trimmed, decl, varName)) { + size_t indentLen = line.find_first_not_of(" \t"); + if (indentLen == std::string::npos) + indentLen = 0; + llvm::StringRef indent = line.take_front(indentLen); + + out.append(indent.str()); + out.append(decl.str()); + out.append("(nullptr);"); + rewritten = true; } if (!rewritten) @@ -838,13 +783,9 @@ static bool parseGeneratedValueAssignment(llvm::StringRef line, static void rewriteScalarConstantDecls(std::string &cpp) { llvm::SmallVector lines; - llvm::StringRef ref(cpp); - while (true) { + for (llvm::StringRef ref(cpp); !ref.empty(); ref = ref.split('\n').second) { auto split = ref.split('\n'); lines.push_back(split.first.str()); - if (split.second.empty()) - break; - ref = split.second; } llvm::SmallVector eraseLine(lines.size(), false); @@ -928,6 +869,12 @@ static void rewriteScalarConstantDecls(std::string &cpp) { cpp.swap(out); } +static bool shouldDeclareVariablesAtTop(ModuleOp module) { + auto hasMultiBlockFunc = [](auto func) { return func.getBlocks().size() > 1; }; + return llvm::any_of(module.getOps(), hasMultiBlockFunc) || + llvm::any_of(module.getOps(), hasMultiBlockFunc); +} + int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); @@ -940,10 +887,8 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); - //mlir::registerAllDialects(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); - //func::registerBufferizableOpInterfaceExternalModels(registry); pto::registerBufferizableOpInterfaceExternalModels(registry); registry.insert(); @@ -1141,7 +1086,6 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); pm.addPass(pto::createPTOViewToMemrefPass()); - //pm.addPass(createInferPTOMemScopePass()); if (effectiveLevel != PTOBuildLevel::Level3) { PlanMemoryOptions planMemoryOption; @@ -1183,21 +1127,7 @@ int main(int argc, char **argv) { // CFG-style lowering (e.g. scf.while -> cf.br/cf.cond_br) may introduce // multiple blocks, requiring variables to be declared at the top for valid // C++ emission. - bool declareVariablesAtTop = false; - for (auto func : module->getOps()) { - if (func.getBlocks().size() > 1) { - declareVariablesAtTop = true; - break; - } - } - if (!declareVariablesAtTop) { - for (auto func : module->getOps()) { - if (func.getBlocks().size() > 1) { - declareVariablesAtTop = true; - break; - } - } - } + bool declareVariablesAtTop = shouldDeclareVariablesAtTop(*module); if (failed(emitc::translateToCpp(*module, cppOS, /*declareVariablesAtTop=*/declareVariablesAtTop))) { llvm::errs() << "Error: Failed to emit C++.\n"; diff --git a/tools/ptobc/CMakeLists.txt b/tools/ptobc/CMakeLists.txt index 8224cd637..61fc5480c 100644 --- a/tools/ptobc/CMakeLists.txt +++ b/tools/ptobc/CMakeLists.txt @@ -21,6 +21,7 @@ add_library(ptobc_lib STATIC src/ptobc_format.cpp src/mlir_helpers.cpp src/mlir_encode.cpp + src/ptobc_opcodes_v0.cpp src/canonical_printer.cpp src/ptobc_decode_print.cpp ) diff --git a/tools/ptobc/src/canonical_printer.cpp b/tools/ptobc/src/canonical_printer.cpp index e2b2c908e..b3c0397d0 100644 --- a/tools/ptobc/src/canonical_printer.cpp +++ b/tools/ptobc/src/canonical_printer.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "ptobc/canonical_printer.h" #include @@ -189,6 +184,45 @@ static std::string canonicalConstBaseName(const std::string &imm, const std::str return base; } +static bool findConstantDefinition(const std::vector &lines, + const std::string &name, std::string &imm, + std::string &ty) { + for (const auto &line : lines) { + if (line.find('%' + name) == std::string::npos) + continue; + if (line.find("= arith.constant") == std::string::npos) + continue; + size_t pos = line.find('%'); + if (pos == std::string::npos) + continue; + size_t end = pos + 1; + while (end < line.size() && isSSAIdentChar(line[end])) + ++end; + if (line.substr(pos + 1, end - (pos + 1)) != name) + continue; + return parseConstantLine(line, imm, ty); + } + return false; +} + +static std::string getCanonicalSSAName(const std::vector &lines, + const std::string &oldName, + std::unordered_map &constCounts, + uint64_t &nextNonConst) { + std::string imm; + std::string ty; + if (!findConstantDefinition(lines, oldName, imm, ty)) + return std::to_string(nextNonConst++); + + std::string base = canonicalConstBaseName(imm, ty); + int &count = constCounts[base]; + std::string newName = base; + if (count > 0) + newName += "_" + std::to_string(count); + ++count; + return newName; +} + static std::string canonicalizeSSANames(const std::string &printed) { auto lines = splitLinesPreserveEmpty(printed); @@ -221,44 +255,10 @@ static std::string canonicalizeSSANames(const std::string &printed) { std::unordered_map ren; ren.reserve(defs.size() * 2); - // Pre-scan constants for nicer `%c...` aliases. std::unordered_map constCounts; - - // Assign names in definition order, but keep constants named via their immediates. uint64_t nextNonConst = 0; - - for (const auto &old : defs) { - // Find the line that defines this value to see if it is a constant. - // (Linear scan; ok for now.) - bool isConst = false; - std::string imm, ty; - for (const auto &ln : lines) { - // quick filter - if (ln.find('%' + old) == std::string::npos) continue; - if (ln.find("= arith.constant") == std::string::npos) continue; - // Must be the definition line. - size_t pos = ln.find('%'); - if (pos == std::string::npos) continue; - size_t j = pos + 1; - while (j < ln.size() && isSSAIdentChar(ln[j])) ++j; - if (ln.substr(pos + 1, j - (pos + 1)) != old) continue; - if (parseConstantLine(ln, imm, ty)) { - isConst = true; - } - break; - } - - if (isConst) { - std::string base = canonicalConstBaseName(imm, ty); - int &n = constCounts[base]; - std::string name = base; - if (n > 0) name += "_" + std::to_string(n); - ++n; - ren.emplace(old, name); - } else { - ren.emplace(old, std::to_string(nextNonConst++)); - } - } + for (const auto &old : defs) + ren.emplace(old, getCanonicalSSAName(lines, old, constCounts, nextNonConst)); return renameSSAInText(printed, ren); } diff --git a/tools/ptobc/src/main.cpp b/tools/ptobc/src/main.cpp index ac816909b..bd8b2034b 100644 --- a/tools/ptobc/src/main.cpp +++ b/tools/ptobc/src/main.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "ptobc/ptobc_format.h" #include @@ -24,6 +19,7 @@ #include #include +#include namespace ptobc { mlir::OwningOpRef parsePTOFile(mlir::MLIRContext& ctx, const std::string& path); @@ -38,68 +34,79 @@ static void usage() { << " ptobc decode -o \n"; } -int main(int argc, char** argv) { - if (argc < 2) { - usage(); - return 2; +struct CommandLineOptions { + std::string cmd; + std::string input; + std::string output; +}; + +static std::optional parseCommandLine(int argc, char **argv) { + if (argc < 2) + return std::nullopt; + + CommandLineOptions options{argv[1], "", ""}; + if (options.cmd != "encode" && options.cmd != "decode") + return options; + if (argc < 5) + return std::nullopt; + + options.input = argv[2]; + for (int i = 3; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "-o" && i + 1 < argc) + options.output = argv[++i]; } + if (options.output.empty()) + return std::nullopt; + return options; +} - std::string cmd = argv[1]; - std::string in; - std::string out; - - if (cmd == "encode" || cmd == "decode") { - if (argc < 5) { - usage(); - return 2; - } - in = argv[2]; - for (int i = 3; i < argc; ++i) { - std::string a = argv[i]; - if (a == "-o" && i + 1 < argc) { - out = argv[++i]; - } - } - if (out.empty()) { - std::cerr << "Missing -o\n"; - return 2; - } +static mlir::DialectRegistry buildRegistry() { + mlir::DialectRegistry registry; + registry.insert(); + return registry; +} + +static void preloadDialects(mlir::MLIRContext &ctx) { + (void)ctx.getOrLoadDialect(); + (void)ctx.getOrLoadDialect(); + (void)ctx.getOrLoadDialect(); + (void)ctx.getOrLoadDialect(); + (void)ctx.getOrLoadDialect(); + (void)ctx.getOrLoadDialect(); +} + +static int runEncode(const CommandLineOptions &options) { + mlir::MLIRContext ctx(buildRegistry()); + ctx.allowUnregisteredDialects(true); + preloadDialects(ctx); + + auto module = ptobc::parsePTOFile(ctx, options.input); + auto file = ptobc::encodeFromMLIRModule(*module); + auto bytes = file.serialize(); + ptobc::writeFile(options.output, bytes); + return 0; +} + +static int runDecode(const CommandLineOptions &options) { + ptobc::decodeFileToPTO(options.input, options.output); + return 0; +} + +int main(int argc, char **argv) { + auto options = parseCommandLine(argc, argv); + if (!options) { + usage(); + return 2; } try { - if (cmd == "encode") { - mlir::DialectRegistry registry; - // ptobc needs to parse sample .pto files that may include core MLIR - // dialects (affine/memref) in addition to PTO + a few basics. - registry.insert(); - mlir::MLIRContext ctx(registry); - ctx.allowUnregisteredDialects(true); - - // Preload dialects so custom op/type parsing is available. - (void)ctx.getOrLoadDialect(); - (void)ctx.getOrLoadDialect(); - (void)ctx.getOrLoadDialect(); - (void)ctx.getOrLoadDialect(); - (void)ctx.getOrLoadDialect(); - (void)ctx.getOrLoadDialect(); - - auto module = ptobc::parsePTOFile(ctx, in); - auto file = ptobc::encodeFromMLIRModule(*module); - auto bytes = file.serialize(); - ptobc::writeFile(out, bytes); - return 0; - } - - if (cmd == "decode") { - ptobc::decodeFileToPTO(in, out); - return 0; - } - + if (options->cmd == "encode") + return runEncode(*options); + if (options->cmd == "decode") + return runDecode(*options); usage(); return 2; } catch (const std::exception& e) { diff --git a/tools/ptobc/src/mlir_encode.cpp b/tools/ptobc/src/mlir_encode.cpp index 8388b4693..9361977ed 100644 --- a/tools/ptobc/src/mlir_encode.cpp +++ b/tools/ptobc/src/mlir_encode.cpp @@ -86,6 +86,53 @@ static std::string apIntToSignedDecimal(const llvm::APInt &v) { return std::string(digits.data(), digits.size()); } +static llvm::SmallVector copyAPIntWords(const llvm::APInt &bits) { + return llvm::SmallVector(bits.getRawData(), + bits.getRawData() + bits.getNumWords()); +} + +static void appendAPIntBytesLE(Buffer &buffer, const llvm::APInt &bits) { + const unsigned byteLen = (bits.getBitWidth() + 7) / 8; + writeULEB128(byteLen, buffer.bytes); + + llvm::SmallVector words = copyAPIntWords(bits); + for (unsigned i = 0; i < byteLen; ++i) { + unsigned word = i / 8; + unsigned off = (i % 8) * 8; + uint8_t byte = uint8_t((words[word] >> off) & 0xFFu); + buffer.bytes.push_back(byte); + } +} + +static std::optional +buildScalarConstantDebugName(mlir::Value value, + std::unordered_map &constCounts) { + auto cst = llvm::dyn_cast_or_null( + value.getDefiningOp()); + if (!cst) + return std::nullopt; + + mlir::Attribute attr = cst.getValue(); + std::string typeName = printType(value.getType()); + std::string baseName; + if (auto floatAttr = llvm::dyn_cast(attr)) { + baseName = "c" + hexFloatLiteral(floatAttr) + "_" + typeName; + } else if (auto intAttr = llvm::dyn_cast(attr)) { + baseName = "c" + apIntToSignedDecimal(intAttr.getValue()); + if (typeName != "index") + baseName += "_" + typeName; + } else { + return std::nullopt; + } + + int &count = constCounts[baseName]; + std::string name = baseName; + if (count > 0) + name += "_" + std::to_string(count); + ++count; + return name; +} + struct Encoder { PTOBCFile file; @@ -154,37 +201,8 @@ struct Encoder { for (uint64_t vid = 0; vid < valueById.size(); ++vid) { mlir::Value v = valueById[vid]; - std::string name; - - if (auto *def = v.getDefiningOp()) { - if (auto cst = llvm::dyn_cast(def)) { - mlir::Attribute a = cst.getValue(); - std::string ty = printType(v.getType()); - - // Only generate special names for scalar ints/floats. - if (auto fa = llvm::dyn_cast(a)) { - std::string imm = hexFloatLiteral(fa); - std::string base = "c" + imm + "_" + ty; - int &n = constCounts[base]; - name = base; - if (n > 0) name += "_" + std::to_string(n); - ++n; - } else if (auto ia = llvm::dyn_cast(a)) { - std::string imm = apIntToSignedDecimal(ia.getValue()); - std::string base = "c" + imm; - if (ty != "index") base += "_" + ty; - int &n = constCounts[base]; - name = base; - if (n > 0) name += "_" + std::to_string(n); - ++n; - } - } - } - - if (name.empty()) { - // Non-constant (or non-scalar-constant) value. - name = std::to_string(vid); - } + std::string name = buildScalarConstantDebugName(v, constCounts) + .value_or(std::to_string(vid)); uint64_t nameSid = file.strings.intern(name); file.dbgValueNames.push_back(DebugValueNameEntry{funcId, vid, nameSid}); @@ -216,43 +234,14 @@ struct Encoder { uint64_t internConstIntBits(uint64_t typeId, const llvm::APInt &bits) { Buffer p; writeULEB128(typeId, p.bytes); - - const unsigned byteLen = (bits.getBitWidth() + 7) / 8; - writeULEB128(byteLen, p.bytes); - - // little-endian bytes - llvm::SmallVector words; - words.resize(bits.getNumWords()); - std::memcpy(words.data(), bits.getRawData(), words.size() * sizeof(uint64_t)); - - for (unsigned i = 0; i < byteLen; ++i) { - unsigned word = i / 8; - unsigned off = (i % 8) * 8; - uint8_t b = uint8_t((words[word] >> off) & 0xFFu); - p.bytes.push_back(b); - } - + appendAPIntBytesLE(p, bits); return internConst(/*tag=*/0x04, p.bytes); } uint64_t internConstFloatBits(uint64_t dtypeId, const llvm::APInt &bits) { Buffer p; writeULEB128(dtypeId, p.bytes); - const unsigned byteLen = (bits.getBitWidth() + 7) / 8; - writeULEB128(byteLen, p.bytes); - - // little-endian bytes - llvm::SmallVector words; - words.resize(bits.getNumWords()); - std::memcpy(words.data(), bits.getRawData(), words.size() * sizeof(uint64_t)); - - for (unsigned i = 0; i < byteLen; ++i) { - unsigned word = i / 8; - unsigned off = (i % 8) * 8; - uint8_t b = uint8_t((words[word] >> off) & 0xFFu); - p.bytes.push_back(b); - } - + appendAPIntBytesLE(p, bits); return internConst(/*tag=*/0x02, p.bytes); } @@ -263,6 +252,18 @@ struct Encoder { valueById.clear(); } + void encodeKnownOpImmediates(mlir::Operation &op, Buffer &out, + const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo, + llvm::SmallVectorImpl &imms); + void encodeKnownOpOperands(mlir::Operation &op, Buffer &out, + const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo, + llvm::ArrayRef imms); + void encodeKnownOp(mlir::Operation &op, Buffer &out, + const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo); + void encodeGenericOp(mlir::Operation &op, Buffer &out); void encodeRegion(mlir::Region& region, Buffer& out); void encodeBlock(mlir::Block& block, Buffer& out); void encodeOp(mlir::Operation& op, Buffer& out); @@ -293,241 +294,262 @@ void Encoder::encodeBlock(mlir::Block& block, Buffer& out) { } } -void Encoder::encodeOp(mlir::Operation& op, Buffer& out) { - if (emitDebugInfo) { - // op_id (preorder DFS, per-function) - uint64_t opId = nextOpId++; - recordOpLocation(opId, op); - } - - // Try compact known-op encoding first (PTO-BC v0). - auto fullName = op.getName().getStringRef(); - auto ov = ptobc::v0::lookupOpcodeAndVariantByFullName(fullName); - if (ov) { - const auto *info = ptobc::v0::lookupByOpcode(ov->opcode); - if (!info) throw std::runtime_error("missing v0 opcode schema for op: " + fullName.str()); - - // Allocate value IDs for results first so nested regions can reference them. - const uint64_t resStart = valueId.size(); - for (auto res : op.getResults()) { - allocValueId(res); - } - - // u16 opcode - out.appendU16LE(ov->opcode); - - // attr_id (allow per-op stripping) - mlir::DictionaryAttr dict = op.getAttrDictionary(); - - // arith.constant: value is encoded via CONSTPOOL (imm_kind=0x05) - if (auto cst = llvm::dyn_cast(&op)) { - dict = stripAttr(op.getContext(), dict, "value"); - } - - auto attrId = internAttr(file, dict); - writeULEB128(attrId, out.bytes); - - // variant u8 - if (info->has_variant_u8) { - out.appendU8(ov->variant); +void Encoder::encodeKnownOpImmediates( + mlir::Operation &op, Buffer &out, const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo, + llvm::SmallVectorImpl &imms) { + switch (info.imm_kind) { + case 0x00: + return; + case 0x01: { + auto cmp = llvm::dyn_cast(&op); + if (!cmp) + throw std::runtime_error("imm_kind=cmpi_pred but op is not arith.cmpi"); + uint8_t predicate = 0; + switch (cmp.getPredicate()) { + case mlir::arith::CmpIPredicate::eq: + predicate = 0; + break; + case mlir::arith::CmpIPredicate::ne: + predicate = 1; + break; + case mlir::arith::CmpIPredicate::slt: + predicate = 2; + break; + case mlir::arith::CmpIPredicate::sle: + predicate = 3; + break; + case mlir::arith::CmpIPredicate::sgt: + predicate = 4; + break; + case mlir::arith::CmpIPredicate::sge: + predicate = 5; + break; + default: + throw std::runtime_error( + "unsupported arith.cmpi predicate (v0 supports only eq/ne/slt/sle/sgt/sge)"); } - - // immediates - llvm::SmallVector imms; - imms.clear(); - - if (info->imm_kind == 0x00) { - // none - } else if (info->imm_kind == 0x01) { - // arith.cmpi predicate - auto cmp = llvm::dyn_cast(&op); - if (!cmp) throw std::runtime_error("imm_kind=cmpi_pred but op is not arith.cmpi"); - uint8_t p; - switch (cmp.getPredicate()) { - case mlir::arith::CmpIPredicate::eq: p = 0; break; - case mlir::arith::CmpIPredicate::ne: p = 1; break; - case mlir::arith::CmpIPredicate::slt: p = 2; break; - case mlir::arith::CmpIPredicate::sle: p = 3; break; - case mlir::arith::CmpIPredicate::sgt: p = 4; break; - case mlir::arith::CmpIPredicate::sge: p = 5; break; - default: - throw std::runtime_error("unsupported arith.cmpi predicate (v0 supports only eq/ne/slt/sle/sgt/sge)"); - } - out.appendU8(p); - imms.push_back(p); - } else if (info->imm_kind == 0x02) { - // record_event/wait_event: event3(u8,u8,u8) - auto src = op.getAttrOfType("src_op"); - auto dst = op.getAttrOfType("dst_op"); - auto eid = op.getAttrOfType("event_id"); - if (!src || !dst || !eid) throw std::runtime_error("event op missing src_op/dst_op/event_id attrs"); - uint8_t a = uint8_t(src.getOpType()); - uint8_t b = uint8_t(dst.getOpType()); - uint8_t c = uint8_t(eid.getEvent()); - out.appendU8(a); - out.appendU8(b); - out.appendU8(c); - imms.push_back(a); - imms.push_back(b); - imms.push_back(c); - } else if (info->imm_kind == 0x05) { - // arith.constant: const_id(uLEB128) - auto cst = llvm::dyn_cast(&op); - if (!cst) throw std::runtime_error("imm_kind=const_id but op is not arith.constant"); - - mlir::Attribute a = cst.getValue(); - uint64_t cid = 0; - if (auto ia = llvm::dyn_cast(a)) { - uint64_t typeId = internType(file, cst.getType()); - const llvm::APInt &v = ia.getValue(); - if (v.getBitWidth() <= 64) { - cid = internConstInt64(typeId, v.getSExtValue()); - } else { - cid = internConstIntBits(typeId, v); - } - } else if (auto fa = llvm::dyn_cast(a)) { - uint64_t dtypeId = internType(file, cst.getType()); - cid = internConstFloatBits(dtypeId, fa.getValue().bitcastToAPInt()); - } else { - throw std::runtime_error("unsupported arith.constant attribute kind for compact v0"); - } - writeULEB128(cid, out.bytes); - imms.push_back(cid); - } else if (info->imm_kind == 0x06) { - // make_tensor_view: list_mode(u8), nshape(uLEB), nstrides(uLEB) - auto mtv = llvm::dyn_cast(&op); - if (!mtv) throw std::runtime_error("imm_kind=make_tensor_view but op is not pto.make_tensor_view"); - uint8_t lm = 0; // list_mode=0 (inline value_ids) - out.appendU8(lm); - writeULEB128(mtv.getShape().size(), out.bytes); - writeULEB128(mtv.getStrides().size(), out.bytes); - imms.push_back(lm); - imms.push_back(mtv.getShape().size()); - imms.push_back(mtv.getStrides().size()); - } else if (info->imm_kind == 0x07) { - // partition_view: list_mode(u8), noffsets(uLEB), nsizes(uLEB) - auto pv = llvm::dyn_cast(&op); - if (!pv) throw std::runtime_error("imm_kind=partition_view but op is not pto.partition_view"); - uint8_t lm = 0; - out.appendU8(lm); - writeULEB128(pv.getOffsets().size(), out.bytes); - writeULEB128(pv.getSizes().size(), out.bytes); - imms.push_back(lm); - imms.push_back(pv.getOffsets().size()); - imms.push_back(pv.getSizes().size()); - } else if (info->imm_kind == 0x08) { - // alloc_tile: optmask(u8) - auto at = llvm::dyn_cast(&op); - if (!at) throw std::runtime_error("imm_kind=alloc_tile but op is not pto.alloc_tile"); - uint8_t mask = 0; - if (at.getValidRow()) mask |= 0x1; - if (at.getValidCol()) mask |= 0x2; - out.appendU8(mask); - imms.push_back(mask); + out.appendU8(predicate); + imms.push_back(predicate); + return; + } + case 0x02: { + auto src = op.getAttrOfType("src_op"); + auto dst = op.getAttrOfType("dst_op"); + auto event = op.getAttrOfType("event_id"); + if (!src || !dst || !event) + throw std::runtime_error("event op missing src_op/dst_op/event_id attrs"); + uint8_t srcValue = uint8_t(src.getOpType()); + uint8_t dstValue = uint8_t(dst.getOpType()); + uint8_t eventValue = uint8_t(event.getEvent()); + out.appendU8(srcValue); + out.appendU8(dstValue); + out.appendU8(eventValue); + imms.append({srcValue, dstValue, eventValue}); + return; + } + case 0x05: { + auto cst = llvm::dyn_cast(&op); + if (!cst) + throw std::runtime_error("imm_kind=const_id but op is not arith.constant"); + + mlir::Attribute attr = cst.getValue(); + uint64_t constId = 0; + if (auto intAttr = llvm::dyn_cast(attr)) { + uint64_t typeId = internType(file, cst.getType()); + const llvm::APInt &value = intAttr.getValue(); + constId = value.getBitWidth() <= 64 ? internConstInt64(typeId, value.getSExtValue()) + : internConstIntBits(typeId, value); + } else if (auto floatAttr = llvm::dyn_cast(attr)) { + uint64_t typeId = internType(file, cst.getType()); + constId = internConstFloatBits(typeId, + floatAttr.getValue().bitcastToAPInt()); } else { - throw std::runtime_error("unknown imm_kind in v0 schema"); + throw std::runtime_error( + "unsupported arith.constant attribute kind for compact v0"); } + writeULEB128(constId, out.bytes); + imms.push_back(constId); + return; + } + case 0x06: { + auto mtv = llvm::dyn_cast(&op); + if (!mtv) + throw std::runtime_error( + "imm_kind=make_tensor_view but op is not pto.make_tensor_view"); + uint8_t listMode = 0; + out.appendU8(listMode); + writeULEB128(mtv.getShape().size(), out.bytes); + writeULEB128(mtv.getStrides().size(), out.bytes); + imms.append({listMode, uint64_t(mtv.getShape().size()), + uint64_t(mtv.getStrides().size())}); + return; + } + case 0x07: { + auto pv = llvm::dyn_cast(&op); + if (!pv) + throw std::runtime_error( + "imm_kind=partition_view but op is not pto.partition_view"); + uint8_t listMode = 0; + out.appendU8(listMode); + writeULEB128(pv.getOffsets().size(), out.bytes); + writeULEB128(pv.getSizes().size(), out.bytes); + imms.append({listMode, uint64_t(pv.getOffsets().size()), + uint64_t(pv.getSizes().size())}); + return; + } + case 0x08: { + auto at = llvm::dyn_cast(&op); + if (!at) + throw std::runtime_error( + "imm_kind=alloc_tile but op is not pto.alloc_tile"); + uint8_t mask = 0; + if (at.getValidRow()) + mask |= 0x1; + if (at.getValidCol()) + mask |= 0x2; + out.appendU8(mask); + imms.push_back(mask); + return; + } + default: + (void)variantInfo; + throw std::runtime_error("unknown imm_kind in v0 schema"); + } +} - // operands - auto emitOperands = [&](size_t n) { - if (op.getNumOperands() != n) { - throw std::runtime_error("operand count mismatch for op: " + fullName.str()); - } - for (auto v : op.getOperands()) { - writeULEB128(getValueId(v), out.bytes); - } - }; - - if (info->operand_mode == 0x00) { - emitOperands(info->num_operands); - - } else if (info->operand_mode == 0x01) { - auto n = ptobc::v0::lookupOperandsByVariant(ov->opcode, ov->variant); - if (!n) throw std::runtime_error("missing by-variant operand count"); - emitOperands(*n); - - } else if (info->operand_mode == 0x02) { - writeULEB128(op.getNumOperands(), out.bytes); - for (auto v : op.getOperands()) { - writeULEB128(getValueId(v), out.bytes); - } - - } else if (info->operand_mode == 0x03) { - // segmented (inline list_mode=0 only) - if (imms.size() < 3) throw std::runtime_error("segmented operands missing immediates"); - if (imms[0] != 0) throw std::runtime_error("list_mode=1 not implemented in ptobc encoder yet"); - const size_t base = info->num_operands; - const size_t n1 = size_t(imms[1]); - const size_t n2 = size_t(imms[2]); - emitOperands(base + n1 + n2); - - } else if (info->operand_mode == 0x04) { - // optional mask2 - if (imms.empty()) throw std::runtime_error("optmask operands missing immediate"); - uint8_t mask = uint8_t(imms[0]); - size_t n = ((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0); - emitOperands(n); - - } else { - throw std::runtime_error("unknown operand_mode in v0 schema"); +void Encoder::encodeKnownOpOperands( + mlir::Operation &op, Buffer &out, const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo, + llvm::ArrayRef imms) { + auto emitOperands = [&](size_t count) { + if (op.getNumOperands() != count) { + throw std::runtime_error("operand count mismatch for op: " + + op.getName().getStringRef().str()); } + for (auto value : op.getOperands()) + writeULEB128(getValueId(value), out.bytes); + }; - // explicit result type ids - if (info->result_type_mode == 0x01) { - if (op.getNumResults() != info->num_results) { - throw std::runtime_error("result count mismatch for op: " + fullName.str()); - } - for (auto res : op.getResults()) { - writeULEB128(internType(file, res.getType()), out.bytes); - } - } + switch (info.operand_mode) { + case 0x00: + emitOperands(info.num_operands); + return; + case 0x01: { + auto count = + ptobc::v0::lookupOperandsByVariant(variantInfo.opcode, variantInfo.variant); + if (!count) + throw std::runtime_error("missing by-variant operand count"); + emitOperands(*count); + return; + } + case 0x02: + writeULEB128(op.getNumOperands(), out.bytes); + for (auto value : op.getOperands()) + writeULEB128(getValueId(value), out.bytes); + return; + case 0x03: { + if (imms.size() < 3) + throw std::runtime_error("segmented operands missing immediates"); + if (imms[0] != 0) + throw std::runtime_error( + "list_mode=1 not implemented in ptobc encoder yet"); + emitOperands(size_t(info.num_operands) + size_t(imms[1]) + size_t(imms[2])); + return; + } + case 0x04: { + if (imms.empty()) + throw std::runtime_error("optmask operands missing immediate"); + uint8_t mask = uint8_t(imms.front()); + emitOperands(((mask & 0x1) ? 1 : 0) + ((mask & 0x2) ? 1 : 0)); + return; + } + default: + throw std::runtime_error("unknown operand_mode in v0 schema"); + } +} - // regions - if (op.getNumRegions() != info->num_regions) { - throw std::runtime_error("region count mismatch for op: " + fullName.str()); +void Encoder::encodeKnownOp(mlir::Operation &op, Buffer &out, + const ptobc::v0::OpInfo &info, + const ptobc::v0::OpcodeAndVariant &variantInfo) { + for (auto result : op.getResults()) + allocValueId(result); + + out.appendU16LE(variantInfo.opcode); + mlir::DictionaryAttr dict = op.getAttrDictionary(); + if (llvm::isa(&op)) + dict = stripAttr(op.getContext(), dict, "value"); + writeULEB128(internAttr(file, dict), out.bytes); + + if (info.has_variant_u8) + out.appendU8(variantInfo.variant); + + llvm::SmallVector imms; + encodeKnownOpImmediates(op, out, info, variantInfo, imms); + encodeKnownOpOperands(op, out, info, variantInfo, imms); + + if (info.result_type_mode == 0x01) { + if (op.getNumResults() != info.num_results) { + throw std::runtime_error("result count mismatch for op: " + + op.getName().getStringRef().str()); } - for (auto &r : op.getRegions()) { - encodeRegion(r, out); - } - - (void)resStart; - return; + for (auto result : op.getResults()) + writeULEB128(internType(file, result.getType()), out.bytes); } - if (!allowGeneric) { - throw std::runtime_error("op is not in v0 opcode table (and PTOBC_ALLOW_GENERIC is not set): " + fullName.str()); + if (op.getNumRegions() != info.num_regions) { + throw std::runtime_error("region count mismatch for op: " + + op.getName().getStringRef().str()); } + for (auto ®ion : op.getRegions()) + encodeRegion(region, out); +} - // === Generic op escape === +void Encoder::encodeGenericOp(mlir::Operation &op, Buffer &out) { out.appendU16LE(kOpcodeGeneric); + writeULEB128(internAttr(file, op.getAttrDictionary()), out.bytes); - // attr_id - auto attrId = internAttr(file, op.getAttrDictionary()); - writeULEB128(attrId, out.bytes); - - // op-name - auto opName = op.getName().getStringRef().str(); - auto opNameSid = file.strings.intern(opName); + auto opNameSid = file.strings.intern(op.getName().getStringRef().str()); writeULEB128(opNameSid, out.bytes); - // results writeULEB128(op.getNumResults(), out.bytes); - for (auto res : op.getResults()) { - allocValueId(res); - writeULEB128(internType(file, res.getType()), out.bytes); + for (auto result : op.getResults()) { + allocValueId(result); + writeULEB128(internType(file, result.getType()), out.bytes); } - // operands writeULEB128(op.getNumOperands(), out.bytes); - for (auto operand : op.getOperands()) { + for (auto operand : op.getOperands()) writeULEB128(getValueId(operand), out.bytes); - } - // regions writeULEB128(op.getNumRegions(), out.bytes); - for (auto& r : op.getRegions()) { - encodeRegion(r, out); + for (auto ®ion : op.getRegions()) + encodeRegion(region, out); +} + +void Encoder::encodeOp(mlir::Operation& op, Buffer& out) { + if (emitDebugInfo) { + uint64_t opId = nextOpId++; + recordOpLocation(opId, op); + } + + auto fullName = op.getName().getStringRef(); + auto variantInfo = ptobc::v0::lookupOpcodeAndVariantByFullName(fullName); + if (variantInfo) { + const auto *info = ptobc::v0::lookupByOpcode(variantInfo->opcode); + if (!info) + throw std::runtime_error("missing v0 opcode schema for op: " + + fullName.str()); + encodeKnownOp(op, out, *info, *variantInfo); + return; + } + + if (!allowGeneric) { + throw std::runtime_error( + "op is not in v0 opcode table (and PTOBC_ALLOW_GENERIC is not set): " + + fullName.str()); } + encodeGenericOp(op, out); } PTOBCFile encodeFromMLIRModule(mlir::ModuleOp module) { diff --git a/tools/ptobc/src/ptobc_decode_print.cpp b/tools/ptobc/src/ptobc_decode_print.cpp index 5a9e3348a..b5199c7c0 100644 --- a/tools/ptobc/src/ptobc_decode_print.cpp +++ b/tools/ptobc/src/ptobc_decode_print.cpp @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - #include "ptobc/mlir_helpers.h" #include "ptobc/ptobc_format.h" #include "ptobc/leb128.h" @@ -305,6 +300,51 @@ static mlir::DictionaryAttr getAttrDict(BuildCtx& bc, uint64_t aid) { static void buildRegionInto(BuildCtx& bc, Reader& r, mlir::Region& region); +static llvm::APInt rebuildAPIntFromBytes(llvm::ArrayRef bytes, + unsigned bitWidth) { + const unsigned numWords = (bitWidth + 63) / 64; + llvm::SmallVector words(numWords, 0); + for (unsigned i = 0; i < bytes.size(); ++i) { + unsigned word = i / 8; + unsigned off = (i % 8) * 8; + words[word] |= (uint64_t(bytes[i]) << off); + } + return llvm::APInt(bitWidth, words); +} + +static mlir::Attribute buildFloatConstAttr(BuildCtx &bc, + const ConstEntryParsed &entry) { + auto ty = getType(bc, entry.typeId); + auto floatType = mlir::dyn_cast(ty); + if (!floatType) + throw std::runtime_error("ConstFloatBits type is not FloatType"); + + unsigned bitWidth = floatType.getWidth(); + unsigned byteLen = (bitWidth + 7) / 8; + if (entry.floatBytes.size() != byteLen) + throw std::runtime_error("ConstFloatBits byte_len mismatch"); + + llvm::APInt bits = rebuildAPIntFromBytes(entry.floatBytes, bitWidth); + llvm::APFloat value(floatType.getFloatSemantics(), bits); + return mlir::FloatAttr::get(floatType, value); +} + +static mlir::Attribute buildIntegerConstAttr(BuildCtx &bc, + const ConstEntryParsed &entry) { + auto ty = getType(bc, entry.typeId); + auto intType = mlir::dyn_cast(ty); + if (!intType) + throw std::runtime_error("ConstIntBits type is not IntegerType"); + + unsigned bitWidth = intType.getWidth(); + unsigned byteLen = (bitWidth + 7) / 8; + if (entry.intBytes.size() != byteLen) + throw std::runtime_error("ConstIntBits byte_len mismatch"); + + llvm::APInt bits = rebuildAPIntFromBytes(entry.intBytes, bitWidth); + return mlir::IntegerAttr::get(intType, bits); +} + static mlir::Attribute buildConstAttr(BuildCtx &bc, uint64_t constId) { if (!bc.consts) throw std::runtime_error("constpool not available"); if (constId >= bc.consts->size()) throw std::runtime_error("const_id out of range"); @@ -320,54 +360,238 @@ static mlir::Attribute buildConstAttr(BuildCtx &bc, uint64_t constId) { return mlir::IntegerAttr::get(ty, e.intValue); } - if (e.tag == 0x02) { - auto ty = getType(bc, e.typeId); - auto ft = mlir::dyn_cast(ty); - if (!ft) throw std::runtime_error("ConstFloatBits type is not FloatType"); + if (e.tag == 0x02) + return buildFloatConstAttr(bc, e); - unsigned bitWidth = ft.getWidth(); - unsigned byteLen = (bitWidth + 7) / 8; - if (e.floatBytes.size() != byteLen) { - throw std::runtime_error("ConstFloatBits byte_len mismatch"); - } + if (e.tag == 0x04) + return buildIntegerConstAttr(bc, e); - const unsigned numWords = (bitWidth + 63) / 64; - llvm::SmallVector words(numWords, 0); - for (unsigned i = 0; i < byteLen; ++i) { - unsigned w = i / 8; - unsigned off = (i % 8) * 8; - words[w] |= (uint64_t(e.floatBytes[i]) << off); - } + throw std::runtime_error("unsupported const tag"); +} + +static void addAttrDictionary(mlir::OperationState &state, + mlir::DictionaryAttr dict) { + for (auto attr : dict) + state.addAttribute(attr.getName(), attr.getValue()); +} + +static void registerDecodedOp(BuildCtx &bc, uint64_t opId, mlir::Operation *op) { + if (!bc.opsById) + return; + if (opId >= bc.opsById->size()) + bc.opsById->resize(opId + 1, nullptr); + (*bc.opsById)[opId] = op; +} + +static void assignDecodedResults(BuildCtx &bc, size_t resStart, + mlir::Operation *op, size_t numResults) { + for (size_t i = 0; i < numResults; ++i) + bc.values[resStart + i] = op->getResult(i); +} + +static llvm::SmallVector readValueIds(Reader &r, size_t count) { + llvm::SmallVector ids; + ids.reserve(count); + for (size_t i = 0; i < count; ++i) + ids.push_back(r.readULEB()); + return ids; +} - llvm::APInt bits(bitWidth, words); - llvm::APFloat f(ft.getFloatSemantics(), bits); - return mlir::FloatAttr::get(ft, f); +struct KnownOpImmediates { + uint8_t cmpPred = 0; + uint8_t evA = 0; + uint8_t evB = 0; + uint8_t evC = 0; + uint64_t constId = 0; + uint8_t listMode = 0; + uint64_t n1 = 0; + uint64_t n2 = 0; + uint8_t optMask = 0; +}; + +static KnownOpImmediates readKnownOpImmediates(Reader &r, + const ptobc::v0::OpInfo &info) { + KnownOpImmediates imms; + switch (info.imm_kind) { + case 0x00: + return imms; + case 0x01: + imms.cmpPred = r.readU8(); + return imms; + case 0x02: + imms.evA = r.readU8(); + imms.evB = r.readU8(); + imms.evC = r.readU8(); + return imms; + case 0x05: + imms.constId = r.readULEB(); + return imms; + case 0x06: + case 0x07: + imms.listMode = r.readU8(); + imms.n1 = r.readULEB(); + imms.n2 = r.readULEB(); + return imms; + case 0x08: + imms.optMask = r.readU8(); + return imms; + default: + throw std::runtime_error("unknown imm_kind"); } +} - if (e.tag == 0x04) { - auto ty = getType(bc, e.typeId); - auto it = mlir::dyn_cast(ty); - if (!it) throw std::runtime_error("ConstIntBits type is not IntegerType"); +static llvm::SmallVector +readKnownOperandIds(BuildCtx &bc, Reader &r, uint16_t opcode, uint8_t variant, + const ptobc::v0::OpInfo &info, + const KnownOpImmediates &imms) { + switch (info.operand_mode) { + case 0x00: + return readValueIds(r, info.num_operands); + case 0x01: { + auto count = ptobc::v0::lookupOperandsByVariant(opcode, variant); + if (!count) + throw std::runtime_error("missing by-variant operand count"); + return readValueIds(r, *count); + } + case 0x02: + return readValueIds(r, r.readULEB()); + case 0x03: + if (imms.listMode != 0) + throw std::runtime_error("list_mode=1 not supported yet"); + return readValueIds(r, size_t(info.num_operands) + size_t(imms.n1) + + size_t(imms.n2)); + case 0x04: + return readValueIds(r, ((imms.optMask & 0x1) ? 1 : 0) + + ((imms.optMask & 0x2) ? 1 : 0)); + default: + (void)bc; + throw std::runtime_error("unknown operand_mode"); + } +} - unsigned bitWidth = it.getWidth(); - unsigned byteLen = (bitWidth + 7) / 8; - if (e.intBytes.size() != byteLen) { - throw std::runtime_error("ConstIntBits byte_len mismatch"); - } +static llvm::SmallVector +materializeOperands(BuildCtx &bc, llvm::ArrayRef operandIds) { + llvm::SmallVector operands; + operands.reserve(operandIds.size()); + for (uint64_t valueId : operandIds) { + if (valueId >= bc.values.size()) + throw std::runtime_error("operand value_id out of range"); + operands.push_back(bc.values[valueId]); + } + return operands; +} - const unsigned numWords = (bitWidth + 63) / 64; - llvm::SmallVector words(numWords, 0); - for (unsigned i = 0; i < byteLen; ++i) { - unsigned w = i / 8; - unsigned off = (i % 8) * 8; - words[w] |= (uint64_t(e.intBytes[i]) << off); - } +static mlir::Operation *buildGenericOpFromReader(BuildCtx &bc, Reader &r, + mlir::Block &block, + uint64_t opId, + uint64_t attrId) { + uint64_t nameSid = r.readULEB(); + if (nameSid >= bc.strings->size()) + throw std::runtime_error("bad op_name sid"); + std::string opName = (*bc.strings)[nameSid]; + + uint64_t nres = r.readULEB(); + llvm::SmallVector resultTypes; + resultTypes.reserve(nres); + + const size_t resStart = bc.values.size(); + for (uint64_t i = 0; i < nres; ++i) { + resultTypes.push_back(getType(bc, r.readULEB())); + bc.values.push_back(mlir::Value()); + } - llvm::APInt bits(bitWidth, words); - return mlir::IntegerAttr::get(it, bits); + auto operandIds = readValueIds(r, r.readULEB()); + auto operands = materializeOperands(bc, operandIds); + uint64_t nreg = r.readULEB(); + + mlir::OperationState state(mlir::UnknownLoc::get(bc.ctx), opName); + state.addOperands(operands); + state.addTypes(resultTypes); + addAttrDictionary(state, getAttrDict(bc, attrId)); + for (uint64_t i = 0; i < nreg; ++i) + (void)state.addRegion(); + + mlir::Operation *op = mlir::Operation::create(state); + block.getOperations().push_back(op); + registerDecodedOp(bc, opId, op); + assignDecodedResults(bc, resStart, op, nres); + for (uint64_t i = 0; i < nreg; ++i) + buildRegionInto(bc, r, op->getRegion(i)); + return op; +} + +static void addImmediateAttrs(BuildCtx &bc, mlir::OperationState &state, + const ptobc::v0::OpInfo &info, + const KnownOpImmediates &imms) { + switch (info.imm_kind) { + case 0x01: + state.addAttribute("predicate", + mlir::arith::CmpIPredicateAttr::get( + bc.ctx, mlir::arith::CmpIPredicate(imms.cmpPred))); + return; + case 0x02: + state.addAttribute("src_op", mlir::pto::SyncOpTypeAttr::get( + bc.ctx, mlir::pto::SyncOpType(imms.evA))); + state.addAttribute("dst_op", mlir::pto::SyncOpTypeAttr::get( + bc.ctx, mlir::pto::SyncOpType(imms.evB))); + state.addAttribute("event_id", + mlir::pto::EventAttr::get(bc.ctx, + mlir::pto::EVENT(imms.evC))); + return; + case 0x05: + state.addAttribute("value", buildConstAttr(bc, imms.constId)); + return; + default: + return; } +} - throw std::runtime_error("unsupported const tag"); +static mlir::Operation *buildKnownOpFromReader(BuildCtx &bc, Reader &r, + mlir::Block &block, uint64_t opId, + uint16_t opcode, + uint64_t attrId) { + const auto *info = ptobc::v0::lookupByOpcode(opcode); + if (!info) + throw std::runtime_error("missing opcode schema"); + + uint8_t variant = info->has_variant_u8 ? r.readU8() : 0; + KnownOpImmediates imms = readKnownOpImmediates(r, *info); + auto operandIds = readKnownOperandIds(bc, r, opcode, variant, *info, imms); + auto operands = materializeOperands(bc, operandIds); + + llvm::SmallVector resultTypes; + resultTypes.reserve(info->num_results); + if (info->result_type_mode == 0x01) { + for (unsigned i = 0; i < info->num_results; ++i) + resultTypes.push_back(getType(bc, r.readULEB())); + } else { + for (unsigned i = 0; i < info->num_results; ++i) + resultTypes.push_back(mlir::NoneType::get(bc.ctx)); + } + + const size_t resStart = bc.values.size(); + for (unsigned i = 0; i < info->num_results; ++i) + bc.values.push_back(mlir::Value()); + + const char *opNameC = ptobc::v0::fullNameFromOpcodeVariant(opcode, variant); + if (!opNameC) + throw std::runtime_error("failed to map opcode->name"); + + mlir::OperationState state(mlir::UnknownLoc::get(bc.ctx), opNameC); + state.addOperands(operands); + state.addTypes(resultTypes); + addAttrDictionary(state, getAttrDict(bc, attrId)); + addImmediateAttrs(bc, state, *info, imms); + for (unsigned i = 0; i < info->num_regions; ++i) + (void)state.addRegion(); + + mlir::Operation *op = mlir::Operation::create(state); + block.getOperations().push_back(op); + registerDecodedOp(bc, opId, op); + assignDecodedResults(bc, resStart, op, info->num_results); + for (unsigned i = 0; i < info->num_regions; ++i) + buildRegionInto(bc, r, op->getRegion(i)); + return op; } static void buildOpList(BuildCtx& bc, Reader& r, mlir::Block& block) { @@ -382,208 +606,11 @@ static void buildOpList(BuildCtx& bc, Reader& r, mlir::Block& block) { uint16_t opcode = r.readU16LE(); uint64_t attrId = r.readULEB(); - // Generic escape. if (opcode == kOpcodeGeneric) { - uint64_t nameSid = r.readULEB(); - if (nameSid >= bc.strings->size()) throw std::runtime_error("bad op_name sid"); - std::string opName = (*bc.strings)[nameSid]; - - uint64_t nres = r.readULEB(); - llvm::SmallVector resTypes; - resTypes.reserve(nres); - - const size_t resStart = bc.values.size(); - for (uint64_t i = 0; i < nres; ++i) { - uint64_t tid = r.readULEB(); - resTypes.push_back(getType(bc, tid)); - bc.values.push_back(mlir::Value()); - } - - uint64_t nops = r.readULEB(); - llvm::SmallVector operands; - operands.reserve(nops); - for (uint64_t i = 0; i < nops; ++i) { - uint64_t vid = r.readULEB(); - if (vid >= bc.values.size()) throw std::runtime_error("operand value_id out of range"); - operands.push_back(bc.values[vid]); - } - - uint64_t nreg = r.readULEB(); - - mlir::OperationState st(mlir::UnknownLoc::get(bc.ctx), opName); - st.addOperands(operands); - st.addTypes(resTypes); - - auto dict = getAttrDict(bc, attrId); - for (auto na : dict) { - st.addAttribute(na.getName(), na.getValue()); - } - - for (uint64_t ri = 0; ri < nreg; ++ri) (void)st.addRegion(); - - mlir::Operation* op = mlir::Operation::create(st); - block.getOperations().push_back(op); - - if (bc.opsById) { - if (opId >= bc.opsById->size()) bc.opsById->resize(opId + 1, nullptr); - (*bc.opsById)[opId] = op; - } - - for (uint64_t i = 0; i < nres; ++i) { - bc.values[resStart + i] = op->getResult(i); - } - - for (uint64_t ri = 0; ri < nreg; ++ri) { - buildRegionInto(bc, r, op->getRegion(ri)); - } + buildGenericOpFromReader(bc, r, block, opId, attrId); continue; } - - // Known compact op. - const auto *info = ptobc::v0::lookupByOpcode(opcode); - if (!info) throw std::runtime_error("missing opcode schema"); - - uint8_t variant = 0; - if (info->has_variant_u8) { - variant = r.readU8(); - } - - // immediates - uint8_t cmpPred = 0; - uint8_t evA = 0, evB = 0, evC = 0; - uint64_t constId = 0; - uint8_t listMode = 0; - uint64_t n1 = 0, n2 = 0; - uint8_t optMask = 0; - - switch (info->imm_kind) { - case 0x00: - break; - case 0x01: - cmpPred = r.readU8(); - break; - case 0x02: - evA = r.readU8(); - evB = r.readU8(); - evC = r.readU8(); - break; - case 0x05: - constId = r.readULEB(); - break; - case 0x06: - case 0x07: - listMode = r.readU8(); - n1 = r.readULEB(); - n2 = r.readULEB(); - break; - case 0x08: - optMask = r.readU8(); - break; - default: - throw std::runtime_error("unknown imm_kind"); - } - - auto readValueIds = [&](size_t n) { - llvm::SmallVector ids; - ids.reserve(n); - for (size_t i = 0; i < n; ++i) ids.push_back(r.readULEB()); - return ids; - }; - - llvm::SmallVector operandIds; - - if (info->operand_mode == 0x00) { - operandIds = readValueIds(info->num_operands); - } else if (info->operand_mode == 0x01) { - auto n = ptobc::v0::lookupOperandsByVariant(opcode, variant); - if (!n) throw std::runtime_error("missing by-variant operand count"); - operandIds = readValueIds(*n); - } else if (info->operand_mode == 0x02) { - uint64_t n = r.readULEB(); - operandIds = readValueIds(n); - } else if (info->operand_mode == 0x03) { - if (listMode != 0) throw std::runtime_error("list_mode=1 not supported yet"); - size_t n = size_t(info->num_operands) + size_t(n1) + size_t(n2); - operandIds = readValueIds(n); - } else if (info->operand_mode == 0x04) { - size_t n = ((optMask & 0x1) ? 1 : 0) + ((optMask & 0x2) ? 1 : 0); - operandIds = readValueIds(n); - } else { - throw std::runtime_error("unknown operand_mode"); - } - - llvm::SmallVector operands; - operands.reserve(operandIds.size()); - for (auto vid : operandIds) { - if (vid >= bc.values.size()) throw std::runtime_error("operand value_id out of range"); - operands.push_back(bc.values[vid]); - } - - // result types - llvm::SmallVector resTypes; - resTypes.reserve(info->num_results); - if (info->result_type_mode == 0x01) { - for (unsigned i = 0; i < info->num_results; ++i) { - uint64_t tid = r.readULEB(); - resTypes.push_back(getType(bc, tid)); - } - } else { - // v0 currently expects explicit for all result-producing ops. - for (unsigned i = 0; i < info->num_results; ++i) { - resTypes.push_back(mlir::NoneType::get(bc.ctx)); - } - } - - // Reserve value ids for results. - const size_t resStart = bc.values.size(); - for (unsigned i = 0; i < info->num_results; ++i) { - bc.values.push_back(mlir::Value()); - } - - // op name - const char *opNameC = ptobc::v0::fullNameFromOpcodeVariant(opcode, variant); - if (!opNameC) throw std::runtime_error("failed to map opcode->name"); - llvm::StringRef opName(opNameC); - - mlir::OperationState st(mlir::UnknownLoc::get(bc.ctx), opName); - st.addOperands(operands); - st.addTypes(resTypes); - - auto dict = getAttrDict(bc, attrId); - for (auto na : dict) { - st.addAttribute(na.getName(), na.getValue()); - } - - // immediate-derived attributes - if (info->imm_kind == 0x01) { - auto pred = (mlir::arith::CmpIPredicate)cmpPred; - st.addAttribute("predicate", mlir::arith::CmpIPredicateAttr::get(bc.ctx, pred)); - } else if (info->imm_kind == 0x02) { - st.addAttribute("src_op", mlir::pto::SyncOpTypeAttr::get(bc.ctx, (mlir::pto::SyncOpType)evA)); - st.addAttribute("dst_op", mlir::pto::SyncOpTypeAttr::get(bc.ctx, (mlir::pto::SyncOpType)evB)); - st.addAttribute("event_id", mlir::pto::EventAttr::get(bc.ctx, (mlir::pto::EVENT)evC)); - } else if (info->imm_kind == 0x05) { - st.addAttribute("value", buildConstAttr(bc, constId)); - } - - // regions - for (unsigned ri = 0; ri < info->num_regions; ++ri) (void)st.addRegion(); - - mlir::Operation *op = mlir::Operation::create(st); - block.getOperations().push_back(op); - - if (bc.opsById) { - if (opId >= bc.opsById->size()) bc.opsById->resize(opId + 1, nullptr); - (*bc.opsById)[opId] = op; - } - - for (unsigned i = 0; i < info->num_results; ++i) { - bc.values[resStart + i] = op->getResult(i); - } - - for (unsigned ri = 0; ri < info->num_regions; ++ri) { - buildRegionInto(bc, r, op->getRegion(ri)); - } + buildKnownOpFromReader(bc, r, block, opId, opcode, attrId); } } @@ -611,90 +638,121 @@ static void buildRegionInto(BuildCtx& bc, Reader& r, mlir::Region& region) { } } -static mlir::ModuleOp decodeToModule(mlir::MLIRContext& ctx, - const std::vector& strings, - const std::vector& types, - const std::vector& attrs, - const std::vector& constPool, - const std::vector& moduleBytes, - std::vector>* opsByFuncOut) { - const bool dbg = debugEnabled(); +struct FuncDecl { + std::string name; + mlir::FunctionType type; + mlir::DictionaryAttr attrs; + uint8_t flags = 0; +}; - Reader r{moduleBytes.data(), moduleBytes.data() + moduleBytes.size()}; +static uint64_t readModuleHeader(Reader &r, bool dbg) { uint8_t profile = r.readU8(); uint8_t indexWidth = r.readU8(); - if (dbg) llvm::errs() << "[ptobc] module: profile=" << unsigned(profile) << " indexWidth=" << unsigned(indexWidth) << "\n"; + if (dbg) { + llvm::errs() << "[ptobc] module: profile=" << unsigned(profile) + << " indexWidth=" << unsigned(indexWidth) << "\n"; + } uint64_t moduleAttrId = r.readULEB(); - uint64_t gcnt = r.readULEB(); - if (dbg) llvm::errs() << "[ptobc] module: moduleAttrId=" << moduleAttrId << " globals=" << gcnt << "\n"; - for (uint64_t i = 0; i < gcnt; ++i) { - throw std::runtime_error("globals not supported"); + uint64_t globalCount = r.readULEB(); + if (dbg) { + llvm::errs() << "[ptobc] module: moduleAttrId=" << moduleAttrId + << " globals=" << globalCount << "\n"; } + if (globalCount != 0) + throw std::runtime_error("globals not supported"); + return moduleAttrId; +} - uint64_t fcnt = r.readULEB(); - if (dbg) llvm::errs() << "[ptobc] module: funcs=" << fcnt << "\n"; +static std::vector readFunctionDecls(BuildCtx &bc, Reader &r, + bool dbg) { + uint64_t funcCount = r.readULEB(); + if (dbg) + llvm::errs() << "[ptobc] module: funcs=" << funcCount << "\n"; - struct FuncDecl { std::string name; mlir::FunctionType type; mlir::DictionaryAttr attrs; uint8_t flags; }; std::vector decls; - decls.reserve(fcnt); - - std::vector consts; - parseConstPoolSection(constPool, consts); - - BuildCtx bc{&ctx, &strings, &types, &attrs, &consts, {}, nullptr, nullptr}; - - for (uint64_t i = 0; i < fcnt; ++i) { + decls.reserve(funcCount); + for (uint64_t i = 0; i < funcCount; ++i) { uint64_t nameSid = r.readULEB(); uint64_t ftypeId = r.readULEB(); uint8_t flags = r.readU8(); uint64_t fattrId = r.readULEB(); - if (nameSid >= strings.size()) throw std::runtime_error("bad func name sid"); - if (ftypeId >= types.size()) throw std::runtime_error("bad func type id"); - - if (dbg) llvm::errs() << "[ptobc] func[" << i << "]: nameSid=" << nameSid << " ftypeId=" << ftypeId << " flags=" << unsigned(flags) << " fattrId=" << fattrId << "\n"; - - auto ty = parseType(ctx, types.at(ftypeId).asmStr); - auto fty = mlir::dyn_cast(ty); - if (!fty) throw std::runtime_error("func type parse failed"); + if (nameSid >= bc.strings->size()) + throw std::runtime_error("bad func name sid"); + if (ftypeId >= bc.types->size()) + throw std::runtime_error("bad func type id"); + + if (dbg) { + llvm::errs() << "[ptobc] func[" << i << "]: nameSid=" << nameSid + << " ftypeId=" << ftypeId << " flags=" << unsigned(flags) + << " fattrId=" << fattrId << "\n"; + } - decls.push_back({strings[nameSid], fty, getAttrDict(bc, fattrId), flags}); + auto type = parseType(*bc.ctx, bc.types->at(ftypeId).asmStr); + auto funcType = mlir::dyn_cast(type); + if (!funcType) + throw std::runtime_error("func type parse failed"); + decls.push_back( + {bc.strings->at(nameSid), funcType, getAttrDict(bc, fattrId), flags}); } + return decls; +} - auto module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx)); +static void applyAttrDictionary(mlir::Operation *op, mlir::DictionaryAttr dict) { + for (auto attr : dict) + op->setAttr(attr.getName(), attr.getValue()); +} - // Apply module attrs - auto modDict = getAttrDict(bc, moduleAttrId); - for (auto na : modDict) { - module->setAttr(na.getName(), na.getValue()); +static void buildFunctionBody(BuildCtx &bc, Reader &r, mlir::func::FuncOp fn, + uint8_t flags, bool dbg, + std::vector> *opsByFuncOut) { + if ((flags & 0x1) != 0) { + if (opsByFuncOut) + opsByFuncOut->push_back({}); + return; } - for (uint64_t i = 0; i < fcnt; ++i) { - if (dbg) llvm::errs() << "[ptobc] building func body: " << decls[i].name << "\n"; - - auto fn = mlir::func::FuncOp::create(mlir::UnknownLoc::get(&ctx), decls[i].name, decls[i].type); - if (dbg) llvm::errs() << "[ptobc] created func op\n"; - for (auto na : decls[i].attrs) { - fn->setAttr(na.getName(), na.getValue()); - } - - if ((decls[i].flags & 0x1) == 0) { - // decode body region - bc.values.clear(); + bc.values.clear(); + uint64_t nextOpId = 0; + std::vector opsById; + bc.nextOpId = &nextOpId; + bc.opsById = &opsById; + buildRegionInto(bc, r, fn.getBody()); + if (dbg) { + llvm::errs() << "[ptobc] func body built ok: values=" << bc.values.size() + << " ops=" << opsById.size() << "\n"; + } + if (opsByFuncOut) + opsByFuncOut->push_back(std::move(opsById)); +} - uint64_t nextOpId = 0; - std::vector opsById; - bc.nextOpId = &nextOpId; - bc.opsById = &opsById; +static mlir::ModuleOp decodeToModule(mlir::MLIRContext& ctx, + const std::vector& strings, + const std::vector& types, + const std::vector& attrs, + const std::vector& constPool, + const std::vector& moduleBytes, + std::vector>* opsByFuncOut) { + const bool dbg = debugEnabled(); - buildRegionInto(bc, r, fn.getBody()); - if (dbg) llvm::errs() << "[ptobc] func body built ok: values=" << bc.values.size() << " ops=" << opsById.size() << "\n"; + Reader r{moduleBytes.data(), moduleBytes.data() + moduleBytes.size()}; + std::vector consts; + parseConstPoolSection(constPool, consts); + BuildCtx bc{&ctx, &strings, &types, &attrs, &consts, {}, nullptr, nullptr}; + uint64_t moduleAttrId = readModuleHeader(r, dbg); + std::vector decls = readFunctionDecls(bc, r, dbg); - if (opsByFuncOut) opsByFuncOut->push_back(std::move(opsById)); - } else { - if (opsByFuncOut) opsByFuncOut->push_back({}); - } + auto module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx)); + applyAttrDictionary(module.getOperation(), getAttrDict(bc, moduleAttrId)); + for (const auto &decl : decls) { + if (dbg) + llvm::errs() << "[ptobc] building func body: " << decl.name << "\n"; + auto fn = mlir::func::FuncOp::create(mlir::UnknownLoc::get(&ctx), decl.name, + decl.type); + if (dbg) llvm::errs() << "[ptobc] created func op\n"; + applyAttrDictionary(fn, decl.attrs); + buildFunctionBody(bc, r, fn, decl.flags, dbg, opsByFuncOut); module.push_back(fn); } @@ -702,6 +760,38 @@ static mlir::ModuleOp decodeToModule(mlir::MLIRContext& ctx, return module; } +static std::pair> readSection(Reader &r, bool dbg) { + uint8_t sid = r.readU8(); + uint32_t sectionLen = r.readU32LE(); + auto bytes = r.readBytes(sectionLen); + if (dbg) + llvm::errs() << "[ptobc] section id=" << unsigned(sid) + << " len=" << sectionLen << "\n"; + return {sid, bytes}; +} + +static void applyDebugLocations(mlir::MLIRContext &ctx, + const std::vector &strings, + const DebugInfo &dbgInfo, + const std::vector> &opsByFunc) { + for (const auto &location : dbgInfo.locations) { + if (location.funcId >= opsByFunc.size()) + continue; + const auto &ops = opsByFunc[location.funcId]; + if (location.opId >= ops.size()) + continue; + mlir::Operation *op = ops[location.opId]; + if (!op || location.fileId >= dbgInfo.files.size()) + continue; + const auto &file = dbgInfo.files[location.fileId]; + if (file.pathSid >= strings.size()) + continue; + op->setLoc(mlir::FileLineColLoc::get(&ctx, strings[file.pathSid], + unsigned(location.sl), + unsigned(location.sc))); + } +} + mlir::OwningOpRef decodePTOBCToModule(llvm::ArrayRef fileBytes, mlir::MLIRContext &ctx) { const bool dbg = debugEnabled(); @@ -716,25 +806,16 @@ decodePTOBCToModule(llvm::ArrayRef fileBytes, mlir::MLIRContext &ctx) { if (payloadLen != fileBytes.size() - 14) throw std::runtime_error("payload_len mismatch"); Reader r{fileBytes.data() + 14, fileBytes.data() + fileBytes.size()}; - - auto readSection = [&]() -> std::pair> { - uint8_t sid = r.readU8(); - uint32_t slen = r.readU32LE(); - auto bytes = r.readBytes(slen); - if (dbg) llvm::errs() << "[ptobc] section id=" << unsigned(sid) << " len=" << slen << "\n"; - return {sid, bytes}; - }; - - auto [s1, d1] = readSection(); - auto [s2, d2] = readSection(); - auto [s3, d3] = readSection(); - auto [s4, d4] = readSection(); - auto [s6, d6] = readSection(); + auto [s1, d1] = readSection(r, dbg); + auto [s2, d2] = readSection(r, dbg); + auto [s3, d3] = readSection(r, dbg); + auto [s4, d4] = readSection(r, dbg); + auto [s6, d6] = readSection(r, dbg); std::optional dbgInfo; // Optional trailing sections: DEBUGINFO, EXTRA. while (r.p != r.end) { - auto [sid, sec] = readSection(); + auto [sid, sec] = readSection(r, dbg); if (sid == kSectionDebugInfo) { if (dbgInfo) throw std::runtime_error("duplicate DEBUGINFO section"); dbgInfo = parseDebugInfoSection(sec); @@ -774,22 +855,8 @@ decodePTOBCToModule(llvm::ArrayRef fileBytes, mlir::MLIRContext &ctx) { auto module = decodeToModule(ctx, strings, types, attrs, d4, d6, dbgInfo ? &opsByFunc : nullptr); // Apply op locations from DEBUGINFO (best-effort). - if (dbgInfo) { - for (const auto &l : dbgInfo->locations) { - if (l.funcId >= opsByFunc.size()) continue; - auto &ops = opsByFunc[l.funcId]; - if (l.opId >= ops.size()) continue; - mlir::Operation *op = ops[l.opId]; - if (!op) continue; - if (l.fileId >= dbgInfo->files.size()) continue; - const auto &f = dbgInfo->files[l.fileId]; - if (f.pathSid >= strings.size()) continue; - - auto path = strings[f.pathSid]; - auto loc = mlir::FileLineColLoc::get(&ctx, path, (unsigned)l.sl, (unsigned)l.sc); - op->setLoc(loc); - } - } + if (dbgInfo) + applyDebugLocations(ctx, strings, *dbgInfo, opsByFunc); return module; } diff --git a/tools/ptobc/src/ptobc_opcodes_v0.cpp b/tools/ptobc/src/ptobc_opcodes_v0.cpp new file mode 100644 index 000000000..64d4364f7 --- /dev/null +++ b/tools/ptobc/src/ptobc_opcodes_v0.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "ptobc_opcodes_v0.h" + +namespace ptobc::v0 { +namespace { + +struct FullNameEntry { + const char *name; + OpcodeAndVariant value; +}; + +struct VariantOperandEntry { + uint16_t opcode; + uint8_t variant; + int operandCount; +}; + +inline constexpr FullNameEntry kVariantNames[] = { + {"pto.section.cube", {0x0006, 1, 0}}, + {"pto.section.vector", {0x0006, 1, 1}}, + {"pto.tgemv", {0x102A, 1, 0}}, + {"pto.tgemv.acc", {0x102A, 1, 1}}, + {"pto.tgemv.bias", {0x102A, 1, 2}}, + {"pto.tgemv.mx", {0x102A, 1, 3}}, + {"pto.tmatmul", {0x1032, 1, 0}}, + {"pto.tmatmul.acc", {0x1032, 1, 1}}, + {"pto.tmatmul.bias", {0x1032, 1, 2}}, + {"pto.tmatmul.mx", {0x1033, 1, 0}}, + {"pto.tmatmul.mx.acc", {0x1033, 1, 1}}, + {"pto.tmatmul.mx.bias", {0x1033, 1, 2}}, +}; + +inline constexpr VariantOperandEntry kVariantOperands[] = { + {0x102A, 0, 3}, {0x102A, 1, 4}, {0x102A, 2, 4}, {0x102A, 3, 5}, + {0x1032, 0, 3}, {0x1032, 1, 4}, {0x1032, 2, 4}, + {0x1033, 0, 5}, {0x1033, 1, 6}, {0x1033, 2, 6}, +}; + +} // namespace + +const OpInfo *lookupByOpcode(uint16_t opcode) { + size_t lo = 0; + size_t hi = sizeof(kOpTable) / sizeof(kOpTable[0]); + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + uint16_t value = kOpTable[mid].opcode; + if (value == opcode) + return &kOpTable[mid]; + if (value < opcode) + lo = mid + 1; + else + hi = mid; + } + return nullptr; +} + +std::optional lookupOpcodeByName(llvm::StringRef name) { + for (const OpInfo &info : kOpTable) { + if (name == info.name) + return info.opcode; + } + return std::nullopt; +} + +const OpInfo *lookupByName(llvm::StringRef name) { + auto opcode = lookupOpcodeByName(name); + if (!opcode) + return nullptr; + return lookupByOpcode(*opcode); +} + +std::optional +lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { + for (const FullNameEntry &entry : kVariantNames) { + if (fullName == entry.name) + return entry.value; + } + + auto opcode = lookupOpcodeByName(fullName); + if (!opcode) + return std::nullopt; + return OpcodeAndVariant{*opcode, 0, 0}; +} + +const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { + const OpInfo *info = lookupByOpcode(opcode); + if (!info) + return nullptr; + if (!info->has_variant_u8) + return info->name; + + for (const FullNameEntry &entry : kVariantNames) { + if (entry.value.opcode == opcode && entry.value.variant == variant) + return entry.name; + } + return info->name; +} + +std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { + for (const VariantOperandEntry &entry : kVariantOperands) { + if (entry.opcode == opcode && entry.variant == variant) + return entry.operandCount; + } + return std::nullopt; +} + +} // namespace ptobc::v0 diff --git a/tools/ptobc/testdata/add_static_multicore.pto b/tools/ptobc/testdata/add_static_multicore.pto index 7f1524101..69c9482ca 100644 --- a/tools/ptobc/testdata/add_static_multicore.pto +++ b/tools/ptobc/testdata/add_static_multicore.pto @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + module { func.func @vec_add_kernel_2d_dynamic(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32) { %c0 = arith.constant 0 : index @@ -31,4 +39,3 @@ module { return } } - diff --git a/tools/ptobc/testdata/matmul_static_singlecore.pto b/tools/ptobc/testdata/matmul_static_singlecore.pto index 0d14dcba9..f3a200166 100644 --- a/tools/ptobc/testdata/matmul_static_singlecore.pto +++ b/tools/ptobc/testdata/matmul_static_singlecore.pto @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + module attributes {"pto.device-spec" = "Ascend910B1"} { func.func @RunTMATMULSplitK(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: i1) { pto.section.cube { @@ -64,4 +72,3 @@ module attributes {"pto.device-spec" = "Ascend910B1"} { return } } - diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index 220414408..424c72ff7 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -1,3 +1,11 @@ +// Copyright (c) 2025 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + module { func.func @recent_ops_v0() { %c0 = arith.constant 0 : index diff --git a/tools/ptobc/testdata/tprint_v0_roundtrip.pto b/tools/ptobc/testdata/tprint_v0_roundtrip.pto index 9167d9b48..e61ccda5d 100644 --- a/tools/ptobc/testdata/tprint_v0_roundtrip.pto +++ b/tools/ptobc/testdata/tprint_v0_roundtrip.pto @@ -1,3 +1,11 @@ +// Copyright (c) 2025 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + module { func.func @tprint_v0() { %0 = pto.alloc_tile : !pto.tile_buf diff --git a/tools/ptobc/testdata/trowexpandsub_v0_roundtrip.pto b/tools/ptobc/testdata/trowexpandsub_v0_roundtrip.pto index 00968e3fe..a7f368b47 100644 --- a/tools/ptobc/testdata/trowexpandsub_v0_roundtrip.pto +++ b/tools/ptobc/testdata/trowexpandsub_v0_roundtrip.pto @@ -1,3 +1,11 @@ +// Copyright (c) 2025 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + module { func.func @trowexpandsub_v0(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { %c0 = arith.constant 0 : index