diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 28fe6a0e3..9e8e02b0c 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2374,6 +2374,35 @@ def TColExpandOp : PTO_TOp<"tcolexpand", [ }]; } +def TColExpandAddOp : PTO_TOp<"tcolexpandadd", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Column-wise broadcast add: add a per-column scalar vector src1 to src0 "; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TColExpandMulOp : PTO_TOp<"tcolexpandmul", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2461,6 +2490,35 @@ def TColExpandSubOp : PTO_TOp<"tcolexpandsub", [ }]; } +def TColExpandExpdifOp : PTO_TOp<"tcolexpandexpdif", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Column-wise broadcast expdif: compute exp(src0 - src1) using a per-column scalar vector src1 "; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TColExpandMaxOp : PTO_TOp<"tcolexpandmax", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2603,6 +2661,34 @@ def TColSumOp : PTO_TOp<"tcolsum", [ }]; } +def TColProdOp : PTO_TOp<"tcolprod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Reduce each column by multiplying across rows"; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src `:` qualified(type($src)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TCvtOp : PTO_TOp<"tcvt", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2692,6 +2778,64 @@ def TDivSOp : PTO_TOp<"tdivs", [ } +def TFModOp : PTO_TOp<"tfmod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Elementwise fmod/remainder of two tiles (tilebuf, DPS)"; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + +def TFModSOp : PTO_TOp<"tfmods", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Elementwise fmod/remainder with a scalar (tilebuf, DPS)"; + + let arguments = (ins + PTODpsType:$src, + ScalarType:$scalar, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $scalar `:` qualified(type($src)) `,` type($scalar) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + def TExpOp : PTO_TOp<"texp", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -4046,6 +4190,93 @@ def TRowExpandAddOp: PTO_TOp<"trowexpandadd", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowExpandExpdifOp: PTO_TOp<"trowexpandexpdif", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDEXPDIF: Row-wise broadcast expdif with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TRowExpandMaxOp: PTO_TOp<"trowexpandmax", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDMAX: Row-wise broadcast max with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TRowExpandMinOp: PTO_TOp<"trowexpandmin", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDMIN: Row-wise broadcast min with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TROWMAX TBDPS/tile buffer op) //===----------------------------------------------------------------------===// @@ -4142,6 +4373,35 @@ def TRowSumOp: PTO_TOp<"trowsum", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowProdOp: PTO_TOp<"trowprod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWPROD: Reduce each row by multiplying across columns."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $tmp `:` qualified(type($src)) `,` qualified(type($tmp)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TRSQRT TBDPS/tile buffer op) //===----------------------------------------------------------------------===// diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 51311a25a..7f448056c 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2854,6 +2854,10 @@ LogicalResult pto::TColExpandMulOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); } +LogicalResult pto::TColExpandAddOp::verify() { + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType()); +} LogicalResult pto::TColExpandDivOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); @@ -2862,6 +2866,10 @@ LogicalResult pto::TColExpandSubOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); } +LogicalResult pto::TColExpandExpdifOp::verify() { + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType()); +} LogicalResult pto::TColExpandMaxOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); @@ -3098,6 +3106,44 @@ LogicalResult pto::TColSumOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +LogicalResult pto::TColProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A2/A3 tcolprod element type to be f16/f32/i16/i32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isBF16() || + elem.isInteger(16) || elem.isUnsignedInteger(16) || + elem.isInteger(32) || elem.isUnsignedInteger(32))) + return emitOpError("expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + llvm::LogicalResult mlir::pto::TCvtOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -5776,6 +5822,52 @@ mlir::LogicalResult mlir::pto::TRemOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TFModOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmod element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmod element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRemSOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -5789,6 +5881,39 @@ mlir::LogicalResult mlir::pto::TRemSOp::verify() { return mlir::success(); } +mlir::LogicalResult mlir::pto::TFModSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + + Type elem = getElemTy(srcTy); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + static std::optional getStaticNumElements(ArrayRef shape) { int64_t numel = 1; @@ -6165,7 +6290,8 @@ void mlir::pto::TRsqrtOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { +static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, + OperationState &result) { OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; Type src0Ty, src1Ty, tmpTy, dstTy; bool hasTmp = false; @@ -6211,135 +6337,75 @@ ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationStat return success(); } -void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; +static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, + Value src1, Value tmp, Value dst) { + p << " ins(" << src0 << ", " << src1; + if (tmp) { + p << ", " << tmp; + p << " : " << src0.getType() << ", " << src1.getType() << ", " + << tmp.getType() << ")"; } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; + p << " : " << src0.getType() << ", " << src1.getType() << ")"; } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); + p << " outs(" << dst << " : " << dst.getType() << ")"; + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); } -ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); +ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); +void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); +ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); } void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; - } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; + return parseTRowExpandBinaryLikeOp(parser, result); +} - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); +void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); +ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); +void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } -void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; - } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { @@ -6480,6 +6546,179 @@ mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp, + PTOArch targetArch, + StringRef opName) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp) { + if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return op->emitOpError() << "expects tmp and dst to have the same element type"; + } + + Type elem = getElemTy(dstTy); + if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) + return op->emitOpError("expects src0, src1, and dst to have the same element type"); + auto ft = elem.dyn_cast(); + if (!ft || (!ft.isF16() && !ft.isF32())) + return op->emitOpError() << "expects " << opName << " element type to be f16 or f32"; + + if (!isRowMajorTileBuf(dstTy)) + return op->emitOpError("expects dst to use row-major layout"); + + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return op->emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return op->emitOpError("expects dst valid_shape[1] to be non-zero"); + + auto validShapeMatches = [](ArrayRef lhs, + ArrayRef rhs) -> bool { + if (lhs.size() != rhs.size()) + return false; + for (auto [l, r] : llvm::zip(lhs, rhs)) { + if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) + return false; + } + return true; + }; + + const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); + const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); + + auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, + StringRef operandName, + bool requireNonRowMajor) -> LogicalResult { + if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + operandValid[0] != dstValid[0]) { + return op->emitOpError() << "expects " << operandName + << " valid_shape[0] to equal dst valid_shape[0]"; + } + int64_t expectedCol = ft.isF16() ? 16 : 8; + int64_t operandCol = operandValid[1]; + bool operandIsRowMajor = isRowMajorTileBuf(operandTy); + if (requireNonRowMajor && operandIsRowMajor) { + return op->emitOpError() << "expects " << operandName + << " to use a non-row-major layout when tmp is present"; + } + if (operandIsRowMajor) { + if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { + return op->emitOpError() + << "expects row-major " << operandName + << " valid_shape[1] to be 32/sizeof(dtype)"; + } + return success(); + } + if (operandCol != ShapedType::kDynamic && operandCol != 1) { + return op->emitOpError() << "expects non-row-major " << operandName + << " valid_shape[1] to be 1"; + } + return success(); + }; + + auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, + StringRef fullName, Type broadcastTy, + ArrayRef broadcastValid, + StringRef broadcastName) -> LogicalResult { + if (!isRowMajorTileBuf(fullTy)) + return op->emitOpError() << "expects " << fullName + << " to use row-major layout when it matches dst"; + if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + fullValid[0] != dstValid[0]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[0] to equal dst valid_shape[0]"; + if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + fullValid[1] != dstValid[1]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[1] to equal dst valid_shape[1]"; + return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, + /*requireNonRowMajor=*/hasTmp && + targetArch == PTOArch::A3); + }; + + if (hasTmp && targetArch == PTOArch::A5) + return op->emitOpError("expects A5 form to omit tmp"); + + if (src0MatchesDst) { + if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, + src1Valid, "src1"))) + return success(); + } + if (src1MatchesDst) { + if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, + src0Valid, "src0"))) + return success(); + } + + return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" + << " and the other to be a per-row scalar vector"; +} + +mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandexpdif"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandexpdif"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmax"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmax"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmin"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmin"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { @@ -6595,6 +6834,50 @@ mlir::LogicalResult mlir::pto::TRowSumOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TRowProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 trowprod element type to be i16/i32/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trowprod element type to be i16/i32/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) @@ -8280,13 +8563,16 @@ PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMuta PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) void TColSumOp::getEffects( SmallVectorImpl> &effects) { @@ -8418,6 +8704,8 @@ void TQuantOp::getEffects( } PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TRemOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TRemSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) @@ -8454,6 +8742,36 @@ void TRowExpandSubOp::getEffects( PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TRowExpandExpdifOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + // Row reductions use tmp scratch tile. void TRowMaxOp::getEffects( SmallVectorImpl> &effects) { @@ -8476,6 +8794,12 @@ void TRowSumOp::getEffects( PTO_ADD_WRITE(getDstMutable()); } +void TRowProdOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} void TRsqrtOp::getEffects( SmallVectorImpl> &effects) { PTO_ADD_READ(getSrcMutable()); @@ -8484,6 +8808,7 @@ void TRsqrtOp::getEffects( PTO_ADD_READ(tmp[0]); PTO_ADD_WRITE(getDstMutable()); } + PTO_DEFINE_BINARY_EFFECTS(TScatterOp, getSrcMutable(), getIndexesMutable(), getDstMutable()) // Select: Read(mask, src0, src1) -> Write(tmp, dst) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 492e82532..81f59a205 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -5367,6 +5367,28 @@ struct PTOColExpandMulToEmitC : public OpConversionPattern } }; +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOColExpandDivToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -5389,6 +5411,29 @@ struct PTOColExpandDivToEmitC : public OpConversionPattern } }; +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOColExpandSubToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -5610,6 +5655,27 @@ struct PTOColSumToEmitC : public OpConversionPattern { return success(); } }; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { using RM = mlir::pto::RoundMode; switch (attr.getValue()) { @@ -6822,6 +6888,28 @@ struct PTORemToEmitC : public OpConversionPattern { return success(); } }; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) //===----------------------------------------------------------------------===// @@ -6848,6 +6936,28 @@ struct PTORemSToEmitC : public OpConversionPattern { } }; +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) //===----------------------------------------------------------------------===// @@ -6895,6 +7005,34 @@ struct PTORowExpandAddToEmitC : public OpConversionPattern } }; +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) //===----------------------------------------------------------------------===// @@ -7091,6 +7229,60 @@ struct PTORowExpandSubToEmitC : public OpConversionPattern } }; +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) //===----------------------------------------------------------------------===// @@ -7167,6 +7359,28 @@ struct PTORowSumToEmitC : public OpConversionPattern { return success(); } }; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) // - no-tmp form : TRSQRT(dst, src) @@ -8672,13 +8886,19 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -8691,16 +8911,19 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/test/samples/Colexpandadd/colexpandadd.py b/test/samples/Colexpandadd/colexpandadd.py new file mode 100644 index 000000000..e84da3c42 --- /dev/null +++ b/test/samples/Colexpandadd/colexpandadd.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandadd_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TColExpandAddOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colexpandexpdif/colexpandexpdif.py b/test/samples/Colexpandexpdif/colexpandexpdif.py new file mode 100644 index 000000000..3f663f1ac --- /dev/null +++ b/test/samples/Colexpandexpdif/colexpandexpdif.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandexpdif_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TColExpandExpdifOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py new file mode 100644 index 000000000..a83a24693 --- /dev/null +++ b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py @@ -0,0 +1,51 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, pto +from mlir.ir import IntegerType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + i32 = IntegerType.get_signless(32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + full_ty = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + scalar_ty = pto.TileBufType.get([1, 32], i32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandexpdif_dtype_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src0 = pto.AllocTileOp(full_ty).result + src1 = pto.AllocTileOp(scalar_ty).result + dst = pto.AllocTileOp(full_ty).result + pto.TColExpandExpdifOp(src0, src1, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colprod/colprod.py b/test/samples/Colprod/colprod.py new file mode 100644 index 000000000..f2e35019d --- /dev/null +++ b/test/samples/Colprod/colprod.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolprod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TColProdOp(tb0, tb1) + + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmod/fmod.py b/test/samples/Fmod/fmod.py new file mode 100644 index 000000000..88dbe3fa3 --- /dev/null +++ b/test/samples/Fmod/fmod.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32).result + tb1 = pto.AllocTileOp(tile_buf_32).result + tb2 = pto.AllocTileOp(tile_buf_32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TFModOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmods/fmods.py b/test/samples/Fmods/fmods.py new file mode 100644 index 000000000..6572582a1 --- /dev/null +++ b/test/samples/Fmods/fmods.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmods_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + scalar = arith.ConstantOp(f32, 3.25).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32).result + tb1 = pto.AllocTileOp(tile_buf_32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TFModSOp(tb0, scalar, tb1) + + sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmods/tfmods_scalar_type_invalid.py b/test/samples/Fmods/tfmods_scalar_type_invalid.py new file mode 100644 index 000000000..40a8c5243 --- /dev/null +++ b/test/samples/Fmods/tfmods_scalar_type_invalid.py @@ -0,0 +1,51 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IntegerType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmods_scalar_type_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src = pto.AllocTileOp(tile_ty).result + dst = pto.AllocTileOp(tile_ty).result + scalar = arith.ConstantOp(i32, 7).result + pto.TFModSOp(src, scalar, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandexpdif/rowexpandexpdif.py b/test/samples/Rowexpandexpdif/rowexpandexpdif.py new file mode 100644 index 000000000..062ee8b5f --- /dev/null +++ b/test/samples/Rowexpandexpdif/rowexpandexpdif.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandexpdif_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandExpdifOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmax/rowexpandmax.py b/test/samples/Rowexpandmax/rowexpandmax.py new file mode 100644 index 000000000..5c086ce24 --- /dev/null +++ b/test/samples/Rowexpandmax/rowexpandmax.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmax_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandMaxOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py new file mode 100644 index 000000000..a0d7e5a05 --- /dev/null +++ b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr +from mlir.dialects import func, pto +from mlir.ir import F32Type + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + + f32 = F32Type.get(ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl_row = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + bl_col = pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg_row = pto.TileBufConfigAttr.get(bl_row, sl, fractal_ab_size, pd, ctx) + cfg_col = pto.TileBufConfigAttr.get(bl_col, sl, fractal_ab_size, pd, ctx) + + full_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg_row, ctx) + scalar_ty = pto.TileBufType.get([32, 1], f32, vec, [32, 1], cfg_col, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmax_a5_tmp_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src0 = pto.AllocTileOp(full_ty).result + src1 = pto.AllocTileOp(scalar_ty).result + tmp = pto.AllocTileOp(full_ty).result + dst = pto.AllocTileOp(full_ty).result + pto.TRowExpandMaxOp(src0=src0, src1=src1, tmp=tmp, dst=dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmin/rowexpandmin.py b/test/samples/Rowexpandmin/rowexpandmin.py new file mode 100644 index 000000000..1e89234c6 --- /dev/null +++ b/test/samples/Rowexpandmin/rowexpandmin.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmin_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandMinOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowprod/rowprod.py b/test/samples/Rowprod/rowprod.py new file mode 100644 index 000000000..f4b01cd9c --- /dev/null +++ b/test/samples/Rowprod/rowprod.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x1 = pto.PartitionTensorViewType.get([32, 1], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x1 = pto.TileBufType.get([32, 32], f32, vec, [32, 1], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowprod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c1], [c1, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb_tmp = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x1).result + + pto.TLoadOp(None, sv0, tb0) + pto.TRowProdOp(tb0, tb_tmp, tb1) + + sv1 = pto.PartitionViewOp(tile_view_32x1, tv1, offsets=[c0, c0], sizes=[c32, c1]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py new file mode 100644 index 000000000..5a7a9f61b --- /dev/null +++ b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, pto +from mlir.ir import F32Type + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + src_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tmp_ty = pto.TileBufType.get([32, 16], f32, vec, [32, 16], cfg, ctx) + dst_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 1], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowprod_tmp_mismatch_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src = pto.AllocTileOp(src_ty).result + tmp = pto.AllocTileOp(tmp_ty).result + dst = pto.AllocTileOp(dst_ty).result + pto.TRowProdOp(src, tmp, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index bd992aeab..e191daf57 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -872,7 +872,7 @@ PY # Regression guard for row-reduction kernels: # (32 x 1) row-major outputs are minor-2D ambiguous; layout must align with # row-major tiles (ND), otherwise pto-isa can hit layout/tile static_assert. - if [[ "$base" == "rowmin" || "$base" == "rowsum" || "$base" == "rowmax" ]]; then + if [[ "$base" == "rowmin" || "$base" == "rowsum" || "$base" == "rowmax" || "$base" == "rowprod" ]]; then if ! grep -Eq "pto::Shape<1, 1, 1, 32, 1>.*pto::Layout::ND" "$cpp"; then echo -e "${A}(${base}.py)\tFAIL\texpected pto::Layout::ND for shape (32 x 1) GlobalTensor" overall=1 diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index d245434cf..1b3b03699 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -6,11 +6,6 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - // Generated by docs/bytecode/tools/gen_v0_tables.py #pragma once @@ -121,9 +116,9 @@ inline constexpr OpInfo kOpTable[] = { {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index 5fb4e3ba1..7a4b39105 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -10,11 +10,23 @@ module { %tmp = pto.alloc_tile : !pto.tile_buf %dst0 = pto.alloc_tile : !pto.tile_buf %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf + %dst6 = pto.alloc_tile : !pto.tile_buf + %dst7 = pto.alloc_tile : !pto.tile_buf + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) + pto.trowexpandexpdif ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) + pto.trowexpandmax ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst6 : !pto.tile_buf) + pto.trowexpandmin ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst7 : !pto.tile_buf) %part0 = pto.alloc_tile : !pto.tile_buf %part1 = pto.alloc_tile : !pto.tile_buf %partdst = pto.alloc_tile : !pto.tile_buf - pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) - pto.trowexpandmul ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) %rs_dst0 = pto.alloc_tile : !pto.tile_buf %rs_tmp = pto.alloc_tile : !pto.tile_buf diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index 7f9c1dc2b..26160e727 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -35,6 +35,9 @@ grep -F "pto.subset " "${ROUNDTRIP}" >/dev/null grep -F "pto.tprint ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trowexpanddiv ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trowexpandmul ins(" "${ROUNDTRIP}" >/dev/null +[[ $(grep -Fc "pto.trowexpandexpdif ins(" "${ROUNDTRIP}") -eq 2 ]] +[[ $(grep -Fc "pto.trowexpandmax ins(" "${ROUNDTRIP}") -eq 2 ]] +[[ $(grep -Fc "pto.trowexpandmin ins(" "${ROUNDTRIP}") -eq 2 ]] grep -F "pto.trsqrt ins(" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.trsqrt ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null grep -F "pto.tpartmul ins(" "${ROUNDTRIP}" >/dev/null