Skip to content

Commit 64fcaab

Browse files
committed
feat: Add TGEMV_MX family ops
1 parent 59be7f8 commit 64fcaab

14 files changed

Lines changed: 772 additions & 14 deletions

File tree

docs/PTO_IR_manual.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,95 @@ pto.tgemv.bias ins(%a, %b, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.
14001400

14011401
---
14021402

1403+
##### `pto.tgemv.mx` - Mixed-Precision Matrix-Vector Multiply
1404+
1405+
**Summary:** Mixed-precision GEMV with explicit A/B scaling tiles.
1406+
1407+
**Semantics:**
1408+
1409+
```
1410+
dst = gemv(a, b) // quantization/mixed-precision behavior is target-defined
1411+
```
1412+
1413+
**Arguments:**
1414+
1415+
| Name | Type | Description |
1416+
|------|------|-------------|
1417+
| `a` | `pto.tile_buf` | Matrix tile (`loc=left`) |
1418+
| `a_scale` | `pto.tile_buf` | Scale tile associated with `a` |
1419+
| `b` | `pto.tile_buf` | Vector tile (`loc=right`) |
1420+
| `b_scale` | `pto.tile_buf` | Scale tile associated with `b` |
1421+
| `dst` | `pto.tile_buf` | Destination accumulator tile (`loc=acc`) |
1422+
1423+
**Results:** None. Writes into `dst` via DPS pattern.
1424+
1425+
**Constraints & Verification:**
1426+
1427+
- `a/b/dst` reuse the same GEMV shape/location checks as `pto.tgemv`.
1428+
- `a_scale` and `b_scale` must be valid tile buffers.
1429+
1430+
**Hardware Mapping:**
1431+
1432+
- Executes on the **Matrix pipeline** (`PIPE_M`)
1433+
1434+
**Basic Example:**
1435+
1436+
```mlir
1437+
pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>,
1438+
!pto.tile_buf<...>, !pto.tile_buf<...>)
1439+
outs(%c : !pto.tile_buf<...>)
1440+
```
1441+
1442+
---
1443+
1444+
##### `pto.tgemv.mx.acc` - Mixed-Precision GEMV with Accumulation
1445+
1446+
**Summary:** Mixed-precision GEMV accumulation form using scale tiles.
1447+
1448+
**Semantics:**
1449+
1450+
```
1451+
dst = c_in + gemv(a, b)
1452+
```
1453+
1454+
**Arguments:** `c_in, a, a_scale, b, b_scale, dst`
1455+
1456+
**Hardware Mapping:** Matrix pipeline (`PIPE_M`)
1457+
1458+
**Basic Example:**
1459+
1460+
```mlir
1461+
pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>,
1462+
!pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>)
1463+
outs(%c_out : !pto.tile_buf<...>)
1464+
```
1465+
1466+
---
1467+
1468+
##### `pto.tgemv.mx.bias` - Mixed-Precision GEMV with Bias
1469+
1470+
**Summary:** Mixed-precision GEMV bias form using scale tiles.
1471+
1472+
**Semantics:**
1473+
1474+
```
1475+
dst = gemv(a, b) + bias
1476+
```
1477+
1478+
**Arguments:** `a, a_scale, b, b_scale, bias, dst`
1479+
1480+
**Hardware Mapping:** Matrix pipeline (`PIPE_M`)
1481+
1482+
**Basic Example:**
1483+
1484+
```mlir
1485+
pto.tgemv.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>,
1486+
!pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>)
1487+
outs(%c : !pto.tile_buf<...>)
1488+
```
1489+
1490+
---
1491+
14031492
### 4.5 Vector Arithmetic Operations
14041493

14051494
All vector arithmetic operations execute on the **Vector pipeline** (`PIPE_V`) and use `ins`/`outs` with tile buffers in the **VEC (UB)** memory space.

include/PTO/IR/PTOOps.td

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,107 @@ def TGemvBiasOp : PTO_TOp<"tgemv.bias", [
918918
}];
919919
}
920920

921+
def TGemvMxOp : PTO_TOp<"tgemv.mx", [
922+
PTO_DpsInitOpInterface,
923+
OpPipeInterface,
924+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
925+
]> {
926+
let summary = "Mixed-precision GEMV with scale tiles (tile world, ins/outs).";
927+
928+
let arguments = (ins
929+
PTODpsType:$a,
930+
PTODpsType:$a_scale,
931+
PTODpsType:$b,
932+
PTODpsType:$b_scale,
933+
PTODpsType:$dst
934+
);
935+
936+
let results = (outs Optional<AnyRankedTensor>:$result);
937+
let hasVerifier = 1;
938+
939+
let assemblyFormat = [{
940+
`ins` `(` $a `,` $a_scale `,` $b `,` $b_scale
941+
`:` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `)`
942+
`outs` `(` $dst `:` qualified(type($dst) ) `)`
943+
attr-dict
944+
(`->` qualified(type($result))^)?
945+
}];
946+
947+
let extraClassDeclaration = [{
948+
static StringRef getIntrinsicName() { return "TGEMV_MX"; }
949+
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; }
950+
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
951+
}];
952+
}
953+
954+
def TGemvMxAccOp : PTO_TOp<"tgemv.mx.acc", [
955+
PTO_DpsInitOpInterface,
956+
OpPipeInterface,
957+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
958+
]> {
959+
let summary = "Mixed-precision GEMV accumulate with scale tiles (tile world, ins/outs).";
960+
961+
let arguments = (ins
962+
PTODpsType:$c_in,
963+
PTODpsType:$a,
964+
PTODpsType:$a_scale,
965+
PTODpsType:$b,
966+
PTODpsType:$b_scale,
967+
PTODpsType:$dst
968+
);
969+
970+
let results = (outs Optional<AnyRankedTensor>:$result);
971+
let hasVerifier = 1;
972+
973+
let assemblyFormat = [{
974+
`ins` `(` $c_in `,` $a `,` $a_scale `,` $b `,` $b_scale
975+
`:` type($c_in) `,` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `)`
976+
`outs` `(` $dst `:` qualified(type($dst) ) `)`
977+
attr-dict
978+
(`->` qualified(type($result))^)?
979+
}];
980+
981+
let extraClassDeclaration = [{
982+
static StringRef getIntrinsicName() { return "TGEMV_MX"; }
983+
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; }
984+
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
985+
}];
986+
}
987+
988+
def TGemvMxBiasOp : PTO_TOp<"tgemv.mx.bias", [
989+
PTO_DpsInitOpInterface,
990+
OpPipeInterface,
991+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
992+
]> {
993+
let summary = "Mixed-precision GEMV with bias and scale tiles (tile world, ins/outs).";
994+
995+
let arguments = (ins
996+
PTODpsType:$a,
997+
PTODpsType:$a_scale,
998+
PTODpsType:$b,
999+
PTODpsType:$b_scale,
1000+
PTODpsType:$bias,
1001+
PTODpsType:$dst
1002+
);
1003+
1004+
let results = (outs Optional<AnyRankedTensor>:$result);
1005+
let hasVerifier = 1;
1006+
1007+
let assemblyFormat = [{
1008+
`ins` `(` $a `,` $a_scale `,` $b `,` $b_scale `,` $bias
1009+
`:` type($a) `,` type($a_scale) `,` type($b) `,` type($b_scale) `,` qualified(type($bias)) `)`
1010+
`outs` `(` $dst `:` qualified(type($dst) ) `)`
1011+
attr-dict
1012+
(`->` qualified(type($result))^)?
1013+
}];
1014+
1015+
let extraClassDeclaration = [{
1016+
static StringRef getIntrinsicName() { return "TGEMV_MX"; }
1017+
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_M; }
1018+
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
1019+
}];
1020+
}
1021+
9211022
def TMovOp : PTO_TOp<"tmov", [
9221023
PTO_DpsInitOpInterface,
9231024
OpPipeInterface,

lib/PTO/IR/PTO.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,45 @@ static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type r
19361936
return success();
19371937
}
19381938

1939+
static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy,
1940+
Type operandTy,
1941+
StringRef scaleName,
1942+
StringRef operandName) {
1943+
if (failed(verifyTileBufCommon(op, scaleTy, scaleName)))
1944+
return failure();
1945+
auto scaleSpace = getPTOMemorySpaceEnum(scaleTy);
1946+
if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING)
1947+
return op->emitOpError() << "expects " << scaleName
1948+
<< " to be in the scaling address space";
1949+
1950+
auto scaleShape = getShapeVec(scaleTy);
1951+
auto operandShape = getShapeVec(operandTy);
1952+
if (scaleShape.size() != operandShape.size())
1953+
return op->emitOpError() << "expects " << scaleName << " and " << operandName
1954+
<< " to have the same rank";
1955+
for (size_t i = 0; i < scaleShape.size(); ++i) {
1956+
if (scaleShape[i] != ShapedType::kDynamic &&
1957+
operandShape[i] != ShapedType::kDynamic &&
1958+
scaleShape[i] != operandShape[i])
1959+
return op->emitOpError() << "expects " << scaleName << " and " << operandName
1960+
<< " to have the same shape";
1961+
}
1962+
1963+
auto scaleValid = getValidShapeVec(scaleTy);
1964+
auto operandValid = getValidShapeVec(operandTy);
1965+
if (scaleValid.size() != operandValid.size())
1966+
return op->emitOpError() << "expects " << scaleName << " and " << operandName
1967+
<< " to have the same valid_shape";
1968+
for (size_t i = 0; i < scaleValid.size(); ++i) {
1969+
if (scaleValid[i] != ShapedType::kDynamic &&
1970+
operandValid[i] != ShapedType::kDynamic &&
1971+
scaleValid[i] != operandValid[i])
1972+
return op->emitOpError() << "expects " << scaleName << " and " << operandName
1973+
<< " to have the same valid_shape";
1974+
}
1975+
return success();
1976+
}
1977+
19391978
static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty,
19401979
Type src1Ty, Type dstTy) {
19411980
auto src0Valid = getValidShapeVec(src0Ty);
@@ -4395,6 +4434,78 @@ LogicalResult TGemvBiasOp::verify() {
43954434
return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5);
43964435
}
43974436

4437+
LogicalResult TGemvMxOp::verify() {
4438+
auto verifyA2A3 = [&]() -> LogicalResult {
4439+
return emitOpError("tgemv.mx is only supported on A5 targets");
4440+
};
4441+
auto verifyA5 = [&]() -> LogicalResult {
4442+
if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(),
4443+
getA().getType(), "a_scale", "a")) ||
4444+
failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(),
4445+
getB().getType(), "b_scale", "b")) ||
4446+
failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(),
4447+
getDst().getType())))
4448+
return failure();
4449+
return verifyMatmulLike(*this, getA().getType(), getB().getType(),
4450+
getDst().getType());
4451+
};
4452+
return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5);
4453+
}
4454+
4455+
LogicalResult TGemvMxAccOp::verify() {
4456+
auto verifyA2A3 = [&]() -> LogicalResult {
4457+
return emitOpError("tgemv.mx.acc is only supported on A5 targets");
4458+
};
4459+
auto verifyA5 = [&]() -> LogicalResult {
4460+
if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) ||
4461+
failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(),
4462+
getA().getType(), "a_scale", "a")) ||
4463+
failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(),
4464+
getB().getType(), "b_scale", "b")) ||
4465+
failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(),
4466+
getDst().getType())))
4467+
return failure();
4468+
if (failed(verifyTileBufSameShapeAndElem(*this, getCIn().getType(),
4469+
getDst().getType(), "c_in", "dst")) ||
4470+
failed(verifyTileBufSameValidShape(*this, getCIn().getType(),
4471+
getDst().getType(), "c_in", "dst")))
4472+
return failure();
4473+
return verifyMatmulLike(*this, getA().getType(), getB().getType(),
4474+
getDst().getType());
4475+
};
4476+
return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5);
4477+
}
4478+
4479+
LogicalResult TGemvMxBiasOp::verify() {
4480+
auto verifyA2A3 = [&]() -> LogicalResult {
4481+
return emitOpError("tgemv.mx.bias is only supported on A5 targets");
4482+
};
4483+
auto verifyA5 = [&]() -> LogicalResult {
4484+
if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(),
4485+
getA().getType(), "a_scale", "a")) ||
4486+
failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(),
4487+
getB().getType(), "b_scale", "b")) ||
4488+
failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(),
4489+
getDst().getType())) ||
4490+
failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(),
4491+
/*requireFloatBias=*/true)))
4492+
return failure();
4493+
auto biasShape = getShapeVec(getBias().getType());
4494+
auto dstShape = getShapeVec(getDst().getType());
4495+
if (biasShape.size() != 2 || dstShape.size() != 2)
4496+
return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias");
4497+
if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic &&
4498+
biasShape[1] != dstShape[1])
4499+
return emitOpError("expects bias and dst to have the same column shape");
4500+
if (failed(verifyTileBufSameValidShape(*this, getBias().getType(),
4501+
getDst().getType(), "bias", "dst")))
4502+
return failure();
4503+
return verifyMatmulLike(*this, getA().getType(), getB().getType(),
4504+
getDst().getType());
4505+
};
4506+
return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5);
4507+
}
4508+
43984509
LogicalResult TMatmulBiasOp::verify() {
43994510
auto verifyA2A3 = [&]() -> LogicalResult {
44004511
if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(),
@@ -8210,6 +8321,38 @@ void TGemvBiasOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryE
82108321
addEffect(effects, &getDstMutable(), MemoryEffects::Write::get());
82118322
}
82128323

8324+
// === TGemvMxOp ===
8325+
// Read: a, a_scale, b, b_scale, Write: dst
8326+
void TGemvMxOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8327+
addEffect(effects, &getAMutable(), MemoryEffects::Read::get());
8328+
addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get());
8329+
addEffect(effects, &getBMutable(), MemoryEffects::Read::get());
8330+
addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get());
8331+
addEffect(effects, &getDstMutable(), MemoryEffects::Write::get());
8332+
}
8333+
8334+
// === TGemvMxAccOp ===
8335+
// Read: c_in, a, a_scale, b, b_scale, Write: dst
8336+
void TGemvMxAccOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8337+
addEffect(effects, &getCInMutable(), MemoryEffects::Read::get());
8338+
addEffect(effects, &getAMutable(), MemoryEffects::Read::get());
8339+
addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get());
8340+
addEffect(effects, &getBMutable(), MemoryEffects::Read::get());
8341+
addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get());
8342+
addEffect(effects, &getDstMutable(), MemoryEffects::Write::get());
8343+
}
8344+
8345+
// === TGemvMxBiasOp ===
8346+
// Read: a, a_scale, b, b_scale, bias, Write: dst
8347+
void TGemvMxBiasOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
8348+
addEffect(effects, &getAMutable(), MemoryEffects::Read::get());
8349+
addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get());
8350+
addEffect(effects, &getBMutable(), MemoryEffects::Read::get());
8351+
addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get());
8352+
addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get());
8353+
addEffect(effects, &getDstMutable(), MemoryEffects::Write::get());
8354+
}
8355+
82138356
// === TMatmulOp ===
82148357
void TMatmulMxOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
82158358
addEffect(effects, &getAMutable(), MemoryEffects::Read::get());

0 commit comments

Comments
 (0)