From c34a254bca088734e0bc25b40570fffa129e7ef7 Mon Sep 17 00:00:00 2001 From: FangRui Date: Mon, 30 Mar 2026 17:14:16 +0800 Subject: [PATCH] feat: add compact mode to tile config --- include/PTO/IR/PTOAttrs.td | 20 ++++++-- include/PTO/IR/PTOTypeDefs.td | 2 + include/pto-c/Dialect/PTO.h | 8 ++++ lib/Bindings/Python/PTOModule.cpp | 41 ++++++++++++++-- lib/CAPI/Dialect/PTO.cpp | 39 ++++++++++++++- lib/PTO/IR/PTOAttrs.cpp | 35 ++++++++++++-- lib/PTO/IR/PTOTypeDefs.cpp | 47 +++++++++++++++++-- lib/PTO/Transforms/PTOToEmitC.cpp | 34 +++++++++++++- lib/PTO/Transforms/PTOViewToMemref.cpp | 31 ++++++++++-- python/pto/dialects/pto.py | 3 ++ test/basic/set_validshape_if.pto | 2 +- test/basic/set_validshape_local_lowering.pto | 4 +- test/basic/tci_i16_emitc.pto | 4 +- test/basic/tgather_three_forms_emitc.pto | 6 +-- test/basic/tile_compact_mode_emitc.pto | 33 +++++++++++++ test/basic/tprint_alloc_tile_no_rebind.pto | 2 +- test/basic/tpush_tpop_emitc.pto | 6 +-- .../basic/tpush_tpop_frontend_lowering_a3.pto | 22 ++++----- .../basic/tpush_tpop_frontend_lowering_a5.pto | 16 +++---- 19 files changed, 300 insertions(+), 55 deletions(-) create mode 100644 test/basic/tile_compact_mode_emitc.pto diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index b79994d4..21098d07 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -425,6 +425,17 @@ def PTO_PadValueAttr : PTO_Attr<"PadValue", "pad_value"> { let assemblyFormat = "`<` params `>`"; } +def PTO_CompactMode_Enum : PTO_I32Enum<"CompactMode", "Tile compact mode", [ + I32EnumAttrCase<"Null", 0, "null">, + I32EnumAttrCase<"Normal", 1, "normal">, + I32EnumAttrCase<"RowPlusOne", 2, "row_plus_one"> + ]>; + +def PTO_CompactModeAttr : PTO_Attr<"CompactMode", "compact_mode"> { + let parameters = (ins EnumParameter:$value); + let assemblyFormat = "`<` params `>`"; +} + // ---------- tile_buf_config (NO b_fractal / s_fractal) ---------- def TileBufConfigAttr : AttrDef { let mnemonic = "tile_buf_config"; @@ -432,7 +443,8 @@ def TileBufConfigAttr : AttrDef { "BLayoutAttr":$bLayout, "SLayoutAttr":$sLayout, "mlir::IntegerAttr":$sFractalSize, // i32 - "PadValueAttr":$pad + "PadValueAttr":$pad, + "CompactModeAttr":$compactMode ); let hasCustomAssemblyFormat = 1; @@ -442,7 +454,8 @@ def TileBufConfigAttr : AttrDef { "BLayoutAttr":$bLayout, "SLayoutAttr":$sLayout, "mlir::IntegerAttr":$sFractalSize, - "PadValueAttr":$pad + "PadValueAttr":$pad, + "CompactModeAttr":$compactMode )> ]; @@ -454,7 +467,8 @@ def TileBufConfigAttr : AttrDef { mlir::Attribute bLayout, mlir::Attribute sLayout, mlir::IntegerAttr sFractalSize, - mlir::Attribute pad); + mlir::Attribute pad, + mlir::Attribute compactMode); }]; } diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 3e507fa2..c04bd209 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -200,11 +200,13 @@ def TileBufType : TypeDef { mlir::Attribute getSLayoutAttr() const; int32_t getSFractalSizeI32() const; mlir::Attribute getPadValueAttr() const; + mlir::Attribute getCompactModeAttr() const; // 如果你仍然想要“数值枚举”,就提供 int getter(不会依赖 enum 类型) int32_t getBLayoutValueI32() const; // 0 row_major, 1 col_major int32_t getSLayoutValueI32() const; // 0 none_box, 1 row_major, 2 col_major int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min + int32_t getCompactModeI32() const; // 0 null, 1 normal, 2 row_plus_one }]; } diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index 907ff1dd..7b0cdada 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -83,6 +83,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOSLayoutAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAPadValueAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOPadValueAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED int32_t mlirPTOPadValueAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr); @@ -148,6 +151,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGet( MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout, MlirAttribute sFractalSize, MlirAttribute pad); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode( + MlirContext ctx, + MlirAttribute bLayout, MlirAttribute sLayout, + MlirAttribute sFractalSize, MlirAttribute pad, + MlirAttribute compactMode); MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithValidShape( MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute memorySpace, intptr_t validRank, const int64_t *validShape); diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index 860e3619..df7c1cf3 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -95,6 +95,11 @@ PYBIND11_MODULE(_pto, m) { .value("Max", mlir::pto::PadValue::Max) .value("Min", mlir::pto::PadValue::Min); + py::enum_(m, "CompactMode") + .value("Null", mlir::pto::CompactMode::Null) + .value("Normal", mlir::pto::CompactMode::Normal) + .value("RowPlusOne", mlir::pto::CompactMode::RowPlusOne); + py::enum_(m, "RoundMode") .value("NONE", mlir::pto::RoundMode::NONE) .value("RINT", mlir::pto::RoundMode::RINT) @@ -210,6 +215,19 @@ PYBIND11_MODULE(_pto, m) { return cls(a); }, py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); + + mlir_attribute_subclass(m, "CompactModeAttr", + [](MlirAttribute a) -> bool { + return mlirPTOAttrIsACompactModeAttr(a); + }) + .def_classmethod( + "get", + [](py::object cls, mlir::pto::CompactMode value, MlirContext ctx) -> py::object { + MlirAttribute a = mlirPTOCompactModeAttrGet(ctx, static_cast(value)); + if (mlirAttributeIsNull(a)) return py::none(); + return cls(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()); // [保留 HEAD]: AddressSpaceAttr 定义 mlir_attribute_subclass( m, "AddressSpaceAttr", @@ -601,11 +619,25 @@ PYBIND11_MODULE(_pto, m) { MlirAttribute slayout, int32_t s_fractal_size, MlirAttribute pad, - MlirContext ctx) -> py::object { + MlirContext ctx, + py::object compactModeObj) -> py::object { MlirType i32 = mlirIntegerTypeGet(ctx, 32); MlirAttribute sz = mlirIntegerAttrGet(i32, s_fractal_size); - - MlirAttribute a = mlirPTOTileBufConfigAttrGet(ctx, blayout, slayout, sz, pad); + MlirAttribute compactMode = mlirPTOCompactModeAttrGet( + ctx, static_cast(mlir::pto::CompactMode::Null)); + if (!compactModeObj.is_none()) { + if (py::isinstance(compactModeObj)) { + compactMode = mlirPTOCompactModeAttrGet( + ctx, compactModeObj.cast()); + } else if (py::hasattr(compactModeObj, "value")) { + compactMode = mlirPTOCompactModeAttrGet( + ctx, compactModeObj.attr("value").cast()); + } else { + compactMode = compactModeObj.cast(); + } + } + MlirAttribute a = mlirPTOTileBufConfigAttrGetWithCompactMode( + ctx, blayout, slayout, sz, pad, compactMode); if (mlirAttributeIsNull(a)) return py::none(); return cls(a); }, @@ -614,7 +646,8 @@ PYBIND11_MODULE(_pto, m) { py::arg("slayout"), py::arg("s_fractal_size"), py::arg("pad"), - py::arg("context") = py::none()); + py::arg("context") = py::none(), + py::arg("compact_mode") = py::none()); // ---- TileBufType ---- mlir_type_subclass(m, "TileBufType", diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index e09bf35e..53e7bc7e 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -514,6 +514,30 @@ static mlir::pto::PadValueAttr toPadValueAttr(mlir::MLIRContext *c, mlir::Attrib return mlir::pto::PadValueAttr::get(c, static_cast(ia.getInt())); return {}; } +static mlir::pto::CompactModeAttr toCompactModeAttr(mlir::MLIRContext *c, + mlir::Attribute a) { + if (auto cm = mlir::dyn_cast(a)) + return cm; + if (auto ia = mlir::dyn_cast(a)) + return mlir::pto::CompactModeAttr::get( + c, static_cast(ia.getInt())); + return {}; +} + +bool mlirPTOAttrIsACompactModeAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPTOCompactModeAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + return wrap(mlir::pto::CompactModeAttr::get( + c, static_cast(value))); +} + +int32_t mlirPTOCompactModeAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx, MlirAttribute bLayout, @@ -521,17 +545,28 @@ MlirAttribute mlirPTOTileBufConfigAttrGet(MlirContext ctx, MlirAttribute sFractalSize, MlirAttribute pad) { auto *c = unwrap(ctx); + auto compactMode = + wrap(mlir::pto::CompactModeAttr::get(c, mlir::pto::CompactMode::Null)); + return mlirPTOTileBufConfigAttrGetWithCompactMode( + ctx, bLayout, sLayout, sFractalSize, pad, compactMode); +} + +MlirAttribute mlirPTOTileBufConfigAttrGetWithCompactMode( + MlirContext ctx, MlirAttribute bLayout, MlirAttribute sLayout, + MlirAttribute sFractalSize, MlirAttribute pad, MlirAttribute compactMode) { + auto *c = unwrap(ctx); auto blA = toBLayoutAttr(c, unwrap(bLayout)); auto slA = toSLayoutAttr(c, unwrap(sLayout)); auto pvA = toPadValueAttr(c, unwrap(pad)); - if (!blA || !slA || !pvA) + auto cmA = toCompactModeAttr(c, unwrap(compactMode)); + if (!blA || !slA || !pvA || !cmA) return MlirAttribute{nullptr}; auto sz = mlir::dyn_cast(unwrap(sFractalSize)); if (!sz || !sz.getType().isInteger(32)) return MlirAttribute{nullptr}; - return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA)); + return wrap(mlir::pto::TileBufConfigAttr::get(c, blA, slA, sz, pvA, cmA)); } MlirType mlirPTOGMTypeGet(MlirContext ctx, intptr_t rank, const int64_t *shape, diff --git a/lib/PTO/IR/PTOAttrs.cpp b/lib/PTO/IR/PTOAttrs.cpp index 9c01a852..04a55fef 100644 --- a/lib/PTO/IR/PTOAttrs.cpp +++ b/lib/PTO/IR/PTOAttrs.cpp @@ -29,8 +29,9 @@ TileBufConfigAttr TileBufConfigAttr::getDefault(MLIRContext *ctx) { BLayoutAttr bl = BLayoutAttr::get(ctx, BLayout::RowMajor); SLayoutAttr sl = SLayoutAttr::get(ctx, SLayout::NoneBox); PadValueAttr pv = PadValueAttr::get(ctx, PadValue::Null); + CompactModeAttr compact = CompactModeAttr::get(ctx, CompactMode::Null); IntegerAttr sz = b.getI32IntegerAttr(512); - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); } bool TileBufConfigAttr::isDefault() const { @@ -38,13 +39,15 @@ bool TileBufConfigAttr::isDefault() const { return getBLayout() == d.getBLayout() && getSLayout() == d.getSLayout() && getSFractalSize() == d.getSFractalSize() && - getPad() == d.getPad(); + getPad() == d.getPad() && + getCompactMode() == d.getCompactMode(); } static int32_t getLayoutInt(Attribute a, int32_t def) { if (auto bl = mlir::dyn_cast(a)) return static_cast(bl.getValue()); if (auto sl = mlir::dyn_cast(a)) return static_cast(sl.getValue()); if (auto pv = mlir::dyn_cast(a)) return static_cast(pv.getValue()); + if (auto cm = mlir::dyn_cast(a)) return static_cast(cm.getValue()); if (auto ia = mlir::dyn_cast(a)) return static_cast(ia.getInt()); return def; } @@ -53,13 +56,18 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE Attribute bLayout, Attribute sLayout, IntegerAttr sFractalSize, - Attribute pad) { + Attribute pad, + Attribute compactMode) { if (!bLayout || (!mlir::isa(bLayout) && !mlir::isa(bLayout))) return emitError() << "blayout must be BLayoutAttr or i32 integer attr", failure(); if (!sLayout || (!mlir::isa(sLayout) && !mlir::isa(sLayout))) return emitError() << "slayout must be SLayoutAttr or i32 integer attr", failure(); if (!pad || (!mlir::isa(pad) && !mlir::isa(pad))) return emitError() << "pad must be PadValueAttr or i32 integer attr", failure(); + if (!compactMode || + (!mlir::isa(compactMode) && + !mlir::isa(compactMode))) + return emitError() << "compact_mode must be CompactModeAttr or i32 integer attr", failure(); if (!sFractalSize || !sFractalSize.getType().isInteger(32)) return emitError() << "s_fractal_size must be i32", failure(); @@ -80,6 +88,10 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE if (pvv < 0 || pvv > 3) return emitError() << "unsupported pad value: " << pvv, failure(); + int32_t cmv = getLayoutInt(compactMode, -1); + if (cmv < 0 || cmv > 2) + return emitError() << "unsupported compact_mode value: " << cmv, failure(); + return success(); } @@ -99,6 +111,12 @@ static PadValueAttr toPadValueAttr(MLIRContext *ctx, Attribute a) { if (auto ia = mlir::dyn_cast(a)) return PadValueAttr::get(ctx, static_cast(ia.getInt())); return {}; } +static CompactModeAttr toCompactModeAttr(MLIRContext *ctx, Attribute a) { + if (auto cm = mlir::dyn_cast(a)) return cm; + if (auto ia = mlir::dyn_cast(a)) + return CompactModeAttr::get(ctx, static_cast(ia.getInt())); + return {}; +} Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { MLIRContext *ctx = p.getContext(); @@ -107,11 +125,12 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { SLayoutAttr sl = def.getSLayout(); IntegerAttr sz = def.getSFractalSize(); PadValueAttr pv = def.getPad(); + CompactModeAttr compact = def.getCompactMode(); if (p.parseLess()) return {}; if (succeeded(p.parseOptionalGreater())) - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); while (true) { StringRef key; @@ -137,6 +156,11 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { if (p.parseAttribute(a)) return {}; pv = toPadValueAttr(ctx, a); if (!pv) return {}; + } else if (key == "compact") { + Attribute a; + if (p.parseAttribute(a)) return {}; + compact = toCompactModeAttr(ctx, a); + if (!compact) return {}; } else { p.emitError(p.getCurrentLocation(), "unknown key in tile_buf_config: ") << key; return {}; @@ -147,7 +171,7 @@ Attribute TileBufConfigAttr::parse(AsmParser &p, Type) { if (p.parseComma()) return {}; } - return TileBufConfigAttr::get(ctx, bl, sl, sz, pv); + return TileBufConfigAttr::get(ctx, bl, sl, sz, pv, compact); } void TileBufConfigAttr::print(AsmPrinter &p) const { @@ -156,5 +180,6 @@ void TileBufConfigAttr::print(AsmPrinter &p) const { p << ", slayout=" << getSLayout(); p << ", s_fractal_size=" << (int32_t)getSFractalSize().getInt(); p << ", pad=" << getPad(); + p << ", compact=" << getCompactMode(); p << ">"; } diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index 219c9abd..503774ab 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -69,6 +69,9 @@ bool TileBufType::hasNonDefaultConfig() const { mlir::Attribute TileBufType::getBLayoutAttr() const { return getConfigAttr().getBLayout(); } mlir::Attribute TileBufType::getSLayoutAttr() const { return getConfigAttr().getSLayout(); } mlir::Attribute TileBufType::getPadValueAttr() const { return getConfigAttr().getPad(); } +mlir::Attribute TileBufType::getCompactModeAttr() const { + return getConfigAttr().getCompactMode(); +} // ✅ numeric getters(可选) int32_t TileBufType::getSFractalSizeI32() const { @@ -93,6 +96,12 @@ int32_t TileBufType::getPadValueI32() const { return 0; } +int32_t TileBufType::getCompactModeI32() const { + if (auto a = llvm::dyn_cast(getCompactModeAttr())) + return static_cast(a.getValue()); + return 0; +} + // ---- TileBufType custom asm ---- // !pto.tile_buf<> Type TileBufType::parse(AsmParser &parser) { @@ -108,6 +117,7 @@ Type TileBufType::parse(AsmParser &parser) { std::string blayoutStr, slayoutStr; int64_t fractal = 0; uint32_t padInt; + uint32_t compactInt = 0; auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult { if (failed(parser.parseKeyword(expectedKey))) @@ -202,12 +212,17 @@ Type TileBufType::parse(AsmParser &parser) { if (failed(parser.parseComma())) return Type(); } - // pad=Null + // pad=0 { if (failed(parseKeyEq("pad"))) return Type(); if (failed(parser.parseInteger(padInt))) return Type(); } + if (succeeded(parser.parseOptionalComma())) { + if (failed(parseKeyEq("compact"))) return Type(); + if (failed(parser.parseInteger(compactInt))) return Type(); + } + if (failed(parser.parseGreater())) return Type(); @@ -234,6 +249,7 @@ Type TileBufType::parse(AsmParser &parser) { auto bl = symbolizeBLayout(blayoutStr); auto sl = symbolizeSLayout(slayoutStr); auto pv = symbolizePadValue(padInt); + auto compact = symbolizeCompactMode(compactInt); if (!bl.has_value()) { parser.emitError(parser.getNameLoc(), "unknown blayout: ") << blayoutStr; return Type(); @@ -246,6 +262,10 @@ Type TileBufType::parse(AsmParser &parser) { parser.emitError(parser.getNameLoc(), "unknown pad: ") << padInt; return Type(); } + if (!compact.has_value()) { + parser.emitError(parser.getNameLoc(), "unknown compact: ") << compactInt; + return Type(); + } BLayout effectiveBLayout = bl.value(); if (memorySpace.value() == AddressSpace::LEFT) { @@ -266,8 +286,10 @@ Type TileBufType::parse(AsmParser &parser) { auto fractalAttr = IntegerAttr::get(IntegerType::get(ctx, 32), fractal); auto padAttr = PadValueAttr::get(ctx, pv.value()); + auto compactAttr = CompactModeAttr::get(ctx, compact.value()); auto memorySpaceAttr = AddressSpaceAttr::get(ctx, memorySpace.value()); - auto cfg = TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr); + auto cfg = + TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr, compactAttr); SmallVector shape{rows, cols}; SmallVector validShape{vrow, vcol}; @@ -305,6 +327,19 @@ static llvm::StringRef stringifyLocFromPad(mlir::Attribute pad) { } } +static llvm::StringRef stringifyCompactModeInt(mlir::Attribute compactMode) { + auto compactAttr = llvm::dyn_cast_or_null(compactMode); + if (!compactAttr) return "9999"; + + switch (compactAttr.getValue()) { + case CompactMode::Null: return "0"; + case CompactMode::Normal: return "1"; + case CompactMode::RowPlusOne: return "2"; + default: + return "9999"; + } +} + void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { auto shape = getShape(); int64_t rows = shape.size() > 0 ? shape[0] : 0; @@ -344,6 +379,10 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { printer << ", blayout=" << stringifyBLayout(blayout.getValue()) << ", slayout=" << stringifySLayout(slayout.getValue()) << ", fractal=" << cfg.getSFractalSize().getInt() - << ", pad=" << stringifyLocFromPad(cfg.getPad()) - << ">"; + << ", pad=" << stringifyLocFromPad(cfg.getPad()); + if (auto compact = llvm::dyn_cast(cfg.getCompactMode())) { + if (compact.getValue() != CompactMode::Null) + printer << ", compact=" << stringifyCompactModeInt(compact); + } + printer << ">"; } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 9a1bb48b..ebd123e5 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -3382,9 +3382,22 @@ struct PointerCastConversion : public OpConversionPattern { case 3: padStr = "PadValue::Min"; break; } + int32_t compactVal = 0; + if (auto attr = dyn_cast(config.getCompactMode())) + compactVal = static_cast(attr.getValue()); + + std::string compactStr = "CompactMode::Null"; + switch (compactVal) { + case 1: compactStr = "CompactMode::Normal"; break; + case 2: compactStr = "CompactMode::RowPlusOne"; break; + } + if (!slStr.empty()) { - extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + padStr; + extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + + padStr + ", " + compactStr; } + } else { + extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; } // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) @@ -5069,7 +5082,8 @@ struct ReinterpretCastToEmitC : public OpConversionPattern"; + std::to_string(cols) + + ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); Value tile = rewriter @@ -7765,6 +7779,21 @@ struct PTOBindTileToEmitC : public OpConversionPattern { } } + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + std::string vrowTok, vcolTok; bool useConstructor = false; bool rowIsDynamic = false; @@ -7829,6 +7858,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { ", " + std::to_string(cols) + ", " + blTok + ", " + vrowTok + ", " + vcolTok + ", " + slTok + ", " + std::to_string(fractal) + ", " + padTok + + ", " + compactTok + ">"; return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; }; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 1c90007b..e99f4cae 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -167,6 +167,18 @@ static bool readSLayoutI32(Attribute attr, int32_t &out) { return false; } +static bool readCompactModeI32(Attribute attr, int32_t &out) { + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getValue(); + return true; + } + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getInt(); + return true; + } + return false; +} + static bool getConstIndexValue(Value v, int64_t &out) { if (auto cOp = v.getDefiningOp()) { out = cOp.value(); @@ -206,9 +218,20 @@ static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, int32_t bl = 0; // RowMajor int32_t sl = 0; // NoneBox int32_t fr = 512; + int32_t compact = 0; // Null (void)readBLayoutI32(cfg.getBLayout(), bl); (void)readSLayoutI32(cfg.getSLayout(), sl); if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); + (void)readCompactModeI32(cfg.getCompactMode(), compact); + + // CompactMode::RowPlusOne means adding one padded element in the major-stride + // dimension (the physically contiguous "row pitch" in row-major, or "column + // pitch" in col-major) to reduce bank conflicts on some vector paths. + auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { + if (compact == 2) // CompactMode::RowPlusOne + return majorStride + 1; + return majorStride; + }; // Inner shape if (sl == 0) { @@ -244,9 +267,9 @@ static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, if (sl == 0) { if (bl == 1) { info.rowStride = 1; - info.colStride = rows; + info.colStride = applyCompactToMajorStride(rows); } else { - info.rowStride = cols; + info.rowStride = applyCompactToMajorStride(cols); info.colStride = 1; } } else { @@ -254,10 +277,10 @@ static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, // ColMajor + InnerRowMajor (NZ) is supported. InnerColMajor is unsupported. if (sl != 1) return false; info.rowStride = info.innerCols; - info.colStride = rows; + info.colStride = applyCompactToMajorStride(rows); } else { // RowMajor (ZZ/ZN) - info.rowStride = cols; + info.rowStride = applyCompactToMajorStride(cols); info.colStride = info.innerRows; } } diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 319258fc..73be85a2 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -44,6 +44,8 @@ def _load_local_pto_ext(): SLayoutAttr = _pto_mod.SLayoutAttr PadValue = _pto_mod.PadValue PadValueAttr = _pto_mod.PadValueAttr +CompactMode = _pto_mod.CompactMode +CompactModeAttr = _pto_mod.CompactModeAttr RoundMode = _pto_mod.RoundMode RoundModeAttr = _pto_mod.RoundModeAttr CmpMode = _pto_mod.CmpMode @@ -75,6 +77,7 @@ def _load_local_pto_ext(): "BLayout","BLayoutAttr", "SLayout","SLayoutAttr", "PadValue","PadValueAttr", + "CompactMode", "CompactModeAttr", "RoundMode", "RoundModeAttr", "CmpMode", "CmpModeAttr", "PIPE", "PipeAttr", diff --git a/test/basic/set_validshape_if.pto b/test/basic/set_validshape_if.pto index f6007e0a..d4792397 100644 --- a/test/basic/set_validshape_if.pto +++ b/test/basic/set_validshape_if.pto @@ -33,7 +33,7 @@ module { // CHECK: [[ORIG:v[0-9]+]]; // CHECK: TASSIGN([[ORIG]], -// CHECK: Tile [[TILE:v[0-9]+]] = Tile({{.*}}) +// CHECK: Tile [[TILE:v[0-9]+]] = Tile({{.*}}) // CHECK: __ubuf__ float* [[DATA:v[0-9]+]] = [[ORIG]].data(); // CHECK: TASSIGN([[TILE]], // CHECK: if ( diff --git a/test/basic/set_validshape_local_lowering.pto b/test/basic/set_validshape_local_lowering.pto index 6ec0496f..a630e411 100644 --- a/test/basic/set_validshape_local_lowering.pto +++ b/test/basic/set_validshape_local_lowering.pto @@ -15,9 +15,9 @@ module { } } -// CHECK: Tile [[BASE:v[0-9]+]]; +// CHECK: Tile [[BASE:v[0-9]+]]; // CHECK: TASSIGN([[BASE]], [[ADDR:v[0-9]+]]); -// CHECK: Tile [[TILE:v[0-9]+]] = Tile({{.*}}) +// CHECK: Tile [[TILE:v[0-9]+]] = Tile({{.*}}) // CHECK: __ubuf__ float* [[DATA:v[0-9]+]] = [[BASE]].data(); // CHECK: uint64_t [[TILE_ADDR:v[0-9]+]] = reinterpret_cast([[DATA]]); // CHECK: TASSIGN([[TILE]], [[TILE_ADDR]]); diff --git a/test/basic/tci_i16_emitc.pto b/test/basic/tci_i16_emitc.pto index 5ed2b34c..4e255b7a 100644 --- a/test/basic/tci_i16_emitc.pto +++ b/test/basic/tci_i16_emitc.pto @@ -14,5 +14,5 @@ module { } } -// A3: TCI, int16_t, 0>( -// A3-NOT: TCI, int32_t, 0>( +// A3: TCI, int16_t, 0>( +// A3-NOT: TCI, int32_t, 0>( diff --git a/test/basic/tgather_three_forms_emitc.pto b/test/basic/tgather_three_forms_emitc.pto index 0146f7b8..e9e7c870 100644 --- a/test/basic/tgather_three_forms_emitc.pto +++ b/test/basic/tgather_three_forms_emitc.pto @@ -22,7 +22,7 @@ module { } // A3: TGATHER({{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}); -// A3: TGATHER, Tile, MaskPattern::P1111>( +// A3: TGATHER, Tile, MaskPattern::P1111>( // A3-NOT: reinterpret_cast< -// A3-NOT: TGATHER, Tile, Tile, Tile, CmpMode::EQ, 7>( -// A3: TGATHER, Tile, Tile, Tile, CmpMode::EQ, 7>( +// A3-NOT: TGATHER, Tile, Tile, Tile, CmpMode::EQ, 7>( +// A3: TGATHER, Tile, Tile, Tile, CmpMode::EQ, 7>( diff --git a/test/basic/tile_compact_mode_emitc.pto b/test/basic/tile_compact_mode_emitc.pto new file mode 100644 index 00000000..b7459c60 --- /dev/null +++ b/test/basic/tile_compact_mode_emitc.pto @@ -0,0 +1,33 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @tile_compact_mode_emitc() { + %default = pto.alloc_tile + : !pto.tile_buf + %compact = pto.alloc_tile + : !pto.tile_buf + %row_plus_one = pto.alloc_tile + : !pto.tile_buf + pto.tprint ins(%default : !pto.tile_buf) + pto.tprint ins(%compact : !pto.tile_buf) + pto.tprint ins(%row_plus_one : !pto.tile_buf) + return + } +} + +// A3-DAG: memref.alloc() : memref<1x16xf16, strided<[16, 1]>, #pto.address_space> +// A3-DAG: memref.alloc() : memref<1x16xf16, strided<[16, 1]>, #pto.address_space> +// A3-DAG: memref.alloc() : memref<1x16xf16, strided<[17, 1]>, #pto.address_space> +// A3: Tile [[DEFAULT:v[0-9]+]]; +// A3: Tile [[COMPACT:v[0-9]+]]; +// A3: Tile [[ROWP1:v[0-9]+]]; diff --git a/test/basic/tprint_alloc_tile_no_rebind.pto b/test/basic/tprint_alloc_tile_no_rebind.pto index 29f6bf30..a3cb04b0 100644 --- a/test/basic/tprint_alloc_tile_no_rebind.pto +++ b/test/basic/tprint_alloc_tile_no_rebind.pto @@ -13,7 +13,7 @@ module { } // CHECK-LABEL: __global__ AICORE void print_kernel() { -// CHECK: Tile [[TILE:v[0-9]+]]; +// CHECK: Tile [[TILE:v[0-9]+]]; // CHECK: TASSIGN([[TILE]], [[ADDR:v[0-9]+]]); // CHECK-NOT: TASSIGN( // CHECK-NOT: .data() diff --git a/test/basic/tpush_tpop_emitc.pto b/test/basic/tpush_tpop_emitc.pto index 8a6841d7..12254a77 100644 --- a/test/basic/tpush_tpop_emitc.pto +++ b/test/basic/tpush_tpop_emitc.pto @@ -37,7 +37,7 @@ module { // A3: const int64_t {{v[0-9]+}} = 0; // A3: #if defined(__DAV_CUBE__) // A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8>( -// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( // A3: #endif // __DAV_CUBE__ // A3-LABEL: AICORE void vector_pop_gm( @@ -46,7 +46,7 @@ module { // A3: set_mask_norm(); // A3: set_vector_mask(-1, -1); // A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_C2V, 1024, 8, 8>( -// A3: Tile {{v[0-9]+}}; -// A3: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>( +// A3: Tile {{v[0-9]+}}; +// A3: TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>( // A3: TFREE, TileSplitAxis::TILE_LEFT_RIGHT>( // A3: #endif // __DAV_VEC__ diff --git a/test/basic/tpush_tpop_frontend_lowering_a3.pto b/test/basic/tpush_tpop_frontend_lowering_a3.pto index 4cc08605..3e30b2aa 100644 --- a/test/basic/tpush_tpop_frontend_lowering_a3.pto +++ b/test/basic/tpush_tpop_frontend_lowering_a3.pto @@ -63,31 +63,31 @@ module { // A3-LABEL: AICORE void cube_kernel(__gm__ float* // A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4>( // A3: TPUSH -// A3: Tile {{v[0-9]+}}; -// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A3: Tile {{v[0-9]+}}; +// A3: Tile {{v[0-9]+}}; +// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: Tile {{v[0-9]+}}; // A3: TMOV( // A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // A3-LABEL: AICORE void vector_kernel(__gm__ float* // A3: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4, 4>( -// A3: Tile {{v[0-9]+}}; -// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A3: Tile {{v[0-9]+}}; +// A3: Tile {{v[0-9]+}}; +// A3: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A3: Tile {{v[0-9]+}}; // A3: TNEG( // A3: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3-LABEL: AICORE void cube_kernel(__gm__ float* -// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3: set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); -// SYNC-A3: Tile +// SYNC-A3: Tile // SYNC-A3: wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // SYNC-A3: TMOV( // SYNC-A3-LABEL: AICORE void vector_kernel(__gm__ float* -// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// SYNC-A3: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( // SYNC-A3: set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); -// SYNC-A3: Tile +// SYNC-A3: Tile // SYNC-A3: wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); // SYNC-A3: TNEG( diff --git a/test/basic/tpush_tpop_frontend_lowering_a5.pto b/test/basic/tpush_tpop_frontend_lowering_a5.pto index 42934522..84e20b79 100644 --- a/test/basic/tpush_tpop_frontend_lowering_a5.pto +++ b/test/basic/tpush_tpop_frontend_lowering_a5.pto @@ -59,19 +59,19 @@ module { // A5-LABEL: AICORE void cube_kernel( // A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>( // A5: TPUSH -// A5: Tile {{v[0-9]+}}; -// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A5: Tile {{v[0-9]+}}; +// A5: Tile {{v[0-9]+}}; +// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: Tile {{v[0-9]+}}; // A5: TMOV( // A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>( // A5-LABEL: AICORE void vector_kernel( // A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>( -// A5: Tile {{v[0-9]+}}; -// A5: Tile {{v[0-9]+}}; +// A5: Tile {{v[0-9]+}}; +// A5: Tile {{v[0-9]+}}; // A5: TMOV( -// A5: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( -// A5: Tile {{v[0-9]+}}; +// A5: TPUSH, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: TPOP, Tile, TileSplitAxis::TILE_NO_SPLIT>( +// A5: Tile {{v[0-9]+}}; // A5: TNEG( // A5: TFREE, TileSplitAxis::TILE_NO_SPLIT>(