Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2374,6 +2374,35 @@ def TColExpandOp : PTO_TOp<"tcolexpand", [
}];
}

def TColExpandAddOp : PTO_TOp<"tcolexpandadd", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Column-wise broadcast add: add a per-column scalar vector src1 to src0 ";

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];

let assemblyFormat = [{
`ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];
}

def TColExpandMulOp : PTO_TOp<"tcolexpandmul", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down Expand Up @@ -2461,6 +2490,35 @@ def TColExpandSubOp : PTO_TOp<"tcolexpandsub", [
}];
}

def TColExpandExpdifOp : PTO_TOp<"tcolexpandexpdif", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Column-wise broadcast expdif: compute exp(src0 - src1) using a per-column scalar vector src1 ";

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];

let assemblyFormat = [{
`ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];
}

def TColExpandMaxOp : PTO_TOp<"tcolexpandmax", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down Expand Up @@ -2603,6 +2661,34 @@ def TColSumOp : PTO_TOp<"tcolsum", [
}];
}

def TColProdOp : PTO_TOp<"tcolprod", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Reduce each column by multiplying across rows";

let arguments = (ins
PTODpsType:$src,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];

let assemblyFormat = [{
`ins` `(` $src `:` qualified(type($src)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];
}

def TCvtOp : PTO_TOp<"tcvt", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down Expand Up @@ -2692,6 +2778,64 @@ def TDivSOp : PTO_TOp<"tdivs", [

}

def TFModOp : PTO_TOp<"tfmod", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Elementwise fmod/remainder of two tiles (tilebuf, DPS)";

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];

let assemblyFormat = [{
`ins` `(` $src0 `,` $src1 `:` qualified(type($src0)) `,` qualified(type($src1)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];
}

def TFModSOp : PTO_TOp<"tfmods", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Elementwise fmod/remainder with a scalar (tilebuf, DPS)";

let arguments = (ins
PTODpsType:$src,
ScalarType:$scalar,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src `,` $scalar `:` qualified(type($src)) `,` type($scalar) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TExpOp : PTO_TOp<"texp", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down Expand Up @@ -4046,6 +4190,93 @@ def TRowExpandAddOp: PTO_TOp<"trowexpandadd", [
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TRowExpandExpdifOp: PTO_TOp<"trowexpandexpdif", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "TROWEXPANDEXPDIF: Row-wise broadcast expdif with per-row scalar vector.";
let description = [{
pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only.
}];

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst
Comment on lines +4205 to +4208
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Align ptobc arity with optional tmp row-expand ops

These new row-expand ops are defined with an optional tmp operand (so A2/A3 can legally build 4-operand forms), but the v0 ptobc schema still encodes them as fixed 3-operand ops and the encoder enforces exact operand counts. As a result, valid IR such as pto.trowexpandmax/min/expdif with tmp will fail ptobc encode with an operand-count mismatch, breaking the bytecode workflow for supported A2/A3 forms.

Useful? React with 👍 / 👎.

);

let results = (outs);

let hasVerifier = 1;

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TRowExpandMaxOp: PTO_TOp<"trowexpandmax", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "TROWEXPANDMAX: Row-wise broadcast max with per-row scalar vector.";
let description = [{
pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only.
}];

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TRowExpandMinOp: PTO_TOp<"trowexpandmin", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "TROWEXPANDMIN: Row-wise broadcast min with per-row scalar vector.";
let description = [{
pto-isa has overloads with/without tmp on A2/A3; A5 supports the 3-operand form only.
}];

let arguments = (ins
PTODpsType:$src0,
PTODpsType:$src1,
Optional<PTODpsType>:$tmp,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}
//===----------------------------------------------------------------------===//
// PTOOps.td (add TROWMAX TBDPS/tile buffer op)
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4142,6 +4373,35 @@ def TRowSumOp: PTO_TOp<"trowsum", [
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TRowProdOp: PTO_TOp<"trowprod", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "TROWPROD: Reduce each row by multiplying across columns.";

let arguments = (ins
PTODpsType:$src,
PTODpsType:$tmp,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src `,` $tmp `:` qualified(type($src)) `,` qualified(type($tmp)) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];

let extraClassDeclaration = [{
::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; }
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}
//===----------------------------------------------------------------------===//
// PTOOps.td (add TRSQRT TBDPS/tile buffer op)
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading