diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index ca24f09e..0a9ffb64 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -54,6 +54,16 @@ def ScalarPtrOrMemRef : def ScalarType : AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">; +def IndexOrI64 : + Type< + CPred<"$_self.isIndex() || $_self.isSignlessInteger(64)">, + "index or i64">; + +def IndexOrI32 : + Type< + CPred<"$_self.isIndex() || $_self.isSignlessInteger(32)">, + "index or i32">; + //===----------------------------------------------------------------------===// // Op Class //===----------------------------------------------------------------------===// @@ -89,16 +99,14 @@ def AddPtrOp : PTO_Op<"addptr", [ let arguments = (ins PtrType:$ptr, - Index:$offset + IndexOrI64:$offset ); let results = (outs PtrType:$result); let hasVerifier = 1; - let assemblyFormat = [{ - $ptr `,` $offset attr-dict `:` type($ptr) `->` type($result) - }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -112,16 +120,14 @@ def LoadScalarOp : PTO_Op<"load_scalar", [ let arguments = (ins ScalarPtrOrMemRef:$ptr, - Index:$offset + IndexOrI64:$offset ); let results = (outs AnyType:$value); let hasVerifier = 1; - let assemblyFormat = [{ - $ptr `[` $offset `]` attr-dict `:` type($ptr) `->` type($value) - }]; + let hasCustomAssemblyFormat = 1; } def StoreScalarOp : PTO_Op<"store_scalar", [ @@ -131,7 +137,7 @@ def StoreScalarOp : PTO_Op<"store_scalar", [ let arguments = (ins ScalarPtrOrMemRef:$ptr, - Index:$offset, + IndexOrI64:$offset, AnyType:$value ); @@ -139,9 +145,7 @@ def StoreScalarOp : PTO_Op<"store_scalar", [ let hasVerifier = 1; - let assemblyFormat = [{ - $value `,` $ptr `[` $offset `]` attr-dict `:` type($ptr) `,` type($value) - }]; + let hasCustomAssemblyFormat = 1; } def MakeTensorViewOp : PTO_Op<"make_tensor_view", [AttrSizedOperandSegments]> { @@ -149,8 +153,8 @@ def MakeTensorViewOp : PTO_Op<"make_tensor_view", [AttrSizedOperandSegments]> { let arguments = (ins AnyType:$ptr, - Variadic:$shape, - Variadic:$strides, + Variadic:$shape, + Variadic:$strides, OptionalAttr:$layout ); @@ -176,18 +180,15 @@ def PartitionViewOp : PTO_Op<"partition_view", [AttrSizedOperandSegments]> { let arguments = (ins TensorViewType:$source, // 输入: 物理大底座 (MakeTensorViewOp 的结果) - Variadic:$offsets, // 动态 offsets - Variadic:$sizes // 动态 sizes + Variadic:$offsets, // 动态 offsets + Variadic:$sizes // 动态 sizes ); let results = (outs PartitionTensorViewType:$result); // 输出: 逻辑切片 let hasVerifier = 1; - let assemblyFormat = [{ - $source `,` `offsets` `=` `[` $offsets `]` `,` `sizes` `=` `[` $sizes `]` - attr-dict `:` qualified(type($source)) `->` qualified(type($result)) - }]; + let hasCustomAssemblyFormat = 1; } // Helper: tensor_view or memref (after lowering tensor_view to memref). @@ -210,13 +211,10 @@ def GetTensorViewDimOp : PTO_Op<"get_tensor_view_dim", [Pure]> { }]; let arguments = (ins TensorViewOrMemRef:$tensor_view, - Index:$dim_index + IndexOrI64:$dim_index ); let results = (outs Index:$result); - let assemblyFormat = [{ - $tensor_view `,` $dim_index `:` qualified(type($tensor_view)) `->` qualified(type($result)) - attr-dict - }]; + let hasCustomAssemblyFormat = 1; } def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { @@ -224,18 +222,13 @@ def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$addr, - Optional:$valid_row, - Optional:$valid_col + Optional:$valid_row, + Optional:$valid_col ); let results = (outs TileBufType:$result); - let assemblyFormat = [{ - (`addr` `=` $addr^)? - (`valid_row` `=` $valid_row^)? - (`valid_col` `=` $valid_col^)? - attr-dict `:` qualified(type($result)) - }]; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ ::mlir::LogicalResult verify(); @@ -329,16 +322,13 @@ def SetValidShapeOp : PTO_Op<"set_validshape", [ let arguments = (ins TileBufOrMemRef:$source, - Index:$valid_row, - Index:$valid_col + IndexOrI32:$valid_row, + IndexOrI32:$valid_col ); let hasVerifier = 1; - let assemblyFormat = [{ - $source `,` $valid_row `,` $valid_col attr-dict `:` - qualified(type($source)) - }]; + let hasCustomAssemblyFormat = 1; } // ============================================================================ @@ -399,7 +389,7 @@ def TLoadOp : PTO_TOp<"tload", [ PTODpsType:$dst, OptionalAttr:$pad_mode, Optional:$pad_value, - Optional:$left_padding_num, + Optional:$left_padding_num, Optional:$right_padding_num, DefaultValuedOptionalAttr:$init_out_buffer, Optional:$init_condition @@ -2017,7 +2007,7 @@ def TSetValOp : PTO_TOp<"tsetval", [ let arguments = (ins PTODpsType:$dst, - Index:$offset, + IndexOrI32:$offset, ScalarType:$val ); @@ -2048,7 +2038,7 @@ def TGetValOp : PTO_TOp<"tgetval", [ let arguments = (ins PTODpsType:$src, - Index:$offset + IndexOrI32:$offset ); let results = (outs ScalarType:$dst); @@ -2858,8 +2848,8 @@ def TExtractOp : PTO_TOp<"textract", [ let arguments = (ins PTODpsType:$src, - Index:$indexRow, - Index:$indexCol, + IndexOrI32:$indexRow, + IndexOrI32:$indexCol, PTODpsType:$dst ); @@ -2943,8 +2933,8 @@ def TInsertOp : PTO_TOp<"tinsert", [ let arguments = (ins PTODpsType:$src, - Index:$indexRow, - Index:$indexCol, + IndexOrI32:$indexRow, + IndexOrI32:$indexCol, PTODpsType:$dst ); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index ea87ca9b..8c413f22 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -383,6 +383,220 @@ static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { return mlir::Type(); } +static bool isIndexOrI64Type(Type type) { + return type && (type.isIndex() || type.isSignlessInteger(64)); +} + +static bool isIndexOrI32Type(Type type) { + return type && (type.isIndex() || type.isSignlessInteger(32)); +} + +template +static ParseResult parseOptionalCompatibleType(OpAsmParser &parser, Type &type, + Pred &&isSupportedType, + StringRef expectedDesc) { + type = parser.getBuilder().getIndexType(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseType(type)) + return failure(); + if (!isSupportedType(type)) + return parser.emitError(parser.getCurrentLocation()) + << "expected " << expectedDesc; + } + return success(); +} + +template +static LogicalResult +verifyUniformCompatibleOperandTypes(Operation *op, ValueRange values, + Pred &&isSupportedType, StringRef groupName) { + Type nonIndexType; + for (Value value : values) { + if (!value) + continue; + Type type = value.getType(); + if (!isSupportedType(type)) { + return op->emitOpError() << "expects " << groupName + << " to use compatible integer/index types"; + } + if (type.isIndex()) + continue; + if (nonIndexType && nonIndexType != type) { + return op->emitOpError() << "expects " << groupName + << " to use a uniform non-index type"; + } + nonIndexType = type; + } + return success(); +} + +static Type getCompatibleOperandTypeOrIndex(MLIRContext *ctx, ValueRange values) { + Type nonIndexType; + for (Value value : values) { + if (!value) + continue; + Type type = value.getType(); + if (type.isIndex()) + continue; + if (!nonIndexType) + nonIndexType = type; + } + return nonIndexType ? nonIndexType : IndexType::get(ctx); +} + +static void printCompatibleTypeSuffix(OpAsmPrinter &printer, Type type) { + if (type && !type.isIndex()) + printer << ", " << type; +} + +ParseResult mlir::pto::AddPtrOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type offsetTy; + SmallVector resultTypes; + + if (parser.parseOperand(ptr) || parser.parseComma() || + parser.parseOperand(offset) || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(ptrTy) || + parseOptionalCompatibleType(parser, offsetTy, isIndexOrI64Type, + "index or i64 offset type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, offsetTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::AddPtrOp::print(OpAsmPrinter &printer) { + printer << " " << getPtr() << ", " << getOffset(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + printCompatibleTypeSuffix(printer, getOffset().getType()); + printer << " -> " << getResult().getType(); +} + +ParseResult mlir::pto::LoadScalarOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type offsetTy; + SmallVector resultTypes; + + if (parser.parseOperand(ptr) || parser.parseLSquare() || + parser.parseOperand(offset) || parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(ptrTy) || + parseOptionalCompatibleType(parser, offsetTy, isIndexOrI64Type, + "index or i64 offset type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, offsetTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::LoadScalarOp::print(OpAsmPrinter &printer) { + printer << " " << getPtr() << "[" << getOffset() << "]"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + printCompatibleTypeSuffix(printer, getOffset().getType()); + printer << " -> " << getValue().getType(); +} + +ParseResult mlir::pto::StoreScalarOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand value; + OpAsmParser::UnresolvedOperand ptr; + OpAsmParser::UnresolvedOperand offset; + Type ptrTy; + Type valueTy; + Type compatTy = parser.getBuilder().getIndexType(); + Type secondTy; + + if (parser.parseOperand(value) || parser.parseComma() || parser.parseOperand(ptr) || + parser.parseLSquare() || parser.parseOperand(offset) || parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(ptrTy) || + parser.parseComma() || parser.parseType(secondTy)) + return failure(); + + if (succeeded(parser.parseOptionalComma())) { + compatTy = secondTy; + if (!isIndexOrI64Type(compatTy)) + return parser.emitError(parser.getCurrentLocation()) + << "expected index or i64 offset type"; + if (parser.parseType(valueTy)) + return failure(); + } else { + valueTy = secondTy; + } + + if (parser.resolveOperand(ptr, ptrTy, result.operands) || + parser.resolveOperand(offset, compatTy, result.operands) || + parser.resolveOperand(value, valueTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::StoreScalarOp::print(OpAsmPrinter &printer) { + printer << " " << getValue() << ", " << getPtr() << "[" << getOffset() << "]"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getPtr().getType(); + if (!getOffset().getType().isIndex()) + printer << ", " << getOffset().getType(); + printer << ", " << getValue().getType(); +} + +ParseResult mlir::pto::GetTensorViewDimOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand tensorView; + OpAsmParser::UnresolvedOperand dimIndex; + Type tensorViewTy; + Type dimIndexTy; + SmallVector resultTypes; + + if (parser.parseOperand(tensorView) || parser.parseComma() || + parser.parseOperand(dimIndex) || parser.parseColonType(tensorViewTy) || + parseOptionalCompatibleType(parser, dimIndexTy, isIndexOrI64Type, + "index or i64 dim type") || + parser.parseArrowTypeList(resultTypes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(tensorView, tensorViewTy, result.operands) || + parser.resolveOperand(dimIndex, dimIndexTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::GetTensorViewDimOp::print(OpAsmPrinter &printer) { + printer << " " << getTensorView() << ", " << getDimIndex() << " : " + << getTensorView().getType(); + printCompatibleTypeSuffix(printer, getDimIndex().getType()); + printer << " -> " << getResult().getType(); + printer.printOptionalAttrDict((*this)->getAttrs()); +} + mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { SmallVector shape; Type elementType; @@ -706,6 +920,7 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, SmallVector strideOps; Type resultTy; + Type operandTy; // %ptr if (parser.parseOperand(ptr)) @@ -730,7 +945,9 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, return failure(); // : result-type - if (parser.parseColonType(resultTy)) + if (parser.parseColonType(resultTy) || + parseOptionalCompatibleType(parser, operandTy, isIndexOrI64Type, + "index or i64 operand type")) return failure(); result.addTypes(resultTy); @@ -747,11 +964,9 @@ ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, if (parser.resolveOperand(ptr, ptrTy, result.operands)) return failure(); - // resolve shape/strides 为 index - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(shapeOps, indexTy, result.operands)) + if (parser.resolveOperands(shapeOps, operandTy, result.operands)) return failure(); - if (parser.resolveOperands(strideOps, indexTy, result.operands)) + if (parser.resolveOperands(strideOps, operandTy, result.operands)) return failure(); auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( @@ -776,6 +991,68 @@ void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { /*elidedAttrs=*/{"operandSegmentSizes"}); p << " : " << getResult().getType(); + SmallVector compatOperands; + compatOperands.reserve(getShape().size() + getStrides().size()); + compatOperands.append(getShape().begin(), getShape().end()); + compatOperands.append(getStrides().begin(), getStrides().end()); + printCompatibleTypeSuffix( + p, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + +ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector sizes; + Type sourceTy; + Type operandTy; + SmallVector resultTypes; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseKeyword("offsets") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(offsets) || + parser.parseRSquare() || parser.parseComma() || + parser.parseKeyword("sizes") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(sizes) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy) || + parseOptionalCompatibleType(parser, operandTy, isIndexOrI64Type, + "index or i64 operand type") || + parser.parseArrowTypeList(resultTypes)) + return failure(); + + if (resultTypes.size() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one result type"); + + result.addTypes(resultTypes); + if (parser.resolveOperand(source, sourceTy, result.operands) || + parser.resolveOperands(offsets, operandTy, result.operands) || + parser.resolveOperands(sizes, operandTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {1, static_cast(offsets.size()), static_cast(sizes.size())}); + result.addAttribute("operandSegmentSizes", segAttr); + return success(); +} + +void mlir::pto::PartitionViewOp::print(OpAsmPrinter &p) { + p << " " << getSource() << ", offsets = ["; + p.printOperands(getOffsets()); + p << "], sizes = ["; + p.printOperands(getSizes()); + p << "]"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + p << " : " << getSource().getType(); + SmallVector compatOperands; + compatOperands.reserve(getOffsets().size() + getSizes().size()); + compatOperands.append(getOffsets().begin(), getOffsets().end()); + compatOperands.append(getSizes().begin(), getSizes().end()); + printCompatibleTypeSuffix( + p, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); + p << " -> " << getResult().getType(); } // Layout inference helpers for make_tensor_view @@ -1046,6 +1323,15 @@ static std::optional getConstantIntegerValue(Value value) { } LogicalResult mlir::pto::MakeTensorViewOp::verify() { + SmallVector compatOperands; + compatOperands.reserve(getShape().size() + getStrides().size()); + compatOperands.append(getShape().begin(), getShape().end()); + compatOperands.append(getStrides().begin(), getStrides().end()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrI64Type, + "shape/strides operands"))) + return failure(); + auto tvTy = dyn_cast(getResult().getType()); if (!tvTy) return emitOpError("result must be pto.tensor_view<...>"); @@ -1107,6 +1393,15 @@ LogicalResult mlir::pto::MakeTensorViewOp::verify() { } LogicalResult mlir::pto::PartitionViewOp::verify() { + SmallVector compatOperands; + compatOperands.reserve(getOffsets().size() + getSizes().size()); + compatOperands.append(getOffsets().begin(), getOffsets().end()); + compatOperands.append(getSizes().begin(), getSizes().end()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrI64Type, + "offset/size operands"))) + return failure(); + auto srcTy = dyn_cast(getSource().getType()); auto resTy = dyn_cast(getResult().getType()); if (!srcTy || !resTy) @@ -1182,6 +1477,72 @@ LogicalResult mlir::pto::AddPtrOp::verify() { return success(); } +ParseResult mlir::pto::AllocTileOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand addr; + OpAsmParser::UnresolvedOperand validRow; + OpAsmParser::UnresolvedOperand validCol; + bool hasAddr = false; + bool hasValidRow = false; + bool hasValidCol = false; + Type compatTy; + Type resultTy; + + if (succeeded(parser.parseOptionalKeyword("addr"))) { + if (parser.parseEqual() || parser.parseOperand(addr)) + return failure(); + hasAddr = true; + } + if (succeeded(parser.parseOptionalKeyword("valid_row"))) { + if (parser.parseEqual() || parser.parseOperand(validRow)) + return failure(); + hasValidRow = true; + } + if (succeeded(parser.parseOptionalKeyword("valid_col"))) { + if (parser.parseEqual() || parser.parseOperand(validCol)) + return failure(); + hasValidCol = true; + } + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultTy) || + parseOptionalCompatibleType(parser, compatTy, isIndexOrI32Type, + "index or i32 operand type")) + return failure(); + + result.addTypes(resultTy); + if (hasAddr && + parser.resolveOperand(addr, parser.getBuilder().getI64Type(), result.operands)) + return failure(); + if (hasValidRow && parser.resolveOperand(validRow, compatTy, result.operands)) + return failure(); + if (hasValidCol && parser.resolveOperand(validCol, compatTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {hasAddr ? 1 : 0, hasValidRow ? 1 : 0, hasValidCol ? 1 : 0}); + result.addAttribute("operandSegmentSizes", segAttr); + return success(); +} + +void mlir::pto::AllocTileOp::print(OpAsmPrinter &printer) { + if (Value addr = getAddr()) + printer << " addr = " << addr; + if (Value validRow = getValidRow()) + printer << " valid_row = " << validRow; + if (Value validCol = getValidCol()) + printer << " valid_col = " << validCol; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + printer << " : " << getResult().getType(); + SmallVector compatOperands; + if (getValidRow()) + compatOperands.push_back(getValidRow()); + if (getValidCol()) + compatOperands.push_back(getValidCol()); + printCompatibleTypeSuffix( + printer, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + @@ -1369,6 +1730,16 @@ static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src LogicalResult AllocTileOp::verify() { auto ty = getResult().getType(); // TileBufType + SmallVector compatOperands; + if (getValidRow()) + compatOperands.push_back(getValidRow()); + if (getValidCol()) + compatOperands.push_back(getValidCol()); + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), isIndexOrI32Type, + "valid_row/valid_col operands"))) + return failure(); + // op 上有没有传 operands bool hasVR = getValidRow() != nullptr; bool hasVC = getValidCol() != nullptr; @@ -3263,6 +3634,10 @@ mlir::LogicalResult mlir::pto::TExpandsOp::verify() { mlir::LogicalResult mlir::pto::TExtractOp::verify() { auto getConstIndex = [&](Value v) -> std::optional { + if (auto cst = v.getDefiningOp()) + return cst.value(); + if (auto cst = v.getDefiningOp()) + return cst.value(); auto cst = v.getDefiningOp(); if (!cst) return std::nullopt; @@ -3271,8 +3646,9 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { return std::nullopt; }; auto verifyIndexOperands = [&]() -> LogicalResult { - if (!getIndexRow().getType().isIndex() || !getIndexCol().getType().isIndex()) - return emitOpError("expects indexRow and indexCol to be index type"); + if (!isIndexOrI32Type(getIndexRow().getType()) || + !isIndexOrI32Type(getIndexCol().getType())) + return emitOpError("expects indexRow and indexCol to be index or i32 type"); auto row = getConstIndex(getIndexRow()); auto col = getConstIndex(getIndexCol()); if (row && *row < 0) @@ -3429,14 +3805,43 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } mlir::LogicalResult mlir::pto::TInsertOp::verify() { - auto getConstIndex = [&](Value v) -> std::optional { - if (auto cst = v.getDefiningOp()) - return cst.value(); - if (auto cst = v.getDefiningOp()) - return cst.value(); - if (auto cst = v.getDefiningOp()) { - if (auto attr = mlir::dyn_cast(cst.getValue())) - return attr.getInt(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) + return emitOpError("expects src/dst to be PTO shaped-like types"); + + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return emitOpError("expects rank-2 shaped types for src/dst"); + + Type srcElemTy = getElemTy(srcTy); + Type dstElemTy = getElemTy(dstTy); + bool sameElemTy = srcElemTy == dstElemTy; + bool castElemTy = + srcElemTy.isF32() && (dstElemTy.isF16() || dstElemTy.isBF16()); + if (!sameElemTy && !castElemTy) + return emitOpError( + "expects src/dst element types to match, or src=f32 with dst=f16/bf16"); + + if (!isIndexOrI32Type(getIndexRow().getType()) || + !isIndexOrI32Type(getIndexCol().getType())) + return emitOpError("expects indexRow/indexCol to be index or i32 type"); + + auto readConstIndex = [&](Value v, int64_t &out) -> bool { + if (auto cOp = v.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = v.getDefiningOp()) { + out = cInt.value(); + return true; + } + if (auto cOp = v.getDefiningOp()) { + if (auto ia = mlir::dyn_cast(cOp.getValue())) { + out = ia.getInt(); + return true; + } } return std::nullopt; }; @@ -5568,6 +5973,38 @@ static bool isLocallyBoundTileSource(Value value) { return false; } +ParseResult mlir::pto::SetValidShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + OpAsmParser::UnresolvedOperand validRow; + OpAsmParser::UnresolvedOperand validCol; + Type sourceTy; + Type compatTy; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(validRow) || parser.parseComma() || + parser.parseOperand(validCol) || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy) || + parseOptionalCompatibleType(parser, compatTy, isIndexOrI32Type, + "index or i32 operand type")) + return failure(); + + if (parser.resolveOperand(source, sourceTy, result.operands) || + parser.resolveOperand(validRow, compatTy, result.operands) || + parser.resolveOperand(validCol, compatTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::SetValidShapeOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getValidRow() << ", " << getValidCol(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType(); + SmallVector compatOperands = {getValidRow(), getValidCol()}; + printCompatibleTypeSuffix( + printer, getCompatibleOperandTypeOrIndex(getContext(), ValueRange(compatOperands))); +} + static std::optional getConstIndexLike(Value v) { if (auto cOp = v.getDefiningOp()) return cOp.value(); @@ -5589,6 +6026,12 @@ static std::optional getConstIndexLike(Value v) { } mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { + SmallVector compatOperands = {getValidRow(), getValidCol()}; + if (failed(verifyUniformCompatibleOperandTypes( + getOperation(), ValueRange(compatOperands), + isIndexOrI32Type, "valid_row/valid_col operands"))) + return failure(); + SmallVector shape; if (auto srcTy = llvm::dyn_cast(getSource().getType())) { if (srcTy.getRank() != 2) diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 07254e42..8d09d4f7 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -616,6 +616,8 @@ struct PTOViewToMemrefPass // 5. 获取 Config (保持不变) auto configAttr = tbTy.getConfigAttr(); if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + Value bindVRow = vRow ? ensureIndex(rewriter, loc, vRow, op) : Value(); + Value bindVCol = vCol ? ensureIndex(rewriter, loc, vCol, op) : Value(); // 6. If alloc_tile provides an explicit address, keep the original // pointer_cast lowering intact and additionally rebind through @@ -624,8 +626,7 @@ struct PTOViewToMemrefPass // unified anchor EmitC uses to recover tile_buf information. if (Value addr = op.getAddr()) { auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); + loc, targetType, ValueRange{addr}, bindVRow, bindVCol, configAttr); markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); auto bindOp = rewriter.create( loc, targetType, pc.getResult(), vRow ? vRow : Value(), @@ -643,8 +644,7 @@ struct PTOViewToMemrefPass // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); + loc, targetType, alloc, bindVRow, bindVCol, configAttr); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); @@ -815,7 +815,7 @@ struct PTOViewToMemrefPass if (!mrTy) continue; // leave it to later passes if it hasn't been lowered yet - Value dimIdx = op.getDimIndex(); + Value dimIdx = ensureIndex(rewriter, loc, op.getDimIndex(), op); Value dim = rewriter.create(loc, view, dimIdx); rewriter.replaceOp(op, dim); } @@ -1324,7 +1324,15 @@ struct PTOViewToMemrefPass Value dst = op->getOperand(1); auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); + rewriter.create(op.getLoc(), TypeRange{}, + src, + dst, + op.getPadModeAttr(), + op.getPadValue(), + op.getLeftPaddingNum(), + op.getRightPaddingNum(), + op.getInitOutBuffer(), + op.getInitCondition()); newOp->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newOp->getResults()); } @@ -2239,10 +2247,12 @@ struct PTOViewToMemrefPass Value dst = op.getDst(); auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + bool rowTyOk = indexRow.getType().isIndex() || + indexRow.getType().isSignlessInteger(32); + bool colTyOk = indexCol.getType().isIndex() || + indexCol.getType().isSignlessInteger(32); + if (!srcTy || !dstTy || !rowTyOk || !colTyOk) { op.emitError("ins/outs are not correct yet"); signalPassFailure(); return; diff --git a/test/basic/index_input_compat_i32.pto b/test/basic/index_input_compat_i32.pto new file mode 100644 index 00000000..49a96e54 --- /dev/null +++ b/test/basic/index_input_compat_i32.pto @@ -0,0 +1,70 @@ +// RUN: ptoas %s | FileCheck %s + +module { + func.func @index_input_compat_i32(%src: !pto.ptr, %vr: i32, %vc: i32, + %lp: i32, %off: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %view = pto.make_tensor_view %src, + shape = [%c32, %c32], + strides = [%c32, %c1] + : !pto.tensor_view<32x32xf32> + %part = pto.partition_view %view, + offsets = [%c0, %c0], + sizes = [%c32, %c32] + : !pto.tensor_view<32x32xf32> + -> !pto.partition_tensor_view<32x32xf32> + + %tile = pto.alloc_tile valid_row = %vr valid_col = %vc + : !pto.tile_buf, i32 + pto.tload ins(%part : !pto.partition_tensor_view<32x32xf32>) + outs(%tile : !pto.tile_buf) + left_padding_num = %lp : i32 + + %scalar = pto.tgetval ins(%tile, %off : !pto.tile_buf, i32) + outs : f32 + pto.tsetval ins(%off, %scalar : i32, f32) + outs(%tile : !pto.tile_buf) + + pto.set_validshape %tile, %vr, %vc + : !pto.tile_buf, i32 + return + } + + func.func @index_input_compat_i32_extract_insert() { + %c0_i32 = arith.constant 0 : i32 + %src_mat = pto.alloc_tile + : !pto.tile_buf + %dst_left = pto.alloc_tile + : !pto.tile_buf + pto.textract ins(%src_mat, %c0_i32, %c0_i32 + : !pto.tile_buf, i32, i32) + outs(%dst_left + : !pto.tile_buf) + + %src_vec = memref.alloc() : memref<32x32xf16, #pto.address_space> + %dst_vec = memref.alloc() : memref<32x32xf16, #pto.address_space> + pto.tinsert ins(%src_vec, %c0_i32, %c0_i32 : memref<32x32xf16, #pto.address_space>, i32, i32) + outs(%dst_vec : memref<32x32xf16, #pto.address_space>) + return + } +} + +// CHECK-LABEL: AICORE void index_input_compat_i32( +// CHECK: [[TILE:v[0-9]+]].GetValue( +// CHECK: [[TILE]].SetValue( +// CHECK: [[TILE]].SetValidShape( +// CHECK-LABEL: AICORE void index_input_compat_i32_extract_insert( +// CHECK: TEXTRACT( +// CHECK: TINSERT( diff --git a/test/basic/index_input_compat_i64.pto b/test/basic/index_input_compat_i64.pto new file mode 100644 index 00000000..b3ac6bbe --- /dev/null +++ b/test/basic/index_input_compat_i64.pto @@ -0,0 +1,52 @@ +// RUN: ptoas %s | FileCheck %s + +module { + func.func @index_input_compat_i64(%src: !pto.ptr, %dst: !pto.ptr, %off: i64) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %zero = arith.constant 0.0 : f32 + + %ptr = pto.addptr %src, %off : !pto.ptr, i64 -> !pto.ptr + %val = pto.load_scalar %ptr[%off] : !pto.ptr, i64 -> f32 + pto.store_scalar %val, %dst[%off] : !pto.ptr, i64, f32 + + %view = pto.make_tensor_view %src, + shape = [%c32_i64, %c32_i64], + strides = [%c32_i64, %c1_i64] + : !pto.tensor_view<32x32xf32>, i64 + %part_i64 = pto.partition_view %view, + offsets = [%c0_i64, %c0_i64], + sizes = [%c32_i64, %c32_i64] + : !pto.tensor_view<32x32xf32>, i64 + -> !pto.partition_tensor_view<32x32xf32> + + %dim0 = pto.get_tensor_view_dim %view, %c0_i64 : !pto.tensor_view<32x32xf32>, i64 -> index + %part_dim = pto.partition_view %view, + offsets = [%c0, %c0], + sizes = [%dim0, %dim0] + : !pto.tensor_view<32x32xf32> + -> !pto.partition_tensor_view + + pto.store_scalar %zero, %dst[%dim0] : !pto.ptr, f32 + + %tile0 = pto.alloc_tile : !pto.tile_buf + %tile1 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%part_i64 : !pto.partition_tensor_view<32x32xf32>) + outs(%tile0 : !pto.tile_buf) + pto.tload ins(%part_dim : !pto.partition_tensor_view) + outs(%tile1 : !pto.tile_buf) + return + } +} + +// CHECK-LABEL: __global__ AICORE void index_input_compat_i64( +// CHECK-SAME: int64_t +// CHECK: v2[v3] = +// CHECK: GlobalTensor +// CHECK: TLOAD(