From 2a7d6f111d19a0f231934687fb7a66ba30e5d4aa Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Wed, 1 Apr 2026 10:26:03 +0800 Subject: [PATCH 1/4] feat: add more PTO tensor op coverage --- include/PTO/IR/PTOOps.td | 260 +++++++++ lib/PTO/IR/PTO.cpp | 551 ++++++++++++++---- lib/PTO/Transforms/PTOToEmitC.cpp | 223 +++++++ test/samples/Colexpandadd/colexpandadd.py | 66 +++ .../Colexpandexpdif/colexpandexpdif.py | 66 +++ .../colexpandexpdif_dtype_invalid.py | 43 ++ test/samples/Colprod/colprod.py | 61 ++ test/samples/Fmod/fmod.py | 64 ++ test/samples/Fmods/fmods.py | 61 ++ .../Fmods/tfmods_scalar_type_invalid.py | 43 ++ .../Rowexpandexpdif/rowexpandexpdif.py | 67 +++ test/samples/Rowexpandmax/rowexpandmax.py | 67 +++ .../rowexpandmax_a5_tmp_invalid.py | 48 ++ test/samples/Rowexpandmin/rowexpandmin.py | 67 +++ test/samples/Rowprod/rowprod.py | 63 ++ .../Rowprod/trowprod_tmp_mismatch_invalid.py | 44 ++ test/samples/runop.sh | 2 +- 17 files changed, 1682 insertions(+), 114 deletions(-) create mode 100644 test/samples/Colexpandadd/colexpandadd.py create mode 100644 test/samples/Colexpandexpdif/colexpandexpdif.py create mode 100644 test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py create mode 100644 test/samples/Colprod/colprod.py create mode 100644 test/samples/Fmod/fmod.py create mode 100644 test/samples/Fmods/fmods.py create mode 100644 test/samples/Fmods/tfmods_scalar_type_invalid.py create mode 100644 test/samples/Rowexpandexpdif/rowexpandexpdif.py create mode 100644 test/samples/Rowexpandmax/rowexpandmax.py create mode 100644 test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py create mode 100644 test/samples/Rowexpandmin/rowexpandmin.py create mode 100644 test/samples/Rowprod/rowprod.py create mode 100644 test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 269e9b2b3..7b659e264 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2301,6 +2301,35 @@ def TColExpandOp : PTO_TOp<"tcolexpand", [ }]; } +def TColExpandAddOp : PTO_TOp<"tcolexpandadd", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Column-wise broadcast add: add a per-column scalar vector src1 to src0 "; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TColExpandMulOp : PTO_TOp<"tcolexpandmul", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2388,6 +2417,35 @@ def TColExpandSubOp : PTO_TOp<"tcolexpandsub", [ }]; } +def TColExpandExpdifOp : PTO_TOp<"tcolexpandexpdif", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Column-wise broadcast expdif: compute exp(src0 - src1) using a per-column scalar vector src1 "; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TColExpandMaxOp : PTO_TOp<"tcolexpandmax", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2530,6 +2588,34 @@ def TColSumOp : PTO_TOp<"tcolsum", [ }]; } +def TColProdOp : PTO_TOp<"tcolprod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Reduce each column by multiplying across rows"; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src `:` qualified(type($src)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + def TCvtOp : PTO_TOp<"tcvt", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -2619,6 +2705,64 @@ def TDivSOp : PTO_TOp<"tdivs", [ } +def TFModOp : PTO_TOp<"tfmod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Elementwise fmod/remainder of two tiles (tilebuf, DPS)"; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; + + let assemblyFormat = [{ + `ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; +} + +def TFModSOp : PTO_TOp<"tfmods", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Elementwise fmod/remainder with a scalar (tilebuf, DPS)"; + + let arguments = (ins + PTODpsType:$src, + ScalarType:$scalar, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $scalar `:` qualified(type($src)) `,` type($scalar) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + def TExpOp : PTO_TOp<"texp", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -3835,6 +3979,93 @@ def TRowExpandAddOp: PTO_TOp<"trowexpandadd", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowExpandExpdifOp: PTO_TOp<"trowexpandexpdif", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDEXPDIF: Row-wise broadcast expdif with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TRowExpandMaxOp: PTO_TOp<"trowexpandmax", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDMAX: Row-wise broadcast max with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TRowExpandMinOp: PTO_TOp<"trowexpandmin", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWEXPANDMIN: Row-wise broadcast min with per-row scalar vector."; + let description = [{ + pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only. + }]; + + let arguments = (ins + PTODpsType:$src0, + PTODpsType:$src1, + Optional:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TROWMAX TBDPS/tile buffer op) //===----------------------------------------------------------------------===// @@ -3931,6 +4162,35 @@ def TRowSumOp: PTO_TOp<"trowsum", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowProdOp: PTO_TOp<"trowprod", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWPROD: Reduce each row by multiplying across columns."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $tmp `:` qualified(type($src)) `,` qualified(type($tmp)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TRSQRT TBDPS/tile buffer op) //===----------------------------------------------------------------------===// diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 2d20580fd..f66f2515f 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2795,6 +2795,10 @@ LogicalResult pto::TColExpandMulOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); } +LogicalResult pto::TColExpandAddOp::verify() { + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType()); +} LogicalResult pto::TColExpandDivOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); @@ -2803,6 +2807,10 @@ LogicalResult pto::TColExpandSubOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); } +LogicalResult pto::TColExpandExpdifOp::verify() { + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType()); +} LogicalResult pto::TColExpandMaxOp::verify() { return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType()); @@ -3039,6 +3047,44 @@ LogicalResult pto::TColSumOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +LogicalResult pto::TColProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A2/A3 tcolprod element type to be f16/f32/i16/i32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isBF16() || + elem.isInteger(16) || elem.isUnsignedInteger(16) || + elem.isInteger(32) || elem.isUnsignedInteger(32))) + return emitOpError("expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + llvm::LogicalResult mlir::pto::TCvtOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -5257,6 +5303,52 @@ mlir::LogicalResult mlir::pto::TRemOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TFModOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmod element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameShapeAndElem(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmod element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRemSOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -5272,6 +5364,39 @@ mlir::LogicalResult mlir::pto::TRemSOp::verify() { return mlir::success(); } +mlir::LogicalResult mlir::pto::TFModSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + + Type elem = getElemTy(srcTy); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + static std::optional getStaticNumElements(ArrayRef shape) { int64_t numel = 1; @@ -5599,7 +5724,8 @@ void mlir::pto::TSort32Op::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); } -ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { +static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, + OperationState &result) { OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; Type src0Ty, src1Ty, tmpTy, dstTy; bool hasTmp = false; @@ -5645,135 +5771,75 @@ ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationStat return success(); } -void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; +static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, + Value src1, Value tmp, Value dst) { + p << " ins(" << src0 << ", " << src1; + if (tmp) { + p << ", " << tmp; + p << " : " << src0.getType() << ", " << src1.getType() << ", " + << tmp.getType() << ")"; } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; + p << " : " << src0.getType() << ", " << src1.getType() << ")"; } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); + p << " outs(" << dst << " : " << dst.getType() << ")"; + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); } -ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); +ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); +void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); +ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); } void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; - } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; + return parseTRowExpandBinaryLikeOp(parser, result); +} - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); +void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); +ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); +void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } -void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc0() << ", " << getSrc1(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ", " - << getTmp().getType() << ")"; - } else { - p << " : " << getSrc0().getType() << ", " << getSrc1().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); } mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { @@ -5914,6 +5980,179 @@ mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp, + PTOArch targetArch, + StringRef opName) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp) { + if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return op->emitOpError() << "expects tmp and dst to have the same element type"; + } + + Type elem = getElemTy(dstTy); + if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) + return op->emitOpError("expects src0, src1, and dst to have the same element type"); + auto ft = elem.dyn_cast(); + if (!ft || (!ft.isF16() && !ft.isF32())) + return op->emitOpError() << "expects " << opName << " element type to be f16 or f32"; + + if (!isRowMajorTileBuf(dstTy)) + return op->emitOpError("expects dst to use row-major layout"); + + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return op->emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return op->emitOpError("expects dst valid_shape[1] to be non-zero"); + + auto validShapeMatches = [](ArrayRef lhs, + ArrayRef rhs) -> bool { + if (lhs.size() != rhs.size()) + return false; + for (auto [l, r] : llvm::zip(lhs, rhs)) { + if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) + return false; + } + return true; + }; + + const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); + const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); + + auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, + StringRef operandName, + bool requireNonRowMajor) -> LogicalResult { + if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + operandValid[0] != dstValid[0]) { + return op->emitOpError() << "expects " << operandName + << " valid_shape[0] to equal dst valid_shape[0]"; + } + int64_t expectedCol = ft.isF16() ? 16 : 8; + int64_t operandCol = operandValid[1]; + bool operandIsRowMajor = isRowMajorTileBuf(operandTy); + if (requireNonRowMajor && operandIsRowMajor) { + return op->emitOpError() << "expects " << operandName + << " to use a non-row-major layout when tmp is present"; + } + if (operandIsRowMajor) { + if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { + return op->emitOpError() + << "expects row-major " << operandName + << " valid_shape[1] to be 32/sizeof(dtype)"; + } + return success(); + } + if (operandCol != ShapedType::kDynamic && operandCol != 1) { + return op->emitOpError() << "expects non-row-major " << operandName + << " valid_shape[1] to be 1"; + } + return success(); + }; + + auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, + StringRef fullName, Type broadcastTy, + ArrayRef broadcastValid, + StringRef broadcastName) -> LogicalResult { + if (!isRowMajorTileBuf(fullTy)) + return op->emitOpError() << "expects " << fullName + << " to use row-major layout when it matches dst"; + if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + fullValid[0] != dstValid[0]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[0] to equal dst valid_shape[0]"; + if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + fullValid[1] != dstValid[1]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[1] to equal dst valid_shape[1]"; + return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, + /*requireNonRowMajor=*/hasTmp && + targetArch == PTOArch::A3); + }; + + if (hasTmp && targetArch == PTOArch::A5) + return op->emitOpError("expects A5 form to omit tmp"); + + if (src0MatchesDst) { + if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, + src1Valid, "src1"))) + return success(); + } + if (src1MatchesDst) { + if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, + src0Valid, "src0"))) + return success(); + } + + return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" + << " and the other to be a per-row scalar vector"; +} + +mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandexpdif"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandexpdif"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmax"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmax"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmin"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmin"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { @@ -6029,6 +6268,50 @@ mlir::LogicalResult mlir::pto::TRowSumOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TRowProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 trowprod element type to be i16/i32/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trowprod element type to be i16/i32/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) @@ -7701,13 +7984,16 @@ PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMuta PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) void TColSumOp::getEffects( SmallVectorImpl> &effects) { @@ -7813,6 +8099,8 @@ void TPReluOp::getEffects( PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TRemOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TRemSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) @@ -7849,6 +8137,36 @@ void TRowExpandSubOp::getEffects( PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TRowExpandExpdifOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + // Row reductions use tmp scratch tile. void TRowMaxOp::getEffects( SmallVectorImpl> &effects) { @@ -7871,6 +8189,13 @@ void TRowSumOp::getEffects( PTO_ADD_WRITE(getDstMutable()); } +void TRowProdOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + PTO_DEFINE_UNARY_EFFECTS(TRsqrtOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TScatterOp, getSrcMutable(), getIndexesMutable(), getDstMutable()) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 69c34cedd..5fb6da5b0 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -5184,6 +5184,28 @@ struct PTOColExpandMulToEmitC : public OpConversionPattern } }; +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOColExpandDivToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -5206,6 +5228,29 @@ struct PTOColExpandDivToEmitC : public OpConversionPattern } }; +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOColExpandSubToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -5427,6 +5472,27 @@ struct PTOColSumToEmitC : public OpConversionPattern { return success(); } }; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { using RM = mlir::pto::RoundMode; switch (attr.getValue()) { @@ -6500,6 +6566,28 @@ struct PTORemToEmitC : public OpConversionPattern { return success(); } }; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) //===----------------------------------------------------------------------===// @@ -6526,6 +6614,28 @@ struct PTORemSToEmitC : public OpConversionPattern { } }; +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) //===----------------------------------------------------------------------===// @@ -6573,6 +6683,34 @@ struct PTORowExpandAddToEmitC : public OpConversionPattern } }; +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) //===----------------------------------------------------------------------===// @@ -6769,6 +6907,60 @@ struct PTORowExpandSubToEmitC : public OpConversionPattern } }; +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) //===----------------------------------------------------------------------===// @@ -6845,6 +7037,28 @@ struct PTORowSumToEmitC : public OpConversionPattern { return success(); } }; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) //===----------------------------------------------------------------------===// @@ -8347,13 +8561,19 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -8366,16 +8586,19 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/test/samples/Colexpandadd/colexpandadd.py b/test/samples/Colexpandadd/colexpandadd.py new file mode 100644 index 000000000..7bc11d019 --- /dev/null +++ b/test/samples/Colexpandadd/colexpandadd.py @@ -0,0 +1,66 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandadd_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TColExpandAddOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colexpandexpdif/colexpandexpdif.py b/test/samples/Colexpandexpdif/colexpandexpdif.py new file mode 100644 index 000000000..714703821 --- /dev/null +++ b/test/samples/Colexpandexpdif/colexpandexpdif.py @@ -0,0 +1,66 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandexpdif_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TColExpandExpdifOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py new file mode 100644 index 000000000..33e90d535 --- /dev/null +++ b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py @@ -0,0 +1,43 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, pto +from mlir.ir import IntegerType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + i32 = IntegerType.get_signless(32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + full_ty = pto.TileBufType.get([32, 32], i32, vec, [32, 32], cfg, ctx) + scalar_ty = pto.TileBufType.get([1, 32], i32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolexpandexpdif_dtype_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src0 = pto.AllocTileOp(full_ty).result + src1 = pto.AllocTileOp(scalar_ty).result + dst = pto.AllocTileOp(full_ty).result + pto.TColExpandExpdifOp(src0, src1, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Colprod/colprod.py b/test/samples/Colprod/colprod.py new file mode 100644 index 000000000..be3e2bd24 --- /dev/null +++ b/test/samples/Colprod/colprod.py @@ -0,0 +1,61 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_1x32 = pto.PartitionTensorViewType.get([1, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_1x32 = pto.TileBufType.get([1, 32], f32, vec, [1, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tcolprod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_1x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TColProdOp(tb0, tb1) + + sv1 = pto.PartitionViewOp(tile_view_1x32, tv1, offsets=[c0, c0], sizes=[c1, c32]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmod/fmod.py b/test/samples/Fmod/fmod.py new file mode 100644 index 000000000..f60108f6c --- /dev/null +++ b/test/samples/Fmod/fmod.py @@ -0,0 +1,64 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c32], [c32, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32).result + tb1 = pto.AllocTileOp(tile_buf_32).result + tb2 = pto.AllocTileOp(tile_buf_32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TFModOp(tb0, tb1, tb2) + + sv2 = pto.PartitionViewOp(tile_view_32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmods/fmods.py b/test/samples/Fmods/fmods.py new file mode 100644 index 000000000..95859d7fe --- /dev/null +++ b/test/samples/Fmods/fmods.py @@ -0,0 +1,61 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmods_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + scalar = arith.ConstantOp(f32, 3.25).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32).result + tb1 = pto.AllocTileOp(tile_buf_32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TFModSOp(tb0, scalar, tb1) + + sv1 = pto.PartitionViewOp(tile_view_32, tv1, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Fmods/tfmods_scalar_type_invalid.py b/test/samples/Fmods/tfmods_scalar_type_invalid.py new file mode 100644 index 000000000..d76fa2adf --- /dev/null +++ b/test/samples/Fmods/tfmods_scalar_type_invalid.py @@ -0,0 +1,43 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IntegerType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tfmods_scalar_type_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src = pto.AllocTileOp(tile_ty).result + dst = pto.AllocTileOp(tile_ty).result + scalar = arith.ConstantOp(i32, 7).result + pto.TFModSOp(src, scalar, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandexpdif/rowexpandexpdif.py b/test/samples/Rowexpandexpdif/rowexpandexpdif.py new file mode 100644 index 000000000..878dc69f3 --- /dev/null +++ b/test/samples/Rowexpandexpdif/rowexpandexpdif.py @@ -0,0 +1,67 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandexpdif_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandExpdifOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmax/rowexpandmax.py b/test/samples/Rowexpandmax/rowexpandmax.py new file mode 100644 index 000000000..6efd10f9c --- /dev/null +++ b/test/samples/Rowexpandmax/rowexpandmax.py @@ -0,0 +1,67 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmax_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandMaxOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py new file mode 100644 index 000000000..1bc752dff --- /dev/null +++ b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py @@ -0,0 +1,48 @@ +from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr +from mlir.dialects import func, pto +from mlir.ir import F32Type + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + + f32 = F32Type.get(ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl_row = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + bl_col = pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg_row = pto.TileBufConfigAttr.get(bl_row, sl, fractal_ab_size, pd, ctx) + cfg_col = pto.TileBufConfigAttr.get(bl_col, sl, fractal_ab_size, pd, ctx) + + full_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg_row, ctx) + scalar_ty = pto.TileBufType.get([32, 1], f32, vec, [32, 1], cfg_col, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmax_a5_tmp_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src0 = pto.AllocTileOp(full_ty).result + src1 = pto.AllocTileOp(scalar_ty).result + tmp = pto.AllocTileOp(full_ty).result + dst = pto.AllocTileOp(full_ty).result + pto.TRowExpandMaxOp(src0=src0, src1=src1, tmp=tmp, dst=dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowexpandmin/rowexpandmin.py b/test/samples/Rowexpandmin/rowexpandmin.py new file mode 100644 index 000000000..260b95e70 --- /dev/null +++ b/test/samples/Rowexpandmin/rowexpandmin.py @@ -0,0 +1,67 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x8 = pto.PartitionTensorViewType.get([32, 8], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x8 = pto.TileBufType.get([32, 8], f32, vec, [32, 8], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowexpandmin_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c8 = arith.ConstantOp(IndexType.get(ctx), 8).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1, arg2 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c8], [c8, c1]).result + tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c32, c32], [c32, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + sv1 = pto.PartitionViewOp(tile_view_32x8, tv1, offsets=[c0, c0], sizes=[c32, c8]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x8).result + tb2 = pto.AllocTileOp(tile_buf_32x32).result + + pto.TLoadOp(None, sv0, tb0) + pto.TLoadOp(None, sv1, tb1) + pto.TRowExpandMinOp(src0=tb0, src1=tb1, dst=tb2) + + sv2 = pto.PartitionViewOp(tile_view_32x32, tv2, offsets=[c0, c0], sizes=[c32, c32]).result + pto.TStoreOp(None, tb2, sv2) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowprod/rowprod.py b/test/samples/Rowprod/rowprod.py new file mode 100644 index 000000000..be8a4cadd --- /dev/null +++ b/test/samples/Rowprod/rowprod.py @@ -0,0 +1,63 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32x32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + tile_view_32x1 = pto.PartitionTensorViewType.get([32, 1], f32, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_32x32 = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tile_buf_32x1 = pto.TileBufType.get([32, 32], f32, vec, [32, 1], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowprod_kernel_2d", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + c32 = arith.ConstantOp(IndexType.get(ctx), 32).result + + arg0, arg1 = entry.arguments + + tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c32, c32], [c32, c1]).result + tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c32, c1], [c1, c1]).result + + sv0 = pto.PartitionViewOp(tile_view_32x32, tv0, offsets=[c0, c0], sizes=[c32, c32]).result + + tb0 = pto.AllocTileOp(tile_buf_32x32).result + tb_tmp = pto.AllocTileOp(tile_buf_32x32).result + tb1 = pto.AllocTileOp(tile_buf_32x1).result + + pto.TLoadOp(None, sv0, tb0) + pto.TRowProdOp(tb0, tb_tmp, tb1) + + sv1 = pto.PartitionViewOp(tile_view_32x1, tv1, offsets=[c0, c0], sizes=[c32, c1]).result + pto.TStoreOp(None, tb1, sv1) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py new file mode 100644 index 000000000..3a787e769 --- /dev/null +++ b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py @@ -0,0 +1,44 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, pto +from mlir.ir import F32Type + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + src_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 32], cfg, ctx) + tmp_ty = pto.TileBufType.get([32, 16], f32, vec, [32, 16], cfg, ctx) + dst_ty = pto.TileBufType.get([32, 32], f32, vec, [32, 1], cfg, ctx) + + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("trowprod_tmp_mismatch_invalid", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src = pto.AllocTileOp(src_ty).result + tmp = pto.AllocTileOp(tmp_ty).result + dst = pto.AllocTileOp(dst_ty).result + pto.TRowProdOp(src, tmp, dst) + func.ReturnOp([]) + + ok = m.operation.verify() + if ok: + return m + raise SystemExit(1) + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index f41f2410b..5f43093cc 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -769,7 +769,7 @@ PY # Regression guard for row-reduction kernels: # (32 x 1) row-major outputs are minor-2D ambiguous; layout must align with # row-major tiles (ND), otherwise pto-isa can hit layout/tile static_assert. - if [[ "$base" == "rowmin" || "$base" == "rowsum" || "$base" == "rowmax" ]]; then + if [[ "$base" == "rowmin" || "$base" == "rowsum" || "$base" == "rowmax" || "$base" == "rowprod" ]]; then if ! grep -Eq "pto::Shape<1, 1, 1, 32, 1>.*pto::Layout::ND" "$cpp"; then echo -e "${A}(${base}.py)\tFAIL\texpected pto::Layout::ND for shape (32 x 1) GlobalTensor" overall=1 From f8f3b083b30f1aee3eabf13f158ee60ae2373b89 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Wed, 1 Apr 2026 10:35:35 +0800 Subject: [PATCH 2/4] fix(ci): add PR386 license header support --- .github/scripts/check_license_headers.py | 183 ++++++++++++++++++ include/PTO/IR/PTOOps.td | 8 + lib/PTO/IR/PTO.cpp | 8 + lib/PTO/Transforms/PTOToEmitC.cpp | 8 + test/samples/Colexpandadd/colexpandadd.py | 8 + .../Colexpandexpdif/colexpandexpdif.py | 8 + .../colexpandexpdif_dtype_invalid.py | 8 + test/samples/Colprod/colprod.py | 8 + test/samples/Fmod/fmod.py | 8 + test/samples/Fmods/fmods.py | 8 + .../Fmods/tfmods_scalar_type_invalid.py | 8 + .../Rowexpandexpdif/rowexpandexpdif.py | 8 + test/samples/Rowexpandmax/rowexpandmax.py | 8 + .../rowexpandmax_a5_tmp_invalid.py | 8 + test/samples/Rowexpandmin/rowexpandmin.py | 8 + test/samples/Rowprod/rowprod.py | 8 + .../Rowprod/trowprod_tmp_mismatch_invalid.py | 8 + test/samples/runop.sh | 8 + 18 files changed, 319 insertions(+) create mode 100644 .github/scripts/check_license_headers.py diff --git a/.github/scripts/check_license_headers.py b/.github/scripts/check_license_headers.py new file mode 100644 index 000000000..c1c175974 --- /dev/null +++ b/.github/scripts/check_license_headers.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +import urllib.error +import urllib.parse +import urllib.request +from pathlib import Path + +HEADER_BODY = [ + "Copyright (c) 2026 Huawei Technologies Co., Ltd.", + "This program is free software, you can redistribute it and/or modify it under the terms and conditions of", + 'CANN Open Software License Agreement Version 2.0 (the "License").', + "Please refer to the License for details. You may not use this file except in compliance with the License.", + "THIS SOFTWARE IS PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,", + "INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.", + "See LICENSE in the root of the software repository for the full text of the License.", +] +HASH_HEADER = [f"# {line}" for line in HEADER_BODY] +SLASH_HEADER = [f"// {line}" for line in HEADER_BODY] +HASH_FILE_SUFFIXES = {".py", ".sh", ".cmake"} +SLASH_FILE_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".h", ".hh", ".hpp", ".hxx", ".td"} +HASH_FILE_BASENAMES = {"CMakeLists.txt"} +SHEBANG_SUFFIXES = {".py", ".sh"} +ZERO_SHA = "0" * 40 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Check PR386-style license headers on changed files.") + parser.add_argument("--repo", required=True, help="owner/repo for GitHub API lookups") + parser.add_argument("--event-name", required=True, help="GitHub event name") + parser.add_argument("--pr-number", default="", help="Pull request number for pull_request events") + parser.add_argument("--base-sha", default="", help="Git base SHA for push events") + parser.add_argument("--head-sha", default="HEAD", help="Git head SHA for push events") + parser.add_argument("--github-token", default="", help="GitHub token used for PR file listing") + return parser.parse_args() + + +def comment_style_for(path_str: str) -> str | None: + path = Path(path_str) + suffix = path.suffix.lower() + if path.name in HASH_FILE_BASENAMES or suffix in HASH_FILE_SUFFIXES: + return "#" + if suffix in SLASH_FILE_SUFFIXES: + return "//" + return None + + +def expected_header(style: str) -> list[str]: + return HASH_HEADER if style == "#" else SLASH_HEADER + + +def git_output(*args: str) -> list[str]: + proc = subprocess.run( + ["git", *args], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + return [line.strip() for line in proc.stdout.splitlines() if line.strip()] + + +def changed_files_from_git(base_sha: str, head_sha: str) -> list[str]: + if base_sha and base_sha != ZERO_SHA: + try: + return git_output("diff", "--name-only", "--diff-filter=ACMR", base_sha, head_sha) + except subprocess.CalledProcessError: + pass + return git_output("diff-tree", "--no-commit-id", "--name-only", "--diff-filter=ACMR", "-r", head_sha) + + +def github_api_json(url: str, token: str) -> list[dict]: + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + if token: + headers["Authorization"] = f"Bearer {token}" + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req, timeout=30) as resp: + return json.load(resp) + + +def changed_files_from_pr(repo: str, pr_number: str, token: str) -> list[str]: + files: list[str] = [] + page = 1 + while True: + url = ( + f"https://api.github.com/repos/{urllib.parse.quote(repo, safe='/')}/pulls/" + f"{pr_number}/files?per_page=100&page={page}" + ) + page_items = github_api_json(url, token) + if not page_items: + break + for item in page_items: + if item.get("status") == "removed": + continue + filename = str(item.get("filename") or "").strip() + if filename: + files.append(filename) + page += 1 + return files + + +def normalize_lines(path: Path) -> list[str]: + lines = path.read_text(encoding="utf-8", errors="replace").splitlines() + if lines: + lines[0] = lines[0].lstrip("\ufeff") + return lines + + +def header_start_index(path: Path, lines: list[str]) -> int: + if lines and path.suffix.lower() in SHEBANG_SUFFIXES and lines[0].startswith("#!"): + return 1 + return 0 + + +def has_expected_header(path_str: str, style: str) -> bool: + path = Path(path_str) + if not path.exists(): + return True + lines = normalize_lines(path) + start = header_start_index(path, lines) + expected = expected_header(style) + return lines[start : start + len(expected)] == expected + + +def main() -> int: + args = parse_args() + try: + if args.event_name == "pull_request" and args.pr_number: + changed_files = changed_files_from_pr(args.repo, args.pr_number, args.github_token) + else: + changed_files = changed_files_from_git(args.base_sha, args.head_sha) + except urllib.error.URLError as exc: + print(f"Failed to query GitHub API: {exc}", file=sys.stderr) + return 2 + except subprocess.CalledProcessError as exc: + print(exc.stderr or str(exc), file=sys.stderr) + return 2 + + relevant_files: list[tuple[str, str]] = [] + for path_str in changed_files: + style = comment_style_for(path_str) + if style is None: + continue + relevant_files.append((path_str, style)) + + if not relevant_files: + print("No changed source/script files require the PR386 license header.") + return 0 + + missing: list[tuple[str, str]] = [] + for path_str, style in relevant_files: + if not has_expected_header(path_str, style): + missing.append((path_str, style)) + + if missing: + print("Missing PR386 license header in changed files:", file=sys.stderr) + for path_str, style in missing: + print(f"- {path_str}", file=sys.stderr) + for line in expected_header(style): + print(f" {line}", file=sys.stderr) + return 1 + + print(f"Checked {len(relevant_files)} changed source/script files: all headers present.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 7b659e264..96c619a2a 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTOOps.td - Pattern descriptor operations -----------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index f66f2515f..ac026a43d 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTO.cpp - PTO Dialect ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 5fb6da5b0..3d9114cdb 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + //===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/test/samples/Colexpandadd/colexpandadd.py b/test/samples/Colexpandadd/colexpandadd.py index 7bc11d019..e84da3c42 100644 --- a/test/samples/Colexpandadd/colexpandadd.py +++ b/test/samples/Colexpandadd/colexpandadd.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Colexpandexpdif/colexpandexpdif.py b/test/samples/Colexpandexpdif/colexpandexpdif.py index 714703821..3f663f1ac 100644 --- a/test/samples/Colexpandexpdif/colexpandexpdif.py +++ b/test/samples/Colexpandexpdif/colexpandexpdif.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py index 33e90d535..a83a24693 100644 --- a/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py +++ b/test/samples/Colexpandexpdif/colexpandexpdif_dtype_invalid.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, pto from mlir.ir import IntegerType diff --git a/test/samples/Colprod/colprod.py b/test/samples/Colprod/colprod.py index be3e2bd24..f2e35019d 100644 --- a/test/samples/Colprod/colprod.py +++ b/test/samples/Colprod/colprod.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Fmod/fmod.py b/test/samples/Fmod/fmod.py index f60108f6c..88dbe3fa3 100644 --- a/test/samples/Fmod/fmod.py +++ b/test/samples/Fmod/fmod.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Fmods/fmods.py b/test/samples/Fmods/fmods.py index 95859d7fe..6572582a1 100644 --- a/test/samples/Fmods/fmods.py +++ b/test/samples/Fmods/fmods.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Fmods/tfmods_scalar_type_invalid.py b/test/samples/Fmods/tfmods_scalar_type_invalid.py index d76fa2adf..40a8c5243 100644 --- a/test/samples/Fmods/tfmods_scalar_type_invalid.py +++ b/test/samples/Fmods/tfmods_scalar_type_invalid.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IntegerType diff --git a/test/samples/Rowexpandexpdif/rowexpandexpdif.py b/test/samples/Rowexpandexpdif/rowexpandexpdif.py index 878dc69f3..062ee8b5f 100644 --- a/test/samples/Rowexpandexpdif/rowexpandexpdif.py +++ b/test/samples/Rowexpandexpdif/rowexpandexpdif.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Rowexpandmax/rowexpandmax.py b/test/samples/Rowexpandmax/rowexpandmax.py index 6efd10f9c..5c086ce24 100644 --- a/test/samples/Rowexpandmax/rowexpandmax.py +++ b/test/samples/Rowexpandmax/rowexpandmax.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py index 1bc752dff..a0d7e5a05 100644 --- a/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py +++ b/test/samples/Rowexpandmax/rowexpandmax_a5_tmp_invalid.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr from mlir.dialects import func, pto from mlir.ir import F32Type diff --git a/test/samples/Rowexpandmin/rowexpandmin.py b/test/samples/Rowexpandmin/rowexpandmin.py index 260b95e70..1e89234c6 100644 --- a/test/samples/Rowexpandmin/rowexpandmin.py +++ b/test/samples/Rowexpandmin/rowexpandmin.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Rowprod/rowprod.py b/test/samples/Rowprod/rowprod.py index be8a4cadd..f4b01cd9c 100644 --- a/test/samples/Rowprod/rowprod.py +++ b/test/samples/Rowprod/rowprod.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, arith, pto from mlir.ir import F32Type, IndexType diff --git a/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py index 3a787e769..5a7a9f61b 100644 --- a/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py +++ b/test/samples/Rowprod/trowprod_tmp_mismatch_invalid.py @@ -1,3 +1,11 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import func, pto from mlir.ir import F32Type diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 5f43093cc..d8bb58a00 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + set -uo pipefail # 注意:去掉 -e,避免失败直接退出整个脚本 BASE_DIR="$(cd -- "$(dirname -- "$0")" && pwd)" From 6cb0a68fe1e00cfdc5d9049de56b456523af1d42 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Thu, 2 Apr 2026 14:57:24 +0800 Subject: [PATCH 3/4] fix(ptobc): allow optional tmp row-expand v0 encoding --- tools/ptobc/generated/ptobc_opcodes_v0.h | 6 +++--- tools/ptobc/testdata/recent_ops_v0_roundtrip.pto | 12 ++++++++++++ tools/ptobc/tests/recent_ops_v0_encode.sh | 4 ++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 4885ca380..fbca6e60d 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -108,9 +108,9 @@ inline constexpr OpInfo kOpTable[] = { {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index 3a5f508de..4991431ac 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -10,8 +10,20 @@ module { %tmp = pto.alloc_tile : !pto.tile_buf %dst0 = pto.alloc_tile : !pto.tile_buf %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf + %dst6 = pto.alloc_tile : !pto.tile_buf + %dst7 = pto.alloc_tile : !pto.tile_buf pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) pto.trowexpandmul ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) + pto.trowexpandexpdif ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) + pto.trowexpandmax ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst6 : !pto.tile_buf) + pto.trowexpandmin ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst7 : !pto.tile_buf) return } } diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index c4f524a7a..a01f778bc 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -27,3 +27,7 @@ grep -F "pto.subset " "${ROUNDTRIP}" >/dev/null grep -F "pto.tprint ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trowexpanddiv ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trowexpandmul ins(" "${ROUNDTRIP}" >/dev/null + +[[ $(grep -Fc "pto.trowexpandexpdif ins(" "${ROUNDTRIP}") -eq 2 ]] +[[ $(grep -Fc "pto.trowexpandmax ins(" "${ROUNDTRIP}") -eq 2 ]] +[[ $(grep -Fc "pto.trowexpandmin ins(" "${ROUNDTRIP}") -eq 2 ]] From 5dd7daa7bbeac094da1578bf0ee122d2f482d6ee Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Thu, 2 Apr 2026 15:15:02 +0800 Subject: [PATCH 4/4] chore(ci): add missing license headers --- tools/ptobc/generated/ptobc_opcodes_v0.h | 8 ++++++++ tools/ptobc/tests/recent_ops_v0_encode.sh | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index fbca6e60d..6ce728bb2 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -1,3 +1,11 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// // Generated by docs/bytecode/tools/gen_v0_tables.py #pragma once diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index a01f778bc..15ac38db2 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + set -euo pipefail PTOBC_BIN=${PTOBC_BIN:-}