Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 36 additions & 46 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def ScalarPtrOrMemRef :
def ScalarType :
AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">;

def IndexOrI64 :
Type<
CPred<"$_self.isIndex() || $_self.isSignlessInteger(64)">,
"index or i64">;

def IndexOrI32 :
Type<
CPred<"$_self.isIndex() || $_self.isSignlessInteger(32)">,
"index or i32">;

//===----------------------------------------------------------------------===//
// Op Class
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -89,16 +99,14 @@ def AddPtrOp : PTO_Op<"addptr", [

let arguments = (ins
PtrType:$ptr,
Index:$offset
IndexOrI64:$offset
);

let results = (outs PtrType:$result);

let hasVerifier = 1;

let assemblyFormat = [{
$ptr `,` $offset attr-dict `:` type($ptr) `->` type($result)
}];
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -112,16 +120,14 @@ def LoadScalarOp : PTO_Op<"load_scalar", [

let arguments = (ins
ScalarPtrOrMemRef:$ptr,
Index:$offset
IndexOrI64:$offset
);

let results = (outs AnyType:$value);

let hasVerifier = 1;

let assemblyFormat = [{
$ptr `[` $offset `]` attr-dict `:` type($ptr) `->` type($value)
}];
let hasCustomAssemblyFormat = 1;
}

def StoreScalarOp : PTO_Op<"store_scalar", [
Expand All @@ -131,26 +137,24 @@ def StoreScalarOp : PTO_Op<"store_scalar", [

let arguments = (ins
ScalarPtrOrMemRef:$ptr,
Index:$offset,
IndexOrI64:$offset,
AnyType:$value
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
$value `,` $ptr `[` $offset `]` attr-dict `:` type($ptr) `,` type($value)
}];
let hasCustomAssemblyFormat = 1;
}

def MakeTensorViewOp : PTO_Op<"make_tensor_view", [AttrSizedOperandSegments]> {
let summary = "Wrap a pointer as a tensor_view descriptor (no allocation, no copy).";

let arguments = (ins
AnyType:$ptr,
Variadic<Index>:$shape,
Variadic<Index>:$strides,
Variadic<IndexOrI64>:$shape,
Variadic<IndexOrI64>:$strides,
OptionalAttr<PTO_LayoutAttr>:$layout
);

Expand All @@ -176,18 +180,15 @@ def PartitionViewOp : PTO_Op<"partition_view", [AttrSizedOperandSegments]> {

let arguments = (ins
TensorViewType:$source, // 输入: 物理大底座 (MakeTensorViewOp 的结果)
Variadic<Index>:$offsets, // 动态 offsets
Variadic<Index>:$sizes // 动态 sizes
Variadic<IndexOrI64>:$offsets, // 动态 offsets
Variadic<IndexOrI64>:$sizes // 动态 sizes
);

let results = (outs PartitionTensorViewType:$result); // 输出: 逻辑切片

let hasVerifier = 1;

let assemblyFormat = [{
$source `,` `offsets` `=` `[` $offsets `]` `,` `sizes` `=` `[` $sizes `]`
attr-dict `:` qualified(type($source)) `->` qualified(type($result))
}];
let hasCustomAssemblyFormat = 1;
}

// Helper: tensor_view or memref (after lowering tensor_view to memref).
Expand All @@ -210,32 +211,24 @@ def GetTensorViewDimOp : PTO_Op<"get_tensor_view_dim", [Pure]> {
}];
let arguments = (ins
TensorViewOrMemRef:$tensor_view,
Index:$dim_index
IndexOrI64:$dim_index
);
let results = (outs Index:$result);
let assemblyFormat = [{
$tensor_view `,` $dim_index `:` qualified(type($tensor_view)) `->` qualified(type($result))
attr-dict
}];
let hasCustomAssemblyFormat = 1;
}

def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> {
let summary = "Allocates a tile buffer (logical buffer).";

let arguments = (ins
Optional<I64>:$addr,
Optional<Index>:$valid_row,
Optional<Index>:$valid_col
Optional<IndexOrI32>:$valid_row,
Optional<IndexOrI32>:$valid_col
);

let results = (outs TileBufType:$result);

let assemblyFormat = [{
(`addr` `=` $addr^)?
(`valid_row` `=` $valid_row^)?
(`valid_col` `=` $valid_col^)?
attr-dict `:` qualified(type($result))
}];
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
::mlir::LogicalResult verify();
Expand Down Expand Up @@ -329,16 +322,13 @@ def SetValidShapeOp : PTO_Op<"set_validshape", [

let arguments = (ins
TileBufOrMemRef:$source,
Index:$valid_row,
Index:$valid_col
IndexOrI32:$valid_row,
IndexOrI32:$valid_col
);

let hasVerifier = 1;

let assemblyFormat = [{
$source `,` $valid_row `,` $valid_col attr-dict `:`
qualified(type($source))
}];
let hasCustomAssemblyFormat = 1;
}

// ============================================================================
Expand Down Expand Up @@ -399,7 +389,7 @@ def TLoadOp : PTO_TOp<"tload", [
PTODpsType:$dst,
OptionalAttr<PTO_PadModeAttr>:$pad_mode,
Optional<AnyType>:$pad_value,
Optional<Index>:$left_padding_num,
Optional<IndexOrI32>:$left_padding_num,
Optional<AnyType>:$right_padding_num,
DefaultValuedOptionalAttr<BoolAttr, "false">:$init_out_buffer,
Optional<AnyType>:$init_condition
Expand Down Expand Up @@ -2017,7 +2007,7 @@ def TSetValOp : PTO_TOp<"tsetval", [

let arguments = (ins
PTODpsType:$dst,
Index:$offset,
IndexOrI32:$offset,
ScalarType:$val
);

Expand Down Expand Up @@ -2048,7 +2038,7 @@ def TGetValOp : PTO_TOp<"tgetval", [

let arguments = (ins
PTODpsType:$src,
Index:$offset
IndexOrI32:$offset
);

let results = (outs ScalarType:$dst);
Expand Down Expand Up @@ -2858,8 +2848,8 @@ def TExtractOp : PTO_TOp<"textract", [

let arguments = (ins
PTODpsType:$src,
Index:$indexRow,
Index:$indexCol,
IndexOrI32:$indexRow,
IndexOrI32:$indexCol,
PTODpsType:$dst
);

Expand Down Expand Up @@ -2943,8 +2933,8 @@ def TInsertOp : PTO_TOp<"tinsert", [

let arguments = (ins
PTODpsType:$src,
Index:$indexRow,
Index:$indexCol,
IndexOrI32:$indexRow,
IndexOrI32:$indexCol,
PTODpsType:$dst
);

Expand Down
Loading
Loading