Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,26 @@ def PTO_PadValueAttr : PTO_Attr<"PadValue", "pad_value"> {
let assemblyFormat = "`<` params `>`";
}

def PTO_CompactMode_Enum : PTO_I32Enum<"CompactMode", "Tile compact mode", [
I32EnumAttrCase<"Null", 0, "null">,
I32EnumAttrCase<"Normal", 1, "normal">,
I32EnumAttrCase<"RowPlusOne", 2, "row_plus_one">
]>;

def PTO_CompactModeAttr : PTO_Attr<"CompactMode", "compact_mode"> {
let parameters = (ins EnumParameter<PTO_CompactMode_Enum>:$value);
let assemblyFormat = "`<` params `>`";
}

// ---------- tile_buf_config (NO b_fractal / s_fractal) ----------
def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
let mnemonic = "tile_buf_config";
let parameters = (ins
"BLayoutAttr":$bLayout,
"SLayoutAttr":$sLayout,
"mlir::IntegerAttr":$sFractalSize, // i32
"PadValueAttr":$pad
"PadValueAttr":$pad,
"CompactModeAttr":$compactMode
);

let hasCustomAssemblyFormat = 1;
Expand All @@ -442,7 +454,8 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
"BLayoutAttr":$bLayout,
"SLayoutAttr":$sLayout,
"mlir::IntegerAttr":$sFractalSize,
"PadValueAttr":$pad
"PadValueAttr":$pad,
"CompactModeAttr":$compactMode
)>
];

Expand All @@ -454,7 +467,8 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
mlir::Attribute bLayout,
mlir::Attribute sLayout,
mlir::IntegerAttr sFractalSize,
mlir::Attribute pad);
mlir::Attribute pad,
mlir::Attribute compactMode);
}];
}

Expand Down
2 changes: 2 additions & 0 deletions include/PTO/IR/PTOTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,13 @@ def TileBufType : TypeDef<PTO_Dialect, "TileBuf"> {
mlir::Attribute getSLayoutAttr() const;
int32_t getSFractalSizeI32() const;
mlir::Attribute getPadValueAttr() const;
mlir::Attribute getCompactModeAttr() const;

// 如果你仍然想要“数值枚举”,就提供 int getter(不会依赖 enum 类型)
int32_t getBLayoutValueI32() const; // 0 row_major, 1 col_major
int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major
int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min
int32_t getCompactModeI32() const; // 0 null, 1 normal, 2 row_plus_one
}];
}

Expand Down
8 changes: 8 additions & 0 deletions include/pto-c/Dialect/PTO.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOSLayoutAttrGetValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAPadValueAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOPadValueAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED int32_t mlirPTOPadValueAttrGetValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value);
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr);
Expand Down Expand Up @@ -148,6 +151,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGet(
MlirContext ctx,
MlirAttribute bLayout, MlirAttribute sLayout,
MlirAttribute sFractalSize, MlirAttribute pad);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode(
MlirContext ctx,
MlirAttribute bLayout, MlirAttribute sLayout,
MlirAttribute sFractalSize, MlirAttribute pad,
MlirAttribute compactMode);
MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithValidShape(
MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType,
MlirAttribute memorySpace, intptr_t validRank, const int64_t *validShape);
Expand Down
41 changes: 37 additions & 4 deletions lib/Bindings/Python/PTOModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ PYBIND11_MODULE(_pto, m) {
.value("Max", mlir::pto::PadValue::Max)
.value("Min", mlir::pto::PadValue::Min);

py::enum_<mlir::pto::CompactMode>(m, "CompactMode")
.value("Null", mlir::pto::CompactMode::Null)
.value("Normal", mlir::pto::CompactMode::Normal)
.value("RowPlusOne", mlir::pto::CompactMode::RowPlusOne);

py::enum_<mlir::pto::RoundMode>(m, "RoundMode")
.value("NONE", mlir::pto::RoundMode::NONE)
.value("RINT", mlir::pto::RoundMode::RINT)
Expand Down Expand Up @@ -210,6 +215,19 @@ PYBIND11_MODULE(_pto, m) {
return cls(a);
},
py::arg("cls"), py::arg("value"), py::arg("context") = py::none());

mlir_attribute_subclass(m, "CompactModeAttr",
[](MlirAttribute a) -> bool {
return mlirPTOAttrIsACompactModeAttr(a);
})
.def_classmethod(
"get",
[](py::object cls, mlir::pto::CompactMode value, MlirContext ctx) -> py::object {
MlirAttribute a = mlirPTOCompactModeAttrGet(ctx, static_cast<int32_t>(value));
if (mlirAttributeIsNull(a)) return py::none();
return cls(a);
},
py::arg("cls"), py::arg("value"), py::arg("context") = py::none());
// [保留 HEAD]: AddressSpaceAttr 定义
mlir_attribute_subclass(
m, "AddressSpaceAttr",
Expand Down Expand Up @@ -601,11 +619,25 @@ PYBIND11_MODULE(_pto, m) {
MlirAttribute slayout,
int32_t s_fractal_size,
MlirAttribute pad,
MlirContext ctx) -> py::object {
MlirContext ctx,
py::object compactModeObj) -> py::object {
MlirType i32 = mlirIntegerTypeGet(ctx, 32);
MlirAttribute sz = mlirIntegerAttrGet(i32, s_fractal_size);

MlirAttribute a = mlirPTOTileBufConfigAttrGet(ctx, blayout, slayout, sz, pad);
MlirAttribute compactMode = mlirPTOCompactModeAttrGet(
ctx, static_cast<int32_t>(mlir::pto::CompactMode::Null));
if (!compactModeObj.is_none()) {
if (py::isinstance<py::int_>(compactModeObj)) {
compactMode = mlirPTOCompactModeAttrGet(
ctx, compactModeObj.cast<int32_t>());
} else if (py::hasattr(compactModeObj, "value")) {
compactMode = mlirPTOCompactModeAttrGet(
ctx, compactModeObj.attr("value").cast<int32_t>());
} else {
compactMode = compactModeObj.cast<MlirAttribute>();
}
}
MlirAttribute a = mlirPTOTileBufConfigAttrGetWithCompactMode(
ctx, blayout, slayout, sz, pad, compactMode);
if (mlirAttributeIsNull(a)) return py::none();
return cls(a);
},
Expand All @@ -614,7 +646,8 @@ PYBIND11_MODULE(_pto, m) {
py::arg("slayout"),
py::arg("s_fractal_size"),
py::arg("pad"),
py::arg("context") = py::none());
py::arg("context") = py::none(),
py::arg("compact_mode") = py::none());

// ---- TileBufType ----
mlir_type_subclass(m, "TileBufType",
Expand Down
39 changes: 37 additions & 2 deletions lib/CAPI/Dialect/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,24 +514,59 @@ static mlir::pto::PadValueAttr toPadValueAttr(mlir::MLIRContext *c, mlir::Attrib
return mlir::pto::PadValueAttr::get(c, static_cast<mlir::pto::PadValue>(ia.getInt()));
return {};
}
static mlir::pto::CompactModeAttr toCompactModeAttr(mlir::MLIRContext *c,
mlir::Attribute a) {
if (auto cm = mlir::dyn_cast<mlir::pto::CompactModeAttr>(a))
return cm;
if (auto ia = mlir::dyn_cast<mlir::IntegerAttr>(a))
return mlir::pto::CompactModeAttr::get(
c, static_cast<mlir::pto::CompactMode>(ia.getInt()));
return {};
}

bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::pto::CompactModeAttr>();
}

MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value) {
auto *c = unwrap(ctx);
return wrap(mlir::pto::CompactModeAttr::get(
c, static_cast<mlir::pto::CompactMode>(value)));
}

int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr) {
auto a = mlir::cast<mlir::pto::CompactModeAttr>(unwrap(attr));
return static_cast<int32_t>(a.getValue());
}

MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx,
MlirAttribute bLayout,
MlirAttribute sLayout,
MlirAttribute sFractalSize,
MlirAttribute pad) {
auto *c = unwrap(ctx);
auto compactMode =
wrap(mlir::pto::CompactModeAttr::get(c, mlir::pto::CompactMode::Null));
return mlirPTOTileBufConfigAttrGetWithCompactMode(
ctx, bLayout, sLayout, sFractalSize, pad, compactMode);
}

MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode(
MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout,
MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode) {
auto *c = unwrap(ctx);
auto blA = toBLayoutAttr(c, unwrap(bLayout));
auto slA = toSLayoutAttr(c, unwrap(sLayout));
auto pvA = toPadValueAttr(c, unwrap(pad));
if (!blA || !slA || !pvA)
auto cmA = toCompactModeAttr(c, unwrap(compactMode));
if (!blA || !slA || !pvA || !cmA)
return MlirAttribute{nullptr};

auto sz = mlir::dyn_cast<mlir::IntegerAttr>(unwrap(sFractalSize));
if (!sz || !sz.getType().isInteger(32))
return MlirAttribute{nullptr};

return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA));
return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA, cmA));
}

MlirType mlirPTOGMTypeGet(MlirContext ctx, intptr_t rank, const int64_t *shape,
Expand Down
35 changes: 30 additions & 5 deletions lib/PTO/IR/PTOAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,25 @@ TileBufConfigAttr TileBufConfigAttr::getDefault(MLIRContext *ctx) {
BLayoutAttr bl = BLayoutAttr::get(ctx, BLayout::RowMajor);
SLayoutAttr sl = SLayoutAttr::get(ctx, SLayout::NoneBox);
PadValueAttr pv = PadValueAttr::get(ctx, PadValue::Null);
CompactModeAttr compact = CompactModeAttr::get(ctx, CompactMode::Null);
IntegerAttr sz = b.getI32IntegerAttr(512);
return TileBufConfigAttr::get(ctx, bl, sl, sz, pv);
return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact);
}

bool TileBufConfigAttr::isDefault() const {
auto d = getDefault(getContext());
return getBLayout() == d.getBLayout() &&
getSLayout() == d.getSLayout() &&
getSFractalSize() == d.getSFractalSize() &&
getPad() == d.getPad();
getPad() == d.getPad() &&
getCompactMode() == d.getCompactMode();
}

static int32_t getLayoutInt(Attribute a, int32_t def) {
if (auto bl = mlir::dyn_cast<BLayoutAttr>(a)) return static_cast<int32_t>(bl.getValue());
if (auto sl = mlir::dyn_cast<SLayoutAttr>(a)) return static_cast<int32_t>(sl.getValue());
if (auto pv = mlir::dyn_cast<PadValueAttr>(a)) return static_cast<int32_t>(pv.getValue());
if (auto cm = mlir::dyn_cast<CompactModeAttr>(a)) return static_cast<int32_t>(cm.getValue());
if (auto ia = mlir::dyn_cast<IntegerAttr>(a)) return static_cast<int32_t>(ia.getInt());
return def;
}
Expand All @@ -53,13 +56,18 @@ LogicalResult TileBufConfigAttr::verify(function_ref<InFlightDiagnostic()> emitE
Attribute bLayout,
Attribute sLayout,
IntegerAttr sFractalSize,
Attribute pad) {
Attribute pad,
Attribute compactMode) {
if (!bLayout || (!mlir::isa<BLayoutAttr>(bLayout) && !mlir::isa<IntegerAttr>(bLayout)))
return emitError() << "blayout must be BLayoutAttr or i32 integer attr", failure();
if (!sLayout || (!mlir::isa<SLayoutAttr>(sLayout) && !mlir::isa<IntegerAttr>(sLayout)))
return emitError() << "slayout must be SLayoutAttr or i32 integer attr", failure();
if (!pad || (!mlir::isa<PadValueAttr>(pad) && !mlir::isa<IntegerAttr>(pad)))
return emitError() << "pad must be PadValueAttr or i32 integer attr", failure();
if (!compactMode ||
(!mlir::isa<CompactModeAttr>(compactMode) &&
!mlir::isa<IntegerAttr>(compactMode)))
return emitError() << "compact_mode must be CompactModeAttr or i32 integer attr", failure();

if (!sFractalSize || !sFractalSize.getType().isInteger(32))
return emitError() << "s_fractal_size must be i32", failure();
Expand All @@ -80,6 +88,10 @@ LogicalResult TileBufConfigAttr::verify(function_ref<InFlightDiagnostic()> emitE
if (pvv < 0 || pvv > 3)
return emitError() << "unsupported pad value: " << pvv, failure();

int32_t cmv = getLayoutInt(compactMode, -1);
if (cmv < 0 || cmv > 2)
return emitError() << "unsupported compact_mode value: " << cmv, failure();

return success();
}

Expand All @@ -99,6 +111,12 @@ static PadValueAttr toPadValueAttr(MLIRContext *ctx, Attribute a) {
if (auto ia = mlir::dyn_cast<IntegerAttr>(a)) return PadValueAttr::get(ctx, static_cast<PadValue>(ia.getInt()));
return {};
}
static CompactModeAttr toCompactModeAttr(MLIRContext *ctx, Attribute a) {
if (auto cm = mlir::dyn_cast<CompactModeAttr>(a)) return cm;
if (auto ia = mlir::dyn_cast<IntegerAttr>(a))
return CompactModeAttr::get(ctx, static_cast<CompactMode>(ia.getInt()));
return {};
}

Attribute TileBufConfigAttr::parse(AsmParser &p, Type) {
MLIRContext *ctx = p.getContext();
Expand All @@ -107,11 +125,12 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) {
SLayoutAttr sl = def.getSLayout();
IntegerAttr sz = def.getSFractalSize();
PadValueAttr pv = def.getPad();
CompactModeAttr compact = def.getCompactMode();

if (p.parseLess()) return {};

if (succeeded(p.parseOptionalGreater()))
return TileBufConfigAttr::get(ctx, bl, sl, sz, pv);
return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact);

while (true) {
StringRef key;
Expand All @@ -137,6 +156,11 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) {
if (p.parseAttribute(a)) return {};
pv = toPadValueAttr(ctx, a);
if (!pv) return {};
} else if (key == "compact") {
Attribute a;
if (p.parseAttribute(a)) return {};
compact = toCompactModeAttr(ctx, a);
if (!compact) return {};
} else {
p.emitError(p.getCurrentLocation(), "unknown key in tile_buf_config: ") << key;
return {};
Expand All @@ -147,7 +171,7 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) {
if (p.parseComma()) return {};
}

return TileBufConfigAttr::get(ctx, bl, sl, sz, pv);
return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact);
}

void TileBufConfigAttr::print(AsmPrinter &p) const {
Expand All @@ -156,5 +180,6 @@ void TileBufConfigAttr::print(AsmPrinter &p) const {
p << ", slayout=" << getSLayout();
p << ", s_fractal_size=" << (int32_t)getSFractalSize().getInt();
p << ", pad=" << getPad();
p << ", compact=" << getCompactMode();
p << ">";
}
Loading
Loading