Skip to content

Commit f7801ff

Browse files
committed
feat: add compact mode to tile config
Signed-off-by: FangRui <fangrui_95@163.com>
1 parent ac41002 commit f7801ff

20 files changed

Lines changed: 299 additions & 52 deletions

docs/PTO_IR_manual.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,9 @@ Composite attribute and component enums for tile buffer configuration.
201201
| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor) |
202202
| `sFractalSize` | `IntegerAttr (i32)` | Secondary fractal size |
203203
| `pad` | `PadValueAttr` | Pad value policy |
204+
| `compact` | `CompactMode` mnemonic or integer literal | Tile compact mode (`null` / `normal`), default `null` |
204205

205-
**Syntax:** `#pto.tile_buf_config<row_major, none_box, 16, zero>`
206+
**Syntax:** `#pto.tile_buf_config<row_major, none_box, 16, zero, null>`
206207

207208
**BLayout** (Base layout):
208209

@@ -228,6 +229,13 @@ Composite attribute and component enums for tile buffer configuration.
228229
| `Max` | 2 | `max` |
229230
| `Min` | 3 | `min` |
230231

232+
**CompactMode** (Tile compact mode):
233+
234+
| Value | Int | Mnemonic |
235+
|-------|-----|----------|
236+
| `Null` | 0 | `null` |
237+
| `Normal` | 1 | `normal` |
238+
231239
---
232240

233241
### 3.5 Layout

include/PTO/IR/PTOAttrs.td

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,25 @@ def PTO_PadValueAttr : PTO_Attr<"PadValue", "pad_value"> {
425425
let assemblyFormat = "`<` params `>`";
426426
}
427427

428+
def PTO_CompactMode_Enum : PTO_I32Enum<"CompactMode", "Tile compact mode", [
429+
I32EnumAttrCase<"Null", 0, "null">,
430+
I32EnumAttrCase<"Normal", 1, "normal">
431+
]>;
432+
433+
def PTO_CompactModeAttr : PTO_Attr<"CompactMode", "compact_mode"> {
434+
let parameters = (ins EnumParameter<PTO_CompactMode_Enum>:$value);
435+
let assemblyFormat = "`<` params `>`";
436+
}
437+
428438
// ---------- tile_buf_config (NO b_fractal / s_fractal) ----------
429439
def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
430440
let mnemonic = "tile_buf_config";
431441
let parameters = (ins
432442
"BLayoutAttr":$bLayout,
433443
"SLayoutAttr":$sLayout,
434444
"mlir::IntegerAttr":$sFractalSize, // i32
435-
"PadValueAttr":$pad
445+
"PadValueAttr":$pad,
446+
"CompactModeAttr":$compactMode
436447
);
437448

438449
let hasCustomAssemblyFormat = 1;
@@ -442,19 +453,26 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
442453
"BLayoutAttr":$bLayout,
443454
"SLayoutAttr":$sLayout,
444455
"mlir::IntegerAttr":$sFractalSize,
445-
"PadValueAttr":$pad
456+
"PadValueAttr":$pad,
457+
"CompactModeAttr":$compactMode
446458
)>
447459
];
448460

449461
let extraClassDeclaration = [{
462+
static TileBufConfigAttr get(MLIRContext *ctx,
463+
BLayoutAttr bLayout,
464+
SLayoutAttr sLayout,
465+
mlir::IntegerAttr sFractalSize,
466+
PadValueAttr pad);
450467
static TileBufConfigAttr getDefault(MLIRContext *ctx);
451468
bool isDefault() const;
452469

453470
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
454471
mlir::Attribute bLayout,
455472
mlir::Attribute sLayout,
456473
mlir::IntegerAttr sFractalSize,
457-
mlir::Attribute pad);
474+
mlir::Attribute pad,
475+
mlir::Attribute compactMode);
458476
}];
459477
}
460478

include/PTO/IR/PTOTypeDefs.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,13 @@ def TileBufType : TypeDef<PTO_Dialect, "TileBuf"> {
200200
mlir::Attribute getSLayoutAttr() const;
201201
int32_t getSFractalSizeI32() const;
202202
mlir::Attribute getPadValueAttr() const;
203+
mlir::Attribute getCompactModeAttr() const;
203204

204205
// 如果你仍然想要“数值枚举”,就提供 int getter(不会依赖 enum 类型)
205206
int32_t getBLayoutValueI32() const; // 0 row_major, 1 col_major
206207
int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major
207208
int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min
209+
int32_t getCompactModeI32() const; // 0 null, 1 normal
208210
}];
209211
}
210212

include/pto-c/Dialect/PTO.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOSLayoutAttrGetValue(MlirAttribute attr);
8383
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAPadValueAttr(MlirAttribute attr);
8484
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOPadValueAttrGet(MlirContext ctx, int32_t value);
8585
MLIR_CAPI_EXPORTED int32_t mlirPTOPadValueAttrGetValue(MlirAttribute attr);
86+
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr);
87+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value);
88+
MLIR_CAPI_EXPORTED int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr);
8689
MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value);
8790
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr);
8891
MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr);
@@ -148,6 +151,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGet(
148151
MlirContext ctx,
149152
MlirAttribute bLayout, MlirAttribute sLayout,
150153
MlirAttribute sFractalSize, MlirAttribute pad);
154+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode(
155+
MlirContext ctx,
156+
MlirAttribute bLayout, MlirAttribute sLayout,
157+
MlirAttribute sFractalSize, MlirAttribute pad,
158+
MlirAttribute compactMode);
151159
MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithValidShape(
152160
MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType,
153161
MlirAttribute memorySpace, intptr_t validRank, const int64_t *validShape);

lib/Bindings/Python/PTOModule.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ PYBIND11_MODULE(_pto, m) {
9595
.value("Max", mlir::pto::PadValue::Max)
9696
.value("Min", mlir::pto::PadValue::Min);
9797

98+
py::enum_<mlir::pto::CompactMode>(m, "CompactMode")
99+
.value("Null", mlir::pto::CompactMode::Null)
100+
.value("Normal", mlir::pto::CompactMode::Normal);
101+
98102
py::enum_<mlir::pto::RoundMode>(m, "RoundMode")
99103
.value("NONE", mlir::pto::RoundMode::NONE)
100104
.value("RINT", mlir::pto::RoundMode::RINT)
@@ -210,6 +214,19 @@ PYBIND11_MODULE(_pto, m) {
210214
return cls(a);
211215
},
212216
py::arg("cls"), py::arg("value"), py::arg("context") = py::none());
217+
218+
mlir_attribute_subclass(m, "CompactModeAttr",
219+
[](MlirAttribute a) -> bool {
220+
return mlirPTOAttrIsACompactModeAttr(a);
221+
})
222+
.def_classmethod(
223+
"get",
224+
[](py::object cls, mlir::pto::CompactMode value, MlirContext ctx) -> py::object {
225+
MlirAttribute a = mlirPTOCompactModeAttrGet(ctx, static_cast<int32_t>(value));
226+
if (mlirAttributeIsNull(a)) return py::none();
227+
return cls(a);
228+
},
229+
py::arg("cls"), py::arg("value"), py::arg("context") = py::none());
213230
// [保留 HEAD]: AddressSpaceAttr 定义
214231
mlir_attribute_subclass(
215232
m, "AddressSpaceAttr",
@@ -571,11 +588,25 @@ PYBIND11_MODULE(_pto, m) {
571588
MlirAttribute slayout,
572589
int32_t s_fractal_size,
573590
MlirAttribute pad,
574-
MlirContext ctx) -> py::object {
591+
MlirContext ctx,
592+
py::object compactModeObj) -> py::object {
575593
MlirType i32 = mlirIntegerTypeGet(ctx, 32);
576594
MlirAttribute sz = mlirIntegerAttrGet(i32, s_fractal_size);
577-
578-
MlirAttribute a = mlirPTOTileBufConfigAttrGet(ctx, blayout, slayout, sz, pad);
595+
MlirAttribute compactMode = mlirPTOCompactModeAttrGet(
596+
ctx, static_cast<int32_t>(mlir::pto::CompactMode::Null));
597+
if (!compactModeObj.is_none()) {
598+
if (py::isinstance<py::int_>(compactModeObj)) {
599+
compactMode = mlirPTOCompactModeAttrGet(
600+
ctx, compactModeObj.cast<int32_t>());
601+
} else if (py::hasattr(compactModeObj, "value")) {
602+
compactMode = mlirPTOCompactModeAttrGet(
603+
ctx, compactModeObj.attr("value").cast<int32_t>());
604+
} else {
605+
compactMode = compactModeObj.cast<MlirAttribute>();
606+
}
607+
}
608+
MlirAttribute a = mlirPTOTileBufConfigAttrGetWithCompactMode(
609+
ctx, blayout, slayout, sz, pad, compactMode);
579610
if (mlirAttributeIsNull(a)) return py::none();
580611
return cls(a);
581612
},
@@ -584,7 +615,8 @@ PYBIND11_MODULE(_pto, m) {
584615
py::arg("slayout"),
585616
py::arg("s_fractal_size"),
586617
py::arg("pad"),
587-
py::arg("context") = py::none());
618+
py::arg("context") = py::none(),
619+
py::arg("compact_mode") = py::none());
588620

589621
// ---- TileBufType ----
590622
mlir_type_subclass(m, "TileBufType",

lib/CAPI/Dialect/PTO.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,24 +499,59 @@ static mlir::pto::PadValueAttr toPadValueAttr(mlir::MLIRContext *c, mlir::Attrib
499499
return mlir::pto::PadValueAttr::get(c, static_cast<mlir::pto::PadValue>(ia.getInt()));
500500
return {};
501501
}
502+
static mlir::pto::CompactModeAttr toCompactModeAttr(mlir::MLIRContext *c,
503+
mlir::Attribute a) {
504+
if (auto cm = mlir::dyn_cast<mlir::pto::CompactModeAttr>(a))
505+
return cm;
506+
if (auto ia = mlir::dyn_cast<mlir::IntegerAttr>(a))
507+
return mlir::pto::CompactModeAttr::get(
508+
c, static_cast<mlir::pto::CompactMode>(ia.getInt()));
509+
return {};
510+
}
511+
512+
bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr) {
513+
return unwrap(attr).isa<mlir::pto::CompactModeAttr>();
514+
}
515+
516+
MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value) {
517+
auto *c = unwrap(ctx);
518+
return wrap(mlir::pto::CompactModeAttr::get(
519+
c, static_cast<mlir::pto::CompactMode>(value)));
520+
}
521+
522+
int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr) {
523+
auto a = mlir::cast<mlir::pto::CompactModeAttr>(unwrap(attr));
524+
return static_cast<int32_t>(a.getValue());
525+
}
502526

503527
MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx,
504528
MlirAttribute bLayout,
505529
MlirAttribute sLayout,
506530
MlirAttribute sFractalSize,
507531
MlirAttribute pad) {
508532
auto *c = unwrap(ctx);
533+
auto compactMode =
534+
wrap(mlir::pto::CompactModeAttr::get(c, mlir::pto::CompactMode::Null));
535+
return mlirPTOTileBufConfigAttrGetWithCompactMode(
536+
ctx, bLayout, sLayout, sFractalSize, pad, compactMode);
537+
}
538+
539+
MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode(
540+
MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout,
541+
MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode) {
542+
auto *c = unwrap(ctx);
509543
auto blA = toBLayoutAttr(c, unwrap(bLayout));
510544
auto slA = toSLayoutAttr(c, unwrap(sLayout));
511545
auto pvA = toPadValueAttr(c, unwrap(pad));
512-
if (!blA || !slA || !pvA)
546+
auto cmA = toCompactModeAttr(c, unwrap(compactMode));
547+
if (!blA || !slA || !pvA || !cmA)
513548
return MlirAttribute{nullptr};
514549

515550
auto sz = mlir::dyn_cast<mlir::IntegerAttr>(unwrap(sFractalSize));
516551
if (!sz || !sz.getType().isInteger(32))
517552
return MlirAttribute{nullptr};
518553

519-
return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA));
554+
return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA, cmA));
520555
}
521556

522557
MlirType mlirPTOGMTypeGet(MlirContext ctx, intptr_t rank, const int64_t *shape,

lib/PTO/IR/PTO.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5587,6 +5587,24 @@ mlir::LogicalResult mlir::pto::TReshapeOp::verify() {
55875587
if (srcBoxed != dstBoxed)
55885588
return emitOpError("cannot reshape between boxed and non-boxed tile layouts");
55895589

5590+
auto getCompactModeValue = [&](TileBufType tbTy) -> int32_t {
5591+
auto cfg = tbTy.getConfigAttr();
5592+
if (!cfg)
5593+
cfg = TileBufConfigAttr::getDefault(getContext());
5594+
5595+
Attribute compact = cfg.getCompactMode();
5596+
if (auto mode = dyn_cast<CompactModeAttr>(compact))
5597+
return static_cast<int32_t>(mode.getValue());
5598+
if (auto mode = dyn_cast<IntegerAttr>(compact))
5599+
return static_cast<int32_t>(mode.getInt());
5600+
return static_cast<int32_t>(pto::CompactMode::Null);
5601+
};
5602+
5603+
if (getCompactModeValue(srcTb) != getCompactModeValue(dstTb))
5604+
return emitOpError(
5605+
"cannot reshape between different compact modes because they imply "
5606+
"different physical storage layouts");
5607+
55905608
return success();
55915609
}
55925610

0 commit comments

Comments
 (0)