diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index a6312694..13fe85f1 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1373,6 +1373,110 @@ def ImportReservedBufferOp : PTO_Op<"import_reserved_buffer"> { // TPUSH/TPOP Pipe Communication Ops //===----------------------------------------------------------------------===// +def BuildAsyncSessionOp : PTO_Op<"build_async_session", [ + DeclareOpInterfaceMethods +]> { + let summary = "Build an async DMA session handle for TPUT_ASYNC/TGET_ASYNC"; + + let arguments = (ins + TileBufOrMemRef:$scratch, + ScalarPtrOrMemRef:$workspace, + OptionalAttr:$sync_id, + OptionalAttr:$block_bytes, + OptionalAttr:$comm_block_offset, + OptionalAttr:$queue_num, + OptionalAttr:$channel_group_idx + ); + + let results = (outs AsyncSessionType:$session); + let hasVerifier = 1; + + let assemblyFormat = [{ + `(` $scratch `,` $workspace `:` qualified(type($scratch)) `,` type($workspace) `)` + attr-dict `->` qualified(type($session)) + }]; +} + +def TPutAsyncOp : PTO_Op<"tput_async", [ + DeclareOpInterfaceMethods +]> { + let summary = "Asynchronous remote write from local GM to remote GM"; + + let arguments = (ins + AnyMemRef:$dst, + AnyMemRef:$src, + AsyncSessionType:$session + ); + + let results = (outs AsyncEventType:$event); + let hasVerifier = 1; + + let assemblyFormat = [{ + `(` $dst `,` $src `,` $session `:` + type($dst) `,` type($src) `,` qualified(type($session)) `)` + attr-dict `->` qualified(type($event)) + }]; +} + +def TGetAsyncOp : PTO_Op<"tget_async", [ + DeclareOpInterfaceMethods +]> { + let summary = "Asynchronous remote read from remote GM to local GM"; + + let arguments = (ins + AnyMemRef:$dst, + AnyMemRef:$src, + AsyncSessionType:$session + ); + + let results = (outs AsyncEventType:$event); + let hasVerifier = 1; + + let assemblyFormat = [{ + `(` $dst `,` $src `,` $session `:` + type($dst) `,` type($src) `,` qualified(type($session)) `)` + attr-dict `->` qualified(type($event)) + }]; +} + +def WaitAsyncEventOp : PTO_Op<"wait_async_event", [ + DeclareOpInterfaceMethods +]> { + let summary = "Block until an async DMA event completes"; + + let arguments = (ins + AsyncEventType:$event, + AsyncSessionType:$session + ); + + let results = (outs I1:$completed); + + let assemblyFormat = [{ + `(` $event `,` $session `:` + qualified(type($event)) `,` qualified(type($session)) `)` + attr-dict `->` type($completed) + }]; +} + +def TestAsyncEventOp : PTO_Op<"test_async_event", [ + DeclareOpInterfaceMethods +]> { + let summary = "Non-blocking completion test for an async DMA event"; + + let arguments = (ins + AsyncEventType:$event, + AsyncSessionType:$session + ); + + let results = (outs I1:$completed); + + let assemblyFormat = [{ + `(` $event `,` $session `:` + qualified(type($event)) `,` qualified(type($session)) `)` + attr-dict `->` type($completed) + }]; +} + def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [ DeclareOpInterfaceMethods ]> { diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 3e507fa2..928c7c83 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -221,3 +221,13 @@ def PipeType : TypeDef { let mnemonic = "pipe"; let summary = "Opaque pipe handle type for TPUSH/TPOP communication"; } + +def AsyncSessionType : TypeDef { + let mnemonic = "async_session"; + let summary = "Opaque async DMA session handle type"; +} + +def AsyncEventType : TypeDef { + let mnemonic = "async_event"; + let summary = "Opaque async DMA event handle type"; +} diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index 42f8f391..4909cb05 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -28,6 +28,12 @@ bool mlirPTOTypeIsAPtrType(MlirType type); MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType); MlirType mlirPTOPtrTypeGetElementType(MlirType type); +// ---- !pto.async_session / !pto.async_event ---- +bool mlirPTOTypeIsAAsyncSessionType(MlirType type); +MlirType mlirPTOAsyncSessionTypeGet(MlirContext ctx); +bool mlirPTOTypeIsAAsyncEventType(MlirType type); +MlirType mlirPTOAsyncEventTypeGet(MlirContext ctx); + // ---- #pto.address_space<...> ---- bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr); diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index a1655772..44e5a60c 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -450,6 +450,28 @@ PYBIND11_MODULE(_pto, m) { return mlirPTOPtrTypeGetElementType(self); }); + mlir_type_subclass( + m, "AsyncSessionType", + [](MlirType type) -> bool { return mlirPTOTypeIsAAsyncSessionType(type); }) + .def_classmethod( + "get", + [](py::object cls, MlirContext context) -> py::object { + MlirType t = mlirPTOAsyncSessionTypeGet(context); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("context") = py::none()); + + mlir_type_subclass( + m, "AsyncEventType", + [](MlirType type) -> bool { return mlirPTOTypeIsAAsyncEventType(type); }) + .def_classmethod( + "get", + [](py::object cls, MlirContext context) -> py::object { + MlirType t = mlirPTOAsyncEventTypeGet(context); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("context") = py::none()); + // -------------------------------------------------------------------------- // !pto.tensor_view // -------------------------------------------------------------------------- diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index eef0d827..c0d8032a 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -65,6 +65,22 @@ MlirType mlirPTOPtrTypeGetElementType(MlirType type) { return wrap(t.getElementType()); } +bool mlirPTOTypeIsAAsyncSessionType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPTOAsyncSessionTypeGet(MlirContext ctx) { + return wrap(mlir::pto::AsyncSessionType::get(unwrap(ctx))); +} + +bool mlirPTOTypeIsAAsyncEventType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPTOAsyncEventTypeGet(MlirContext ctx) { + return wrap(mlir::pto::AsyncEventType::get(unwrap(ctx))); +} + bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr) { return mlir::isa(unwrap(attr)); } diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index b6e75eeb..19528c4d 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -1801,6 +1801,87 @@ static SmallVector getValidShapeVec(Value value) { return valid; } +static bool isByteIntegerType(Type ty) { + auto intTy = dyn_cast(ty); + return intTy && intTy.getWidth() == 8; +} + +static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, + Value value, + StringRef name) { + auto memTy = dyn_cast(value.getType()); + if (!memTy) + return op->emitOpError() << "expects " << name << " to be a memref"; + if (!memTy.hasRank()) + return op->emitOpError() << "expects " << name << " to be a ranked memref"; + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() << "expects " << name + << " to be in GM address space"; + + ArrayRef shape = memTy.getShape(); + if (shape.empty()) + return op->emitOpError() << "expects " << name + << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static shape"; + } + + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(memTy, strides, offset))) + return op->emitOpError() << "expects " << name + << " to use a strided memref layout"; + if (offset == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static offset"; + for (int64_t stride : strides) { + if (stride == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have static strides"; + } + + bool packed = !strides.empty() && strides.back() == 1; + for (int i = static_cast(shape.size()) - 2; i >= 0 && packed; --i) + packed &= strides[i] == strides[i + 1] * shape[i + 1]; + if (!packed) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + bool logical1D = true; + for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) + logical1D &= shape[i] == 1; + if (!logical1D) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + return success(); +} + +static std::optional getStaticByteSize(Type ty) { + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return std::nullopt; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim < 0) + return std::nullopt; + } + + Type elemTy = getElemTy(ty); + uint64_t elemBytes = getElemByteSize(elemTy); + if (elemBytes == 0) + return std::nullopt; + + uint64_t total = elemBytes; + for (int64_t dim : shape) { + total *= static_cast(dim); + } + return total; +} + static std::optional getPTOMemorySpaceEnum(Type ty) { if (auto tb = dyn_cast(ty)) { if (auto as = dyn_cast_or_null(tb.getMemorySpace())) @@ -8380,6 +8461,94 @@ static LogicalResult verifyPipeHandleProducer(Operation *op, Value pipeHandle) { return success(); } +LogicalResult BuildAsyncSessionOp::verify() { + Type scratchTy = getScratch().getType(); + if (!isa(scratchTy)) + return emitOpError("expects scratch to be tile_buf or memref type"); + + auto scratchSpace = getPTOMemorySpaceEnum(scratchTy); + if (!scratchSpace || *scratchSpace != pto::AddressSpace::VEC) + return emitOpError("expects scratch to be in vec address space"); + + auto scratchShape = getShapeVec(scratchTy); + if (scratchShape.empty() || scratchShape.size() > 2) + return emitOpError("expects scratch to be rank-1 or rank-2"); + for (int64_t dim : scratchShape) { + if (dim == ShapedType::kDynamic) + return emitOpError("expects scratch to have a static shape"); + } + + auto scratchBytes = getStaticByteSize(scratchTy); + if (!scratchBytes) + return emitOpError("expects scratch byte size to be statically known"); + if (*scratchBytes < sizeof(uint64_t)) + return emitOpError("expects scratch to provide at least 8 bytes"); + + Type workspaceElemTy; + Type workspaceTy = getWorkspace().getType(); + if (auto ptrTy = dyn_cast(workspaceTy)) { + workspaceElemTy = ptrTy.getElementType(); + } else if (auto memTy = dyn_cast(workspaceTy)) { + workspaceElemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError("expects workspace to be in GM address space"); + } else { + return emitOpError("expects workspace to be !pto.ptr or memref type"); + } + if (!isByteIntegerType(workspaceElemTy)) + return emitOpError("expects workspace element type to be an 8-bit integer"); + + if (auto syncIdAttr = getSyncIdAttr()) { + int64_t syncId = syncIdAttr.getInt(); + if (syncId < 0 || syncId > 7) + return emitOpError("expects sync_id in range [0, 7]"); + } + if (auto blockBytesAttr = getBlockBytesAttr()) { + if (blockBytesAttr.getInt() <= 0) + return emitOpError("expects block_bytes to be greater than 0"); + } + if (auto commBlockOffsetAttr = getCommBlockOffsetAttr()) { + if (commBlockOffsetAttr.getInt() < 0) + return emitOpError("expects comm_block_offset to be non-negative"); + } + if (auto queueNumAttr = getQueueNumAttr()) { + if (queueNumAttr.getInt() <= 0) + return emitOpError("expects queue_num to be greater than 0"); + } + if (auto channelGroupIdxAttr = getChannelGroupIdxAttr()) { + APInt value = channelGroupIdxAttr.getValue(); + if (value.isNegative()) + return emitOpError("expects channel_group_idx to be non-negative"); + if (value.ugt(UINT32_MAX)) + return emitOpError("expects channel_group_idx to fit in uint32"); + } + + return success(); +} + +static LogicalResult verifyAsyncTransferOp(Operation *op, Value dst, Value src) { + auto dstTy = dyn_cast(dst.getType()); + auto srcTy = dyn_cast(src.getType()); + if (!dstTy || !srcTy) + return op->emitOpError("expects src and dst to be memref types"); + if (dstTy.getElementType() != srcTy.getElementType()) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyAsyncFlatContiguous1DGMMemRef(op, dst, "dst")) || + failed(verifyAsyncFlatContiguous1DGMMemRef(op, src, "src"))) + return failure(); + if (dstTy.getShape() != srcTy.getShape()) + return op->emitOpError("expects src and dst to have the same static shape"); + return success(); +} + +LogicalResult TPutAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + +LogicalResult TGetAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + LogicalResult AicInitializePipeOp::verify() { return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); } @@ -8490,6 +8659,48 @@ LogicalResult TFreeOp::verify() { return verifySplitAttr(getOperation(), getSplit()); } +void BuildAsyncSessionOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getScratchMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getWorkspaceMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPutAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TGetAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void WaitAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TestAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + void InitializeL2G2LPipeOp::getEffects( SmallVectorImpl> &effects) { diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index cedf3e95..c0389f8b 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -306,7 +306,8 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { UpdateOpGenInfo(curOpInfo, llvm::to_vector(callOp->getOperands())); OpKillHandle(curOpInfo, live, op->getBlock()); } else if (isa(op)) { + pto::InitializeL2G2LPipeOp, pto::BuildAsyncSessionOp, + pto::TPutAsyncOp, pto::TGetAsyncOp>(op)) { UpdateOpGenInfo(curOpInfo, llvm::to_vector(op->getOperands())); OpKillHandle(curOpInfo, live, op->getBlock()); } else if (auto gpuLaunchOp = dyn_cast(op)) { diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 4f0ca524..8ff9910d 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -9,6 +9,8 @@ //===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// //===----------------------------------------------------------------------===// +#include + #include "PTO/IR/PTO.h" #include "PTO/IR/PTOSyncUtils.h" #include "PTO/Transforms/Passes.h" @@ -249,6 +251,16 @@ class PTOToEmitCTypeConverter : public TypeConverter { std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; return emitc::OpaqueType::get(Ctx, tok); }); + + addConversion([Ctx](pto::AsyncSessionType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); + }); + + addConversion([Ctx](pto::AsyncEventType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncEvent"); + }); // --------------------------------------------------------- // 3. MemRef 转换 (Debug 重点) @@ -3099,11 +3111,6 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, offVal); } - std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); - std::string shapeTypeName = "GTShape" + suffix; - std::string strideTypeName = "GTStride" + suffix; - std::string gtTypeName = "GT" + suffix; - std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); SmallVector shapeParamsVec; @@ -3148,10 +3155,8 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, std::string shapeParams = joinParams(finalShape); std::string strideParams = joinParams(finalStride); - rewriter.create( - loc, "using " + shapeTypeName + " = pto::Shape<" + shapeParams + ">;"); - rewriter.create( - loc, "using " + strideTypeName + " = pto::Stride<" + strideParams + ">;"); + std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; + std::string strideCppType = "pto::Stride<" + strideParams + ">"; // Layout: prefer the attribute from InferPTOLayout; only fall back to local // inference when the pass is disabled. @@ -3192,24 +3197,18 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, else if (layoutTag == 2) layoutEnum = "pto::Layout::NZ"; } - std::string layoutConstName = gtTypeName + "_layout"; - rewriter.create( - loc, "constexpr pto::Layout " + layoutConstName + " = " + layoutEnum + ";"); - - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeTypeName); - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideTypeName); + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); auto shapeInstOp = rewriter.create( - loc, shapeTypeOpaque, shapeTypeName, ArrayAttr{}, ArrayAttr{}, + loc, shapeTypeOpaque, shapeCppType, ArrayAttr{}, ArrayAttr{}, ValueRange{}); auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, strideTypeName, ArrayAttr{}, ArrayAttr{}, + loc, strideTypeOpaque, strideCppType, ArrayAttr{}, ArrayAttr{}, ValueRange{}); - rewriter.create( - loc, "using " + gtTypeName + " = GlobalTensor<" + elemTypeStr + ", " + - shapeTypeName + ", " + strideTypeName + ", " + - layoutConstName + ">;"); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeName); + std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + + ", " + strideCppType + ", " + layoutEnum + ">"; + auto gtType = emitc::OpaqueType::get(ctx, gtCppType); SmallVector gtArgs; gtArgs.push_back(ptr); @@ -3217,11 +3216,131 @@ static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, gtArgs.push_back(strideInstOp.getResult(0)); auto gtInst = rewriter.create( - loc, gtType, gtTypeName, ArrayAttr{}, ArrayAttr{}, ValueRange(gtArgs)); + loc, gtType, gtCppType, ArrayAttr{}, ArrayAttr{}, ValueRange(gtArgs)); return gtInst.getResult(0); } +static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, + Location loc, Value value) { + auto *ctx = rewriter.getContext(); + auto targetTy = emitc::OpaqueType::get(ctx, "__gm__ uint8_t*"); + if (value.getType() == targetTy) + return value; + + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); + if (isSetFFTsPointerLikeType(value.getType())) { + return rewriter + .create(loc, targetTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{value}) + .getResult(0); + } + return rewriter.create(loc, targetTy, value).getResult(); +} + +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + return blTok; +} + +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + return slTok; +} + +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + return padTok; +} + +static FailureOr buildAsyncScratchTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, + Value emittedScratch) { + Value scratch = peelUnrealized(emittedScratch); + if (auto opaqueTy = dyn_cast(scratch.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return scratch; + } + + auto memTy = dyn_cast(originalScratch.getType()); + if (!memTy) + return failure(); + + ArrayRef shape = memTy.getShape(); + if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) + return failure(); + + int64_t rows = shape.size() == 1 ? 1 : shape[0]; + int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalScratch.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalScratch.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + std::string elemTypeStr = getElemTypeStringForGT(memTy.getElementType()); + std::string tileTypeStr = + "Tile"; + + Value tile = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, tileTypeStr), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + Value scratchAddr = + rewriter + .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), + "reinterpret_cast", ArrayAttr{}, addr, + ValueRange{scratch}) + .getResult(0); + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, scratchAddr}); + return tile; +} + //===----------------------------------------------------------------------===// // pto.pointer_cast lowering //===----------------------------------------------------------------------=== @@ -4673,6 +4792,148 @@ struct PTOInitializeL2LPipeToEmitC PTOArch targetArch; }; +struct PTOBuildAsyncSessionToEmitC + : public OpConversionPattern { + PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + auto sessionTy = + dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); + if (!sessionTy) + return rewriter.notifyMatchFailure(op, "failed to convert async session type"); + + FailureOr scratchTile = + buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), + adaptor.getScratch()); + if (failed(scratchTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); + + Value workspace = + castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); + + Value session = rewriter + .create( + loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + auto makeU32Const = [&](uint64_t value) -> Value { + return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, + std::to_string(value) + "u"); + }; + uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t blockBytes = + op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + uint64_t commBlockOffset = + op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; + uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() + ? op.getChannelGroupIdxAttr().getInt() + : UINT32_MAX; + + Value syncIdVal = makeU32Const(syncId); + Value channelGroupIdxVal = + channelGroupIdx == UINT32_MAX + ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") + : makeU32Const(channelGroupIdx); + + auto baseConfigTy = + emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); + Value baseConfig = + rewriter + .create( + loc, baseConfigTy, + emitc::OpaqueAttr::get( + ctx, "{" + std::to_string(blockBytes) + "ULL, " + + std::to_string(commBlockOffset) + "ULL, " + + std::to_string(queueNum) + "u}")) + .getResult(); + + rewriter.create( + loc, TypeRange{}, "pto::comm::BuildAsyncSession", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, + channelGroupIdxVal}); + + rewriter.replaceOp(op, session); + return success(); + } +}; + +template +struct PTOAsyncTransferToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + auto dstMrTy = dyn_cast(op.getDst().getType()); + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!dstMrTy || !srcMrTy) + return rewriter.notifyMatchFailure(op, "expected memref src/dst"); + + Value dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getDst().getDefiningOp() + ? op.getDst().getDefiningOp() + : op.getOperation()); + Value srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + if (!dstGT || !srcGT) + return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); + + Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +template +struct PTOAsyncEventToEmitC : public OpConversionPattern { + explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncEventOp op, + typename AsyncEventOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + this->getTypeConverter()->convertType(op.getCompleted().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getEvent()), + peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + struct PTODeclareTileMemRefToEmitC : public OpConversionPattern { using OpConversionPattern< @@ -8619,6 +8880,17 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx, + "pto::comm::TPUT_ASYNC"); + patterns.add>( + typeConverter, ctx, + "pto::comm::TGET_ASYNC"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); patterns.add(typeConverter, ctx, targetArch); patterns.add(typeConverter, ctx, targetArch); patterns.add(typeConverter, ctx); @@ -8708,8 +8980,9 @@ struct EmitPTOManualPass } bool needsEventIdArrayHelper = false; - mop.walk([&](mlir::pto::DeclareEventIdArrayOp) { - needsEventIdArrayHelper = true; + mop.walk([&](Operation *op) { + if (isa(op)) + needsEventIdArrayHelper = true; }); // 1. 插入头文件 @@ -8754,7 +9027,6 @@ static AICORE inline void ptoas_auto_sync_tail( } } )cpp")); - // Only inject the bitcast helper when we actually lower ops that need it // (e.g. arith.bitcast or arith.maximumf/minimumf tie-breaking on zeros). bool needsBitcastHelper = false; diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index 7cf53e5e..39b6ba25 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -31,6 +31,8 @@ def _load_local_pto_ext(): register_dialect = _pto_mod.register_dialect PtrType = _pto_mod.PtrType +AsyncSessionType = _pto_mod.AsyncSessionType +AsyncEventType = _pto_mod.AsyncEventType TensorViewType = _pto_mod.TensorViewType PartitionTensorViewType = _pto_mod.PartitionTensorViewType TileType = _pto_mod.TileType @@ -65,6 +67,8 @@ def _load_local_pto_ext(): # Types "PtrType", + "AsyncSessionType", + "AsyncEventType", "TensorViewType", "PartitionTensorViewType", "TileType", diff --git a/test/basic/async_put_get_emitc.pto b/test/basic/async_put_get_emitc.pto new file mode 100644 index 00000000..710dacd7 --- /dev/null +++ b/test/basic/async_put_get_emitc.pto @@ -0,0 +1,35 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @async_put_get(%dst: memref<128xf32, #pto.address_space>, + %src: memref<128xf32, #pto.address_space>, + %workspace: memref<1024xi8, #pto.address_space>) { + %scratch = pto.alloc_tile : !pto.tile_buf + %session = pto.build_async_session(%scratch, %workspace : !pto.tile_buf, memref<1024xi8, #pto.address_space>) -> !pto.async_session + %put = pto.tput_async(%dst, %src, %session : memref<128xf32, #pto.address_space>, memref<128xf32, #pto.address_space>, !pto.async_session) -> !pto.async_event + %get = pto.tget_async(%src, %dst, %session : memref<128xf32, #pto.address_space>, memref<128xf32, #pto.address_space>, !pto.async_session) -> !pto.async_event + %put_done = pto.wait_async_event(%put, %session : !pto.async_event, !pto.async_session) -> i1 + %get_done = pto.test_async_event(%get, %session : !pto.async_event, !pto.async_session) -> i1 + return + } +} + +// A3-LABEL: AICORE void async_put_get( +// A3: Tile [[SCRATCH:v[0-9]+]]; +// A3: TASSIGN([[SCRATCH]], [[SCRATCH_ADDR:v[0-9]+]]); +// A3: pto::comm::AsyncSession [[SESSION:v[0-9]+]]; +// A3: pto::comm::sdma::SdmaBaseConfig [[CFG:v[0-9]+]] = {32768ULL, 0ULL, 1u}; +// A3: pto::comm::BuildAsyncSession([[SCRATCH]], {{.*}}, [[SESSION]], {{.*}}, [[CFG]], {{.*}}); +// A3-NOT: using GTShape_ +// A3-NOT: using GTStride_ +// A3-NOT: using GT_ +// A3: pto::Shape<1, 1, 1, 1, 128> [[SHAPE0:v[0-9]+]] = pto::Shape<1, 1, 1, 1, 128>(); +// A3: pto::Stride<128, 128, 128, 128, 1> [[STRIDE0:v[0-9]+]] = pto::Stride<128, 128, 128, 128, 1>(); +// A3: GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND> [[GT0:v[0-9]+]] = GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND>({{.*}}, [[SHAPE0]], [[STRIDE0]]); +// A3: pto::Shape<1, 1, 1, 1, 128> [[SHAPE1:v[0-9]+]] = pto::Shape<1, 1, 1, 1, 128>(); +// A3: pto::Stride<128, 128, 128, 128, 1> [[STRIDE1:v[0-9]+]] = pto::Stride<128, 128, 128, 128, 1>(); +// A3: GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND> [[GT1:v[0-9]+]] = GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND>({{.*}}, [[SHAPE1]], [[STRIDE1]]); +// A3: pto::comm::AsyncEvent [[PUT_EVT:v[0-9]+]] = pto::comm::TPUT_ASYNC( +// A3: pto::comm::AsyncEvent [[GET_EVT:v[0-9]+]] = pto::comm::TGET_ASYNC( +// A3: bool [[PUT_DONE:v[0-9]+]] = [[PUT_EVT]].Wait([[SESSION]]); +// A3: bool [[GET_DONE:v[0-9]+]] = [[GET_EVT]].Test([[SESSION]]); diff --git a/test/basic/async_put_invalid_non_1d.pto b/test/basic/async_put_invalid_non_1d.pto new file mode 100644 index 00000000..b99e79d0 --- /dev/null +++ b/test/basic/async_put_invalid_non_1d.pto @@ -0,0 +1,14 @@ +// RUN: not ptoas %s 2>&1 | FileCheck %s + +module { + func.func @bad_async_put(%dst: memref<4x32xf32, #pto.address_space>, + %src: memref<4x32xf32, #pto.address_space>, + %workspace: memref<1024xi8, #pto.address_space>) { + %scratch = pto.alloc_tile : !pto.tile_buf + %session = pto.build_async_session(%scratch, %workspace : !pto.tile_buf, memref<1024xi8, #pto.address_space>) -> !pto.async_session + %event = pto.tput_async(%dst, %src, %session : memref<4x32xf32, #pto.address_space>, memref<4x32xf32, #pto.address_space>, !pto.async_session) -> !pto.async_event + return + } +} + +// CHECK: error: 'pto.tput_async' op expects dst to be a static flat contiguous logical 1D GM memref diff --git a/test/npu_validation/scripts/run_remote_npu_validation.sh b/test/npu_validation/scripts/run_remote_npu_validation.sh index c86fc6ab..6873d505 100644 --- a/test/npu_validation/scripts/run_remote_npu_validation.sh +++ b/test/npu_validation/scripts/run_remote_npu_validation.sh @@ -232,6 +232,16 @@ while IFS= read -r -d '' cpp; do testcase="${testcase%-pto}" testcase="${testcase%_pto}" + # AsyncComm smoke sample issues async remote DMA against plain local buffers. + # In board-runtime STAGE=run this can trigger invalid MPU access on single-rank + # execution, so skip it in runtime stage. + if [[ "${STAGE}" == "run" && "${testcase}" == "async_comm" ]]; then + skip_count=$((skip_count + 1)) + printf "%s\tSKIP\t%s\truntime skip: async_comm\n" "${testcase}" "${STAGE}" >> "${RESULTS_TSV}" + log "SKIP: ${testcase} (runtime skip)" + continue + fi + if [[ -n "${RUN_ONLY_CASES_NORM}" ]] && ! list_contains "${RUN_ONLY_CASES_NORM}" "${testcase}"; then continue fi diff --git a/test/samples/AsyncComm/async_comm.py b/test/samples/AsyncComm/async_comm.py new file mode 100644 index 00000000..d3246346 --- /dev/null +++ b/test/samples/AsyncComm/async_comm.py @@ -0,0 +1,64 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import Context, Location, Module, InsertionPoint, F32Type, IntegerType, MemRefType +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i8 = IntegerType.get_signless(8, ctx) + + gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + + data_ty = MemRefType.get([128], f32, memory_space=gm) + workspace_ty = MemRefType.get([1024], i8, memory_space=gm) + scratch_ty = pto.TileBufType.get([1, 256], i8, vec, [1, 256], None, ctx) + + i32 = IntegerType.get_signless(32, ctx) + + fn_ty = func.FunctionType.get([data_ty, data_ty, workspace_ty, i32], []) + with InsertionPoint(module.body): + fn = func.FuncOp("async_comm_kernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + dst, src, workspace, nranks = entry.arguments + c1_i32 = arith.ConstantOp(i32, 1).result + single_rank = arith.CmpIOp( + arith.CmpIPredicate.sle, nranks, c1_i32 + ).result + guarded = scf.IfOp(single_rank, [], hasElse=True) + + with InsertionPoint(guarded.then_block): + scf.YieldOp([]) + + with InsertionPoint(guarded.else_block): + scratch = pto.AllocTileOp(scratch_ty).result + session = pto.BuildAsyncSessionOp(scratch, workspace).result + put_event = pto.TPutAsyncOp(dst, src, session).result + get_event = pto.TGetAsyncOp(src, dst, session).result + pto.WaitAsyncEventOp(put_event, session) + pto.TestAsyncEventOp(get_event, session) + scf.YieldOp([]) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/AsyncComm/tget_async_kernel_impl_like.py b/test/samples/AsyncComm/tget_async_kernel_impl_like.py new file mode 100644 index 00000000..fd18b63e --- /dev/null +++ b/test/samples/AsyncComm/tget_async_kernel_impl_like.py @@ -0,0 +1,214 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import ( + Context, + IndexType, + InsertionPoint, + IntegerAttr, + IntegerType, + Location, + MemRefType, + Module, + Operation, + Type, + F32Type, +) +from mlir.dialects import arith, func, pto, scf + + +def _build_async_session(scratch, workspace, i32, sync_id=0): + if hasattr(pto, "BuildAsyncSessionOp"): + return pto.BuildAsyncSessionOp(scratch, workspace, sync_id=sync_id).result + if hasattr(pto, "build_async_session"): + return pto.build_async_session(scratch, workspace, sync_id=sync_id) + op = Operation.create( + "pto.build_async_session", + operands=[scratch, workspace], + attributes={"sync_id": IntegerAttr.get(i32, sync_id)}, + results=[Type.parse("!pto.async_session")], + ) + return op.result + + +def _tget_async(dst, src, session): + if hasattr(pto, "TGetAsyncOp"): + return pto.TGetAsyncOp(dst, src, session).result + if hasattr(pto, "tget_async"): + return pto.tget_async(dst, src, session) + op = Operation.create( + "pto.tget_async", + operands=[dst, src, session], + results=[Type.parse("!pto.async_event")], + ) + return op.result + + +def _wait_async_event(event, session): + if hasattr(pto, "WaitAsyncEventOp"): + return pto.WaitAsyncEventOp(event, session).result + if hasattr(pto, "wait_async_event"): + return pto.wait_async_event(event, session) + op = Operation.create( + "pto.wait_async_event", + operands=[event, session], + results=[IntegerType.get_signless(1)], + ) + return op.result + + +def _wait_after_async(event, session): + _wait_async_event(event, session) + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i8 = IntegerType.get_signless(8, ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + + data_ty = MemRefType.get([256], f32, memory_space=gm) + workspace_ty = MemRefType.get([1024], i8, memory_space=gm) + scratch_ty = pto.TileBufType.get([1, 256], i8, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ + data_ty, # dst_from_rank1 + data_ty, # dst_from_rank2 + data_ty, # dst_from_rank3 + data_ty, # src_rank1 + data_ty, # src_rank2 + data_ty, # src_rank3 + workspace_ty, + i32, # nranks + i32, # root_rank + i32, # my_rank + i32, # elem_offset + i32, # elem_count + ], + [], + ) + + with InsertionPoint(module.body): + fn = func.FuncOp("tget_async_kernel_impl_like", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + ( + dst_rank1, + dst_rank2, + dst_rank3, + src_rank1, + src_rank2, + src_rank3, + workspace, + nranks, + root_rank, + my_rank, + elem_offset, + elem_count, + ) = entry.arguments + + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c0_i32 = arith.ConstantOp(i32, 0).result + c1_i32 = arith.ConstantOp(i32, 1).result + c2_i32 = arith.ConstantOp(i32, 2).result + c3_i32 = arith.ConstantOp(i32, 3).result + c256_i32 = arith.ConstantOp(i32, 256).result + + count_gt_zero = arith.CmpIOp( + arith.CmpIPredicate.sgt, elem_count, c0_i32 + ).result + offset_ge_zero = arith.CmpIOp( + arith.CmpIPredicate.sge, elem_offset, c0_i32 + ).result + end_index = arith.AddIOp(elem_offset, elem_count).result + end_le_bound = arith.CmpIOp( + arith.CmpIPredicate.sle, end_index, c256_i32 + ).result + valid = arith.AndIOp( + arith.AndIOp(count_gt_zero, offset_ge_zero).result, end_le_bound + ).result + + valid_if = scf.IfOp(valid, [], hasElse=False) + with InsertionPoint(valid_if.then_block): + pto.barrier(pipe_all) + + scratch = pto.AllocTileOp(scratch_ty).result + session = _build_async_session(scratch, workspace, i32, sync_id=0) + + is_root = arith.CmpIOp( + arith.CmpIPredicate.eq, my_rank, root_rank + ).result + root_if = scf.IfOp(is_root, [], hasElse=False) + + with InsertionPoint(root_if.then_block): + nranks_idx = arith.IndexCastOp(idx, nranks).result + loop = scf.ForOp(c0, nranks_idx, c1, []) + with InsertionPoint(loop.body): + target_rank = loop.induction_variable + target_rank_i32 = arith.IndexCastOp(i32, target_rank).result + is_not_self = arith.CmpIOp( + arith.CmpIPredicate.ne, target_rank_i32, root_rank + ).result + target_if = scf.IfOp(is_not_self, [], hasElse=False) + + with InsertionPoint(target_if.then_block): + is_rank1 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c1_i32 + ).result + rank1_if = scf.IfOp(is_rank1, [], hasElse=False) + with InsertionPoint(rank1_if.then_block): + event1 = _tget_async(dst_rank1, src_rank1, session) + _wait_after_async(event1, session) + scf.YieldOp([]) + + is_rank2 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c2_i32 + ).result + rank2_if = scf.IfOp(is_rank2, [], hasElse=False) + with InsertionPoint(rank2_if.then_block): + event2 = _tget_async(dst_rank2, src_rank2, session) + _wait_after_async(event2, session) + scf.YieldOp([]) + + is_rank3 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c3_i32 + ).result + rank3_if = scf.IfOp(is_rank3, [], hasElse=False) + with InsertionPoint(rank3_if.then_block): + event3 = _tget_async(dst_rank3, src_rank3, session) + _wait_after_async(event3, session) + scf.YieldOp([]) + + scf.YieldOp([]) + scf.YieldOp([]) + scf.YieldOp([]) + + pto.barrier(pipe_all) + scf.YieldOp([]) + + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/AsyncComm/tput_async_kernel_impl_like.py b/test/samples/AsyncComm/tput_async_kernel_impl_like.py new file mode 100644 index 00000000..3eacb9c4 --- /dev/null +++ b/test/samples/AsyncComm/tput_async_kernel_impl_like.py @@ -0,0 +1,209 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from mlir.ir import ( + Context, + IndexType, + InsertionPoint, + IntegerAttr, + IntegerType, + Location, + MemRefType, + Module, + Operation, + Type, + F32Type, +) +from mlir.dialects import arith, func, pto, scf + + +def _build_async_session(scratch, workspace, i32, sync_id=0): + if hasattr(pto, "BuildAsyncSessionOp"): + return pto.BuildAsyncSessionOp(scratch, workspace, sync_id=sync_id).result + if hasattr(pto, "build_async_session"): + return pto.build_async_session(scratch, workspace, sync_id=sync_id) + op = Operation.create( + "pto.build_async_session", + operands=[scratch, workspace], + attributes={"sync_id": IntegerAttr.get(i32, sync_id)}, + results=[Type.parse("!pto.async_session")], + ) + return op.result + + +def _tput_async(dst, src, session): + if hasattr(pto, "TPutAsyncOp"): + return pto.TPutAsyncOp(dst, src, session).result + if hasattr(pto, "tput_async"): + return pto.tput_async(dst, src, session) + op = Operation.create( + "pto.tput_async", + operands=[dst, src, session], + results=[Type.parse("!pto.async_event")], + ) + return op.result + + +def _wait_async_event(event, session): + if hasattr(pto, "WaitAsyncEventOp"): + return pto.WaitAsyncEventOp(event, session).result + if hasattr(pto, "wait_async_event"): + return pto.wait_async_event(event, session) + op = Operation.create( + "pto.wait_async_event", + operands=[event, session], + results=[IntegerType.get_signless(1)], + ) + return op.result + + +def _wait_after_async(event, session): + _wait_async_event(event, session) + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + module = Module.create() + + f32 = F32Type.get(ctx) + i8 = IntegerType.get_signless(8, ctx) + i32 = IntegerType.get_signless(32, ctx) + idx = IndexType.get(ctx) + + gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + pipe_all = pto.PipeAttr.get(pto.PIPE.PIPE_ALL, ctx) + + data_ty = MemRefType.get([256], f32, memory_space=gm) + workspace_ty = MemRefType.get([1024], i8, memory_space=gm) + scratch_ty = pto.TileBufType.get([1, 256], i8, vec, [1, 256], None, ctx) + + fn_ty = func.FunctionType.get( + [ + data_ty, + data_ty, + data_ty, + data_ty, + workspace_ty, + i32, # nranks + i32, # root_rank + i32, # my_rank + i32, # elem_offset + i32, # elem_count + ], + [], + ) + + with InsertionPoint(module.body): + fn = func.FuncOp("tput_async_kernel_impl_like", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + ( + dst_rank1, + dst_rank2, + dst_rank3, + src, + workspace, + nranks, + root_rank, + my_rank, + elem_offset, + elem_count, + ) = entry.arguments + + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c0_i32 = arith.ConstantOp(i32, 0).result + c1_i32 = arith.ConstantOp(i32, 1).result + c2_i32 = arith.ConstantOp(i32, 2).result + c3_i32 = arith.ConstantOp(i32, 3).result + c256_i32 = arith.ConstantOp(i32, 256).result + + count_gt_zero = arith.CmpIOp( + arith.CmpIPredicate.sgt, elem_count, c0_i32 + ).result + offset_ge_zero = arith.CmpIOp( + arith.CmpIPredicate.sge, elem_offset, c0_i32 + ).result + end_index = arith.AddIOp(elem_offset, elem_count).result + end_le_bound = arith.CmpIOp( + arith.CmpIPredicate.sle, end_index, c256_i32 + ).result + valid = arith.AndIOp( + arith.AndIOp(count_gt_zero, offset_ge_zero).result, end_le_bound + ).result + + valid_if = scf.IfOp(valid, [], hasElse=False) + + with InsertionPoint(valid_if.then_block): + scratch = pto.AllocTileOp(scratch_ty).result + session = _build_async_session(scratch, workspace, i32, sync_id=0) + + is_root = arith.CmpIOp( + arith.CmpIPredicate.eq, my_rank, root_rank + ).result + root_if = scf.IfOp(is_root, [], hasElse=False) + + with InsertionPoint(root_if.then_block): + nranks_idx = arith.IndexCastOp(idx, nranks).result + loop = scf.ForOp(c0, nranks_idx, c1, []) + with InsertionPoint(loop.body): + target_rank = loop.induction_variable + target_rank_i32 = arith.IndexCastOp(i32, target_rank).result + is_not_self = arith.CmpIOp( + arith.CmpIPredicate.ne, target_rank_i32, root_rank + ).result + target_if = scf.IfOp(is_not_self, [], hasElse=False) + + with InsertionPoint(target_if.then_block): + is_rank1 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c1_i32 + ).result + rank1_if = scf.IfOp(is_rank1, [], hasElse=False) + with InsertionPoint(rank1_if.then_block): + event1 = _tput_async(dst_rank1, src, session) + _wait_after_async(event1, session) + scf.YieldOp([]) + + is_rank2 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c2_i32 + ).result + rank2_if = scf.IfOp(is_rank2, [], hasElse=False) + with InsertionPoint(rank2_if.then_block): + event2 = _tput_async(dst_rank2, src, session) + _wait_after_async(event2, session) + scf.YieldOp([]) + + is_rank3 = arith.CmpIOp( + arith.CmpIPredicate.eq, target_rank_i32, c3_i32 + ).result + rank3_if = scf.IfOp(is_rank3, [], hasElse=False) + with InsertionPoint(rank3_if.then_block): + event3 = _tput_async(dst_rank3, src, session) + _wait_after_async(event3, session) + scf.YieldOp([]) + + scf.YieldOp([]) + scf.YieldOp([]) + scf.YieldOp([]) + + scf.YieldOp([]) + + pto.barrier(pipe_all) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 06ee16ec..c8ea9cf6 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1,10 +1,10 @@ -//===- ptoas.cpp -------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" @@ -372,6 +372,17 @@ static void rewriteTileGetSetValueMarkers(std::string &cpp) { } } +static void rewriteAsyncEventMarkers(std::string &cpp) { + bool changed = true; + while (changed) { + changed = false; + changed |= rewriteMarkerCallToMember( + cpp, "PTOAS__ASYNC_EVENT_WAIT", "Wait", /*expectedNumArgs=*/2); + changed |= rewriteMarkerCallToMember( + cpp, "PTOAS__ASYNC_EVENT_TEST", "Test", /*expectedNumArgs=*/2); + } +} + // -------------------------------------------------------------------------- // EmitC cleanup: drop empty emitc.expression ops. // @@ -1148,6 +1159,7 @@ int main(int argc, char **argv) { } cppOS.flush(); rewriteTileGetSetValueMarkers(cppOutput); + rewriteAsyncEventMarkers(cppOutput); rewritePtrScalarMarkers(cppOutput); rewriteEventIdArrayMarkers(cppOutput); rewriteAddPtrTraceMarkers(cppOutput, emitAddPtrTrace);