Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/designs/tilebuf-planmemory-phase1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Tile Buffer -> PlanMemory (Phase-1)

## Scope
- Base: `origin/main`
- Phase-1 only: `tile_buffer -> PlanMemory`
- Explicitly out of scope:
- Sync migration
- New MultiBuffer capabilities

## Why
PlanMemory previously consumed mainly memref-centric alias/shape/space signals.
Tile metadata (`bind_tile/subset/bitcast/treshape`) was available but not normalized
as a reusable semantic layer.

This phase introduces a tile semantic input path while keeping the core planner
(`MultiSpecPlan`, rollback/reuse) unchanged.

## Changes
1. Unified tile semantic extraction in `Utils`:
- alias unification: `bind_tile/subset/bitcast/treshape` + memref view-like ops
- root traceback: `tracebackBufferRoot(...)`
- semantic record: `TileBufferSemantics` (root/scope/shape/valid/config/view-kind/bits)

2. PlanMemory liveness/buffer info wiring:
- `MemLivenessAnalysis` uses unified alias API
- local buffer definition accepts `memref.alloc` and `pto.alloc_tile`
- `GetBufferInfo` prefers tile-native semantic extraction and keeps a legacy fallback

3. No algorithm rewrite:
- Allocation/reuse/rollback algorithm unchanged
- Boundary fallback remains internal (no new user-visible switch)

## Capability -> Test Mapping
- Unified semantic smoke:
- `test/basic/tilebuf_semantic_smoke.pto`
- Alias/root trace across bind + view chain:
- `test/basic/tilebuf_root_trace.pto`
- View-like alias chain (`subset -> treshape -> bitcast`) stability:
- `test/samples/planmemory/tilebuf_alias_chain.py`
- `test/samples/runop.sh` check for `TRESHAPE` + `TASSIGN`
- PlanMemory auto-address reachability:
- `test/basic/tilebuf_auto_addr_assign.pto`
- `test/samples/planmemory/tilebuf_planmemory_auto_addr.py`
- `test/samples/runop.sh` check for `TASSIGN` + `TPRINT`
- Address contract (manual addr preserve):
- `test/basic/tilebuf_manual_addr_preserve.pto`

## Next (Phase-2)
- Sync analysis/input migration to consume the same tile semantic layer.
- Remove remaining internal fallback branches after boundary coverage is complete.
4 changes: 2 additions & 2 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def TMovOp : PTO_TOp<"tmov", [
//===----------------------------------------------------------------------===//

def PointerCastOp : PTO_Op<"pointer_cast", [AttrSizedOperandSegments, Pure]> {
let summary = "Casts an integer address to a MemRef with optional valid dims";
let summary = "Binds an integer address to a tile buffer descriptor";

// 参数定义 (保持 Optional)
let arguments = (ins
Expand All @@ -1017,7 +1017,7 @@ def PointerCastOp : PTO_Op<"pointer_cast", [AttrSizedOperandSegments, Pure]> {
OptionalAttr<TileBufConfigAttr>:$config
);

let results = (outs Res<AnyMemRef, "", [MemAlloc]>:$result);
let results = (outs TileBufType:$result);

// Assembly Format (去掉了 [])
let assemblyFormat = [{
Expand Down
125 changes: 59 additions & 66 deletions lib/PTO/Transforms/AllocToPointerCast.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AllocToPointerCast.cpp - convert memref.AllocOp to pto.pointercastOp.//
//===- AllocToPointerCast.cpp - convert alloc_tile to pto.pointer_cast. -------//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -8,7 +8,7 @@

#include "AllocToPointerCast.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
Expand All @@ -22,29 +22,21 @@ using namespace mlir::pto;

namespace {} // namespace

LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite(
memref::AllocOp op, PatternRewriter &rewriter) const {
const auto &currentMemRefType = cast<BaseMemRefType>(op.getType());

// Preserve tile config carried by the downstream bind_tile user. Losing this
// metadata here makes PointerCast lowering fall back to RowMajor defaults,
// which can generate illegal intermediate TRESHAPE sequences.
TileBufConfigAttr configAttr;
for (Operation *user : op.getResult().getUsers()) {
auto bind = dyn_cast<pto::BindTileOp>(user);
if (!bind || bind.getSource() != op.getResult())
continue;
if (!configAttr) {
configAttr = bind.getConfigAttr();
continue;
}
if (configAttr != bind.getConfigAttr()) {
op.emitWarning("alloc has multiple bind_tile users with different configs; "
"using the first one");
break;
}
}

LogicalResult
AllocTileOpToPointerCastOpPattern::matchAndRewrite(pto::AllocTileOp op,
PatternRewriter &rewriter) const {
// Manual-address alloc_tile is already fully bound and must not be remapped.
if (op.getAddr())
return failure();

auto tileType = dyn_cast<pto::TileBufType>(op.getResult().getType());
if (!tileType)
return failure();

// Keep config from the tile descriptor so lowering can generate the exact
// Tile<...> type token (layout/fractal/pad) without memref-side recovery.
TileBufConfigAttr configAttr = tileType.getConfigAttr();

constexpr uint64_t kAlign = 4096;
auto iter = buffer2Offsets.find(op.getResult());

Expand All @@ -55,30 +47,31 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite(
offsets = iter->second;

if (offsets.empty()) {
// Estimate buffer size (best-effort). Most PTO tile buffers are 32x32 and
// naturally align to 4096 bytes.
// Estimate tile size in bytes using the static tile descriptor.
uint64_t bytes = kAlign;
if (auto memrefTy = dyn_cast<MemRefType>(currentMemRefType)) {
uint64_t elemBytes = 0;
Type elemTy = memrefTy.getElementType();
if (elemTy.isF16()) elemBytes = 2;
else if (elemTy.isF32()) elemBytes = 4;
else if (auto it = dyn_cast<IntegerType>(elemTy)) elemBytes = it.getWidth() / 8;

if (elemBytes != 0) {
uint64_t numel = 1;
bool allStatic = true;
for (int64_t d : memrefTy.getShape()) {
if (d == ShapedType::kDynamic) {
allStatic = false;
break;
}
numel *= static_cast<uint64_t>(d);
uint64_t elemBytes = 0;
Type elemTy = tileType.getElementType();
if (elemTy.isF16() || elemTy.isBF16())
elemBytes = 2;
else if (elemTy.isF32())
elemBytes = 4;
else if (auto it = dyn_cast<IntegerType>(elemTy))
elemBytes = it.getWidth() / 8;

if (elemBytes != 0) {
uint64_t numel = 1;
bool allStatic = true;
for (int64_t d : tileType.getShape()) {
if (d == ShapedType::kDynamic) {
allStatic = false;
break;
}
if (allStatic && numel != 0)
bytes = numel * elemBytes;
numel *= static_cast<uint64_t>(d);
}
if (allStatic && numel != 0)
bytes = numel * elemBytes;
}

uint64_t stride = ((bytes + kAlign - 1) / kAlign) * kAlign;
uint64_t off = fallbackNextOffset;
fallbackNextOffset += std::max<uint64_t>(stride, kAlign);
Expand All @@ -93,34 +86,34 @@ LogicalResult MemrefAllocaOpToPointerCastOpPattern::matchAndRewrite(
addrs.push_back(constantIntOffsetOp);
}

// [修改 1] 从 ValueRange 中拆解出 row 和 col
// memref.alloc 的 getDynamicSizes() 返回的是变长列表。
// 既然我们只支持 2D Tile,且如果是动态 shape 通常两个维度都是动态的 (?x?),
// 我们直接按顺序提取。
// Preserve valid-shape contract:
// - dynamic valid dims: forward alloc_tile operands
// - static valid dims: materialize constants from TileBufType
// This keeps semantics identical to alloc_tile across PlanMemory rewrite.
Value vRow, vCol;
auto dynSizes = op.getDynamicSizes();

if (dynSizes.size() >= 2) {
vRow = dynSizes[0];
vCol = dynSizes[1];
} else if (dynSizes.size() == 1) {
// 极其罕见的混合情况 (例如 32x?),视具体需求处理,这里默认取第一个
// 或者根据维度索引判断是 row 还是 col,这里暂时从简
vCol = dynSizes[0];
vRow = op.getValidRow();
vCol = op.getValidCol();
auto validShape = tileType.getValidShape();
if (validShape.size() >= 2) {
auto indexType = rewriter.getIndexType();
Location loc = op.getLoc();
if (!vRow && validShape[0] >= 0) {
vRow = rewriter.create<arith::ConstantOp>(
loc, indexType, rewriter.getIndexAttr(validShape[0]));
}
if (!vCol && validShape[1] >= 0) {
vCol = rewriter.create<arith::ConstantOp>(
loc, indexType, rewriter.getIndexAttr(validShape[1]));
}
}

// [修改 2] 调用新的 Builder 签名
// 1. ValueRange(addrs) -> 传递物理地址列表
// 2. vRow ? vRow : Value() -> 传递 Value 对象(如果为空则传空 Value)
// 3. TileBufConfigAttr() -> 传递空 Attribute 对象 (不能传 nullptr)

// Build tile-native pointer_cast with assigned physical address.
auto ptoPointerCastOp = rewriter.create<pto::PointerCastOp>(
op.getLoc(),
currentMemRefType,
op.getLoc(), tileType,
ValueRange(addrs), // addrs
vRow ? vRow : Value(), // valid_row
vCol ? vCol : Value(), // valid_col
configAttr // preserve bind_tile config when available
configAttr // config from tile descriptor
);

rewriter.replaceOp(op, ptoPointerCastOp->getResults());
Expand Down
15 changes: 7 additions & 8 deletions lib/PTO/Transforms/AllocToPointerCast.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AllocToPointerCast.h --Convert memref.AllocOp to pto.pointercastOp-===//
//===- AllocToPointerCast.h --Convert pto.alloc_tile to pto.pointer_cast ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -8,26 +8,25 @@
#define LLVM_PROJECT_ALLOCTOPOINTERCAST_H
#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallSet.h"

namespace mlir {
namespace pto {
class MemrefAllocaOpToPointerCastOpPattern
: public OpRewritePattern<memref::AllocOp> {
class AllocTileOpToPointerCastOpPattern
: public OpRewritePattern<pto::AllocTileOp> {
public:
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
using OpRewritePattern<pto::AllocTileOp>::OpRewritePattern;

/// map from buffer to its allocated addresses
/// note: the buffer which does multibuffer n optimization will be allocated n
/// addresses.
DenseMap<Value, SmallVector<uint64_t>> buffer2Offsets;
mutable uint64_t fallbackNextOffset = 0;

explicit MemrefAllocaOpToPointerCastOpPattern(
explicit AllocTileOpToPointerCastOpPattern(
MLIRContext *context,
DenseMap<Value, SmallVector<uint64_t>> buffer2Offsets)
: OpRewritePattern<memref::AllocOp>(context),
: OpRewritePattern<pto::AllocTileOp>(context),
buffer2Offsets(std::move(buffer2Offsets)) {
// Seed fallback offsets above any known planned offsets to reduce collisions.
constexpr uint64_t kAlign = 4096;
Expand All @@ -38,7 +37,7 @@ class MemrefAllocaOpToPointerCastOpPattern
}
fallbackNextOffset = ((maxOff + kAlign - 1) / kAlign) * kAlign;
}
LogicalResult matchAndRewrite(memref::AllocOp op,
LogicalResult matchAndRewrite(pto::AllocTileOp op,
PatternRewriter &rewriter) const final;
};

Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_dialect_library(PTOTransforms
PTOViewToMemref.cpp
PTOToEmitC.cpp
Utils.cpp
TileBufferSemantics.cpp
OptMemPlanForPipeline.cpp
AllocToPointerCast.cpp
InferPTOMemScope.cpp
Expand Down
21 changes: 15 additions & 6 deletions lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,24 +252,33 @@ LogicalResult PTOIRTranslator::UpdateAllocTileOpMemInfo(pto::AllocTileOp op) {

LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) {
Value res = op.getResult();
auto memRefType = dyn_cast<MemRefType>(res.getType());
if (!memRefType) return failure();
auto tileType = dyn_cast<pto::TileBufType>(res.getType());
if (!tileType)
return failure();

if (op.getAddrs().empty()) {
return op.emitError("PointerCast must have at least one address operand");
}
Value rootSrc = op.getAddrs().front();

uint64_t sizeInBytes = 0;
if (memRefType.hasStaticShape()) {
int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8;
bool isStatic = true;
for (auto dim : tileType.getShape()) {
if (dim == ShapedType::kDynamic) {
isStatic = false;
break;
}
}
if (isStatic) {
int64_t elemSize = tileType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t numElements = 1;
for (auto dim : memRefType.getShape()) numElements *= dim;
for (auto dim : tileType.getShape())
numElements *= dim;
sizeInBytes = numElements * elemSize;
}

pto::AddressSpace space = pto::AddressSpace::GM;
if (auto attr = memRefType.getMemorySpace()) {
if (auto attr = tileType.getMemorySpace()) {
if (auto ptoAttr = dyn_cast<pto::AddressSpaceAttr>(attr)) {
space = ptoAttr.getAddressSpace();
}
Expand Down
Loading