From 556f279c5d53cb6b97aed93cc13830bfc9d43129 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 1 Apr 2026 14:06:33 +0800 Subject: [PATCH 1/2] Add integer compatibility for selected PTO index operands --- include/PTO/IR/PTOOps.td | 82 ++--- lib/PTO/IR/PTO.cpp | 434 ++++++++++++++++++++++++- lib/PTO/Transforms/PTOViewToMemref.cpp | 20 +- 3 files changed, 473 insertions(+), 63 deletions(-) diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 00dd0e41..83b7ec39 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -51,6 +51,16 @@ def ScalarPtrOrMemRef : def ScalarType : AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">; +def IndexOrI64 : + Type< + CPred<"$_self.isIndex() || $_self.isSignlessInteger(64)">, + "index or i64">; + +def IndexOrU32 : + Type< + CPred<"$_self.isIndex() || $_self.isUnsignedInteger(32) || $_self.isSignlessInteger(32)">, + "index or u32-compatible i32">; + //===----------------------------------------------------------------------===// // Op Class //===----------------------------------------------------------------------===// @@ -86,16 +96,14 @@ def AddPtrOp : PTO_Op<"addptr", [ let arguments = (ins PtrType:$ptr, - Index:$offset + IndexOrI64:$offset ); let results = (outs PtrType:$result); let hasVerifier = 1; - let assemblyFormat = [{ - $ptr `,` $offset attr-dict `:` type($ptr) `->` type($result) - }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -109,16 +117,14 @@ def LoadScalarOp : PTO_Op<"load_scalar", [ let arguments = (ins ScalarPtrOrMemRef:$ptr, - Index:$offset + IndexOrI64:$offset ); let results = (outs AnyType:$value); let hasVerifier = 1; - let assemblyFormat = [{ - $ptr `[` $offset `]` attr-dict `:` type($ptr) `->` type($value) - }]; + let hasCustomAssemblyFormat = 1; } def StoreScalarOp : PTO_Op<"store_scalar", [ @@ -128,7 +134,7 @@ def StoreScalarOp : PTO_Op<"store_scalar", [ let arguments = (ins ScalarPtrOrMemRef:$ptr, - Index:$offset, + IndexOrI64:$offset, AnyType:$value ); @@ -136,9 +142,7 @@ def StoreScalarOp : PTO_Op<"store_scalar", [ let hasVerifier = 1; - let assemblyFormat = [{ - $value `,` $ptr `[` $offset `]` attr-dict `:` type($ptr) `,` type($value) - }]; + let hasCustomAssemblyFormat = 1; } def MakeTensorViewOp : PTO_Op<"make_tensor_view", [AttrSizedOperandSegments]> { @@ -146,8 +150,8 @@ def MakeTensorViewOp : PTO_Op<"make_tensor_view", [AttrSizedOperandSegments]> { let arguments = (ins AnyType:$ptr, - Variadic:$shape, - Variadic:$strides, + Variadic:$shape, + Variadic:$strides, OptionalAttr:$layout ); @@ -173,18 +177,15 @@ def PartitionViewOp : PTO_Op<"partition_view", [AttrSizedOperandSegments]> { let arguments = (ins TensorViewType:$source, // 输入: 物理大底座 (MakeTensorViewOp 的结果) - Variadic:$offsets, // 动态 offsets - Variadic:$sizes // 动态 sizes + Variadic:$offsets, // 动态 offsets + Variadic:$sizes // 动态 sizes ); let results = (outs PartitionTensorViewType:$result); // 输出: 逻辑切片 let hasVerifier = 1; - let assemblyFormat = [{ - $source `,` `offsets` `=` `[` $offsets `]` `,` `sizes` `=` `[` $sizes `]` - attr-dict `:` qualified(type($source)) `->` qualified(type($result)) - }]; + let hasCustomAssemblyFormat = 1; } // Helper: tensor_view or memref (after lowering tensor_view to memref). @@ -207,13 +208,10 @@ def GetTensorViewDimOp : PTO_Op<"get_tensor_view_dim", [Pure]> { }]; let arguments = (ins TensorViewOrMemRef:$tensor_view, - Index:$dim_index + IndexOrI64:$dim_index ); let results = (outs Index:$result); - let assemblyFormat = [{ - $tensor_view `,` $dim_index `:` qualified(type($tensor_view)) `->` qualified(type($result)) - attr-dict - }]; + let hasCustomAssemblyFormat = 1; } def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { @@ -221,18 +219,13 @@ def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$addr, - Optional:$valid_row, - Optional:$valid_col + Optional:$valid_row, + Optional:$valid_col ); let results = (outs TileBufType:$result); - let assemblyFormat = [{ - (`addr` `=` $addr^)? - (`valid_row` `=` $valid_row^)? - (`valid_col` `=` $valid_col^)? - attr-dict `:` qualified(type($result)) - }]; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ ::mlir::LogicalResult verify(); @@ -326,16 +319,13 @@ def SetValidShapeOp : PTO_Op<"set_validshape", [ let arguments = (ins TileBufOrMemRef:$source, - Index:$valid_row, - Index:$valid_col + IndexOrU32:$valid_row, + IndexOrU32:$valid_col ); let hasVerifier = 1; - let assemblyFormat = [{ - $source `,` $valid_row `,` $valid_col attr-dict `:` - qualified(type($source)) - }]; + let hasCustomAssemblyFormat = 1; } // ============================================================================ @@ -396,7 +386,7 @@ def TLoadOp : PTO_TOp<"tload", [ PTODpsType:$dst, OptionalAttr:$pad_mode, Optional:$pad_value, - Optional:$left_padding_num, + Optional:$left_padding_num, Optional:$right_padding_num, DefaultValuedOptionalAttr:$init_out_buffer, Optional:$init_condition @@ -1383,7 +1373,7 @@ def TSetValOp : PTO_TOp<"tsetval", [ let arguments = (ins PTODpsType:$dst, - Index:$offset, + IndexOrU32:$offset, ScalarType:$val ); @@ -1414,7 +1404,7 @@ def TGetValOp : PTO_TOp<"tgetval", [ let arguments = (ins PTODpsType:$src, - Index:$offset + IndexOrU32:$offset ); let results = (outs ScalarType:$dst); @@ -2166,8 +2156,8 @@ def TExtractOp : PTO_TOp<"textract", [ let arguments = (ins PTODpsType:$src, - Index:$indexRow, - Index:$indexCol, + IndexOrU32:$indexRow, + IndexOrU32:$indexCol, PTODpsType:$dst ); @@ -2198,8 +2188,8 @@ def TInsertOp : PTO_TOp<"tinsert", [ let arguments = (ins PTODpsType:$src, - Index:$indexRow, - Index:$indexCol, + IndexOrU32:$indexRow, + IndexOrU32:$indexCol, PTODpsType:$dst ); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index e4ca8db5..94c53945 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -338,6 +338,221 @@ static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { return mlir::Type(); } +static bool isIndexOrI64Type(Type type) { + return type && (type.isIndex() || type.isSignlessInteger(64)); +} + +static bool isIndexOrU32Type(Type type) { + return type && (type.isIndex() || type.isUnsignedInteger(32) || + type.isSignlessInteger(32)); +} + +template +static ParseResult parseOptionalCompatibleType(OpAsmParser &parser, Type &type, + Pred &&isSupportedType, + StringRef expectedDesc) { + type = parser.getBuilder().getIndexType(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseType(type)) + return failure(); + if (!isSupportedType(type)) + return parser.emitError(parser.getCurrentLocation()) + << "expected " << expectedDesc; + } + return success(); +} + +template +static LogicalResult +verifyUniformCompatibleOperandTypes(Operation *op, ValueRange values, + Pred &&isSupportedType, StringRef groupName) { + Type nonIndexType; + for (Value value : values) { + if (!value) + continue; + Type type = value.getType(); + if (!isSupportedType(type)) { + return op->emitOpError() << "expects " << groupName + << " to use compatible integer/index types"; + } + if (type.isIndex()) + continue; + if (nonIndexType && nonIndexType != type) { + return op->emitOpError() << "expects " << groupName + << " to use a uniform non-index type"; + } + nonIndexType = type; + } + return success(); +} + +static Type getCompatibleOperandTypeOrIndex(MLIRContext *ctx, ValueRange values) { + Type nonIndexType; + for (Value value : values) { + if (!value) + continue; + Type type = value.getType(); + if (type.isIndex()) + continue; + if (!nonIndexType) + nonIndexType = type; + } + return nonIndexType ? nonIndexType : IndexType::get(ctx); +} + +static void printCompatibleTypeSuffix(OpAsmPrinter &printer, Type type) { + if (type && !type.isIndex()) + printer << ", " << type; +} + +ParseResult mlir::pto::AddPtrOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type offsetTy; + SmallVector resultTypes; + + if (parser.parseOperand(ptr) || parser.parseComma() || + parser.parseOperand(offset) || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(ptrTy) || + parseOptionalCompatibleType(parser, offsetTy, isIndexOrI64Type, + "index or i64 offset type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, offsetTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::AddPtrOp::print(OpAsmPrinter &printer) { + printer << " " << getPtr() << ", " << getOffset(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + printCompatibleTypeSuffix(printer, getOffset().getType()); + printer << " -> " << getResult().getType(); +} + +ParseResult mlir::pto::LoadScalarOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type offsetTy; + SmallVector resultTypes; + + if (parser.parseOperand(ptr) || parser.parseLSquare() || + parser.parseOperand(offset) || parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(ptrTy) || + parseOptionalCompatibleType(parser, offsetTy, isIndexOrI64Type, + "index or i64 offset type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, offsetTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::LoadScalarOp::print(OpAsmPrinter &printer) { + printer << " " << getPtr() << "[" << getOffset() << "]"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + printCompatibleTypeSuffix(printer, getOffset().getType()); + printer << " -> " << getValue().getType(); +} + +ParseResult mlir::pto::StoreScalarOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand value; + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type valueTy; + Type compatTy = parser.getBuilder().getIndexType(); + Type secondTy; + + if (parser.parseOperand(value) || parser.parseComma() || parser.parseOperand(ptr) || + parser.parseLSquare() || parser.parseOperand(offset) || parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(ptrTy) || + parser.parseComma() || parser.parseType(secondTy)) + return failure(); + + if (succeeded(parser.parseOptionalComma())) { + compatTy = secondTy; + if (!isIndexOrI64Type(compatTy)) + return parser.emitError(parser.getCurrentLocation()) + << "expected index or i64 offset type"; + if (parser.parseType(valueTy)) + return failure(); + } else { + valueTy = secondTy; + } + + if (parser.resolveOperand(value, valueTy, result.operands) || + parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, compatTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::StoreScalarOp::print(OpAsmPrinter &printer) { + printer << " " << getValue() << ", " << getPtr() << "[" << getOffset() << "]"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + if (!getOffset().getType().isIndex()) + printer << ", " << getOffset().getType(); + printer << ", " << getValue().getType(); +} + +ParseResult mlir::pto::GetTensorViewDimOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand tensorView; + OpAsmParser::UnresolvedOperand dimIndex; + Type tensorViewTy; + Type dimIndexTy; + SmallVector resultTypes; + + if (parser.parseOperand(tensorView) || parser.parseComma() || + parser.parseOperand(dimIndex) || parser.parseColonType(tensorViewTy) || + parseOptionalCompatibleType(parser, dimIndexTy, isIndexOrI64Type, + "index or i64 dim type") || + parser.parseArrowTypeList(resultTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(tensorView, tensorViewTy, result.operands) || + parser.resolveOperand(dimIndex, dimIndexTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::GetTensorViewDimOp::print(OpAsmPrinter &printer) { + printer << " " << getTensorView() << ", " << getDimIndex() << " : " + << getTensorView().getType(); + printCompatibleTypeSuffix(printer, getDimIndex().getType()); + printer << " -> " << getResult().getType(); + printer.printOptionalAttrDict((*this)->getAttrs()); +} + mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { SmallVector shape; Type elementType; @@ -555,6 +770,7 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, SmallVector strideOps; Type resultTy; + Type operandTy; // %ptr if (parser.parseOperand(ptr)) @@ -579,7 +795,9 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, return failure(); // : result-type - if (parser.parseColonType(resultTy)) + if (parser.parseColonType(resultTy) || + parseOptionalCompatibleType(parser, operandTy, isIndexOrI64Type, + "index or i64 operand type")) return failure(); result.addTypes(resultTy); @@ -596,11 +814,9 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, if (parser.resolveOperand(ptr, ptrTy, result.operands)) return failure(); - // resolve shape/strides 为 index - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(shapeOps, indexTy, result.operands)) + if (parser.resolveOperands(shapeOps, operandTy, result.operands)) return failure(); - if (parser.resolveOperands(strideOps, indexTy, result.operands)) + if (parser.resolveOperands(strideOps, operandTy, result.operands)) return failure(); auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( @@ -625,6 +841,68 @@ void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { /*elidedAttrs=*/{"operandSegmentSizes"}); p << " : " << getResult().getType(); + SmallVector compatOperands; + compatOperands.reserve(getShape().size() + getStrides().size()); + compatOperands.append(getShape().begin(), getShape().end()); + compatOperands.append(getStrides().begin(), getStrides().end()); + printCompatibleTypeSuffix( + p, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + +ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector sizes; + Type sourceTy; + Type operandTy; + SmallVector resultTypes; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseKeyword("offsets") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(offsets) || + parser.parseRSquare() || parser.parseComma() || + parser.parseKeyword("sizes") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(sizes) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy) || + parseOptionalCompatibleType(parser, operandTy, isIndexOrI64Type, + "index or i64 operand type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(source, sourceTy, result.operands) || + parser.resolveOperands(offsets, operandTy, result.operands) || + parser.resolveOperands(sizes, operandTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {1, static_cast(offsets.size()), static_cast(sizes.size())}); + result.addAttribute("operandSegmentSizes", segAttr); + return success(); +} + +void mlir::pto::PartitionViewOp::print(OpAsmPrinter &p) { + p << " " << getSource() << ", offsets = ["; + p.printOperands(getOffsets()); + p << "], sizes = ["; + p.printOperands(getSizes()); + p << "]"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + p << " : " << getSource().getType(); + SmallVector compatOperands; + compatOperands.reserve(getOffsets().size() + getSizes().size()); + compatOperands.append(getOffsets().begin(), getOffsets().end()); + compatOperands.append(getSizes().begin(), getSizes().end()); + printCompatibleTypeSuffix( + p, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); + p << " -> " << getResult().getType(); } // Layout inference helpers for make_tensor_view @@ -895,6 +1173,15 @@ static std::optional getConstantIntegerValue(Value value) { } LogicalResult mlir::pto::MakeTensorViewOp::verify() { + SmallVector compatOperands; + compatOperands.reserve(getShape().size() + getStrides().size()); + compatOperands.append(getShape().begin(), getShape().end()); + compatOperands.append(getStrides().begin(), getStrides().end()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrI64Type, + "shape/strides operands"))) + return failure(); + auto tvTy = dyn_cast(getResult().getType()); if (!tvTy) return emitOpError("result must be pto.tensor_view<...>"); @@ -956,6 +1243,15 @@ LogicalResult mlir::pto::MakeTensorViewOp::verify() { } LogicalResult mlir::pto::PartitionViewOp::verify() { + SmallVector compatOperands; + compatOperands.reserve(getOffsets().size() + getSizes().size()); + compatOperands.append(getOffsets().begin(), getOffsets().end()); + compatOperands.append(getSizes().begin(), getSizes().end()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrI64Type, + "offset/size operands"))) + return failure(); + auto srcTy = dyn_cast(getSource().getType()); auto resTy = dyn_cast(getResult().getType()); if (!srcTy || !resTy) @@ -1031,6 +1327,72 @@ LogicalResult mlir::pto::AddPtrOp::verify() { return success(); } +ParseResult mlir::pto::AllocTileOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand addr; + OpAsmParser::UnresolvedOperand validRow; + OpAsmParser::UnresolvedOperand validCol; + bool hasAddr = false; + bool hasValidRow = false; + bool hasValidCol = false; + Type compatTy; + Type resultTy; + + if (succeeded(parser.parseOptionalKeyword("addr"))) { + if (parser.parseEqual() || parser.parseOperand(addr)) + return failure(); + hasAddr = true; + } + if (succeeded(parser.parseOptionalKeyword("valid_row"))) { + if (parser.parseEqual() || parser.parseOperand(validRow)) + return failure(); + hasValidRow = true; + } + if (succeeded(parser.parseOptionalKeyword("valid_col"))) { + if (parser.parseEqual() || parser.parseOperand(validCol)) + return failure(); + hasValidCol = true; + } + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultTy) || + parseOptionalCompatibleType(parser, compatTy, isIndexOrU32Type, + "index or u32-compatible operand type")) + return failure(); + + result.addTypes(resultTy); + if (hasAddr && + parser.resolveOperand(addr, parser.getBuilder().getI64Type(), result.operands)) + return failure(); + if (hasValidRow && parser.resolveOperand(validRow, compatTy, result.operands)) + return failure(); + if (hasValidCol && parser.resolveOperand(validCol, compatTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {hasAddr ? 1 : 0, hasValidRow ? 1 : 0, hasValidCol ? 1 : 0}); + result.addAttribute("operandSegmentSizes", segAttr); + return success(); +} + +void mlir::pto::AllocTileOp::print(OpAsmPrinter &printer) { + if (Value addr = getAddr()) + printer << " addr = " << addr; + if (Value validRow = getValidRow()) + printer << " valid_row = " << validRow; + if (Value validCol = getValidCol()) + printer << " valid_col = " << validCol; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + printer << " : " << getResult().getType(); + SmallVector compatOperands; + if (getValidRow()) + compatOperands.push_back(getValidRow()); + if (getValidCol()) + compatOperands.push_back(getValidCol()); + printCompatibleTypeSuffix( + printer, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + @@ -1218,6 +1580,16 @@ static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src LogicalResult AllocTileOp::verify() { auto ty = getResult().getType(); // TileBufType + SmallVector compatOperands; + if (getValidRow()) + compatOperands.push_back(getValidRow()); + if (getValidCol()) + compatOperands.push_back(getValidCol()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrU32Type, + "valid_row/valid_col operands"))) + return failure(); + // op 上有没有传 operands bool hasVR = getValidRow() != nullptr; bool hasVC = getValidCol() != nullptr; @@ -2922,6 +3294,10 @@ mlir::LogicalResult mlir::pto::TExpandsOp::verify() { mlir::LogicalResult mlir::pto::TExtractOp::verify() { auto getConstIndex = [&](Value v) -> std::optional { + if (auto cst = v.getDefiningOp()) + return cst.value(); + if (auto cst = v.getDefiningOp()) + return cst.value(); auto cst = v.getDefiningOp(); if (!cst) return std::nullopt; @@ -2930,8 +3306,9 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { return std::nullopt; }; auto verifyIndexOperands = [&]() -> LogicalResult { - if (!getIndexRow().getType().isIndex() || !getIndexCol().getType().isIndex()) - return emitOpError("expects indexRow and indexCol to be index type"); + if (!isIndexOrU32Type(getIndexRow().getType()) || + !isIndexOrU32Type(getIndexCol().getType())) + return emitOpError("expects indexRow and indexCol to be index or u32-compatible type"); auto row = getConstIndex(getIndexRow()); auto col = getConstIndex(getIndexCol()); if (row && *row < 0) @@ -3096,8 +3473,9 @@ mlir::LogicalResult mlir::pto::TInsertOp::verify() { return emitOpError( "expects src/dst element types to match, or src=f32 with dst=f16/bf16"); - if (!getIndexRow().getType().isIndex() || !getIndexCol().getType().isIndex()) - return emitOpError("expects indexRow/indexCol to be index type"); + if (!isIndexOrU32Type(getIndexRow().getType()) || + !isIndexOrU32Type(getIndexCol().getType())) + return emitOpError("expects indexRow/indexCol to be index or u32-compatible type"); auto readConstIndex = [&](Value v, int64_t &out) -> bool { if (auto cOp = v.getDefiningOp()) { @@ -4803,6 +5181,38 @@ static bool isLocallyBoundTileSource(Value value) { return false; } +ParseResult mlir::pto::SetValidShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + OpAsmParser::UnresolvedOperand validRow; + OpAsmParser::UnresolvedOperand validCol; + Type sourceTy; + Type compatTy; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(validRow) || parser.parseComma() || + parser.parseOperand(validCol) || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy) || + parseOptionalCompatibleType(parser, compatTy, isIndexOrU32Type, + "index or u32-compatible operand type")) + return failure(); + + if (parser.resolveOperand(source, sourceTy, result.operands) || + parser.resolveOperand(validRow, compatTy, result.operands) || + parser.resolveOperand(validCol, compatTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::SetValidShapeOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getValidRow() << ", " << getValidCol(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType(); + SmallVector compatOperands = {getValidRow(), getValidCol()}; + printCompatibleTypeSuffix( + printer, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + static std::optional getConstIndexLike(Value v) { if (auto cOp = v.getDefiningOp()) return cOp.value(); @@ -4824,6 +5234,12 @@ static std::optional getConstIndexLike(Value v) { } mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { + SmallVector compatOperands = {getValidRow(), getValidCol()}; + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), + isIndexOrU32Type, "valid_row/valid_col operands"))) + return failure(); + SmallVector shape; if (auto srcTy = llvm::dyn_cast(getSource().getType())) { if (srcTy.getRank() != 2) diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index a8d9f6c7..dd1f31b9 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -614,14 +614,15 @@ struct PTOViewToMemrefPass // 5. 获取 Config (保持不变) auto configAttr = tbTy.getConfigAttr(); if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + Value bindVRow = vRow ? ensureIndex(rewriter, loc, vRow, op) : Value(); + Value bindVCol = vCol ? ensureIndex(rewriter, loc, vCol, op) : Value(); // 6. If alloc_tile provides an explicit address, lower directly to // pto.pointer_cast so downstream EmitC lowering can use the integral // address without relying on MemPlan. if (Value addr = op.getAddr()) { auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); + loc, targetType, ValueRange{addr}, bindVRow, bindVCol, configAttr); markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, pc.getResult()); continue; @@ -635,8 +636,7 @@ struct PTOViewToMemrefPass // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); + loc, targetType, alloc, bindVRow, bindVCol, configAttr); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); @@ -731,7 +731,7 @@ struct PTOViewToMemrefPass if (!mrTy) continue; // leave it to later passes if it hasn't been lowered yet - Value dimIdx = op.getDimIndex(); + Value dimIdx = ensureIndex(rewriter, loc, op.getDimIndex(), op); Value dim = rewriter.create(loc, view, dimIdx); rewriter.replaceOp(op, dim); } @@ -2120,10 +2120,14 @@ struct PTOViewToMemrefPass Value dst = op.getDst(); auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + bool rowTyOk = indexRow.getType().isIndex() || + indexRow.getType().isUnsignedInteger(32) || + indexRow.getType().isSignlessInteger(32); + bool colTyOk = indexCol.getType().isIndex() || + indexCol.getType().isUnsignedInteger(32) || + indexCol.getType().isSignlessInteger(32); + if (!srcTy || !dstTy || !rowTyOk || !colTyOk) { op.emitError("ins/outs are not correct yet"); signalPassFailure(); return; From 9314fbac10b56f66bd5cde34c87c79295bce293f Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 1 Apr 2026 15:12:29 +0800 Subject: [PATCH 2/2] Switch index-compatible operands to i32 and add tests --- include/PTO/IR/PTOOps.td | 28 +++++------ lib/PTO/IR/PTO.cpp | 35 +++++++------ lib/PTO/Transforms/PTOViewToMemref.cpp | 12 +++-- test/basic/index_input_compat_i32.pto | 70 ++++++++++++++++++++++++++ test/basic/index_input_compat_i64.pto | 52 +++++++++++++++++++ 5 files changed, 162 insertions(+), 35 deletions(-) create mode 100644 test/basic/index_input_compat_i32.pto create mode 100644 test/basic/index_input_compat_i64.pto diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 83b7ec39..a3cbe796 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -56,10 +56,10 @@ def IndexOrI64 : CPred<"$_self.isIndex() || $_self.isSignlessInteger(64)">, "index or i64">; -def IndexOrU32 : +def IndexOrI32 : Type< - CPred<"$_self.isIndex() || $_self.isUnsignedInteger(32) || $_self.isSignlessInteger(32)">, - "index or u32-compatible i32">; + CPred<"$_self.isIndex() || $_self.isSignlessInteger(32)">, + "index or i32">; //===----------------------------------------------------------------------===// // Op Class @@ -219,8 +219,8 @@ def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$addr, - Optional:$valid_row, - Optional:$valid_col + Optional:$valid_row, + Optional:$valid_col ); let results = (outs TileBufType:$result); @@ -319,8 +319,8 @@ def SetValidShapeOp : PTO_Op<"set_validshape", [ let arguments = (ins TileBufOrMemRef:$source, - IndexOrU32:$valid_row, - IndexOrU32:$valid_col + IndexOrI32:$valid_row, + IndexOrI32:$valid_col ); let hasVerifier = 1; @@ -386,7 +386,7 @@ def TLoadOp : PTO_TOp<"tload", [ PTODpsType:$dst, OptionalAttr:$pad_mode, Optional:$pad_value, - Optional:$left_padding_num, + Optional:$left_padding_num, Optional:$right_padding_num, DefaultValuedOptionalAttr:$init_out_buffer, Optional:$init_condition @@ -1373,7 +1373,7 @@ def TSetValOp : PTO_TOp<"tsetval", [ let arguments = (ins PTODpsType:$dst, - IndexOrU32:$offset, + IndexOrI32:$offset, ScalarType:$val ); @@ -1404,7 +1404,7 @@ def TGetValOp : PTO_TOp<"tgetval", [ let arguments = (ins PTODpsType:$src, - IndexOrU32:$offset + IndexOrI32:$offset ); let results = (outs ScalarType:$dst); @@ -2156,8 +2156,8 @@ def TExtractOp : PTO_TOp<"textract", [ let arguments = (ins PTODpsType:$src, - IndexOrU32:$indexRow, - IndexOrU32:$indexCol, + IndexOrI32:$indexRow, + IndexOrI32:$indexCol, PTODpsType:$dst ); @@ -2188,8 +2188,8 @@ def TInsertOp : PTO_TOp<"tinsert", [ let arguments = (ins PTODpsType:$src, - IndexOrU32:$indexRow, - IndexOrU32:$indexCol, + IndexOrI32:$indexRow, + IndexOrI32:$indexCol, PTODpsType:$dst ); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 94c53945..5290f4c0 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -342,9 +342,8 @@ static bool isIndexOrI64Type(Type type) { return type && (type.isIndex() || type.isSignlessInteger(64)); } -static bool isIndexOrU32Type(Type type) { - return type && (type.isIndex() || type.isUnsignedInteger(32) || - type.isSignlessInteger(32)); +static bool isIndexOrI32Type(Type type) { + return type && (type.isIndex() || type.isSignlessInteger(32)); } template @@ -502,9 +501,9 @@ ParseResult mlir::pto::StoreScalarOp::parse(OpAsmParser &parser, valueTy = secondTy; } - if (parser.resolveOperand(value, valueTy, result.operands) || - parser.resolveOperand(ptr, ptrTy, result.operands) || - parser.resolveOperand(offset, compatTy, result.operands)) + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, compatTy, result.operands) || + parser.resolveOperand(value, valueTy, result.operands)) return failure(); return success(); } @@ -1355,8 +1354,8 @@ ParseResult mlir::pto::AllocTileOp::parse(OpAsmParser &parser, } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(resultTy) || - parseOptionalCompatibleType(parser, compatTy, isIndexOrU32Type, - "index or u32-compatible operand type")) + parseOptionalCompatibleType(parser, compatTy, isIndexOrI32Type, + "index or i32 operand type")) return failure(); result.addTypes(resultTy); @@ -1586,7 +1585,7 @@ LogicalResult AllocTileOp::verify() { if (getValidCol()) compatOperands.push_back(getValidCol()); if (failed(verifyUniformCompatibleOperandTypes( - getOperation(), ValueRange(compatOperands), isIndexOrU32Type, + getOperation(), ValueRange(compatOperands), isIndexOrI32Type, "valid_row/valid_col operands"))) return failure(); @@ -3306,9 +3305,9 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { return std::nullopt; }; auto verifyIndexOperands = [&]() -> LogicalResult { - if (!isIndexOrU32Type(getIndexRow().getType()) || - !isIndexOrU32Type(getIndexCol().getType())) - return emitOpError("expects indexRow and indexCol to be index or u32-compatible type"); + if (!isIndexOrI32Type(getIndexRow().getType()) || + !isIndexOrI32Type(getIndexCol().getType())) + return emitOpError("expects indexRow and indexCol to be index or i32 type"); auto row = getConstIndex(getIndexRow()); auto col = getConstIndex(getIndexCol()); if (row && *row < 0) @@ -3473,9 +3472,9 @@ mlir::LogicalResult mlir::pto::TInsertOp::verify() { return emitOpError( "expects src/dst element types to match, or src=f32 with dst=f16/bf16"); - if (!isIndexOrU32Type(getIndexRow().getType()) || - !isIndexOrU32Type(getIndexCol().getType())) - return emitOpError("expects indexRow/indexCol to be index or u32-compatible type"); + if (!isIndexOrI32Type(getIndexRow().getType()) || + !isIndexOrI32Type(getIndexCol().getType())) + return emitOpError("expects indexRow/indexCol to be index or i32 type"); auto readConstIndex = [&](Value v, int64_t &out) -> bool { if (auto cOp = v.getDefiningOp()) { @@ -5193,8 +5192,8 @@ ParseResult mlir::pto::SetValidShapeOp::parse(OpAsmParser &parser, parser.parseOperand(validRow) || parser.parseComma() || parser.parseOperand(validCol) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(sourceTy) || - parseOptionalCompatibleType(parser, compatTy, isIndexOrU32Type, - "index or u32-compatible operand type")) + parseOptionalCompatibleType(parser, compatTy, isIndexOrI32Type, + "index or i32 operand type")) return failure(); if (parser.resolveOperand(source, sourceTy, result.operands) || @@ -5237,7 +5236,7 @@ mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { SmallVector compatOperands = {getValidRow(), getValidCol()}; if (failed(verifyUniformCompatibleOperandTypes( getOperation(), ValueRange(compatOperands), - isIndexOrU32Type, "valid_row/valid_col operands"))) + isIndexOrI32Type, "valid_row/valid_col operands"))) return failure(); SmallVector shape; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index dd1f31b9..a105b334 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1240,7 +1240,15 @@ struct PTOViewToMemrefPass Value dst = op->getOperand(1); auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); + rewriter.create(op.getLoc(), TypeRange{}, + src, + dst, + op.getPadModeAttr(), + op.getPadValue(), + op.getLeftPaddingNum(), + op.getRightPaddingNum(), + op.getInitOutBuffer(), + op.getInitCondition()); newOp->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newOp->getResults()); } @@ -2122,10 +2130,8 @@ struct PTOViewToMemrefPass auto srcTy = dyn_cast(src.getType()); auto dstTy = dyn_cast(dst.getType()); bool rowTyOk = indexRow.getType().isIndex() || - indexRow.getType().isUnsignedInteger(32) || indexRow.getType().isSignlessInteger(32); bool colTyOk = indexCol.getType().isIndex() || - indexCol.getType().isUnsignedInteger(32) || indexCol.getType().isSignlessInteger(32); if (!srcTy || !dstTy || !rowTyOk || !colTyOk) { op.emitError("ins/outs are not correct yet"); diff --git a/test/basic/index_input_compat_i32.pto b/test/basic/index_input_compat_i32.pto new file mode 100644 index 00000000..49a96e54 --- /dev/null +++ b/test/basic/index_input_compat_i32.pto @@ -0,0 +1,70 @@ +// RUN: ptoas %s | FileCheck %s + +module { + func.func @index_input_compat_i32(%src: !pto.ptr, %vr: i32, %vc: i32, + %lp: i32, %off: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %view = pto.make_tensor_view %src, + shape = [%c32, %c32], + strides = [%c32, %c1] + : !pto.tensor_view<32x32xf32> + %part = pto.partition_view %view, + offsets = [%c0, %c0], + sizes = [%c32, %c32] + : !pto.tensor_view<32x32xf32> + -> !pto.partition_tensor_view<32x32xf32> + + %tile = pto.alloc_tile valid_row = %vr valid_col = %vc + : !pto.tile_buf, i32 + pto.tload ins(%part : !pto.partition_tensor_view<32x32xf32>) + outs(%tile : !pto.tile_buf) + left_padding_num = %lp : i32 + + %scalar = pto.tgetval ins(%tile, %off : !pto.tile_buf, i32) + outs : f32 + pto.tsetval ins(%off, %scalar : i32, f32) + outs(%tile : !pto.tile_buf) + + pto.set_validshape %tile, %vr, %vc + : !pto.tile_buf, i32 + return + } + + func.func @index_input_compat_i32_extract_insert() { + %c0_i32 = arith.constant 0 : i32 + %src_mat = pto.alloc_tile + : !pto.tile_buf + %dst_left = pto.alloc_tile + : !pto.tile_buf + pto.textract ins(%src_mat, %c0_i32, %c0_i32 + : !pto.tile_buf, i32, i32) + outs(%dst_left + : !pto.tile_buf) + + %src_vec = memref.alloc() : memref<32x32xf16, #pto.address_space> + %dst_vec = memref.alloc() : memref<32x32xf16, #pto.address_space> + pto.tinsert ins(%src_vec, %c0_i32, %c0_i32 : memref<32x32xf16, #pto.address_space>, i32, i32) + outs(%dst_vec : memref<32x32xf16, #pto.address_space>) + return + } +} + +// CHECK-LABEL: AICORE void index_input_compat_i32( +// CHECK: [[TILE:v[0-9]+]].GetValue( +// CHECK: [[TILE]].SetValue( +// CHECK: [[TILE]].SetValidShape( +// CHECK-LABEL: AICORE void index_input_compat_i32_extract_insert( +// CHECK: TEXTRACT( +// CHECK: TINSERT( diff --git a/test/basic/index_input_compat_i64.pto b/test/basic/index_input_compat_i64.pto new file mode 100644 index 00000000..b3ac6bbe --- /dev/null +++ b/test/basic/index_input_compat_i64.pto @@ -0,0 +1,52 @@ +// RUN: ptoas %s | FileCheck %s + +module { + func.func @index_input_compat_i64(%src: !pto.ptr, %dst: !pto.ptr, %off: i64) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %zero = arith.constant 0.0 : f32 + + %ptr = pto.addptr %src, %off : !pto.ptr, i64 -> !pto.ptr + %val = pto.load_scalar %ptr[%off] : !pto.ptr, i64 -> f32 + pto.store_scalar %val, %dst[%off] : !pto.ptr, i64, f32 + + %view = pto.make_tensor_view %src, + shape = [%c32_i64, %c32_i64], + strides = [%c32_i64, %c1_i64] + : !pto.tensor_view<32x32xf32>, i64 + %part_i64 = pto.partition_view %view, + offsets = [%c0_i64, %c0_i64], + sizes = [%c32_i64, %c32_i64] + : !pto.tensor_view<32x32xf32>, i64 + -> !pto.partition_tensor_view<32x32xf32> + + %dim0 = pto.get_tensor_view_dim %view, %c0_i64 : !pto.tensor_view<32x32xf32>, i64 -> index + %part_dim = pto.partition_view %view, + offsets = [%c0, %c0], + sizes = [%dim0, %dim0] + : !pto.tensor_view<32x32xf32> + -> !pto.partition_tensor_view + + pto.store_scalar %zero, %dst[%dim0] : !pto.ptr, f32 + + %tile0 = pto.alloc_tile : !pto.tile_buf + %tile1 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%part_i64 : !pto.partition_tensor_view<32x32xf32>) + outs(%tile0 : !pto.tile_buf) + pto.tload ins(%part_dim : !pto.partition_tensor_view) + outs(%tile1 : !pto.tile_buf) + return + } +} + +// CHECK-LABEL: __global__ AICORE void index_input_compat_i64( +// CHECK-SAME: int64_t +// CHECK: v2[v3] = +// CHECK: GlobalTensor +// CHECK: TLOAD(