diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb4749ec..7f1445ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -261,7 +261,7 @@ jobs: # Temporary CI gate: skip cases that still error/flap on the remote NPU. # Update this list as we fix the underlying issues. DEFAULT_SKIP_CASES: >- - mix_kernel,vadd_validshape,vadd_validshape_dynamic,print,storefp + mix_kernel,vadd_validshape,vadd_validshape_dynamic,print,storefp,Gemvmx steps: - name: Resolve validation parameters shell: bash diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index 5f13ab32..c5c2630b 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -1400,6 +1400,95 @@ pto.tgemv.bias ins(%a, %b, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto. --- +##### `pto.tgemv.mx` - Mixed-Precision Matrix-Vector Multiply + +**Summary:** Mixed-precision GEMV with explicit A/B scaling tiles. + +**Semantics:** + +``` +dst = gemv(a, b) // quantization/mixed-precision behavior is target-defined +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `a` | `pto.tile_buf` | Matrix tile (`loc=left`) | +| `a_scale` | `pto.tile_buf` | Scale tile associated with `a` | +| `b` | `pto.tile_buf` | Vector tile (`loc=right`) | +| `b_scale` | `pto.tile_buf` | Scale tile associated with `b` | +| `dst` | `pto.tile_buf` | Destination accumulator tile (`loc=acc`) | + +**Results:** None. Writes into `dst` via DPS pattern. + +**Constraints & Verification:** + +- `a/b/dst` reuse the same GEMV shape/location checks as `pto.tgemv`. +- `a_scale` and `b_scale` must be valid tile buffers. + +**Hardware Mapping:** + +- Executes on the **Matrix pipeline** (`PIPE_M`) + +**Basic Example:** + +```mlir +pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%c : !pto.tile_buf<...>) +``` + +--- + +##### `pto.tgemv.mx.acc` - Mixed-Precision GEMV with Accumulation + +**Summary:** Mixed-precision GEMV accumulation form using scale tiles. + +**Semantics:** + +``` +dst = c_in + gemv(a, b) +``` + +**Arguments:** `c_in, a, a_scale, b, b_scale, dst` + +**Hardware Mapping:** Matrix pipeline (`PIPE_M`) + +**Basic Example:** + +```mlir +pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%c_out : !pto.tile_buf<...>) +``` + +--- + +##### `pto.tgemv.mx.bias` - Mixed-Precision GEMV with Bias + +**Summary:** Mixed-precision GEMV bias form using scale tiles. + +**Semantics:** + +``` +dst = gemv(a, b) + bias +``` + +**Arguments:** `a, a_scale, b, b_scale, bias, dst` + +**Hardware Mapping:** Matrix pipeline (`PIPE_M`) + +**Basic Example:** + +```mlir +pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%c : !pto.tile_buf<...>) +``` + +--- + ### 4.5 Vector Arithmetic Operations All vector arithmetic operations execute on the **Vector pipeline** (`PIPE_V`) and use `ins`/`outs` with tile buffers in the **VEC (UB)** memory space. diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 28fe6a0e..fc392bd4 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -918,6 +918,107 @@ def TGemvBiasOp : PTO_TOp<"tgemv.bias", [ }]; } +def TGemvMxOp : PTO_TOp<"tgemv.mx", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Mixed-precision GEMV with scale tiles (tile world, ins/outs)."; + + let arguments = (ins + PTODpsType:$a, + PTODpsType:$a_scale, + PTODpsType:$b, + PTODpsType:$b_scale, + PTODpsType:$dst + ); + + let results = (outs Optional:$result); + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $a `,` $a_scale `,` $b `,` $b_scale + `:` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + (`->` qualified(type($result))^)? + }]; + + let extraClassDeclaration = [{ + static StringRef getIntrinsicName() { return "TGEMV_MX"; } + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TGemvMxAccOp : PTO_TOp<"tgemv.mx.acc", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Mixed-precision GEMV accumulate with scale tiles (tile world, ins/outs)."; + + let arguments = (ins + PTODpsType:$c_in, + PTODpsType:$a, + PTODpsType:$a_scale, + PTODpsType:$b, + PTODpsType:$b_scale, + PTODpsType:$dst + ); + + let results = (outs Optional:$result); + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $c_in `,` $a `,` $a_scale `,` $b `,` $b_scale + `:` type($c_in) `,` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + (`->` qualified(type($result))^)? + }]; + + let extraClassDeclaration = [{ + static StringRef getIntrinsicName() { return "TGEMV_MX"; } + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TGemvMxBiasOp : PTO_TOp<"tgemv.mx.bias", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Mixed-precision GEMV with bias and scale tiles (tile world, ins/outs)."; + + let arguments = (ins + PTODpsType:$a, + PTODpsType:$a_scale, + PTODpsType:$b, + PTODpsType:$b_scale, + PTODpsType:$bias, + PTODpsType:$dst + ); + + let results = (outs Optional:$result); + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $a `,` $a_scale `,` $b `,` $b_scale `,` $bias + `:` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `,` qualified(type($bias)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + (`->` qualified(type($result))^)? + }]; + + let extraClassDeclaration = [{ + static StringRef getIntrinsicName() { return "TGEMV_MX"; } + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + def TMovOp : PTO_TOp<"tmov", [ PTO_DpsInitOpInterface, OpPipeInterface, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 51311a25..0c9bb7b5 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -1936,6 +1936,45 @@ static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type r return success(); } +static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy, + Type operandTy, + StringRef scaleName, + StringRef operandName) { + if (failed(verifyTileBufCommon(op, scaleTy, scaleName))) + return failure(); + auto scaleSpace = getPTOMemorySpaceEnum(scaleTy); + if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING) + return op->emitOpError() << "expects " << scaleName + << " to be in the scaling address space"; + + auto scaleShape = getShapeVec(scaleTy); + auto operandShape = getShapeVec(operandTy); + if (scaleShape.size() != operandShape.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same rank"; + for (size_t i = 0; i < scaleShape.size(); ++i) { + if (scaleShape[i] != ShapedType::kDynamic && + operandShape[i] != ShapedType::kDynamic && + scaleShape[i] != operandShape[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same shape"; + } + + auto scaleValid = getValidShapeVec(scaleTy); + auto operandValid = getValidShapeVec(operandTy); + if (scaleValid.size() != operandValid.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + for (size_t i = 0; i < scaleValid.size(); ++i) { + if (scaleValid[i] != ShapedType::kDynamic && + operandValid[i] != ShapedType::kDynamic && + scaleValid[i] != operandValid[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + } + return success(); +} + static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, Type src1Ty, Type dstTy) { auto src0Valid = getValidShapeVec(src0Ty); @@ -3606,6 +3645,29 @@ static bool isA5Fp8LikeType(Type ty) { return false; } +static bool isA5MxInputType(Type ty) { + return isA5Fp8LikeType(ty); +} + +static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy, StringRef lhsName, + StringRef rhsName, StringRef dstName) { + Type lhsElem = getElemTy(lhsTy); + Type rhsElem = getElemTy(rhsTy); + Type dstElem = getElemTy(dstTy); + + if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) + return op->emitOpError() + << "expects A5 mx operands " << lhsName << " and " << rhsName + << " to use fp8 element types"; + + if (!dstElem.isF32()) + return op->emitOpError() + << "expects A5 mx result " << dstName << " to use f32 element type"; + + return success(); +} + static bool isA5VectorPreQuantTypePair(Type srcElem, Type dstElem) { if (srcElem.isF32()) return dstElem.isInteger(8) || isA5Fp8LikeType(dstElem) || dstElem.isF16() || @@ -4665,6 +4727,87 @@ LogicalResult TGemvBiasOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +LogicalResult TGemvMxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxAccOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.acc is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || + failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.bias is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), + /*requireFloatBias=*/true))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + auto biasShape = getShapeVec(getBias().getType()); + auto dstShape = getShapeVec(getDst().getType()); + if (biasShape.size() != 2 || dstShape.size() != 2) + return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias"); + if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + biasShape[1] != dstShape[1]) + return emitOpError("expects bias and dst to have the same column shape"); + if (failed(verifyTileBufSameValidShape(*this, getBias().getType(), + getDst().getType(), "bias", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + LogicalResult TMatmulBiasOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), @@ -4690,7 +4833,12 @@ LogicalResult TMatmulMxOp::verify() { return verifyMatmulLike(*this, getA().getType(), getB().getType(), getDst().getType()); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } @@ -4702,7 +4850,19 @@ LogicalResult TMatmulMxAccOp::verify() { return failure(); return success(); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return success(); + }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } LogicalResult TMatmulMxBiasOp::verify() { @@ -4717,7 +4877,12 @@ LogicalResult TMatmulMxBiasOp::verify() { return verifyMatmulLike(*this, getA().getType(), getB().getType(), getDst().getType()); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } // ---- TSetValOp ---- @@ -8618,6 +8783,38 @@ void TGemvBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxAccOp === +// Read: c_in, a, a_scale, b, b_scale, Write: dst +void TGemvMxAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxBiasOp === +// Read: a, a_scale, b, b_scale, bias, Write: dst +void TGemvMxBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + // === TMatmulOp === void TMatmulMxOp::getEffects(SmallVectorImpl> &effects) { addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 492e8253..024eec0f 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -152,6 +152,51 @@ static std::string layoutToEmitCString(mlir::pto::Layout layout) { return "pto::Layout::ND"; } +static std::string getEmitCScalarTypeToken(Type elemTy) { + if (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ()) + return "float8_e4m3_t"; + if (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ()) + return "float8_e5m2_t"; + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) + return (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) ? "int8_t" + : "uint8_t"; + if (elemTy.isInteger(16)) + return (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + ? "int16_t" + : "uint16_t"; + if (elemTy.isInteger(32)) + return (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + ? "int32_t" + : "uint32_t"; + if (elemTy.isInteger(64)) + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + return "float"; +} + +static int64_t getEmitCScalarByteWidth(Type elemTy) { + if (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ() || + elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ() || + elemTy.isInteger(8)) + return 1; + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) + return 2; + if (elemTy.isF32() || elemTy.isInteger(32)) + return 4; + if (elemTy.isF64() || elemTy.isInteger(64)) + return 8; + return 4; +} + //===----------------------------------------------------------------------===// // Type Converter //===----------------------------------------------------------------------===// @@ -163,6 +208,11 @@ class PTOToEmitCTypeConverter : public TypeConverter { // 1. 基本类型 (f32, i32, index) // --------------------------------------------------------- addConversion([Ctx](FloatType type) -> Type { + if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); + if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); if (type.isBF16()) return emitc::OpaqueType::get(Ctx, "bfloat16_t"); @@ -3027,36 +3077,6 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } }; -//===----------------------------------------------------------------------===// -// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) -//===----------------------------------------------------------------------===// - -static std::string getElemTypeStringForGT(Type elemTy) { - if (elemTy.isF16()) return "half"; - if (elemTy.isBF16()) return "bfloat16_t"; - if (elemTy.isF32()) return "float"; - if (elemTy.isF64()) return "double"; - if (elemTy.isInteger(8)) { - if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) - return "int8_t"; - return "uint8_t"; - } - if (elemTy.isInteger(16)) { - if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - return "int16_t"; - return "uint16_t"; - } - if (elemTy.isInteger(32)) { - if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - return "int32_t"; - return "uint32_t"; - } - if (elemTy.isInteger(64)) { - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - } - return "float"; -} - static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, Location loc, Value basePtr, MemRefType mrTy, @@ -3104,7 +3124,7 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, std::string strideTypeName = "GTStride" + suffix; std::string gtTypeName = "GT" + suffix; - std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); + std::string elemTypeStr = getEmitCScalarTypeToken(mrTy.getElementType()); SmallVector shapeParamsVec; SmallVector strideParamsVec; @@ -3320,14 +3340,7 @@ struct PointerCastConversion : public OpConversionPattern { TileRole role = inferRole(op); // 2. 类型字符串生成 (elemTypeStr, dimStr) - std::string elemTypeStr = "T"; - if (elemType.isF16()) elemTypeStr = "half"; - else if (elemType.isBF16()) elemTypeStr = "bfloat16_t"; - else if (elemType.isF32()) elemTypeStr = "float"; - else if (elemType.isInteger(8)) elemTypeStr = cast(elemType).isUnsigned() ? "uint8_t" : "int8_t"; - else if (elemType.isInteger(16)) elemTypeStr = cast(elemType).isUnsigned() ? "uint16_t" : "int16_t"; - else if (elemType.isInteger(32)) elemTypeStr = cast(elemType).isUnsigned() ? "uint32_t" : "int32_t"; - else if (elemType.isInteger(64)) elemTypeStr = cast(elemType).isUnsigned() ? "uint64_t" : "int64_t"; + std::string elemTypeStr = getEmitCScalarTypeToken(elemType); std::string dimStr; auto dimToString = [](int64_t dim, const char* symbol) -> std::string { @@ -5006,30 +5019,9 @@ struct ReinterpretCastToEmitC : public OpConversionPattern(elemTy).isUnsigned() ? "uint8_t" : "int8_t"; - else if (elemTy.isInteger(16)) - elemBytes = 2, - elemTok = cast(elemTy).isUnsigned() ? "uint16_t" : "int16_t"; - else if (elemTy.isInteger(32)) - elemBytes = 4, - elemTok = cast(elemTy).isUnsigned() ? "uint32_t" : "int32_t"; - else if (elemTy.isInteger(64)) - elemBytes = 8, - elemTok = cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + std::string elemTok = getEmitCScalarTypeToken(elemTy); + int64_t elemBytes = getEmitCScalarByteWidth(elemTy); // Tile role. const char *roleTok = "TileType::Vec"; @@ -6912,6 +6904,18 @@ static void replaceOrEraseWithOpaqueCall(Operation *op, rewriter.replaceOp(op, call.getResults()); } +static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + rewriter.create( + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (op->getNumResults() == 1) + rewriter.replaceOp(op, dst); + else + rewriter.eraseOp(op); +} + // ---------- TOp ---------- struct PTOTGemvBiasToTGEMV_BIAS : public OpConversionPattern { @@ -6930,6 +6934,62 @@ struct PTOTGemvBiasToTGEMV_BIAS } }; +struct PTOTGemvMXToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXAccToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXBiasToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + struct PTOTMatmulBiasToTMATMUL_BIAS : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -7929,33 +7989,7 @@ struct PTOBindTileToEmitC : public OpConversionPattern { }; auto emitElemTypeToString = [&](Type elemTy) -> std::string { - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) { - if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) - return "int8_t"; - return "uint8_t"; - } - if (elemTy.isInteger(16)) { - if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - return "int16_t"; - return "uint16_t"; - } - if (elemTy.isInteger(32)) { - if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - return "int32_t"; - return "uint32_t"; - } - if (elemTy.isInteger(64)) { - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - } - return "float"; + return getEmitCScalarTypeToken(elemTy); }; auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { @@ -8841,6 +8875,9 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, PTOTMatmulMXAccToTMATMUL_MX_ACC, PTOTMatmulMXBiasToTMATMUL_MX_BIAS, PTOTGemvBiasToTGEMV_BIAS, + PTOTGemvMXToTGEMV_MX, + PTOTGemvMXAccToTGEMV_MX, + PTOTGemvMXBiasToTGEMV_MX, PTOBarrierToEmitC >(typeConverter, ctx); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index eaf3cf0e..60ac7454 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1509,6 +1509,41 @@ struct PTOViewToMemrefPass op->getOperand(0), op->getOperand(1), op->getOperand(2), op->getOperand(3)); } + // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- + SmallVector gemvMxs; + func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); + for (auto op : gemvMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), op->getOperand(2), op->getOperand(3), op->getOperand(4)); + } + + // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- + SmallVector gemvMxAccs; + func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); + for (auto op : gemvMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), op->getOperand(2), + op->getOperand(3), op->getOperand(4), op->getOperand(5)); + } + + // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- + SmallVector gemvMxBiass; + func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); + for (auto op : gemvMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), op->getOperand(2), + op->getOperand(3), op->getOperand(4), op->getOperand(5)); + } + // --- TMovOp [Src, Dst] --- SmallVector movs; func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); diff --git a/test/basic/tgemv_mx_emitc.pto b/test/basic/tgemv_mx_emitc.pto new file mode 100644 index 00000000..66077e63 --- /dev/null +++ b/test/basic/tgemv_mx_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module { + func.func @tgemv_mx_emitc() { + %a = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + + return + } +} + +// CHECK-LABEL: __global__ AICORE void tgemv_mx_emitc() +// CHECK: TGEMV_MX( diff --git a/test/basic/tgemv_mx_variants_emitc.pto b/test/basic/tgemv_mx_variants_emitc.pto new file mode 100644 index 00000000..6c2551e4 --- /dev/null +++ b/test/basic/tgemv_mx_variants_emitc.pto @@ -0,0 +1,23 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module { + func.func @tgemv_mx_variants_emitc() { + %a = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) + + return + } +} + +// CHECK-LABEL: __global__ AICORE void tgemv_mx_variants_emitc() +// CHECK: TGEMV_MX( +// CHECK: TGEMV_MX( diff --git a/test/npu_validation/scripts/run_remote_npu_validation.sh b/test/npu_validation/scripts/run_remote_npu_validation.sh index c86fc6ab..e25ad352 100644 --- a/test/npu_validation/scripts/run_remote_npu_validation.sh +++ b/test/npu_validation/scripts/run_remote_npu_validation.sh @@ -241,6 +241,21 @@ while IFS= read -r -d '' cpp; do log "SKIP: ${testcase} (SKIP_CASES)" continue fi + if [[ "${testcase}" == "gemvmx" ]]; then + soc_lc="$(printf '%s' "${SOC_VERSION:-}" | tr '[:upper:]' '[:lower:]')" + if [[ "$soc_lc" != *"a5"* && "$soc_lc" != *"950"* ]]; then + skip_count=$((skip_count + 1)) + printf "%s\tSKIP\t%s\trequires A5 (set SOC_VERSION to A5/950)\n" "${testcase}" "${STAGE}" >> "${RESULTS_TSV}" + log "SKIP: ${testcase} (requires A5 SOC_VERSION)" + continue + fi + if [[ "${PTOAS_BOARD_IS_A3:-0}" == "1" ]]; then + skip_count=$((skip_count + 1)) + printf "%s\tSKIP\t%s\trequires A5 board\n" "${testcase}" "${STAGE}" >> "${RESULTS_TSV}" + log "SKIP: ${testcase} (requires A5 board)" + continue + fi + fi echo log "=== CASE: ${cpp} ===" diff --git a/test/samples/Gemvmx/gemvmx.py b/test/samples/Gemvmx/gemvmx.py new file mode 100644 index 00000000..8c639f23 --- /dev/null +++ b/test/samples/Gemvmx/gemvmx.py @@ -0,0 +1,211 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr +from mlir.dialects import func, arith, pto +from mlir.ir import F16Type, F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + + f16 = F16Type.get(ctx) + f32 = F32Type.get(ctx) + ptr_f16 = pto.PtrType.get(f16, ctx) + ptr_f32 = pto.PtrType.get(f32, ctx) + + # TGEMV_MX family: + # TGEMV_MX(dst, a, a_scale, b, b_scale) + # TGEMV_MX(dst, c_in, a, a_scale, b, b_scale) + # TGEMV_MX(dst, a, a_scale, b, b_scale, bias) + M = 1 + M_ALIGN = M + K = 128 + N = 16 + + tv2_f16 = pto.TensorViewType.get(2, f16, ctx) + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_a = pto.PartitionTensorViewType.get([M, K], f16, ctx) + tile_view_b = pto.PartitionTensorViewType.get([K, N], f16, ctx) + tile_view_as = pto.PartitionTensorViewType.get([M, K], f16, ctx) + tile_view_bs = pto.PartitionTensorViewType.get([K, N], f16, ctx) + tile_view_c = pto.PartitionTensorViewType.get([M, N], f32, ctx) + tile_view_bias = pto.PartitionTensorViewType.get([M, N], f32, ctx) + + mat = pto.AddressSpaceAttr.get(pto.AddressSpace.MAT, ctx) + left = pto.AddressSpaceAttr.get(pto.AddressSpace.LEFT, ctx) + right = pto.AddressSpaceAttr.get(pto.AddressSpace.RIGHT, ctx) + scaling = pto.AddressSpaceAttr.get(pto.AddressSpace.SCALING, ctx) + acc = pto.AddressSpaceAttr.get(pto.AddressSpace.ACC, ctx) + bias_space = pto.AddressSpaceAttr.get(pto.AddressSpace.BIAS, ctx) + + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + cfg_a_mat = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + cfg_b_mat = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + cfg_left = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + cfg_right = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.ColMajor, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + cfg_scaling = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + cfg_acc = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx), + pto.TileConfig.fractalCSize, + pd, + ctx, + ) + cfg_bias = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx), + pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx), + pto.TileConfig.fractalABSize, + pd, + ctx, + ) + + tile_buf_a_mat = pto.TileBufType.get([M, K], f16, mat, [M, K], cfg_a_mat, ctx) + tile_buf_b_mat = pto.TileBufType.get([K, N], f16, mat, [K, N], cfg_b_mat, ctx) + tile_buf_as_mat = pto.TileBufType.get([M, K], f16, mat, [M, K], cfg_a_mat, ctx) + tile_buf_bs_mat = pto.TileBufType.get([K, N], f16, mat, [K, N], cfg_b_mat, ctx) + tile_buf_bias_mat = pto.TileBufType.get([M, N], f32, mat, [M, N], cfg_bias, ctx) + + tile_buf_a = pto.TileBufType.get([M, K], f16, left, [M, K], cfg_left, ctx) + tile_buf_b = pto.TileBufType.get([K, N], f16, right, [K, N], cfg_right, ctx) + tile_buf_as = pto.TileBufType.get([M, K], f16, scaling, [M, K], cfg_scaling, ctx) + tile_buf_bs = pto.TileBufType.get([K, N], f16, scaling, [K, N], cfg_scaling, ctx) + tile_buf_bias = pto.TileBufType.get([M, N], f32, bias_space, [M, N], cfg_bias, ctx) + tile_buf_c = pto.TileBufType.get([M_ALIGN, N], f32, acc, [M, N], cfg_acc, ctx) + + fn_ty = func.FunctionType.get( + [ + ptr_f16, # a + ptr_f16, # b + ptr_f16, # a_scale + ptr_f16, # b_scale + ptr_f32, # bias (for mx.bias) + ptr_f32, # out_mx + ptr_f32, # out_mx_acc + ptr_f32, # out_mx_bias + ], + [], + ) + with InsertionPoint(m.body): + fn = func.FuncOp("gemvmx_kernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(IndexType.get(ctx), 0).result + c1 = arith.ConstantOp(IndexType.get(ctx), 1).result + cM = arith.ConstantOp(IndexType.get(ctx), M).result + cK = arith.ConstantOp(IndexType.get(ctx), K).result + cN = arith.ConstantOp(IndexType.get(ctx), N).result + + ( + arg_a, + arg_b, + arg_as, + arg_bs, + arg_bias, + arg_out_mx, + arg_out_mx_acc, + arg_out_mx_bias, + ) = entry.arguments + + tv_a = pto.MakeTensorViewOp(tv2_f16, arg_a, [cM, cK], [cK, c1]).result + tv_b = pto.MakeTensorViewOp(tv2_f16, arg_b, [cK, cN], [cN, c1]).result + tv_as = pto.MakeTensorViewOp(tv2_f16, arg_as, [cM, cK], [cK, c1]).result + tv_bs = pto.MakeTensorViewOp(tv2_f16, arg_bs, [cK, cN], [cN, c1]).result + tv_bias = pto.MakeTensorViewOp(tv2_f32, arg_bias, [cM, cN], [cN, c1]).result + tv_out_mx = pto.MakeTensorViewOp(tv2_f32, arg_out_mx, [cM, cN], [cN, c1]).result + tv_out_mx_acc = pto.MakeTensorViewOp(tv2_f32, arg_out_mx_acc, [cM, cN], [cN, c1]).result + tv_out_mx_bias = pto.MakeTensorViewOp(tv2_f32, arg_out_mx_bias, [cM, cN], [cN, c1]).result + + sv_a = pto.PartitionViewOp(tile_view_a, tv_a, offsets=[c0, c0], sizes=[cM, cK]).result + sv_b = pto.PartitionViewOp(tile_view_b, tv_b, offsets=[c0, c0], sizes=[cK, cN]).result + sv_as = pto.PartitionViewOp(tile_view_as, tv_as, offsets=[c0, c0], sizes=[cM, cK]).result + sv_bs = pto.PartitionViewOp(tile_view_bs, tv_bs, offsets=[c0, c0], sizes=[cK, cN]).result + sv_bias = pto.PartitionViewOp(tile_view_bias, tv_bias, offsets=[c0, c0], sizes=[cM, cN]).result + sv_out_mx = pto.PartitionViewOp(tile_view_c, tv_out_mx, offsets=[c0, c0], sizes=[cM, cN]).result + sv_out_mx_acc = pto.PartitionViewOp(tile_view_c, tv_out_mx_acc, offsets=[c0, c0], sizes=[cM, cN]).result + sv_out_mx_bias = pto.PartitionViewOp(tile_view_c, tv_out_mx_bias, offsets=[c0, c0], sizes=[cM, cN]).result + + a_mat = pto.AllocTileOp(tile_buf_a_mat).result + b_mat = pto.AllocTileOp(tile_buf_b_mat).result + as_mat = pto.AllocTileOp(tile_buf_as_mat).result + bs_mat = pto.AllocTileOp(tile_buf_bs_mat).result + bias_mat = pto.AllocTileOp(tile_buf_bias_mat).result + a_tile = pto.AllocTileOp(tile_buf_a).result + b_tile = pto.AllocTileOp(tile_buf_b).result + as_tile = pto.AllocTileOp(tile_buf_as).result + bs_tile = pto.AllocTileOp(tile_buf_bs).result + bias_tile = pto.AllocTileOp(tile_buf_bias).result + c_mx_tile = pto.AllocTileOp(tile_buf_c).result + c_mx_acc_tile = pto.AllocTileOp(tile_buf_c).result + c_mx_bias_tile = pto.AllocTileOp(tile_buf_c).result + + pto.TLoadOp(None, sv_a, a_mat) + pto.TLoadOp(None, sv_b, b_mat) + pto.TLoadOp(None, sv_as, as_mat) + pto.TLoadOp(None, sv_bs, bs_mat) + pto.TLoadOp(None, sv_bias, bias_mat) + + pto.TExtractOp(a_mat, c0, c0, a_tile) + pto.TMovOp(None, b_mat, b_tile) + pto.TMovOp(None, as_mat, as_tile) + pto.TMovOp(None, bs_mat, bs_tile) + pto.TMovOp(None, bias_mat, bias_tile) + + pto.TGemvMxOp(None, a_tile, as_tile, b_tile, bs_tile, c_mx_tile) + pto.TGemvMxAccOp(None, c_mx_tile, a_tile, as_tile, b_tile, bs_tile, c_mx_acc_tile) + pto.TGemvMxBiasOp(None, a_tile, as_tile, b_tile, bs_tile, bias_tile, c_mx_bias_tile) + + pto.TStoreOp(None, c_mx_tile, sv_out_mx) + pto.TStoreOp(None, c_mx_acc_tile, sv_out_mx_acc) + pto.TStoreOp(None, c_mx_bias_tile, sv_out_mx_bias) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index bd992aea..1bd7f838 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -235,7 +235,6 @@ process_one_dir() { continue fi fi - # Inter-core sync regression samples are arch-specific. if [[ "$base" == "test_intercore_sync_a5" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a5" ]]; then echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a5" @@ -253,6 +252,10 @@ process_one_dir() { echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a5" continue fi + if [[ "$base" == "gemvmx" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a5" ]]; then + echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a5" + continue + fi if [[ "$base" == "test_intercore_sync_a3" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a3" ]]; then echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a3" continue diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e7034eca..29170aca 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -940,6 +940,15 @@ int main(int argc, char **argv) { llvm::cl::SetVersionPrinter(printPTOASVersion); + bool cliArchSpecified = false; + for (int i = 1; i < argc; ++i) { + llvm::StringRef arg(argv[i]); + if (arg == "--pto-arch" || arg.starts_with("--pto-arch=")) { + cliArchSpecified = true; + break; + } + } + // Parse command line options llvm::cl::ParseCommandLineOptions(argc, argv, "PTO Assembler (ptoas)\n"); @@ -964,18 +973,37 @@ int main(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); - std::string arch = ptoTargetArch; - for (char &c : arch) - c = static_cast(std::tolower(static_cast(c))); - if (arch != "a3" && arch != "a5") { - llvm::errs() << "Error: invalid --pto-arch='" << ptoTargetArch - << "'. Expected 'a3' or 'a5'.\n"; - return 1; - } - OwningOpRef module; llvm::StringRef buf = (*fileOrErr)->getBuffer(); const bool isPTOBC = (buf.size() >= 6 && std::memcmp(buf.data(), "PTOBC\0", 6) == 0); + auto normalizeArch = [](llvm::StringRef archValue) { + std::string normalized = archValue.str(); + for (char &c : normalized) + c = static_cast(std::tolower(static_cast(c))); + return normalized; + }; + auto detectTextualModuleArch = [&](llvm::StringRef text) -> std::optional { + llvm::SmallVector matches; + llvm::Regex archRegex( + R"ptoarch("?(pto\.target_arch)"?[[:space:]]*=[[:space:]]*"([[:alpha:][:digit:]_]+)")ptoarch"); + if (!archRegex.match(text, &matches) || matches.size() < 3) + return std::nullopt; + return normalizeArch(matches[2]); + }; + + std::string arch = normalizeArch(ptoTargetArch); + if (cliArchSpecified) { + if (arch != "a3" && arch != "a5") { + llvm::errs() << "Error: invalid --pto-arch='" << ptoTargetArch + << "'. Expected 'a3' or 'a5'.\n"; + return 1; + } + } else if (!isPTOBC) { + if (auto detectedArch = detectTextualModuleArch(buf)) + arch = *detectedArch; + } + if (arch != "a3" && arch != "a5") + arch = "a3"; if (isPTOBC) { // Decode PTO bytecode directly into an MLIR module. @@ -1008,10 +1036,13 @@ int main(int argc, char **argv) { } } - // Set target arch on the module from CLI before any passes run. - // This is the single source of truth — input IR does not need pto.target_arch. - module->getOperation()->setAttr("pto.target_arch", - mlir::StringAttr::get(&context, arch)); + // If the CLI explicitly requested an arch, it overrides the input module. + // Otherwise, preserve the textual module's arch when present and only fall + // back to the effective default. + if (cliArchSpecified || !module->getOperation()->hasAttr("pto.target_arch")) { + module->getOperation()->setAttr("pto.target_arch", + mlir::StringAttr::get(&context, arch)); + } PTOBuildLevel effectiveLevel = defaultBuildLevel(); if (!parseBuildLevel(ptoBuildLevel, effectiveLevel)) { diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index d245434c..19abacbd 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -484,6 +484,8 @@ inline std::optional lookupOpcodeAndVariantByFullName(llvm::St .Case("pto.tgemv.acc", OpcodeAndVariant{0x102A, 1, 1}) .Case("pto.tgemv.bias", OpcodeAndVariant{0x102A, 1, 2}) .Case("pto.tgemv.mx", OpcodeAndVariant{0x102A, 1, 3}) + .Case("pto.tgemv.mx.acc", OpcodeAndVariant{0x102A, 1, 4}) + .Case("pto.tgemv.mx.bias", OpcodeAndVariant{0x102A, 1, 5}) .Case("pto.tmatmul", OpcodeAndVariant{0x1032, 1, 0}) .Case("pto.tmatmul.acc", OpcodeAndVariant{0x1032, 1, 1}) .Case("pto.tmatmul.bias", OpcodeAndVariant{0x1032, 1, 2}) @@ -510,6 +512,8 @@ inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { case 1: return "pto.tgemv.acc"; case 2: return "pto.tgemv.bias"; case 3: return "pto.tgemv.mx"; + case 4: return "pto.tgemv.mx.acc"; + case 5: return "pto.tgemv.mx.bias"; default: return info->name; } case 0x1032: @@ -538,6 +542,8 @@ inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t varia case 1: return 4; case 2: return 4; case 3: return 5; + case 4: return 6; + case 5: return 6; default: return std::nullopt; } case 0x1032: diff --git a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto index 5fb4e3ba..732b8a80 100644 --- a/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto +++ b/tools/ptobc/testdata/recent_ops_v0_roundtrip.pto @@ -22,6 +22,18 @@ module { pto.trsqrt ins(%src : !pto.tile_buf) outs(%rs_dst0 : !pto.tile_buf) pto.trsqrt ins(%src, %rs_tmp : !pto.tile_buf, !pto.tile_buf) outs(%rs_dst1 : !pto.tile_buf) pto.tpartmul ins(%part0, %part1 : !pto.tile_buf, !pto.tile_buf) outs(%partdst : !pto.tile_buf) + %a = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %bias_mx = pto.alloc_tile : !pto.tile_buf + %dst_mx = pto.alloc_tile : !pto.tile_buf + %dst_mx_acc = pto.alloc_tile : !pto.tile_buf + %dst_mx_bias = pto.alloc_tile : !pto.tile_buf + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_acc : !pto.tile_buf) + pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias_mx : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst_mx_bias : !pto.tile_buf) return } } diff --git a/tools/ptobc/tests/recent_ops_v0_encode.sh b/tools/ptobc/tests/recent_ops_v0_encode.sh index 7f9c1dc2..2b3afe5e 100755 --- a/tools/ptobc/tests/recent_ops_v0_encode.sh +++ b/tools/ptobc/tests/recent_ops_v0_encode.sh @@ -38,3 +38,6 @@ grep -F "pto.trowexpandmul ins(" "${ROUNDTRIP}" >/dev/null grep -F "pto.trsqrt ins(" "${ROUNDTRIP}" >/dev/null grep -E "pto\\.trsqrt ins\\(%[^,]+, %[^:]+ :" "${ROUNDTRIP}" >/dev/null grep -F "pto.tpartmul ins(" "${ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx ins(" "${ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.acc ins(" "${ROUNDTRIP}" >/dev/null +grep -F "pto.tgemv.mx.bias ins(" "${ROUNDTRIP}" >/dev/null