Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions lib/PTO/Transforms/PTOViewToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,21 +617,16 @@ struct PTOViewToMemrefPass
auto configAttr = tbTy.getConfigAttr();
if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx);

// 6. If alloc_tile provides an explicit address, keep the original
// pointer_cast lowering intact and additionally rebind through
// pto.bind_tile. PointerCastOp continues to carry the tile metadata
// used by existing lowering paths, while BindTileOp provides the
// unified anchor EmitC uses to recover tile_buf information.
// 6. If alloc_tile provides an explicit address, lower directly to
// pto.pointer_cast. Rebinding through pto.bind_tile here is redundant
// and can produce an extra tile rewrap in EmitC for dynamic valid
// shapes (double TASSIGN pattern).
if (Value addr = op.getAddr()) {
auto pc = rewriter.create<pto::PointerCastOp>(
loc, targetType, ValueRange{addr}, vRow ? vRow : Value(),
vCol ? vCol : Value(), configAttr);
markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx);
auto bindOp = rewriter.create<pto::BindTileOp>(
loc, targetType, pc.getResult(), vRow ? vRow : Value(),
vCol ? vCol : Value(), configAttr);
markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx);
rewriter.replaceOp(op, bindOp.getResult());
rewriter.replaceOp(op, pc.getResult());
continue;
}

Expand Down
23 changes: 23 additions & 0 deletions test/basic/alloc_tile_addr_dynamic_no_rebind.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: ptoas --pto-level=level3 %s | FileCheck %s

module {
func.func @print_alloc_addr_dyn(%arg0: index, %arg1: index) attributes {pto.entry} {
%c0_i64 = arith.constant 0 : i64

%0 = pto.alloc_tile addr = %c0_i64 valid_row = %arg0 valid_col = %arg1
: !pto.tile_buf<loc=vec, dtype=f16, rows=64, cols=128, v_row=?, v_col=?,
blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tprint ins(%0 : !pto.tile_buf<loc=vec, dtype=f16, rows=64, cols=128,
v_row=?, v_col=?, blayout=row_major,
slayout=none_box, fractal=512, pad=0>)
return
}
}

// CHECK-LABEL: __global__ AICORE void print_alloc_addr_dyn(
// CHECK: Tile<TileType::Vec, half, 64, 128, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null> [[TILE:v[0-9]+]] = Tile<TileType::Vec, half, 64, 128, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null>(
// CHECK: TASSIGN([[TILE]], [[ADDR:v[0-9]+]]);
// CHECK-NOT: .data()
// CHECK-NOT: reinterpret_cast<uint64_t>
// CHECK: TPRINT([[TILE]]);
Loading