@@ -1936,6 +1936,45 @@ static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type r
19361936 return success ();
19371937}
19381938
1939+ static LogicalResult verifyScaleTileMatchesOperand (Operation *op, Type scaleTy,
1940+ Type operandTy,
1941+ StringRef scaleName,
1942+ StringRef operandName) {
1943+ if (failed (verifyTileBufCommon (op, scaleTy, scaleName)))
1944+ return failure ();
1945+ auto scaleSpace = getPTOMemorySpaceEnum (scaleTy);
1946+ if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING)
1947+ return op->emitOpError () << " expects " << scaleName
1948+ << " to be in the scaling address space" ;
1949+
1950+ auto scaleShape = getShapeVec (scaleTy);
1951+ auto operandShape = getShapeVec (operandTy);
1952+ if (scaleShape.size () != operandShape.size ())
1953+ return op->emitOpError () << " expects " << scaleName << " and " << operandName
1954+ << " to have the same rank" ;
1955+ for (size_t i = 0 ; i < scaleShape.size (); ++i) {
1956+ if (scaleShape[i] != ShapedType::kDynamic &&
1957+ operandShape[i] != ShapedType::kDynamic &&
1958+ scaleShape[i] != operandShape[i])
1959+ return op->emitOpError () << " expects " << scaleName << " and " << operandName
1960+ << " to have the same shape" ;
1961+ }
1962+
1963+ auto scaleValid = getValidShapeVec (scaleTy);
1964+ auto operandValid = getValidShapeVec (operandTy);
1965+ if (scaleValid.size () != operandValid.size ())
1966+ return op->emitOpError () << " expects " << scaleName << " and " << operandName
1967+ << " to have the same valid_shape" ;
1968+ for (size_t i = 0 ; i < scaleValid.size (); ++i) {
1969+ if (scaleValid[i] != ShapedType::kDynamic &&
1970+ operandValid[i] != ShapedType::kDynamic &&
1971+ scaleValid[i] != operandValid[i])
1972+ return op->emitOpError () << " expects " << scaleName << " and " << operandName
1973+ << " to have the same valid_shape" ;
1974+ }
1975+ return success ();
1976+ }
1977+
19391978static LogicalResult verifyPartialValidPattern (Operation *op, Type src0Ty,
19401979 Type src1Ty, Type dstTy) {
19411980 auto src0Valid = getValidShapeVec (src0Ty);
@@ -4395,6 +4434,78 @@ LogicalResult TGemvBiasOp::verify() {
43954434 return dispatchVerifierByArch (getOperation (), verifyA2A3, verifyA5);
43964435}
43974436
4437+ LogicalResult TGemvMxOp::verify () {
4438+ auto verifyA2A3 = [&]() -> LogicalResult {
4439+ return emitOpError (" tgemv.mx is only supported on A5 targets" );
4440+ };
4441+ auto verifyA5 = [&]() -> LogicalResult {
4442+ if (failed (verifyScaleTileMatchesOperand (*this , getAScale ().getType (),
4443+ getA ().getType (), " a_scale" , " a" )) ||
4444+ failed (verifyScaleTileMatchesOperand (*this , getBScale ().getType (),
4445+ getB ().getType (), " b_scale" , " b" )) ||
4446+ failed (verifyGemvTileOperands (*this , getA ().getType (), getB ().getType (),
4447+ getDst ().getType ())))
4448+ return failure ();
4449+ return verifyMatmulLike (*this , getA ().getType (), getB ().getType (),
4450+ getDst ().getType ());
4451+ };
4452+ return dispatchVerifierByArch (getOperation (), verifyA2A3, verifyA5);
4453+ }
4454+
4455+ LogicalResult TGemvMxAccOp::verify () {
4456+ auto verifyA2A3 = [&]() -> LogicalResult {
4457+ return emitOpError (" tgemv.mx.acc is only supported on A5 targets" );
4458+ };
4459+ auto verifyA5 = [&]() -> LogicalResult {
4460+ if (failed (verifyAccTileCommon (*this , getCIn ().getType (), " c_in" )) ||
4461+ failed (verifyScaleTileMatchesOperand (*this , getAScale ().getType (),
4462+ getA ().getType (), " a_scale" , " a" )) ||
4463+ failed (verifyScaleTileMatchesOperand (*this , getBScale ().getType (),
4464+ getB ().getType (), " b_scale" , " b" )) ||
4465+ failed (verifyGemvTileOperands (*this , getA ().getType (), getB ().getType (),
4466+ getDst ().getType ())))
4467+ return failure ();
4468+ if (failed (verifyTileBufSameShapeAndElem (*this , getCIn ().getType (),
4469+ getDst ().getType (), " c_in" , " dst" )) ||
4470+ failed (verifyTileBufSameValidShape (*this , getCIn ().getType (),
4471+ getDst ().getType (), " c_in" , " dst" )))
4472+ return failure ();
4473+ return verifyMatmulLike (*this , getA ().getType (), getB ().getType (),
4474+ getDst ().getType ());
4475+ };
4476+ return dispatchVerifierByArch (getOperation (), verifyA2A3, verifyA5);
4477+ }
4478+
4479+ LogicalResult TGemvMxBiasOp::verify () {
4480+ auto verifyA2A3 = [&]() -> LogicalResult {
4481+ return emitOpError (" tgemv.mx.bias is only supported on A5 targets" );
4482+ };
4483+ auto verifyA5 = [&]() -> LogicalResult {
4484+ if (failed (verifyScaleTileMatchesOperand (*this , getAScale ().getType (),
4485+ getA ().getType (), " a_scale" , " a" )) ||
4486+ failed (verifyScaleTileMatchesOperand (*this , getBScale ().getType (),
4487+ getB ().getType (), " b_scale" , " b" )) ||
4488+ failed (verifyGemvTileOperands (*this , getA ().getType (), getB ().getType (),
4489+ getDst ().getType ())) ||
4490+ failed (verifyMatBiasTile (*this , getBias ().getType (), getDst ().getType (),
4491+ /* requireFloatBias=*/ true )))
4492+ return failure ();
4493+ auto biasShape = getShapeVec (getBias ().getType ());
4494+ auto dstShape = getShapeVec (getDst ().getType ());
4495+ if (biasShape.size () != 2 || dstShape.size () != 2 )
4496+ return emitOpError (" expects bias and dst to be rank-2 for tgemv.mx.bias" );
4497+ if (biasShape[1 ] != ShapedType::kDynamic && dstShape[1 ] != ShapedType::kDynamic &&
4498+ biasShape[1 ] != dstShape[1 ])
4499+ return emitOpError (" expects bias and dst to have the same column shape" );
4500+ if (failed (verifyTileBufSameValidShape (*this , getBias ().getType (),
4501+ getDst ().getType (), " bias" , " dst" )))
4502+ return failure ();
4503+ return verifyMatmulLike (*this , getA ().getType (), getB ().getType (),
4504+ getDst ().getType ());
4505+ };
4506+ return dispatchVerifierByArch (getOperation (), verifyA2A3, verifyA5);
4507+ }
4508+
43984509LogicalResult TMatmulBiasOp::verify () {
43994510 auto verifyA2A3 = [&]() -> LogicalResult {
44004511 if (failed (verifyMatTileOperands (*this , getA ().getType (), getB ().getType (),
@@ -8210,6 +8321,38 @@ void TGemvBiasOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryE
82108321 addEffect (effects, &getDstMutable (), MemoryEffects::Write::get ());
82118322}
82128323
8324+ // === TGemvMxOp ===
8325+ // Read: a, a_scale, b, b_scale, Write: dst
8326+ void TGemvMxOp::getEffects (SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8327+ addEffect (effects, &getAMutable (), MemoryEffects::Read::get ());
8328+ addEffect (effects, &getAScaleMutable (), MemoryEffects::Read::get ());
8329+ addEffect (effects, &getBMutable (), MemoryEffects::Read::get ());
8330+ addEffect (effects, &getBScaleMutable (), MemoryEffects::Read::get ());
8331+ addEffect (effects, &getDstMutable (), MemoryEffects::Write::get ());
8332+ }
8333+
8334+ // === TGemvMxAccOp ===
8335+ // Read: c_in, a, a_scale, b, b_scale, Write: dst
8336+ void TGemvMxAccOp::getEffects (SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8337+ addEffect (effects, &getCInMutable (), MemoryEffects::Read::get ());
8338+ addEffect (effects, &getAMutable (), MemoryEffects::Read::get ());
8339+ addEffect (effects, &getAScaleMutable (), MemoryEffects::Read::get ());
8340+ addEffect (effects, &getBMutable (), MemoryEffects::Read::get ());
8341+ addEffect (effects, &getBScaleMutable (), MemoryEffects::Read::get ());
8342+ addEffect (effects, &getDstMutable (), MemoryEffects::Write::get ());
8343+ }
8344+
8345+ // === TGemvMxBiasOp ===
8346+ // Read: a, a_scale, b, b_scale, bias, Write: dst
8347+ void TGemvMxBiasOp::getEffects (SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8348+ addEffect (effects, &getAMutable (), MemoryEffects::Read::get ());
8349+ addEffect (effects, &getAScaleMutable (), MemoryEffects::Read::get ());
8350+ addEffect (effects, &getBMutable (), MemoryEffects::Read::get ());
8351+ addEffect (effects, &getBScaleMutable (), MemoryEffects::Read::get ());
8352+ addEffect (effects, &getBiasMutable (), MemoryEffects::Read::get ());
8353+ addEffect (effects, &getDstMutable (), MemoryEffects::Write::get ());
8354+ }
8355+
82138356// === TMatmulOp ===
82148357void TMatmulMxOp::getEffects (SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
82158358 addEffect (effects, &getAMutable (), MemoryEffects::Read::get ());
0 commit comments