From 8be3bdae1dcc82fa040cbdcdf96e33ed5dfeaba3 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 24 Mar 2026 20:29:49 +0800 Subject: [PATCH 1/2] WIP: tile-native pointer_cast pipeline and PlanMemory rewiring --- docs/designs/tilebuf-planmemory-phase1.md | 50 +++ include/PTO/IR/PTOOps.td | 4 +- lib/PTO/Transforms/AllocToPointerCast.cpp | 125 ++++---- lib/PTO/Transforms/AllocToPointerCast.h | 15 +- lib/PTO/Transforms/CMakeLists.txt | 1 + .../Transforms/InsertSync/PTOIRTranslator.cpp | 21 +- lib/PTO/Transforms/PTOPlanMemory.cpp | 125 ++++++-- lib/PTO/Transforms/PTOPlanMemory.h | 16 +- lib/PTO/Transforms/PTOToEmitC.cpp | 202 ++++++++++++- lib/PTO/Transforms/TileBufferSemantics.cpp | 284 ++++++++++++++++++ lib/PTO/Transforms/TileBufferSemantics.h | 72 +++++ lib/PTO/Transforms/Utils.cpp | 56 ++-- lib/PTO/Transforms/Utils.h | 18 +- test/basic/set_validshape_local_lowering.pto | 6 +- test/basic/tilebuf_auto_addr_assign.pto | 23 ++ test/basic/tilebuf_manual_addr_preserve.pto | 19 ++ test/basic/tilebuf_root_trace.pto | 33 ++ test/basic/tilebuf_semantic_smoke.pto | 27 ++ .../samples/planmemory/tilebuf_alias_chain.py | 36 +++ .../tilebuf_planmemory_auto_addr.py | 16 + test/samples/runop.sh | 30 +- tools/ptoas/ptoas.cpp | 3 +- 22 files changed, 1034 insertions(+), 148 deletions(-) create mode 100644 docs/designs/tilebuf-planmemory-phase1.md create mode 100644 lib/PTO/Transforms/TileBufferSemantics.cpp create mode 100644 lib/PTO/Transforms/TileBufferSemantics.h create mode 100644 test/basic/tilebuf_auto_addr_assign.pto create mode 100644 test/basic/tilebuf_manual_addr_preserve.pto create mode 100644 test/basic/tilebuf_root_trace.pto create mode 100644 test/basic/tilebuf_semantic_smoke.pto create mode 100644 test/samples/planmemory/tilebuf_alias_chain.py create mode 100644 test/samples/planmemory/tilebuf_planmemory_auto_addr.py diff --git a/docs/designs/tilebuf-planmemory-phase1.md b/docs/designs/tilebuf-planmemory-phase1.md new file mode 100644 index 00000000..e205688d --- /dev/null +++ b/docs/designs/tilebuf-planmemory-phase1.md @@ -0,0 +1,50 @@ +# Tile Buffer -> PlanMemory (Phase-1) + +## Scope +- Base: `origin/main` +- Phase-1 only: `tile_buffer -> PlanMemory` +- Explicitly out of scope: + - Sync migration + - New MultiBuffer capabilities + +## Why +PlanMemory previously consumed mainly memref-centric alias/shape/space signals. +Tile metadata (`bind_tile/subset/bitcast/treshape`) was available but not normalized +as a reusable semantic layer. + +This phase introduces a tile semantic input path while keeping the core planner +(`MultiSpecPlan`, rollback/reuse) unchanged. + +## Changes +1. Unified tile semantic extraction in `Utils`: +- alias unification: `bind_tile/subset/bitcast/treshape` + memref view-like ops +- root traceback: `tracebackBufferRoot(...)` +- semantic record: `TileBufferSemantics` (root/scope/shape/valid/config/view-kind/bits) + +2. PlanMemory liveness/buffer info wiring: +- `MemLivenessAnalysis` uses unified alias API +- local buffer definition accepts `memref.alloc` and `pto.alloc_tile` +- `GetBufferInfo` prefers tile-native semantic extraction and keeps a legacy fallback + +3. No algorithm rewrite: +- Allocation/reuse/rollback algorithm unchanged +- Boundary fallback remains internal (no new user-visible switch) + +## Capability -> Test Mapping +- Unified semantic smoke: + - `test/basic/tilebuf_semantic_smoke.pto` +- Alias/root trace across bind + view chain: + - `test/basic/tilebuf_root_trace.pto` +- View-like alias chain (`subset -> treshape -> bitcast`) stability: + - `test/samples/planmemory/tilebuf_alias_chain.py` + - `test/samples/runop.sh` check for `TRESHAPE` + `TASSIGN` +- PlanMemory auto-address reachability: + - `test/basic/tilebuf_auto_addr_assign.pto` + - `test/samples/planmemory/tilebuf_planmemory_auto_addr.py` + - `test/samples/runop.sh` check for `TASSIGN` + `TPRINT` +- Address contract (manual addr preserve): + - `test/basic/tilebuf_manual_addr_preserve.pto` + +## Next (Phase-2) +- Sync analysis/input migration to consume the same tile semantic layer. +- Remove remaining internal fallback branches after boundary coverage is complete. diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index a5d0f9cd..feb38624 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1007,7 +1007,7 @@ def TMovOp : PTO_TOp<"tmov", [ //===----------------------------------------------------------------------===// def PointerCastOp : PTO_Op<"pointer_cast", [AttrSizedOperandSegments, Pure]> { - let summary = "Casts an integer address to a MemRef with optional valid dims"; + let summary = "Binds an integer address to a tile buffer descriptor"; // 参数定义 (保持 Optional) let arguments = (ins @@ -1017,7 +1017,7 @@ def PointerCastOp : PTO_Op<"pointer_cast", [AttrSizedOperandSegments, Pure]> { OptionalAttr:$config ); - let results = (outs Res:$result); + let results = (outs TileBufType:$result); // Assembly Format (去掉了 []) let assemblyFormat = [{ diff --git a/lib/PTO/Transforms/AllocToPointerCast.cpp b/lib/PTO/Transforms/AllocToPointerCast.cpp index c28e9e4b..d5af5a67 100644 --- a/lib/PTO/Transforms/AllocToPointerCast.cpp +++ b/lib/PTO/Transforms/AllocToPointerCast.cpp @@ -1,4 +1,4 @@ -//===- AllocToPointerCast.cpp - convert memref.AllocOp to pto.pointercastOp.// +//===- AllocToPointerCast.cpp - convert alloc_tile to pto.pointer_cast. -------// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,7 +8,7 @@ #include "AllocToPointerCast.h" #include "PTO/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -22,29 +22,21 @@ using namespace mlir::pto; namespace {} // namespace -LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( - memref::AllocOp op, PatternRewriter &rewriter) const { - const auto ¤tMemRefType = cast(op.getType()); - - // Preserve tile config carried by the downstream bind_tile user. Losing this - // metadata here makes PointerCast lowering fall back to RowMajor defaults, - // which can generate illegal intermediate TRESHAPE sequences. - TileBufConfigAttr configAttr; - for (Operation *user : op.getResult().getUsers()) { - auto bind = dyn_cast(user); - if (!bind || bind.getSource() != op.getResult()) - continue; - if (!configAttr) { - configAttr = bind.getConfigAttr(); - continue; - } - if (configAttr != bind.getConfigAttr()) { - op.emitWarning("alloc has multiple bind_tile users with different configs; " - "using the first one"); - break; - } - } - +LogicalResult +AllocTileOpToPointerCastOpPattern::matchAndRewrite(pto::AllocTileOp op, + PatternRewriter &rewriter) const { + // Manual-address alloc_tile is already fully bound and must not be remapped. + if (op.getAddr()) + return failure(); + + auto tileType = dyn_cast(op.getResult().getType()); + if (!tileType) + return failure(); + + // Keep config from the tile descriptor so lowering can generate the exact + // Tile<...> type token (layout/fractal/pad) without memref-side recovery. + TileBufConfigAttr configAttr = tileType.getConfigAttr(); + constexpr uint64_t kAlign = 4096; auto iter = buffer2Offsets.find(op.getResult()); @@ -55,30 +47,31 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( offsets = iter->second; if (offsets.empty()) { - // Estimate buffer size (best-effort). Most PTO tile buffers are 32x32 and - // naturally align to 4096 bytes. + // Estimate tile size in bytes using the static tile descriptor. uint64_t bytes = kAlign; - if (auto memrefTy = dyn_cast(currentMemRefType)) { - uint64_t elemBytes = 0; - Type elemTy = memrefTy.getElementType(); - if (elemTy.isF16()) elemBytes = 2; - else if (elemTy.isF32()) elemBytes = 4; - else if (auto it = dyn_cast(elemTy)) elemBytes = it.getWidth() / 8; - - if (elemBytes != 0) { - uint64_t numel = 1; - bool allStatic = true; - for (int64_t d : memrefTy.getShape()) { - if (d == ShapedType::kDynamic) { - allStatic = false; - break; - } - numel *= static_cast(d); + uint64_t elemBytes = 0; + Type elemTy = tileType.getElementType(); + if (elemTy.isF16() || elemTy.isBF16()) + elemBytes = 2; + else if (elemTy.isF32()) + elemBytes = 4; + else if (auto it = dyn_cast(elemTy)) + elemBytes = it.getWidth() / 8; + + if (elemBytes != 0) { + uint64_t numel = 1; + bool allStatic = true; + for (int64_t d : tileType.getShape()) { + if (d == ShapedType::kDynamic) { + allStatic = false; + break; } - if (allStatic && numel != 0) - bytes = numel * elemBytes; + numel *= static_cast(d); } + if (allStatic && numel != 0) + bytes = numel * elemBytes; } + uint64_t stride = ((bytes + kAlign - 1) / kAlign) * kAlign; uint64_t off = fallbackNextOffset; fallbackNextOffset += std::max(stride, kAlign); @@ -93,34 +86,34 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite( addrs.push_back(constantIntOffsetOp); } - // [修改 1] 从 ValueRange 中拆解出 row 和 col - // memref.alloc 的 getDynamicSizes() 返回的是变长列表。 - // 既然我们只支持 2D Tile,且如果是动态 shape 通常两个维度都是动态的 (?x?), - // 我们直接按顺序提取。 + // Preserve valid-shape contract: + // - dynamic valid dims: forward alloc_tile operands + // - static valid dims: materialize constants from TileBufType + // This keeps semantics identical to alloc_tile across PlanMemory rewrite. Value vRow, vCol; - auto dynSizes = op.getDynamicSizes(); - - if (dynSizes.size() >= 2) { - vRow = dynSizes[0]; - vCol = dynSizes[1]; - } else if (dynSizes.size() == 1) { - // 极其罕见的混合情况 (例如 32x?),视具体需求处理,这里默认取第一个 - // 或者根据维度索引判断是 row 还是 col,这里暂时从简 - vCol = dynSizes[0]; + vRow = op.getValidRow(); + vCol = op.getValidCol(); + auto validShape = tileType.getValidShape(); + if (validShape.size() >= 2) { + auto indexType = rewriter.getIndexType(); + Location loc = op.getLoc(); + if (!vRow && validShape[0] >= 0) { + vRow = rewriter.create( + loc, indexType, rewriter.getIndexAttr(validShape[0])); + } + if (!vCol && validShape[1] >= 0) { + vCol = rewriter.create( + loc, indexType, rewriter.getIndexAttr(validShape[1])); + } } - // [修改 2] 调用新的 Builder 签名 - // 1. ValueRange(addrs) -> 传递物理地址列表 - // 2. vRow ? vRow : Value() -> 传递 Value 对象(如果为空则传空 Value) - // 3. TileBufConfigAttr() -> 传递空 Attribute 对象 (不能传 nullptr) - + // Build tile-native pointer_cast with assigned physical address. auto ptoPointerCastOp = rewriter.create( - op.getLoc(), - currentMemRefType, + op.getLoc(), tileType, ValueRange(addrs), // addrs vRow ? vRow : Value(), // valid_row vCol ? vCol : Value(), // valid_col - configAttr // preserve bind_tile config when available + configAttr // config from tile descriptor ); rewriter.replaceOp(op, ptoPointerCastOp->getResults()); diff --git a/lib/PTO/Transforms/AllocToPointerCast.h b/lib/PTO/Transforms/AllocToPointerCast.h index 8b49d82e..a8abb18d 100644 --- a/lib/PTO/Transforms/AllocToPointerCast.h +++ b/lib/PTO/Transforms/AllocToPointerCast.h @@ -1,4 +1,4 @@ -//===- AllocToPointerCast.h --Convert memref.AllocOp to pto.pointercastOp-===// +//===- AllocToPointerCast.h --Convert pto.alloc_tile to pto.pointer_cast ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,15 +8,14 @@ #define LLVM_PROJECT_ALLOCTOPOINTERCAST_H #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "llvm/ADT/SmallSet.h" namespace mlir { namespace pto { -class MemrefAllocaOpToPointerCastOpPattern - : public OpRewritePattern { +class AllocTileOpToPointerCastOpPattern + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; /// map from buffer to its allocated addresses /// note: the buffer which does multibuffer n optimization will be allocated n @@ -24,10 +23,10 @@ class MemrefAllocaOpToPointerCastOpPattern DenseMap> buffer2Offsets; mutable uint64_t fallbackNextOffset = 0; - explicit MemrefAllocaOpToPointerCastOpPattern( + explicit AllocTileOpToPointerCastOpPattern( MLIRContext *context, DenseMap> buffer2Offsets) - : OpRewritePattern(context), + : OpRewritePattern(context), buffer2Offsets(std::move(buffer2Offsets)) { // Seed fallback offsets above any known planned offsets to reduce collisions. constexpr uint64_t kAlign = 4096; @@ -38,7 +37,7 @@ class MemrefAllocaOpToPointerCastOpPattern } fallbackNextOffset = ((maxOff + kAlign - 1) / kAlign) * kAlign; } - LogicalResult matchAndRewrite(memref::AllocOp op, + LogicalResult matchAndRewrite(pto::AllocTileOp op, PatternRewriter &rewriter) const final; }; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 378690b4..fc297aea 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(PTOTransforms PTOViewToMemref.cpp PTOToEmitC.cpp Utils.cpp + TileBufferSemantics.cpp OptMemPlanForPipeline.cpp AllocToPointerCast.cpp InferPTOMemScope.cpp diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 33aec28b..1e0a58db 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -252,8 +252,9 @@ LogicalResult PTOIRTranslator::UpdateAllocTileOpMemInfo(pto::AllocTileOp op) { LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) { Value res = op.getResult(); - auto memRefType = dyn_cast(res.getType()); - if (!memRefType) return failure(); + auto tileType = dyn_cast(res.getType()); + if (!tileType) + return failure(); if (op.getAddrs().empty()) { return op.emitError("PointerCast must have at least one address operand"); @@ -261,15 +262,23 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) Value rootSrc = op.getAddrs().front(); uint64_t sizeInBytes = 0; - if (memRefType.hasStaticShape()) { - int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; + bool isStatic = true; + for (auto dim : tileType.getShape()) { + if (dim == ShapedType::kDynamic) { + isStatic = false; + break; + } + } + if (isStatic) { + int64_t elemSize = tileType.getElementType().getIntOrFloatBitWidth() / 8; int64_t numElements = 1; - for (auto dim : memRefType.getShape()) numElements *= dim; + for (auto dim : tileType.getShape()) + numElements *= dim; sizeInBytes = numElements * elemSize; } pto::AddressSpace space = pto::AddressSpace::GM; - if (auto attr = memRefType.getMemorySpace()) { + if (auto attr = tileType.getMemorySpace()) { if (auto ptoAttr = dyn_cast(attr)) { space = ptoAttr.getAddressSpace(); } diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index afba16ee..1c75979e 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -271,6 +271,7 @@ static LogicalResult verifyManualReserveBufferMode(func::FuncOp funcOp) { } // namespace +// Entry point that builds linear op order, alias map and lifetime intervals. void MemLivenessAnalysis::build() { Region &funcRegion = func_.getBody(); Liveness live(func_); @@ -281,15 +282,21 @@ void MemLivenessAnalysis::build() { //InitializeInplacePairList(); } +// True when planning mode is local on-chip memory allocation. bool MemLivenessAnalysis::isLocalMemPlan() const { return planMode == MemPlanMode::LOCAL_MEM_PLAN; } +// True when planning mode is global-workspace allocation. bool MemLivenessAnalysis::isGlobalWorkSpaceMemPlan() const { return planMode == MemPlanMode::GLOBAL_WORKSPACE_PLAN; } void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { + // Traverse region operations and collect: + // 1) alias relation, + // 2) local-buffer definitions, + // 3) gen/kill events used by memory planning. auto result = region->walk([&](Operation *op) { // recursive control flow if (auto ifOp = dyn_cast(op)) { @@ -302,17 +309,15 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { // process operation auto curOpInfo = UpdateLinearOperation(op); - auto mayAliasOp = getOperationAliasInfo(op); + auto mayAliasOp = getBufferAliasInfo(op); if (mayAliasOp.has_value()) { auto aliasPair = mayAliasOp.value(); UpdateBufferAlias(aliasPair.first, aliasPair.second); - } else if (auto bindOp = dyn_cast(op)) { - // BindTile result is only an alias of the source buffer. Treat every use - // of the result as a use of the source in liveness analysis. - UpdateBufferAlias(bindOp.getResult(), bindOp.getSource()); - return WalkResult::advance(); - } else if (isLocalMemPlan() && dyn_cast(op)) { - if (failed(CheckLocalBufferAllocOp(op))) { + // Local-memory planning now accepts both legacy memref.alloc and + // tile-native pto.alloc_tile as defining points. + } else if (isLocalMemPlan() && + (isa(op))) { + if (failed(CheckLocalBufferDefOp(op))) { return WalkResult::interrupt(); } UpdateOpBufferInfo(op, op->getResults()); @@ -532,15 +537,25 @@ SmallVector MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp, // buffer2MultiNum[markOp.getSrc()] = static_cast(valAttr.getInt()); // } -LogicalResult -MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const { - auto allocOp = dyn_cast(op); - assert(allocOp && "must be alloc op"); - auto memorySpaceAttr = GetBufferSpaceAttr(allocOp.getResult()); +// Validates local buffer defining ops and rejects non-local address-space. +LogicalResult MemLivenessAnalysis::CheckLocalBufferDefOp(Operation *op) const { + // Validate the defining op shape: this helper is intentionally limited to + // ops that create local buffers participating in PlanMemory. + Value defBuffer; + if (auto allocOp = dyn_cast(op)) { + defBuffer = allocOp.getResult(); + } else if (auto allocTileOp = dyn_cast(op)) { + defBuffer = allocTileOp.getResult(); + } else { + op->emitError("expects local buffer defining op"); + return failure(); + } + + auto memorySpaceAttr = getPlanningBufferSpaceAttr(defBuffer); if (isLocalBuffer(memorySpaceAttr)) { return success(); } - allocOp.getOperation()->emitError("Alloc buffer not at UB space! "); + op->emitError("Alloc buffer not at local memory space!"); return failure(); } @@ -557,7 +572,7 @@ MemLivenessAnalysis::CheckIfUnknownOpTouchBuffer(Operation *op) const { // This scene can be ignored. return success(); } - if (isOpTouchLocalBuffer(op)) { + if (isOpTouchPlannableLocalBuffer(op)) { op->emitError("PlanMemory Fail : Unrecognized type of Operation touches " "local buffer!"); return failure(); @@ -722,9 +737,10 @@ bool MemLivenessAnalysis::AllDeadAfter(Operation *op, SetVector aliasVec, return true; } +// Dispatches to local/global buffer-info builders based on memory scope. BufferInfo MemLivenessAnalysis::GenerateBufferInfo(Operation *op, Value operand) { - auto memorySpaceAttr = GetBufferSpaceAttr(operand); + auto memorySpaceAttr = getPlanningBufferSpaceAttr(operand); if (isLocalMemPlan() && isLocalBuffer(memorySpaceAttr)) { assert(memorySpaceAttr.has_value() && "buffer must has space!"); return GetBufferInfo(op, operand, @@ -740,21 +756,72 @@ BufferInfo MemLivenessAnalysis::GenerateBufferInfo(Operation *op, BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, pto::AddressSpace bufferScope) { + // Build normalized buffer metadata consumed by PlanMemory without coupling + // to a memref-only representation. BufferInfo bufferInfo; bufferInfo.operation = op; bufferInfo.bufferScope = bufferScope; - // get buffer size, now for static shape + + // Prefer tile-native semantic extraction. This keeps PlanMemory input + // independent from a specific memref-only view chain. + TileBufferSemantics semantics; + if (succeeded(inferTileBufferSemantics(operand, semantics)) && + semantics.constBits > 0) { + bufferInfo.rootBuffer = semantics.root; + bufferInfo.bufferScope = semantics.scope; + bufferInfo.bufferType = semantics.elementType; + bufferInfo.bufferShape = semantics.shape; + bufferInfo.bufferValidShape = semantics.validShape; + bufferInfo.tileConfig = semantics.config; + bufferInfo.viewKind = semantics.viewKind; + bufferInfo.constBits = semantics.constBits; + return bufferInfo; + } + + // Fallback path: keep legacy sizing behavior for boundary cases where + // tile semantics are not fully recoverable in this phase. Value traceValue = tracebackMemRef(operand); - auto memRefType = cast(traceValue.getType()); - bufferInfo.bufferType = memRefType.getElementType(); - std::optional totalStaticSize = - getStaticTotalSize(memRefType.getShape()); - assert(totalStaticSize.has_value() && - "Failed to obtain op buffer shape size!"); - bufferInfo.constBits = - totalStaticSize.value() * - static_cast(memRefType.getElementTypeBitWidth()); - return bufferInfo; + if (auto memRefType = dyn_cast(traceValue.getType())) { + bufferInfo.rootBuffer = traceValue; + bufferInfo.bufferType = memRefType.getElementType(); + bufferInfo.bufferShape.assign(memRefType.getShape().begin(), + memRefType.getShape().end()); + bufferInfo.bufferValidShape = bufferInfo.bufferShape; + std::optional totalStaticSize = + getStaticTotalSize(memRefType.getShape()); + assert(totalStaticSize.has_value() && + "Failed to obtain op buffer shape size!"); + bufferInfo.constBits = + totalStaticSize.value() * + static_cast(memRefType.getElementTypeBitWidth()); + return bufferInfo; + } + + if (auto tileType = dyn_cast(traceValue.getType())) { + bufferInfo.rootBuffer = traceValue; + bufferInfo.bufferType = tileType.getElementType(); + bufferInfo.bufferShape.assign(tileType.getShape().begin(), + tileType.getShape().end()); + bufferInfo.bufferValidShape.assign(tileType.getValidShape().begin(), + tileType.getValidShape().end()); + bufferInfo.tileConfig = tileType.getConfigAttr(); + std::optional totalStaticSize = + getStaticTotalSize(tileType.getShape()); + assert(totalStaticSize.has_value() && + "Failed to obtain tile buffer shape size!"); + int64_t elemBits = 0; + if (auto intTy = dyn_cast(bufferInfo.bufferType)) + elemBits = intTy.getWidth(); + else if (auto floatTy = dyn_cast(bufferInfo.bufferType)) + elemBits = floatTy.getWidth(); + else if (isa(bufferInfo.bufferType)) + elemBits = 64; + assert(elemBits > 0 && "Unsupported element type for tile buffer sizing"); + bufferInfo.constBits = totalStaticSize.value() * elemBits; + return bufferInfo; + } + + llvm_unreachable("Failed to infer buffer info"); } // void MemLivenessAnalysis::InitializeInplacePairList() { @@ -2148,8 +2215,8 @@ struct PlanMemoryPass : public mlir::pto::impl::PlanMemoryBase { RewritePatternSet &patterns, DenseMap> buffer2Offsets) { if (this->memMode == MemPlanMode::LOCAL_MEM_PLAN) { - patterns.add(patterns.getContext(), - buffer2Offsets); + patterns.add(patterns.getContext(), + buffer2Offsets); } // } else { // assert(this->memMode == MemPlanMode::GLOBAL_WORKSPACE_PLAN); diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 6089087c..132444da 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -10,6 +10,7 @@ #include "PTO/IR/PTO.h" #include "OptMemPlanForPipeline.h" +#include "TileBufferSemantics.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Analysis/Liveness.h" @@ -65,12 +66,22 @@ constexpr const int SPEC_LEVEL_2 = 2; struct BufferInfo { /// Alloc operation of buffer. Operation *operation{nullptr}; + /// Root storage value traced from alias/view chains. + Value rootBuffer; /// Space corresponding to buffer. pto::AddressSpace bufferScope; /// The size required for the buffer. int64_t constBits{0}; /// The type of element in the buffer. Type bufferType; + /// Logical shape used for memory planning/debug. + SmallVector bufferShape; + /// Logical valid shape if tile metadata is available. + SmallVector bufferValidShape; + /// Tile config when the buffer comes from tile semantics. + TileBufConfigAttr tileConfig; + /// View-like kind of the queried buffer. + TileViewKind viewKind{TileViewKind::Unknown}; /// Alias buffer does not participate in inplace. /// e.g : /// alloc A @@ -355,8 +366,9 @@ class MemLivenessAnalysis { /// Update store op information. void UpdateStoreOpInfo(OpInfo *opInfo, const Value storeValue, Liveness live); - /// Check if it is local buffer with memory space - LogicalResult CheckLocalBufferAllocOp(Operation *op) const; + /// Check whether a local-buffer defining op (memref.alloc / pto.alloc_tile) + /// is placed in a supported local address space. + LogicalResult CheckLocalBufferDefOp(Operation *op) const; /// kill buffer handle. void OpKillHandle(OpInfo *opInfo, Liveness live, Block *block); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 97dcc2e2..1a8dc9ae 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -226,6 +226,75 @@ static Value peelUnrealized(Value v) { return v; } +// Returns true if `value` is a constant integer-like zero. +static bool isConstZeroIndexLike(Value value) { + if (!value) + return false; + if (auto cst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt() == 0; + } + if (auto cst = value.getDefiningOp()) + return cst.value() == 0; + if (auto castOp = value.getDefiningOp()) + return isConstZeroIndexLike(castOp.getIn()); + return false; +} + +// Materializes valid_row/valid_col constants from a static tile descriptor. +static std::pair +materializeStaticValidDims(ConversionPatternRewriter &rewriter, Location loc, + pto::TileBufType tileType) { + Value vRow; + Value vCol; + auto validShape = tileType.getValidShape(); + if (validShape.size() >= 2) { + if (validShape[0] >= 0) { + vRow = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(validShape[0])); + } + if (validShape[1] >= 0) { + vCol = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(validShape[1])); + } + } + return {vRow, vCol}; +} + +// Traces view-like source to the defining pointer_cast and returns its addrs. +// WIP constraint: subset hops are only supported when all offsets are zero. +static FailureOr> +tracePointerCastAddrsFromSource(Value source) { + source = peelUnrealized(source); + int depthGuard = 64; + while (source && depthGuard-- > 0) { + if (auto srcCast = source.getDefiningOp()) { + SmallVector addrs(srcCast.getAddrs().begin(), srcCast.getAddrs().end()); + if (addrs.empty()) + return failure(); + return addrs; + } + if (auto subsetOp = source.getDefiningOp()) { + for (Value offset : subsetOp.getOffsets()) { + if (!isConstZeroIndexLike(offset)) + return failure(); + } + source = peelUnrealized(subsetOp.getSource()); + continue; + } + if (auto bitcastOp = source.getDefiningOp()) { + source = peelUnrealized(bitcastOp.getSrc()); + continue; + } + if (auto reshapeOp = source.getDefiningOp()) { + source = peelUnrealized(reshapeOp.getSrc()); + continue; + } + break; + } + return failure(); +} + static std::optional getLayoutAttrFromOp(Operation *op) { if (!op) return std::nullopt; @@ -3238,6 +3307,130 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, return gtInst.getResult(0); } +//===----------------------------------------------------------------------===// +// pto.alloc_tile -> pto.pointer_cast (tile-native pre-lowering) +//===----------------------------------------------------------------------===// +struct PTOAllocTileToPointerCast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileType = dyn_cast(op.getResult().getType()); + if (!tileType) + return rewriter.notifyMatchFailure(op, "expected tile_buf result type"); + + auto loc = op.getLoc(); + Value addr = adaptor.getAddr(); + if (!addr) { + // Keep EmitC resilient for non-PlanMemory paths by assigning a default + // base address when alloc_tile reaches this stage unexpectedly. + addr = rewriter.create(loc, 0, 64); + } + + Value vRow = adaptor.getValidRow(); + Value vCol = adaptor.getValidCol(); + auto validShape = tileType.getValidShape(); + if (validShape.size() >= 2) { + if (!vRow && validShape[0] >= 0) { + vRow = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(validShape[0])); + } + if (!vCol && validShape[1] >= 0) { + vCol = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(validShape[1])); + } + } + + auto castOp = rewriter.create( + loc, tileType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), tileType.getConfigAttr()); + + if (op->hasAttr(kForceDynamicValidShapeAttrName)) + castOp->setAttr(kForceDynamicValidShapeAttrName, + op->getAttr(kForceDynamicValidShapeAttrName)); + + rewriter.replaceOp(op, castOp.getResult()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.bitcast/pto.treshape/pto.subset -> pto.pointer_cast (tile-native views) +//===----------------------------------------------------------------------===// +struct PTOBitcastToPointerCast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = dyn_cast(op.getResult().getType()); + if (!dstType) + return failure(); + + FailureOr> addrs = tracePointerCastAddrsFromSource(op.getSrc()); + if (failed(addrs)) + return rewriter.notifyMatchFailure(op, "expects bitcast source from pointer_cast"); + + auto [vRow, vCol] = materializeStaticValidDims(rewriter, op.getLoc(), dstType); + auto newCast = rewriter.create( + op.getLoc(), dstType, *addrs, vRow ? vRow : Value(), vCol ? vCol : Value(), + dstType.getConfigAttr()); + rewriter.replaceOp(op, newCast.getResult()); + return success(); + } +}; + +struct PTOTReshapeToPointerCast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = dyn_cast(op.getResult().getType()); + if (!dstType) + return failure(); + + FailureOr> addrs = tracePointerCastAddrsFromSource(op.getSrc()); + if (failed(addrs)) + return rewriter.notifyMatchFailure(op, "expects treshape source from pointer_cast"); + + auto [vRow, vCol] = materializeStaticValidDims(rewriter, op.getLoc(), dstType); + auto newCast = rewriter.create( + op.getLoc(), dstType, *addrs, vRow ? vRow : Value(), vCol ? vCol : Value(), + dstType.getConfigAttr()); + rewriter.replaceOp(op, newCast.getResult()); + return success(); + } +}; + +struct PTOSubsetToPointerCast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SubsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = dyn_cast(op.getResult().getType()); + if (!dstType) + return failure(); + + // WIP scope: support zero-offset subset by rebinding to the same base addr. + for (Value offset : op.getOffsets()) { + if (!isConstZeroIndexLike(offset)) + return rewriter.notifyMatchFailure( + op, "subset->pointer_cast currently supports only zero offsets"); + } + + FailureOr> addrs = + tracePointerCastAddrsFromSource(op.getSource()); + if (failed(addrs)) + return rewriter.notifyMatchFailure(op, "expects subset source from pointer_cast"); + + auto [vRow, vCol] = materializeStaticValidDims(rewriter, op.getLoc(), dstType); + auto newCast = rewriter.create( + op.getLoc(), dstType, *addrs, vRow ? vRow : Value(), vCol ? vCol : Value(), + dstType.getConfigAttr()); + rewriter.replaceOp(op, newCast.getResult()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // pto.pointer_cast lowering //===----------------------------------------------------------------------=== @@ -3276,8 +3469,8 @@ struct PointerCastConversion : public OpConversionPattern { static TileRole inferRole(pto::PointerCastOp op) { // 1. 优先检查 AddressSpace - if (auto memRefTy = dyn_cast(op.getType())) { - Attribute memorySpace = memRefTy.getMemorySpace(); + if (auto tileTy = dyn_cast(op.getType())) { + Attribute memorySpace = tileTy.getMemorySpace(); if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { switch (ptoAttr.getAddressSpace()) { case pto::AddressSpace::LEFT: return TileRole::Left; @@ -3328,7 +3521,7 @@ struct PointerCastConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto *ctx = rewriter.getContext(); - auto selfType = mlir::cast(op.getType()); + auto selfType = mlir::cast(op.getType()); ArrayRef shape = selfType.getShape(); Type elemType = selfType.getElementType(); @@ -8342,6 +8535,9 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/lib/PTO/Transforms/TileBufferSemantics.cpp b/lib/PTO/Transforms/TileBufferSemantics.cpp new file mode 100644 index 00000000..68e3e74b --- /dev/null +++ b/lib/PTO/Transforms/TileBufferSemantics.cpp @@ -0,0 +1,284 @@ +#include "TileBufferSemantics.h" + +#include "Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#define DEBUG_TYPE "pto-tile-buffer-semantics" + +namespace mlir { +namespace pto { + +// Reads planning address-space from either memref or tile-buffer values. +std::optional getPlanningBufferSpaceAttr(Value operand) { + if (auto memRefType = dyn_cast(operand.getType())) { + auto memorySpace = memRefType.getMemorySpace(); + if (!memorySpace) + return std::nullopt; + return dyn_cast(memorySpace); + } + + if (auto tileBufType = dyn_cast(operand.getType())) { + auto memorySpace = tileBufType.getMemorySpace(); + if (!memorySpace) + return std::nullopt; + return dyn_cast(memorySpace); + } + return std::nullopt; +} + +// Returns (alias_result, source) for planning aliases, including tile views. +std::optional> getBufferAliasInfo(Operation *op) { + if (auto genericAlias = getOperationAliasInfo(op)) + return genericAlias; + + if (auto bindOp = dyn_cast(op)) + return std::make_pair(bindOp.getResult(), bindOp.getSource()); + if (auto subsetOp = dyn_cast(op)) + return std::make_pair(subsetOp.getResult(), subsetOp.getSource()); + if (auto bitcastOp = dyn_cast(op)) + return std::make_pair(bitcastOp.getResult(), bitcastOp.getSrc()); + if (auto treshapeOp = dyn_cast(op)) + return std::make_pair(treshapeOp.getResult(), treshapeOp.getSrc()); + return std::nullopt; +} + +// Classifies view semantics of planning-relevant operations. +TileViewKind getTileViewKind(Operation *op) { + if (!op) + return TileViewKind::Unknown; + + if (auto bindOp = dyn_cast(op)) { + if (auto semantics = + bindOp->getAttrOfType("pto.view_semantics")) { + if (semantics.getValue() == "subset") + return TileViewKind::Subset; + if (semantics.getValue() == "bitcast") + return TileViewKind::Bitcast; + if (semantics.getValue() == "treshape") + return TileViewKind::TReshape; + } + return TileViewKind::BindTile; + } + + if (isa(op)) + return TileViewKind::Subset; + if (isa(op)) + return TileViewKind::Bitcast; + if (isa(op)) + return TileViewKind::TReshape; + if (isa(op)) + return TileViewKind::MemRefViewLike; + return TileViewKind::Unknown; +} + +// Best-effort constant folding for index-like integers used by valid-shape +// propagation. +static std::optional getConstIndexLike(Value v) { + if (!v) + return std::nullopt; + if (auto cOp = v.getDefiningOp()) + return cOp.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) + return ia.getInt(); + } + if (auto castOp = v.getDefiningOp()) + return getConstIndexLike(castOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto truncOp = v.getDefiningOp()) + return getConstIndexLike(truncOp.getIn()); + return std::nullopt; +} + +// Single-step traceback through planning alias/view-like constructs. +static Value tracebackRootOneStep(Value value) { + // Case 1: value is the iter_arg of a scf.for. + if (auto arg = dyn_cast(value)) { + if (auto forOp = + dyn_cast(arg.getParentRegion()->getParentOp())) { + if (arg.getArgNumber() > 0 && + forOp.getInitArgs().size() > arg.getArgNumber() - 1) { + return forOp.getInitArgs()[arg.getArgNumber() - 1]; + } + } + } + + Operation *def = value.getDefiningOp(); + if (!def) + return Value{}; + + if (auto aliasPair = getBufferAliasInfo(def)) { + auto [aliasValue, sourceValue] = *aliasPair; + if (aliasValue == value) + return sourceValue; + } + + // Case 2: cast-like memref ops. + if (auto op = dyn_cast(def)) + return op.getSource(); + if (auto op = dyn_cast(def)) + return op.getIn(); + if (auto op = dyn_cast(def)) + return op.getOperand(cast(value).getResultNumber()); + if (auto op = dyn_cast(def)) + return op.getInitArgs()[cast(value).getResultNumber()]; + + return Value{}; +} + +// Traces to the ultimate storage root used by planning and alias checks. +Value tracebackBufferRoot(Value value) { + int loopBound = 256; + while (value) { + auto upward = tracebackRootOneStep(value); + if (!upward) + break; + value = upward; + if (loopBound-- < 0) { + LLVM_DEBUG(llvm::dbgs() << "tracebackBufferRoot exceeds loopBound(" + << loopBound << ")!"); + break; + } + } + return value; +} + +// Converts element type to bit-width for static-size estimation. +static int64_t getElemBitWidth(Type elemTy) { + if (!elemTy) + return -1; + if (auto intTy = dyn_cast(elemTy)) + return intTy.getWidth(); + if (auto floatTy = dyn_cast(elemTy)) + return floatTy.getWidth(); + if (isa(elemTy)) + return 64; + return -1; +} + +// Decodes shape/valid/config from either memref or tile type. +static bool decodeTypeSemantics(Type type, Type &elemTy, + SmallVectorImpl &shape, + SmallVectorImpl &validShape, + TileBufConfigAttr &config) { + if (auto memRefType = dyn_cast(type)) { + elemTy = memRefType.getElementType(); + shape.assign(memRefType.getShape().begin(), memRefType.getShape().end()); + validShape = shape; + return true; + } + if (auto tileBufType = dyn_cast(type)) { + elemTy = tileBufType.getElementType(); + shape.assign(tileBufType.getShape().begin(), tileBufType.getShape().end()); + validShape.assign(tileBufType.getValidShape().begin(), + tileBufType.getValidShape().end()); + config = tileBufType.getConfigAttr(); + return true; + } + return false; +} + +// Overrides valid-shape from bind_tile operands when constants are available. +static void applyBindTileValidShape(pto::BindTileOp bindOp, + SmallVectorImpl &validShape) { + auto ensureRank2 = [&]() { + if (validShape.size() < 2) + validShape.resize(2, ShapedType::kDynamic); + }; + + if (bindOp.getValidRow()) { + ensureRank2(); + validShape[0] = getConstIndexLike(bindOp.getValidRow()) + .value_or(ShapedType::kDynamic); + } + if (bindOp.getValidCol()) { + ensureRank2(); + validShape[1] = getConstIndexLike(bindOp.getValidCol()) + .value_or(ShapedType::kDynamic); + } +} + +// Collects all SSA buffers that an op reads/writes/produces. +static SmallVector getOpTouchBuffer(Operation *op) { + SmallVector touchBuffer; + touchBuffer.insert(touchBuffer.end(), op->getResults().begin(), + op->getResults().end()); + for (OpOperand &operand : op->getOpOperands()) + touchBuffer.push_back(operand.get()); + return touchBuffer; +} + +// Returns true when any touched SSA value resolves to local planning space. +bool isOpTouchPlannableLocalBuffer(Operation *op) { + auto touchBuffer = getOpTouchBuffer(op); + for (Value buffer : touchBuffer) { + auto bufferSpace = getPlanningBufferSpaceAttr(buffer); + if (isLocalBuffer(bufferSpace)) + return true; + } + return false; +} + +// Builds normalized planning semantics from a value: +// - root: traced storage owner +// - scope: local memory space +// - shape/valid/config/view-kind +// - constBits: static bytes in bits +LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out) { + if (!value) + return failure(); + + out = TileBufferSemantics{}; + out.value = value; + out.root = tracebackBufferRoot(value); + if (auto as = getPlanningBufferSpaceAttr(value)) { + out.scope = as->getAddressSpace(); + } else if (auto as = getPlanningBufferSpaceAttr(out.root)) { + out.scope = as->getAddressSpace(); + } else { + return failure(); + } + + // Prefer root storage type for size calculation and keep queried type as + // fallback when root cannot provide a shaped type. + bool decoded = decodeTypeSemantics(out.root ? out.root.getType() : Type{}, + out.elementType, out.shape, + out.validShape, out.config); + if (!decoded) { + decoded = decodeTypeSemantics(value.getType(), out.elementType, out.shape, + out.validShape, out.config); + } + if (!decoded) + return failure(); + + if (auto def = value.getDefiningOp()) { + out.viewKind = getTileViewKind(def); + if (auto bindOp = dyn_cast(def)) { + if (!out.config) + out.config = bindOp.getConfigAttr(); + applyBindTileValidShape(bindOp, out.validShape); + } + } + + auto staticSize = getStaticTotalSize(out.shape); + int64_t elemBits = getElemBitWidth(out.elementType); + if (staticSize.has_value() && elemBits > 0) { + out.constBits = staticSize.value() * elemBits; + return success(); + } + return failure(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/TileBufferSemantics.h b/lib/PTO/Transforms/TileBufferSemantics.h new file mode 100644 index 00000000..ac2ace96 --- /dev/null +++ b/lib/PTO/Transforms/TileBufferSemantics.h @@ -0,0 +1,72 @@ +#ifndef PTO_TILE_BUFFER_SEMANTICS_H +#define PTO_TILE_BUFFER_SEMANTICS_H + +#include "PTO/IR/PTO.h" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/SmallVector.h" + +#include +#include + +namespace mlir { +namespace pto { + +/// Unified view-like category used by tile-buffer semantic normalization. +enum class TileViewKind { + Unknown = 0, + BindTile, + Subset, + Bitcast, + TReshape, + MemRefViewLike +}; + +/// Normalized semantic payload consumed by PlanMemory. +struct TileBufferSemantics { + /// queried value. + Value value; + /// traced root value (alloc/pointer-cast/function arg/...). + Value root; + /// storage scope resolved from value/root type. + pto::AddressSpace scope{pto::AddressSpace::Zero}; + /// element type used for byte/bits calculation. + Type elementType; + /// logical shape. + SmallVector shape; + /// logical valid shape. + SmallVector validShape; + /// optional tile config carried by tile/bind semantics. + TileBufConfigAttr config; + /// view-kind of the queried value's defining op. + TileViewKind viewKind{TileViewKind::Unknown}; + /// static size in bits, if computable. + int64_t constBits{0}; +}; + +/// Reads PTO address-space from memref or tile values for planning. +std::optional getPlanningBufferSpaceAttr(Value operand); + +/// Returns (result, source) when `op` is an alias/view-like op for planning. +std::optional> getBufferAliasInfo(Operation *op); + +/// Classifies the view semantics of an op for tracing/debug/planning. +TileViewKind getTileViewKind(Operation *op); + +/// Traces through alias/view-like chains to the storage root value. +Value tracebackBufferRoot(Value value); + +/// Returns true when an operation touches any local plannable buffer. +bool isOpTouchPlannableLocalBuffer(Operation *op); + +/// Infers normalized tile semantics (scope/shape/valid/config/root/bytes). +/// Returns failure when static bits cannot be proven. +LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out); + +} // namespace pto +} // namespace mlir + +#endif // PTO_TILE_BUFFER_SEMANTICS_H diff --git a/lib/PTO/Transforms/Utils.cpp b/lib/PTO/Transforms/Utils.cpp index 2fe6a88f..aa2efe91 100644 --- a/lib/PTO/Transforms/Utils.cpp +++ b/lib/PTO/Transforms/Utils.cpp @@ -13,6 +13,7 @@ namespace mlir { namespace pto { +// Returns the unique return op of a function when it exists. func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { func::ReturnOp returnOp; for (Block &b : funcOp.getBody()) { @@ -25,7 +26,7 @@ func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -// New helper function to get the updated BaseMemRefType +// Returns a memref type identical to `type` except for memory-space. BaseMemRefType getBaseMemRefTypeWithNewScope(BaseMemRefType type, AddressSpaceAttr targetMemScope) { if (auto memRefType = dyn_cast(type)) { @@ -38,6 +39,8 @@ BaseMemRefType getBaseMemRefTypeWithNewScope(BaseMemRefType type, return type; } +// Updates an SSA memref value type with a target memory-space, keeping shape +// and element type unchanged. void setBaseMemRefTypeScope(Value val, AddressSpaceAttr targetMemScope) { Type type = val.getType(); if (!isa(type)) { @@ -56,22 +59,21 @@ void setBaseMemRefTypeScope(Value val, AddressSpaceAttr targetMemScope) { val.setType(newMemRefType); } - +// Resolve local memory space from memref SSA values. std::optional GetBufferSpaceAttr(Value operand) { - if (!llvm::isa(operand.getType())) { - return std::nullopt; - } - auto memRefType = cast(operand.getType()); - auto memorySpace = memRefType.getMemorySpace(); - if (!memorySpace) - return std::nullopt; - auto memorySpaceAttr = dyn_cast(memorySpace); - if (!memorySpaceAttr) { + if (auto memRefType = dyn_cast(operand.getType())) { + auto memorySpace = memRefType.getMemorySpace(); + if (!memorySpace) + return std::nullopt; + if (auto memorySpaceAttr = dyn_cast(memorySpace)) + return memorySpaceAttr; return std::nullopt; } - return memorySpaceAttr; + + return std::nullopt; } +// Return (alias_result, source) for generic view/alias ops. std::optional> getOperationAliasInfo(Operation *op) { if (auto subViewOp = dyn_cast(op)) { return std::make_pair(subViewOp.getResult(), subViewOp.getViewSource()); @@ -100,16 +102,12 @@ std::optional> getOperationAliasInfo(Operation *op) { return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); } else if (auto toTensorOp = dyn_cast(op)) { return std::make_pair(toTensorOp.getResult(), toTensorOp.getOperand()); - } else if (auto toMemrefOp = dyn_cast(op)) { - return std::make_pair(toMemrefOp.getResult(), toMemrefOp.getOperand()); } -// } else if (auto bitCastOp = dyn_cast(op)) { -// return std::make_pair(bitCastOp.getResult(), bitCastOp.getSrc()); -// } return std::nullopt; } -Value tracebackImpl(Value memrefVal) { +// Single-step traceback through alias/view-like constructs. +static Value tracebackImpl(Value memrefVal) { // case 1: v is the iter_arg of a scf.for if (auto arg = dyn_cast(memrefVal)) { if (auto forOp = @@ -128,6 +126,13 @@ Value tracebackImpl(Value memrefVal) { return result; } + if (auto aliasPair = getOperationAliasInfo(def)) { + auto [aliasValue, sourceValue] = *aliasPair; + if (aliasValue == memrefVal) { + return sourceValue; + } + } + // case 2: v is the result of cast-like ops // - memref.cast // - memref.collapse_shape @@ -155,8 +160,6 @@ Value tracebackImpl(Value memrefVal) { } else if (auto op = dyn_cast(def)) { // trace back memref.alloc support scf.for result = op.getInitArgs()[cast(memrefVal).getResultNumber()]; - } else if (auto op = dyn_cast(def)) { - result = op.getSource(); } if (result) { @@ -175,16 +178,19 @@ Value tracebackImpl(Value memrefVal) { return result; } +// Checks whether an operation is a heap/stack memref allocation. bool isAllocLikeOp(Operation *op) { if (!op) return false; return isa(op) || isa(op); } +// Convenience overload for value-based alloc-like checks. bool isAllocLikeOp(Value val) { return isAllocLikeOp(val.getDefiningOp()); } +// Computes total static element count for a ranked shape. std::optional getStaticTotalSize(const ArrayRef &shapes) { int64_t totalSize = 1; for (const auto &shape : shapes) { @@ -196,6 +202,7 @@ std::optional getStaticTotalSize(const ArrayRef &shapes) { return totalSize; } +// Aligns `lhs` upward to the nearest multiple of `rhs`. uint64_t AlignUp(uint64_t lhs, uint64_t rhs) { assert(rhs != 0); if (lhs % rhs != 0) { @@ -204,6 +211,7 @@ uint64_t AlignUp(uint64_t lhs, uint64_t rhs) { return lhs; } +// Traces memref values through view/cast/for chains until alloc-like root. Value tracebackMemRef(Value memrefVal) { int loopBound = 256; while (memrefVal && !isAllocLikeOp(memrefVal)) { @@ -225,6 +233,7 @@ Value tracebackMemRef(Value memrefVal) { return memrefVal; } +// Traces a memref to `memref.alloc` when the root is alloc-like. std::optional tracebackMemRefToAlloc(Value memrefVal) { auto tracedValue = tracebackMemRef(memrefVal); return isAllocLikeOp(tracedValue) @@ -237,6 +246,7 @@ bool isFromFunctionArg(mlir::Value v) { return tracebackMemRef(v).getDefiningOp() == nullptr; } +// Returns true when an address-space belongs to local on-chip memories. bool isLocalBuffer(std::optional memorySpaceAttr) { if (!memorySpaceAttr.has_value()) { return false; @@ -251,6 +261,7 @@ bool isLocalBuffer(std::optional memorySpaceAttr) { llvm_unreachable("Currently only support (UB | L1 | L0C) allocation"); } +// Collects all SSA buffers that an op reads/writes/produces. SmallVector getOpTouchBuffer(Operation *op) { SmallVector touchBuffer; touchBuffer.insert(touchBuffer.end(), op->getResults().begin(), @@ -261,6 +272,7 @@ SmallVector getOpTouchBuffer(Operation *op) { return touchBuffer; } +// True when any touched SSA value resolves to local memory space. bool isOpTouchLocalBuffer(Operation *op) { auto touchBuffer = getOpTouchBuffer(op); for (Value buffer : touchBuffer) { @@ -272,6 +284,7 @@ bool isOpTouchLocalBuffer(Operation *op) { return false; } +// Returns the outermost module op that contains `op`. ModuleOp getTopLevelModuleOp(Operation *op) { ModuleOp moduleOp = op->getParentOfType(); while (moduleOp && moduleOp->getParentOp()) { @@ -290,6 +303,7 @@ std::optional getYieldValueIdx(Value targetVal, ValueRange yieldedValues) { return std::nullopt; } +// Finds the nearest loop owner for a value, accounting for yielded results. LoopLikeOpInterface getParentLoop(Value val) { assert(val.getDefiningOp() && "val should have defining op."); @@ -339,4 +353,4 @@ LoopLikeOpInterface getParentLoop(Value val) { } } -} \ No newline at end of file +} diff --git a/lib/PTO/Transforms/Utils.h b/lib/PTO/Transforms/Utils.h index 2c8c7c43..256aeef8 100644 --- a/lib/PTO/Transforms/Utils.h +++ b/lib/PTO/Transforms/Utils.h @@ -23,24 +23,40 @@ namespace mlir { namespace pto { + /// Address spaces treated as local reusable buffer scopes in planning. const std::set LocalBufferSpace{ pto::AddressSpace::VEC, pto::AddressSpace::MAT, pto::AddressSpace::ACC, pto::AddressSpace::LEFT, pto::AddressSpace::RIGHT, pto::AddressSpace::BIAS, pto::AddressSpace::SCALING}; constexpr const uint8_t kBitsToByte = 8; + /// Returns the only `func.return` in a function, otherwise null. func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp); + /// Returns (result, source) when `op` is a generic alias/view-like op. + /// TileBuffer-specific alias semantics are provided by TileBufferSemantics. std::optional> getOperationAliasInfo(Operation *op); + /// Reads PTO address-space from memref values. std::optional GetBufferSpaceAttr(Value operand); + /// Returns true when the address-space belongs to local memory. bool isLocalBuffer(std::optional memorySpaceAttr); + /// Traces memref aliases/views to alloc-like roots. Value tracebackMemRef(Value memrefVal); + /// Computes static product of shape dims; returns nullopt on dynamic dims. std::optional getStaticTotalSize(const ArrayRef &shapes); + /// Rounds `lhs` up to `rhs` alignment. uint64_t AlignUp(uint64_t lhs, uint64_t rhs); + /// Returns nearest parent loop that semantically owns the value. LoopLikeOpInterface getParentLoop(Value val); + /// Gets the top-most module containing `op`. ModuleOp getTopLevelModuleOp(Operation *op); + /// Rewrites memref value's address-space while preserving shape/element type. void setBaseMemRefTypeScope(Value val, AddressSpaceAttr targetMemScope); + /// Builds a memref type with the same payload and a new address-space. BaseMemRefType getBaseMemRefTypeWithNewScope(BaseMemRefType type, AddressSpaceAttr targetMemScope); + /// Traces a memref value to `memref.alloc` when possible. std::optional tracebackMemRefToAlloc(Value memrefVal); + /// Returns true when the value ultimately comes from function arguments. bool isFromFunctionArg(mlir::Value v); + /// Returns true when an operation touches any local buffer value. bool isOpTouchLocalBuffer(Operation *op); } } -#endif \ No newline at end of file +#endif diff --git a/test/basic/set_validshape_local_lowering.pto b/test/basic/set_validshape_local_lowering.pto index 6ec0496f..59d41da4 100644 --- a/test/basic/set_validshape_local_lowering.pto +++ b/test/basic/set_validshape_local_lowering.pto @@ -17,8 +17,4 @@ module { // CHECK: Tile [[BASE:v[0-9]+]]; // CHECK: TASSIGN([[BASE]], [[ADDR:v[0-9]+]]); -// CHECK: Tile [[TILE:v[0-9]+]] = Tile({{.*}}) -// CHECK: __ubuf__ float* [[DATA:v[0-9]+]] = [[BASE]].data(); -// CHECK: uint64_t [[TILE_ADDR:v[0-9]+]] = reinterpret_cast([[DATA]]); -// CHECK: TASSIGN([[TILE]], [[TILE_ADDR]]); -// CHECK: [[TILE]].SetValidShape([[ROW:v[0-9]+]], [[COL:v[0-9]+]]) +// CHECK: [[BASE]].SetValidShape([[ROW:v[0-9]+]], [[COL:v[0-9]+]]) diff --git a/test/basic/tilebuf_auto_addr_assign.pto b/test/basic/tilebuf_auto_addr_assign.pto new file mode 100644 index 00000000..b5cd850c --- /dev/null +++ b/test/basic/tilebuf_auto_addr_assign.pto @@ -0,0 +1,23 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=PM +// RUN: ptoas %s | FileCheck %s --check-prefix=EMITC + +module { + func.func @tilebuf_auto_addr_assign() attributes {pto.entry} { + %buf = pto.alloc_tile + : !pto.tile_buf + pto.tprint ins(%buf : !pto.tile_buf) + return + } +} + +// PM: end PTO plan Mem! +// PM: func.func @tilebuf_auto_addr_assign +// PM-NOT: memref.alloc +// PM: pto.pointer_cast( + +// EMITC-LABEL: tilebuf_auto_addr_assign +// EMITC: TASSIGN( +// EMITC: TPRINT( diff --git a/test/basic/tilebuf_manual_addr_preserve.pto b/test/basic/tilebuf_manual_addr_preserve.pto new file mode 100644 index 00000000..e53b3d07 --- /dev/null +++ b/test/basic/tilebuf_manual_addr_preserve.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-level=level3 --pto-arch a5 %s | FileCheck %s + +module { + func.func @tilebuf_manual_addr_preserve() attributes {pto.entry} { + %c4096_i64 = arith.constant 4096 : i64 + %buf = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + pto.tprint ins(%buf : !pto.tile_buf) + return + } +} + +// CHECK-LABEL: tilebuf_manual_addr_preserve +// CHECK: 4096 +// CHECK: TASSIGN( +// CHECK: TPRINT( diff --git a/test/basic/tilebuf_root_trace.pto b/test/basic/tilebuf_root_trace.pto new file mode 100644 index 00000000..4e73f33c --- /dev/null +++ b/test/basic/tilebuf_root_trace.pto @@ -0,0 +1,33 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @tilebuf_root_trace(%arg0: memref<16x16xf16, #pto.address_space>, + %arg1: memref<16x16xi16, #pto.address_space>) { + %c0 = arith.constant 0 : index + + %base = pto.alloc_tile + : !pto.tile_buf + %sub = pto.subset %base[%c0, %c0] sizes [16, 16] + : !pto.tile_buf + %cast = pto.bitcast %sub + : !pto.tile_buf + -> !pto.tile_buf + + pto.tload ins(%arg0 : memref<16x16xf16, #pto.address_space>) + outs(%base : !pto.tile_buf) + pto.tstore ins(%cast : !pto.tile_buf) + outs(%arg1 : memref<16x16xi16, #pto.address_space>) + return + } +} + +// CHECK: end PTO plan Mem! +// CHECK: func.func @tilebuf_root_trace +// CHECK-NOT: PlanMemory Fail : Unrecognized type of Operation touches local buffer! +// CHECK: pto.pointer_cast( diff --git a/test/basic/tilebuf_semantic_smoke.pto b/test/basic/tilebuf_semantic_smoke.pto new file mode 100644 index 00000000..702c12df --- /dev/null +++ b/test/basic/tilebuf_semantic_smoke.pto @@ -0,0 +1,27 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @tilebuf_semantic_smoke(%arg0: memref<32x32xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + + %buf = pto.alloc_tile + : !pto.tile_buf + %sub = pto.subset %buf[%c0, %c0] sizes [16, 32] + : !pto.tile_buf + + pto.tload ins(%arg0 : memref<32x32xf16, #pto.address_space>) + outs(%buf : !pto.tile_buf) + pto.tprint ins(%sub : !pto.tile_buf) + return + } +} + +// CHECK: end PTO plan Mem! +// CHECK: func.func @tilebuf_semantic_smoke +// CHECK-NOT: PlanMemory Fail : Unrecognized type of Operation touches local buffer! +// CHECK-NOT: memref.alloc +// CHECK: pto.pointer_cast( diff --git a/test/samples/planmemory/tilebuf_alias_chain.py b/test/samples/planmemory/tilebuf_alias_chain.py new file mode 100644 index 00000000..23009f83 --- /dev/null +++ b/test/samples/planmemory/tilebuf_alias_chain.py @@ -0,0 +1,36 @@ +PTO_IR = r""" +module { + func.func @tilebuf_alias_chain(%arg0: memref<32x32xf16, #pto.address_space>, + %arg1: memref<32x16xi16, #pto.address_space>) { + %c0 = arith.constant 0 : index + + %base = pto.alloc_tile + : !pto.tile_buf + %sub = pto.subset %base[%c0, %c0] sizes [16, 32] + : !pto.tile_buf + %reshape = pto.treshape %sub + : !pto.tile_buf + -> !pto.tile_buf + %cast = pto.bitcast %reshape + : !pto.tile_buf + -> !pto.tile_buf + + pto.tload ins(%arg0 : memref<32x32xf16, #pto.address_space>) + outs(%base : !pto.tile_buf) + pto.tstore ins(%cast : !pto.tile_buf) + outs(%arg1 : memref<32x16xi16, #pto.address_space>) + return + } +} +""" + +if __name__ == "__main__": + print(PTO_IR) diff --git a/test/samples/planmemory/tilebuf_planmemory_auto_addr.py b/test/samples/planmemory/tilebuf_planmemory_auto_addr.py new file mode 100644 index 00000000..7e012c8c --- /dev/null +++ b/test/samples/planmemory/tilebuf_planmemory_auto_addr.py @@ -0,0 +1,16 @@ +PTO_IR = r""" +module { + func.func @tilebuf_planmemory_auto_addr() attributes {pto.entry} { + %buf = pto.alloc_tile + : !pto.tile_buf + pto.tprint ins(%buf : !pto.tile_buf) + return + } +} +""" + +if __name__ == "__main__": + print(PTO_IR) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 3c21e022..68dedc9d 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -515,11 +515,12 @@ process_one_dir() { fi # Regression guard for Issue #207: - # SSA `pto.treshape` (lowered into `pto.bind_tile`) must lower to a single - # `TRESHAPE(dst, src)` instead of an invalid Tile-to-pointer cast sequence. + # SSA view-like ops must preserve tile alias semantics in EmitC. Depending on + # the lowering path, this may appear as `TRESHAPE(dst, src)` or as + # pointer-cast-style sibling tile rebinding (`TASSIGN`). if [[ "$base" == "reshape" ]]; then - if ! grep -Fq "TRESHAPE(" "$cpp"; then - echo -e "${A}(${base}.py) FAIL missing TRESHAPE() lowering for SSA treshape" + if ! grep -Fq "TRESHAPE(" "$cpp" && ! grep -Fq "TASSIGN(" "$cpp"; then + echo -e "${A}(${base}.py) FAIL missing alias-preserving lowering (TRESHAPE/TASSIGN) for SSA treshape" overall=1 continue fi @@ -530,6 +531,14 @@ process_one_dir() { fi fi + if [[ "$base" == "tilebuf_alias_chain" ]]; then + if ! grep -Fq "TASSIGN(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing TASSIGN() lowering in alias chain" + overall=1 + continue + fi + fi + if [[ "$base" == "bitcast_dtype_alias" ]]; then if ! grep -Eq "Tile<[^>]*, int32_t," "$cpp"; then echo -e "${A}(${base}.py) FAIL missing int32_t Tile declaration for pto.bitcast" @@ -553,6 +562,19 @@ process_one_dir() { fi fi + if [[ "$base" == "tilebuf_planmemory_auto_addr" ]]; then + if ! grep -Fq "TASSIGN(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing TASSIGN() lowering for auto-address tilebuf" + overall=1 + continue + fi + if ! grep -Fq "TPRINT(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing TPRINT() lowering for auto-address tilebuf" + overall=1 + continue + fi + fi + # Regression guard for Issue #207 follow-up: # `pto.bitcast` must alias the original tile storage via # `TASSIGN(dst, reinterpret_cast(src.data()))`. diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 017018d2..998e8105 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -880,7 +880,8 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); - pm.addPass(pto::createPTOViewToMemrefPass()); + // Tile-native pipeline: keep tile_buf descriptors through PlanMemory/EmitC. + // The legacy memref bridge pass is intentionally disabled here. // bufferizationPipeline(pm); //pm.addPass(createInferPTOMemScopePass()); From 57a88831d79a67fd503aa2982a6e5753f462685d Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 26 Mar 2026 15:30:43 +0800 Subject: [PATCH 2/2] PlanMemory: enforce tilebuf-only local planning and diagnostics --- lib/PTO/Transforms/PTOPlanMemory.cpp | 255 ++++++++++-------- lib/PTO/Transforms/PTOPlanMemory.h | 29 +- lib/PTO/Transforms/TileBufferSemantics.cpp | 53 +++- lib/PTO/Transforms/TileBufferSemantics.h | 13 +- ...erve_buffer_manual_reject_nested_alloc.pto | 2 +- ...al_reserve_reject_unplanned_tile_alloc.pto | 22 ++ .../planmemory_reject_local_memref_alloc.pto | 11 + .../planmemory_semantics_infer_fail_diag.pto | 15 ++ 8 files changed, 269 insertions(+), 131 deletions(-) create mode 100644 test/basic/planmemory_manual_reserve_reject_unplanned_tile_alloc.pto create mode 100644 test/basic/planmemory_reject_local_memref_alloc.pto create mode 100644 test/basic/planmemory_semantics_infer_fail_diag.pto diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 1c75979e..329574ec 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #define DEBUG_TYPE "pto-plan-memory" @@ -217,17 +218,40 @@ static LogicalResult assignAutoReserveBufferBase( return success(); } +// Verifies that manual reserve-buffer mode has no unresolved local allocations. +// +// In `auto=false` mode, local addresses are already fixed by contract. +// Any local `alloc_tile` without explicit `addr` still needs PlanMemory +// assignment and must be rejected. Legacy local `memref.alloc` is also +// rejected because Phase-2 PlanMemory is tilebuf-only. static LogicalResult verifyManualReserveBufferMode(func::FuncOp funcOp) { LogicalResult result = success(); - funcOp.walk([&](memref::AllocOp allocOp) { - auto memorySpaceAttr = GetBufferSpaceAttr(allocOp.getResult()); + auto walkResult = funcOp.walk([&](Operation *op) { + if (auto allocTileOp = dyn_cast(op)) { + auto memorySpaceAttr = getPlanningBufferSpaceAttr(allocTileOp.getResult()); + if (!isLocalBuffer(memorySpaceAttr)) + return WalkResult::advance(); + if (allocTileOp.getAddr()) + return WalkResult::advance(); + result = allocTileOp.emitOpError( + "cannot use pto.reserve_buffer with auto = false when local " + "pto.alloc_tile still requires PlanMemory address assignment"); + return WalkResult::interrupt(); + } + + auto allocOp = dyn_cast(op); + if (!allocOp) + return WalkResult::advance(); + auto memorySpaceAttr = getPlanningBufferSpaceAttr(allocOp.getResult()); if (!isLocalBuffer(memorySpaceAttr)) return WalkResult::advance(); - result = allocOp.emitOpError("cannot use pto.reserve_buffer with auto = " - "false when local memref.alloc " - "still requires PlanMemory allocation"); + result = allocOp.emitOpError( + "PlanMemory is tilebuf-only: local memref.alloc is unsupported; use " + "pto.alloc_tile"); return WalkResult::interrupt(); }); + if (walkResult.wasInterrupted()) + return failure(); return result; } @@ -272,14 +296,20 @@ static LogicalResult verifyManualReserveBufferMode(func::FuncOp funcOp) { } // namespace // Entry point that builds linear op order, alias map and lifetime intervals. -void MemLivenessAnalysis::build() { +// Returns failure when IR traversal already emitted a validation diagnostic. +LogicalResult MemLivenessAnalysis::build() { Region &funcRegion = func_.getBody(); Liveness live(func_); + hasAnalysisError = false; // Recursively obtaining IR information. RecursionIR(&funcRegion, live); + if (hasAnalysisError) { + return failure(); + } // the lifetime of the buffer. GenerateBufferLife(); //InitializeInplacePairList(); + return success(); } // True when planning mode is local on-chip memory allocation. @@ -298,12 +328,22 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { // 2) local-buffer definitions, // 3) gen/kill events used by memory planning. auto result = region->walk([&](Operation *op) { + if (hasAnalysisError) { + return WalkResult::interrupt(); + } + // recursive control flow if (auto ifOp = dyn_cast(op)) { RecursiveIfOp(ifOp, live); + if (hasAnalysisError) { + return WalkResult::interrupt(); + } return WalkResult::skip(); } else if (auto forOp = dyn_cast(op)) { RecursiveForOp(forOp, live); + if (hasAnalysisError) { + return WalkResult::interrupt(); + } return WalkResult::skip(); } @@ -313,14 +353,24 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { if (mayAliasOp.has_value()) { auto aliasPair = mayAliasOp.value(); UpdateBufferAlias(aliasPair.first, aliasPair.second); - // Local-memory planning now accepts both legacy memref.alloc and - // tile-native pto.alloc_tile as defining points. - } else if (isLocalMemPlan() && - (isa(op))) { + // Local-memory planning only accepts tile-native defining points. + } else if (isLocalMemPlan() && isa(op)) { if (failed(CheckLocalBufferDefOp(op))) { return WalkResult::interrupt(); } - UpdateOpBufferInfo(op, op->getResults()); + if (failed(UpdateOpBufferInfo(op, op->getResults()))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } else if (isLocalMemPlan() && isa(op)) { + auto allocOp = cast(op); + auto memorySpaceAttr = getPlanningBufferSpaceAttr(allocOp.getResult()); + if (isLocalBuffer(memorySpaceAttr)) { + allocOp.emitOpError( + "PlanMemory is tilebuf-only: local memref.alloc is unsupported; " + "use pto.alloc_tile"); + return WalkResult::interrupt(); + } return WalkResult::advance(); // } else if (isGlobalWorkSpaceMemPlan() && // dyn_cast(op)) { @@ -382,7 +432,8 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { return WalkResult::advance(); }); if (result == WalkResult::interrupt()) { - llvm_unreachable("PlanMemory Traverse IR Failed! "); + hasAnalysisError = true; + return; } } @@ -444,6 +495,9 @@ void MemLivenessAnalysis::RecursiveForOp(scf::ForOp forOp, Liveness live) { UpdateOpGenInfo(forBeginSeq, GetLiveBuffersInLoop(forOp, live)); UpdateForOpInitArgsAlias(forOp); RecursionIR(&forOp.getRegion(), live); + if (hasAnalysisError) { + return; + } UpdateForOpBufferAlias(forOp); auto forEndSeq = UpdateLinearOperation(forOp.getOperation()); OpKillHandle(forEndSeq, live, forOp->getBlock()); @@ -478,14 +532,20 @@ void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) { // scf.yield %alloc0: memref<16xf16, #pto.address_space> // else: // scf.yield %alloc1 : memref<16xf16, #pto.address_space> - auto curIfThen = UpdateLinearOperation(ifOp.getOperation()); + UpdateLinearOperation(ifOp.getOperation()); RecursionIR(&ifOp.getThenRegion(), live); + if (hasAnalysisError) { + return; + } auto curIfElse = UpdateLinearOperation(ifOp.getOperation()); UpdateIfOpBufferAlias(ifOp, ifOp.thenYield()); auto curIfEnd = curIfElse; if (ifOp.elseBlock()) { RecursionIR(&ifOp.getElseRegion(), live); + if (hasAnalysisError) { + return; + } curIfEnd = UpdateLinearOperation(ifOp.getOperation()); UpdateIfOpBufferAlias(ifOp, ifOp.elseYield()); } @@ -537,25 +597,21 @@ SmallVector MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp, // buffer2MultiNum[markOp.getSrc()] = static_cast(valAttr.getInt()); // } -// Validates local buffer defining ops and rejects non-local address-space. +// Validates that a local PlanMemory defining op is tile-native alloc_tile and +// that its result is in a supported local address space. LogicalResult MemLivenessAnalysis::CheckLocalBufferDefOp(Operation *op) const { - // Validate the defining op shape: this helper is intentionally limited to - // ops that create local buffers participating in PlanMemory. - Value defBuffer; - if (auto allocOp = dyn_cast(op)) { - defBuffer = allocOp.getResult(); - } else if (auto allocTileOp = dyn_cast(op)) { - defBuffer = allocTileOp.getResult(); - } else { - op->emitError("expects local buffer defining op"); + auto allocTileOp = dyn_cast(op); + if (!allocTileOp) { + op->emitError("expects local buffer defining op pto.alloc_tile"); return failure(); } + Value defBuffer = allocTileOp.getResult(); auto memorySpaceAttr = getPlanningBufferSpaceAttr(defBuffer); if (isLocalBuffer(memorySpaceAttr)) { return success(); } - op->emitError("Alloc buffer not at local memory space!"); + op->emitError("expects pto.alloc_tile result in local address space"); return failure(); } @@ -617,8 +673,7 @@ void MemLivenessAnalysis::UpdateBufferAlias(Value buffer, Value aliasBuffer, // buffer2status[buffer] = BufferStatus::UNDEFFINED; // } - // mark the alias buffer as ignoring Inplace if it is not generated by - // memref.alloc. + // Mark alias values as ignore-inplace when they are transient view results. auto it = bufferInfos.find(aliasBuffer); if (isIgnoreInplace && it != bufferInfos.end()) { it->second.ignoreInplace = true; @@ -652,16 +707,23 @@ void MemLivenessAnalysis::UpdateStoreOpInfo(OpInfo *opInfo, OpKillHandle(opInfo, live, opInfo->operation->getBlock()); } -void MemLivenessAnalysis::UpdateOpBufferInfo(Operation *op, - const ValueRange &results) { +// Builds BufferInfo records for op results and marks them as defined. +// Fails when any result cannot be expressed as tile-buffer semantics. +LogicalResult MemLivenessAnalysis::UpdateOpBufferInfo( + Operation *op, const ValueRange &results) { for (const Value &operand : results) { auto it = buffer2status.find(operand); if (it != buffer2status.end()) { continue; } - bufferInfos[operand] = GenerateBufferInfo(op, operand); + BufferInfo bufferInfo; + if (failed(GenerateBufferInfo(op, operand, bufferInfo))) { + return failure(); + } + bufferInfos[operand] = std::move(bufferInfo); buffer2status[operand] = BufferStatus::DEFFINED; } + return success(); } void MemLivenessAnalysis::UpdateOpGenInfo(OpInfo *opInfo, @@ -737,91 +799,70 @@ bool MemLivenessAnalysis::AllDeadAfter(Operation *op, SetVector aliasVec, return true; } -// Dispatches to local/global buffer-info builders based on memory scope. -BufferInfo MemLivenessAnalysis::GenerateBufferInfo(Operation *op, - Value operand) { +// Dispatches to the tilebuf-only buffer-info builder based on memory scope. +// Fails if the value is not a local plannable tile buffer. +LogicalResult MemLivenessAnalysis::GenerateBufferInfo(Operation *op, + Value operand, + BufferInfo &out) { auto memorySpaceAttr = getPlanningBufferSpaceAttr(operand); if (isLocalMemPlan() && isLocalBuffer(memorySpaceAttr)) { assert(memorySpaceAttr.has_value() && "buffer must has space!"); - return GetBufferInfo(op, operand, - memorySpaceAttr.value().getAddressSpace()); + return GetBufferInfo(op, operand, memorySpaceAttr.value().getAddressSpace(), + out); } // } else if (isGlobalWorkSpaceMemPlan() && // isa( // operand.getDefiningOp())) { // return GetBufferInfo(op, operand, pto::AddressSpace::GM); // } - llvm_unreachable("buffer must has BufferInfo !"); + op->emitError("expects local tile buffer result for PlanMemory"); + return failure(); } -BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, - pto::AddressSpace bufferScope) { - // Build normalized buffer metadata consumed by PlanMemory without coupling - // to a memref-only representation. +// Resolves normalized tile-buffer semantics and materializes one BufferInfo. +// Fails with a location-aware diagnostic when semantic inference is incomplete. +LogicalResult MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, + pto::AddressSpace bufferScope, + BufferInfo &out) { BufferInfo bufferInfo; bufferInfo.operation = op; bufferInfo.bufferScope = bufferScope; - // Prefer tile-native semantic extraction. This keeps PlanMemory input - // independent from a specific memref-only view chain. + std::string failureReason; TileBufferSemantics semantics; - if (succeeded(inferTileBufferSemantics(operand, semantics)) && - semantics.constBits > 0) { - bufferInfo.rootBuffer = semantics.root; - bufferInfo.bufferScope = semantics.scope; - bufferInfo.bufferType = semantics.elementType; - bufferInfo.bufferShape = semantics.shape; - bufferInfo.bufferValidShape = semantics.validShape; - bufferInfo.tileConfig = semantics.config; - bufferInfo.viewKind = semantics.viewKind; - bufferInfo.constBits = semantics.constBits; - return bufferInfo; - } - - // Fallback path: keep legacy sizing behavior for boundary cases where - // tile semantics are not fully recoverable in this phase. - Value traceValue = tracebackMemRef(operand); - if (auto memRefType = dyn_cast(traceValue.getType())) { - bufferInfo.rootBuffer = traceValue; - bufferInfo.bufferType = memRefType.getElementType(); - bufferInfo.bufferShape.assign(memRefType.getShape().begin(), - memRefType.getShape().end()); - bufferInfo.bufferValidShape = bufferInfo.bufferShape; - std::optional totalStaticSize = - getStaticTotalSize(memRefType.getShape()); - assert(totalStaticSize.has_value() && - "Failed to obtain op buffer shape size!"); - bufferInfo.constBits = - totalStaticSize.value() * - static_cast(memRefType.getElementTypeBitWidth()); - return bufferInfo; - } - - if (auto tileType = dyn_cast(traceValue.getType())) { - bufferInfo.rootBuffer = traceValue; - bufferInfo.bufferType = tileType.getElementType(); - bufferInfo.bufferShape.assign(tileType.getShape().begin(), - tileType.getShape().end()); - bufferInfo.bufferValidShape.assign(tileType.getValidShape().begin(), - tileType.getValidShape().end()); - bufferInfo.tileConfig = tileType.getConfigAttr(); - std::optional totalStaticSize = - getStaticTotalSize(tileType.getShape()); - assert(totalStaticSize.has_value() && - "Failed to obtain tile buffer shape size!"); - int64_t elemBits = 0; - if (auto intTy = dyn_cast(bufferInfo.bufferType)) - elemBits = intTy.getWidth(); - else if (auto floatTy = dyn_cast(bufferInfo.bufferType)) - elemBits = floatTy.getWidth(); - else if (isa(bufferInfo.bufferType)) - elemBits = 64; - assert(elemBits > 0 && "Unsupported element type for tile buffer sizing"); - bufferInfo.constBits = totalStaticSize.value() * elemBits; - return bufferInfo; - } - - llvm_unreachable("Failed to infer buffer info"); + if (failed(inferTileBufferSemantics(operand, semantics, &failureReason))) { + auto diag = + op->emitOpError("failed to infer tile buffer semantics for PlanMemory"); + if (!failureReason.empty()) { + diag << " (reason: " << failureReason << ")"; + } + diag.attachNote(operand.getLoc()) << "buffer value: " << operand; + if (Operation *def = operand.getDefiningOp()) { + diag.attachNote(def->getLoc()) << "defining op: " << def->getName(); + } else if (auto arg = dyn_cast(operand)) { + diag.attachNote(arg.getLoc()) + << "value is block argument #" << arg.getArgNumber(); + } + return failure(); + } + + if (semantics.constBits <= 0) { + op->emitOpError( + "failed to infer tile buffer semantics for PlanMemory: " + "constBits must be positive"); + return failure(); + } + + bufferInfo.rootBuffer = semantics.root; + bufferInfo.bufferScope = semantics.scope; + bufferInfo.bufferType = semantics.elementType; + bufferInfo.bufferShape = semantics.shape; + bufferInfo.bufferValidShape = semantics.validShape; + bufferInfo.tileConfig = semantics.config; + bufferInfo.viewKind = semantics.viewKind; + bufferInfo.constBits = semantics.constBits; + out = std::move(bufferInfo); + return success(); } // void MemLivenessAnalysis::InitializeInplacePairList() { @@ -1429,11 +1470,7 @@ void MemPlan::ReportMemLifeDebugInfo(StorageEntry *rootStorageEntry) { void MemPlan::MemLifeDebugInfo(StorageEntry *storageEntry) { for (auto &buffer : storageEntry->inplaceBuffers) { - if (buffer.getDefiningOp()) { - if (auto allocOp = dyn_cast(buffer.getDefiningOp())) { - LDBG("Buffer : " << allocOp.getResult() << "\n"); - } - } + LDBG("Buffer : " << buffer << "\n"); } for (auto &bufferLife : storageEntry->bufferLifeVec) { LDBG("bufferLife : " @@ -1445,12 +1482,8 @@ void MemPlan::MemLifeDebugInfo(StorageEntry *storageEntry) { void MemPlan::ReportCurEntryDebugInfo(const StorageEntry *curEntry) { for (auto &buffer : curEntry->inplaceBuffers) { - if (buffer.getDefiningOp()) { - if (auto allocOp = dyn_cast(buffer.getDefiningOp())) { - LDBG("buffer : "); - LDBG(allocOp.getResult()); - } - } + LDBG("buffer : "); + LDBG(buffer); } } @@ -2261,7 +2294,9 @@ void PlanMemoryPass::runOnOperation() { } MemLivenessAnalysis memLiveness(funcOp, this->memMode); - memLiveness.build(); + if (failed(memLiveness.build())) { + return signalPassFailure(); + } MemPlan memPlan(this->memMode, this->enableGlobalReuse, this->enablePrintMemoryAllocatedSize, diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 132444da..c2efe024 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -268,7 +268,9 @@ class MemLivenessAnalysis { MemLivenessAnalysis(func::FuncOp func, MemPlanMode planMode) : func_(func), planMode(planMode) {} - void build(); + /// Builds alias/gen-kill/lifetime data used by PlanMemory. + /// Returns failure if traversal encounters unsupported local-memory patterns. + LogicalResult build(); /// linear operation info. SmallVector> linearOperation; @@ -324,15 +326,19 @@ class MemLivenessAnalysis { /// Update and obtain op info information. OpInfo *UpdateLinearOperation(Operation *op); - /// Obtain all information about the buffer. - void UpdateOpBufferInfo(Operation *op, const ValueRange &results); + /// Materialize planning buffer-info for op results. + /// Returns failure when any result cannot be expressed as tilebuf semantics. + LogicalResult UpdateOpBufferInfo(Operation *op, const ValueRange &results); - /// Generate buffer info. - BufferInfo GenerateBufferInfo(Operation *op, Value operand); + /// Build planning metadata for one operand. + /// Returns failure when tilebuf semantic inference fails. + LogicalResult GenerateBufferInfo(Operation *op, Value operand, + BufferInfo &out); - /// Obtain the buffer info of plan operation. - BufferInfo GetBufferInfo(Operation *op, Value operand, - pto::AddressSpace bufferScope); + /// Populate `out` from tilebuf semantic inference in the target scope. + /// No memref fallback is allowed in tilebuf-only PlanMemory mode. + LogicalResult GetBufferInfo(Operation *op, Value operand, + pto::AddressSpace bufferScope, BufferInfo &out); /// Process gen buffer based on the result value of op. void UpdateOpGenInfo(OpInfo *opInfo, const ValueRange &results); @@ -366,8 +372,8 @@ class MemLivenessAnalysis { /// Update store op information. void UpdateStoreOpInfo(OpInfo *opInfo, const Value storeValue, Liveness live); - /// Check whether a local-buffer defining op (memref.alloc / pto.alloc_tile) - /// is placed in a supported local address space. + /// Check whether a local-buffer defining op (`pto.alloc_tile`) is placed in + /// a supported local address space. LogicalResult CheckLocalBufferDefOp(Operation *op) const; /// kill buffer handle. @@ -402,6 +408,9 @@ class MemLivenessAnalysis { /// map on buffer alias DenseMap> buffer2AliasVec; + /// Set when IR traversal already emitted a semantic/validation diagnostic. + bool hasAnalysisError{false}; + int seqIndex{0}; }; diff --git a/lib/PTO/Transforms/TileBufferSemantics.cpp b/lib/PTO/Transforms/TileBufferSemantics.cpp index 68e3e74b..9b7c5644 100644 --- a/lib/PTO/Transforms/TileBufferSemantics.cpp +++ b/lib/PTO/Transforms/TileBufferSemantics.cpp @@ -230,14 +230,24 @@ bool isOpTouchPlannableLocalBuffer(Operation *op) { return false; } +// Records semantic inference failure reason when callers request diagnostics. +static LogicalResult failSemantics(std::string *failureReason, + llvm::StringRef message) { + if (failureReason) { + *failureReason = message.str(); + } + return failure(); +} + // Builds normalized planning semantics from a value: // - root: traced storage owner // - scope: local memory space // - shape/valid/config/view-kind // - constBits: static bytes in bits -LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out) { +LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out, + std::string *failureReason) { if (!value) - return failure(); + return failSemantics(failureReason, "value is null"); out = TileBufferSemantics{}; out.value = value; @@ -247,7 +257,8 @@ LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out) { } else if (auto as = getPlanningBufferSpaceAttr(out.root)) { out.scope = as->getAddressSpace(); } else { - return failure(); + return failSemantics(failureReason, + "failed to resolve address-space from value/root"); } // Prefer root storage type for size calculation and keep queried type as @@ -260,7 +271,24 @@ LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out) { out.validShape, out.config); } if (!decoded) - return failure(); + return failSemantics( + failureReason, + "failed to decode shape/element/config from root/value type"); + + if (out.shape.empty()) { + return failSemantics(failureReason, "decoded shape is empty"); + } + for (int64_t dim : out.shape) { + if (ShapedType::isDynamic(dim)) { + return failSemantics( + failureReason, + "dynamic shape is unsupported for PlanMemory static sizing"); + } + if (dim <= 0) { + return failSemantics(failureReason, + "shape dimensions must be positive"); + } + } if (auto def = value.getDefiningOp()) { out.viewKind = getTileViewKind(def); @@ -273,11 +301,20 @@ LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out) { auto staticSize = getStaticTotalSize(out.shape); int64_t elemBits = getElemBitWidth(out.elementType); - if (staticSize.has_value() && elemBits > 0) { - out.constBits = staticSize.value() * elemBits; - return success(); + if (!staticSize.has_value()) { + return failSemantics(failureReason, + "failed to compute static element count from shape"); } - return failure(); + if (staticSize.value() <= 0) { + return failSemantics(failureReason, + "static element count must be positive"); + } + if (elemBits <= 0) { + return failSemantics(failureReason, + "unsupported element type bit-width for sizing"); + } + out.constBits = staticSize.value() * elemBits; + return success(); } } // namespace pto diff --git a/lib/PTO/Transforms/TileBufferSemantics.h b/lib/PTO/Transforms/TileBufferSemantics.h index ac2ace96..a8ccff4c 100644 --- a/lib/PTO/Transforms/TileBufferSemantics.h +++ b/lib/PTO/Transforms/TileBufferSemantics.h @@ -10,6 +10,7 @@ #include "llvm/ADT/SmallVector.h" #include +#include #include namespace mlir { @@ -63,8 +64,16 @@ Value tracebackBufferRoot(Value value); bool isOpTouchPlannableLocalBuffer(Operation *op); /// Infers normalized tile semantics (scope/shape/valid/config/root/bytes). -/// Returns failure when static bits cannot be proven. -LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out); +/// Returns failure when static bits cannot be proven. If `failureReason` is +/// provided, it will be filled with a concise cause. +LogicalResult inferTileBufferSemantics(Value value, TileBufferSemantics &out, + std::string *failureReason); + +/// Convenience wrapper for callers that do not need failure diagnostics. +inline LogicalResult inferTileBufferSemantics(Value value, + TileBufferSemantics &out) { + return inferTileBufferSemantics(value, out, nullptr); +} } // namespace pto } // namespace mlir diff --git a/test/basic/plan_memory_reserve_buffer_manual_reject_nested_alloc.pto b/test/basic/plan_memory_reserve_buffer_manual_reject_nested_alloc.pto index c78d954b..d262ab18 100644 --- a/test/basic/plan_memory_reserve_buffer_manual_reject_nested_alloc.pto +++ b/test/basic/plan_memory_reserve_buffer_manual_reject_nested_alloc.pto @@ -22,4 +22,4 @@ module { } } -// CHECK: error: 'memref.alloc' op cannot use pto.reserve_buffer with auto = false when local memref.alloc still requires PlanMemory allocation +// CHECK: error: 'memref.alloc' op PlanMemory is tilebuf-only: local memref.alloc is unsupported; use pto.alloc_tile diff --git a/test/basic/planmemory_manual_reserve_reject_unplanned_tile_alloc.pto b/test/basic/planmemory_manual_reserve_reject_unplanned_tile_alloc.pto new file mode 100644 index 00000000..ff91e9de --- /dev/null +++ b/test/basic/planmemory_manual_reserve_reject_unplanned_tile_alloc.pto @@ -0,0 +1,22 @@ +// RUN: not ptoas %s 2>&1 | FileCheck %s + +module { + func.func @manual_reserve_reject_unplanned_tile_alloc() { + %fifo = pto.reserve_buffer { + name = "fifo", + size = 8192, + location = #pto.address_space, + auto = false, + base = 0 + } -> i32 + %buf = pto.alloc_tile + : !pto.tile_buf + pto.tprint ins(%buf : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.alloc_tile' op cannot use pto.reserve_buffer with auto = false when local pto.alloc_tile still requires PlanMemory address assignment diff --git a/test/basic/planmemory_reject_local_memref_alloc.pto b/test/basic/planmemory_reject_local_memref_alloc.pto new file mode 100644 index 00000000..b05678f9 --- /dev/null +++ b/test/basic/planmemory_reject_local_memref_alloc.pto @@ -0,0 +1,11 @@ +// RUN: not ptoas %s 2>&1 | FileCheck %s + +module { + func.func @reject_local_memref_alloc() { + %ub = memref.alloc() : memref<16x16xf16, #pto.address_space> + memref.dealloc %ub : memref<16x16xf16, #pto.address_space> + return + } +} + +// CHECK: error: 'memref.alloc' op PlanMemory is tilebuf-only: local memref.alloc is unsupported; use pto.alloc_tile diff --git a/test/basic/planmemory_semantics_infer_fail_diag.pto b/test/basic/planmemory_semantics_infer_fail_diag.pto new file mode 100644 index 00000000..6f217305 --- /dev/null +++ b/test/basic/planmemory_semantics_infer_fail_diag.pto @@ -0,0 +1,15 @@ +// RUN: not ptoas %s 2>&1 | FileCheck %s + +module { + func.func @semantics_infer_fail_diag() { + %buf = pto.alloc_tile + : !pto.tile_buf + pto.tprint ins(%buf : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.alloc_tile' op failed to infer tile buffer semantics for PlanMemory (reason: shape dimensions must be positive)