@@ -95,6 +95,10 @@ PYBIND11_MODULE(_pto, m) {
9595 .value (" Max" , mlir::pto::PadValue::Max)
9696 .value (" Min" , mlir::pto::PadValue::Min);
9797
98+ py::enum_<mlir::pto::CompactMode>(m, " CompactMode" )
99+ .value (" Null" , mlir::pto::CompactMode::Null)
100+ .value (" Normal" , mlir::pto::CompactMode::Normal);
101+
98102 py::enum_<mlir::pto::RoundMode>(m, " RoundMode" )
99103 .value (" NONE" , mlir::pto::RoundMode::NONE)
100104 .value (" RINT" , mlir::pto::RoundMode::RINT)
@@ -210,6 +214,19 @@ PYBIND11_MODULE(_pto, m) {
210214 return cls (a);
211215 },
212216 py::arg (" cls" ), py::arg (" value" ), py::arg (" context" ) = py::none ());
217+
218+ mlir_attribute_subclass (m, " CompactModeAttr" ,
219+ [](MlirAttribute a) -> bool {
220+ return mlirPTOAttrIsACompactModeAttr (a);
221+ })
222+ .def_classmethod (
223+ " get" ,
224+ [](py::object cls, mlir::pto::CompactMode value, MlirContext ctx) -> py::object {
225+ MlirAttribute a = mlirPTOCompactModeAttrGet (ctx, static_cast <int32_t >(value));
226+ if (mlirAttributeIsNull (a)) return py::none ();
227+ return cls (a);
228+ },
229+ py::arg (" cls" ), py::arg (" value" ), py::arg (" context" ) = py::none ());
213230 // [保留 HEAD]: AddressSpaceAttr 定义
214231 mlir_attribute_subclass (
215232 m, " AddressSpaceAttr" ,
@@ -571,11 +588,25 @@ PYBIND11_MODULE(_pto, m) {
571588 MlirAttribute slayout,
572589 int32_t s_fractal_size,
573590 MlirAttribute pad,
574- MlirContext ctx) -> py::object {
591+ MlirContext ctx,
592+ py::object compactModeObj) -> py::object {
575593 MlirType i32 = mlirIntegerTypeGet (ctx, 32 );
576594 MlirAttribute sz = mlirIntegerAttrGet (i32 , s_fractal_size);
577-
578- MlirAttribute a = mlirPTOTileBufConfigAttrGet (ctx, blayout, slayout, sz, pad);
595+ MlirAttribute compactMode = mlirPTOCompactModeAttrGet (
596+ ctx, static_cast <int32_t >(mlir::pto::CompactMode::Null));
597+ if (!compactModeObj.is_none ()) {
598+ if (py::isinstance<py::int_>(compactModeObj)) {
599+ compactMode = mlirPTOCompactModeAttrGet (
600+ ctx, compactModeObj.cast <int32_t >());
601+ } else if (py::hasattr (compactModeObj, " value" )) {
602+ compactMode = mlirPTOCompactModeAttrGet (
603+ ctx, compactModeObj.attr (" value" ).cast <int32_t >());
604+ } else {
605+ compactMode = compactModeObj.cast <MlirAttribute>();
606+ }
607+ }
608+ MlirAttribute a = mlirPTOTileBufConfigAttrGetWithCompactMode (
609+ ctx, blayout, slayout, sz, pad, compactMode);
579610 if (mlirAttributeIsNull (a)) return py::none ();
580611 return cls (a);
581612 },
@@ -584,7 +615,8 @@ PYBIND11_MODULE(_pto, m) {
584615 py::arg (" slayout" ),
585616 py::arg (" s_fractal_size" ),
586617 py::arg (" pad" ),
587- py::arg (" context" ) = py::none ());
618+ py::arg (" context" ) = py::none (),
619+ py::arg (" compact_mode" ) = py::none ());
588620
589621 // ---- TileBufType ----
590622 mlir_type_subclass (m, " TileBufType" ,
0 commit comments