Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,110 @@ def ImportReservedBufferOp : PTO_Op<"import_reserved_buffer"> {
// TPUSH/TPOP Pipe Communication Ops
//===----------------------------------------------------------------------===//

def BuildAsyncSessionOp : PTO_Op<"build_async_session", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Build an async DMA session handle for TPUT_ASYNC/TGET_ASYNC";

let arguments = (ins
TileBufOrMemRef:$scratch,
ScalarPtrOrMemRef:$workspace,
OptionalAttr<I32Attr>:$sync_id,
OptionalAttr<I64Attr>:$block_bytes,
OptionalAttr<I64Attr>:$comm_block_offset,
OptionalAttr<I32Attr>:$queue_num,
OptionalAttr<I64Attr>:$channel_group_idx
);

let results = (outs AsyncSessionType:$session);
let hasVerifier = 1;

let assemblyFormat = [{
`(` $scratch `,` $workspace `:` qualified(type($scratch)) `,` type($workspace) `)`
attr-dict `->` qualified(type($session))
}];
}

def TPutAsyncOp : PTO_Op<"tput_async", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Asynchronous remote write from local GM to remote GM";

let arguments = (ins
AnyMemRef:$dst,
AnyMemRef:$src,
AsyncSessionType:$session
);

let results = (outs AsyncEventType:$event);
let hasVerifier = 1;

let assemblyFormat = [{
`(` $dst `,` $src `,` $session `:`
type($dst) `,` type($src) `,` qualified(type($session)) `)`
attr-dict `->` qualified(type($event))
}];
}

def TGetAsyncOp : PTO_Op<"tget_async", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Asynchronous remote read from remote GM to local GM";

let arguments = (ins
AnyMemRef:$dst,
AnyMemRef:$src,
AsyncSessionType:$session
);

let results = (outs AsyncEventType:$event);
let hasVerifier = 1;

let assemblyFormat = [{
`(` $dst `,` $src `,` $session `:`
type($dst) `,` type($src) `,` qualified(type($session)) `)`
attr-dict `->` qualified(type($event))
}];
}

def WaitAsyncEventOp : PTO_Op<"wait_async_event", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Block until an async DMA event completes";

let arguments = (ins
AsyncEventType:$event,
AsyncSessionType:$session
);

let results = (outs I1:$completed);

let assemblyFormat = [{
`(` $event `,` $session `:`
qualified(type($event)) `,` qualified(type($session)) `)`
attr-dict `->` type($completed)
}];
}

def TestAsyncEventOp : PTO_Op<"test_async_event", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Non-blocking completion test for an async DMA event";

let arguments = (ins
AsyncEventType:$event,
AsyncSessionType:$session
);

let results = (outs I1:$completed);

let assemblyFormat = [{
`(` $event `,` $session `:`
qualified(type($event)) `,` qualified(type($session)) `)`
attr-dict `->` type($completed)
}];
}

def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
Expand Down
10 changes: 10 additions & 0 deletions include/PTO/IR/PTOTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,13 @@ def PipeType : TypeDef<PTO_Dialect, "Pipe"> {
let mnemonic = "pipe";
let summary = "Opaque pipe handle type for TPUSH/TPOP communication";
}

def AsyncSessionType : TypeDef<PTO_Dialect, "AsyncSession"> {
let mnemonic = "async_session";
let summary = "Opaque async DMA session handle type";
}

def AsyncEventType : TypeDef<PTO_Dialect, "AsyncEvent"> {
let mnemonic = "async_event";
let summary = "Opaque async DMA event handle type";
}
6 changes: 6 additions & 0 deletions include/pto-c/Dialect/PTO.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ bool mlirPTOTypeIsAPtrType(MlirType type);
MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType);
MlirType mlirPTOPtrTypeGetElementType(MlirType type);

// ---- !pto.async_session / !pto.async_event ----
bool mlirPTOTypeIsAAsyncSessionType(MlirType type);
MlirType mlirPTOAsyncSessionTypeGet(MlirContext ctx);
bool mlirPTOTypeIsAAsyncEventType(MlirType type);
MlirType mlirPTOAsyncEventTypeGet(MlirContext ctx);

// ---- #pto.address_space<...> ----
bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr);

Expand Down
22 changes: 22 additions & 0 deletions lib/Bindings/Python/PTOModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,28 @@ PYBIND11_MODULE(_pto, m) {
return mlirPTOPtrTypeGetElementType(self);
});

mlir_type_subclass(
m, "AsyncSessionType",
[](MlirType type) -> bool { return mlirPTOTypeIsAAsyncSessionType(type); })
.def_classmethod(
"get",
[](py::object cls, MlirContext context) -> py::object {
MlirType t = mlirPTOAsyncSessionTypeGet(context);
return cls.attr("__call__")(t);
},
py::arg("cls"), py::arg("context") = py::none());

mlir_type_subclass(
m, "AsyncEventType",
[](MlirType type) -> bool { return mlirPTOTypeIsAAsyncEventType(type); })
.def_classmethod(
"get",
[](py::object cls, MlirContext context) -> py::object {
MlirType t = mlirPTOAsyncEventTypeGet(context);
return cls.attr("__call__")(t);
},
py::arg("cls"), py::arg("context") = py::none());

// --------------------------------------------------------------------------
// !pto.tensor_view<shape x elem>
// --------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions lib/CAPI/Dialect/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ MlirType mlirPTOPtrTypeGetElementType(MlirType type) {
return wrap(t.getElementType());
}

bool mlirPTOTypeIsAAsyncSessionType(MlirType type) {
return isa<mlir::pto::AsyncSessionType>(unwrap(type));
}

MlirType mlirPTOAsyncSessionTypeGet(MlirContext ctx) {
return wrap(mlir::pto::AsyncSessionType::get(unwrap(ctx)));
}

bool mlirPTOTypeIsAAsyncEventType(MlirType type) {
return isa<mlir::pto::AsyncEventType>(unwrap(type));
}

MlirType mlirPTOAsyncEventTypeGet(MlirContext ctx) {
return wrap(mlir::pto::AsyncEventType::get(unwrap(ctx)));
}

bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr) {
return mlir::isa<mlir::pto::AddressSpaceAttr>(unwrap(attr));
}
Expand Down
Loading
Loading