Skip to content

Commit ff9be95

Browse files
fix: add pre-loop event init and tail event drain for back-edge sync (#428)
The generated C++ Cube kernel hangs with --enable-insert-sync because: 1. Back-edge wait_flag ops hoisted to loop heads fire on iteration 0 against uninitialized event registers. This patch adds a pre-loop set_flag for each back-edge wait to initialize the event before the first iteration. 2. The tail sync (ptoas_auto_sync_tail with kBarrierAll) only emits pipe_barrier(PIPE_ALL) which does not drain pending event flag registers. This patch adds explicit wait_flag calls at the function tail to drain all back-edge event dependencies. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 927af18 commit ff9be95

2 files changed

Lines changed: 181 additions & 5 deletions

File tree

include/PTO/Transforms/InsertSync/SyncCodegen.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ class SyncCodegen {
5757

5858
// Insert the compiler tail-clean barrier right before function return.
5959
void AppendAutoSyncTailBarrierIfNeeded(IRRewriter &rewriter);
60+
61+
// [Fix #428] Collect back-edge waits at loop heads that need pre-loop
62+
// set_flag initialization to avoid hangs on iteration 0.
63+
void CollectBackEdgeLoopHeadWaits();
64+
65+
// [Fix #428] Emit set_flag ops before a for-loop to initialize event
66+
// registers for back-edge wait_flag ops at the loop head.
67+
void EmitPreLoopEventInit(IRRewriter &rewriter, Operation *op);
68+
69+
// [Fix #428] Emit explicit wait_flag ops before function return to drain
70+
// all pending back-edge event dependencies that pipe_barrier(PIPE_ALL)
71+
// alone does not cover.
72+
void EmitTailEventDrain(IRRewriter &rewriter, func::ReturnOp ret);
6073

6174
void CreateSetWaitOpForSingleBuffer(IRRewriter &rewriter, Operation *op,
6275
SyncOperation *sync, bool beforeInsert);
@@ -92,6 +105,10 @@ class SyncCodegen {
92105

93106
// Deferred tail-clean barrier requested by sync analysis.
94107
bool pendingAutoSyncTailBarrier_ = false;
108+
109+
// [Fix #428] Map from scf.for Operation* to the back-edge wait ops at
110+
// its loop head that need pre-loop set_flag initialization.
111+
DenseMap<const Operation *, SmallVector<SyncOperation *, 4>> preLoopInitWaits_;
95112
};
96113

97114
} // namespace pto

lib/PTO/Transforms/InsertSync/SyncCodegen.cpp

Lines changed: 164 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,25 @@ static void MergeSyncList(SyncOps &dstList, const SyncOps &srcList) {
6666
void SyncCodegen::Run() {
6767
MLIRContext *ctx = func_->getContext();
6868
IRRewriter rewriter(ctx);
69-
69+
7070
UpdateOpInsertSync(rewriter);
71-
72-
// [Optional Debug] 这里的 Debug 打印可以保留或注释掉
73-
// ...
74-
71+
72+
// [Fix #428] Collect back-edge wait ops at loop heads that need pre-loop
73+
// set_flag initialization. A back-edge wait_flag at LOOP_BEGIN fires on
74+
// iteration 0, but the matching set_flag only fires at the end of iteration
75+
// 0. Without a pre-loop set_flag, the event register is uninitialized and
76+
// the hardware hangs.
77+
CollectBackEdgeLoopHeadWaits();
78+
7579
func_->walk<WalkOrder::PreOrder>([&](Operation *op) {
7680
if (op2InsertSync.count(op)) {
81+
// [Fix #428] Before emitting the normal pipeBefore waits for a loop op,
82+
// emit pre-loop set_flag initialization for any back-edge wait_flag that
83+
// would otherwise fire against an uninitialized event register.
84+
if (isa<scf::ForOp>(op)) {
85+
EmitPreLoopEventInit(rewriter, op);
86+
}
87+
7788
// 处理 PRE Sync
7889
for (auto &syncBefore : op2InsertSync[op].pipeBefore) {
7990
SyncInsert(rewriter, op, syncBefore, true);
@@ -299,6 +310,15 @@ void SyncCodegen::AppendAutoSyncTailBarrierIfNeeded(IRRewriter &rewriter) {
299310

300311
auto pipeAllAttr = getPipeAttr(rewriter, PipelineType::PIPE_ALL);
301312
for (auto ret : returns) {
313+
// [Fix #428] Before the tail barrier, emit explicit wait_flag ops to
314+
// drain all pending back-edge event dependencies. pipe_barrier(PIPE_ALL)
315+
// waits for in-flight pipe operations but does NOT drain event flag
316+
// registers. Without explicit wait_flag calls, stale event state can
317+
// leak to the next kernel invocation.
318+
rewriter.setInsertionPoint(ret);
319+
EmitTailEventDrain(rewriter, ret);
320+
321+
// Re-set insertion point before ret (after any drain ops we just emitted)
302322
rewriter.setInsertionPoint(ret);
303323
auto barrier = rewriter.create<pto::BarrierOp>(ret.getLoc(), pipeAllAttr);
304324
barrier->setAttr("pto.auto_sync_tail_barrier", rewriter.getUnitAttr());
@@ -332,6 +352,84 @@ void SyncCodegen::CreateSetWaitOpForSingleBuffer(IRRewriter &rewriter,
332352
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
333353
}
334354
}
355+
356+
// ==============================================================================
357+
// [Fix #428] Tail event drain for back-edge sync events
358+
// ==============================================================================
359+
360+
void SyncCodegen::EmitTailEventDrain(IRRewriter &rewriter,
361+
func::ReturnOp ret) {
362+
// Collect all unique (srcPipe, dstPipe, eventId) triples from back-edge
363+
// syncs across all loops. These events may still be pending when the kernel
364+
// reaches the return statement and must be explicitly drained.
365+
//
366+
// We use a set of tuples to deduplicate — the same event may appear in
367+
// multiple loops or be shared via widen.
368+
struct EventKey {
369+
PipelineType src;
370+
PipelineType dst;
371+
int eventId;
372+
bool operator<(const EventKey &o) const {
373+
if (src != o.src) return static_cast<unsigned>(src) < static_cast<unsigned>(o.src);
374+
if (dst != o.dst) return static_cast<unsigned>(dst) < static_cast<unsigned>(o.dst);
375+
return eventId < o.eventId;
376+
}
377+
bool operator==(const EventKey &o) const {
378+
return src == o.src && dst == o.dst && eventId == o.eventId;
379+
}
380+
};
381+
382+
SmallVector<EventKey> drainEvents;
383+
auto addUnique = [&](PipelineType src, PipelineType dst, int eid) {
384+
EventKey key{src, dst, eid};
385+
for (auto &existing : drainEvents) {
386+
if (existing == key)
387+
return;
388+
}
389+
drainEvents.push_back(key);
390+
};
391+
392+
// Scan all LOOP_END elements for back-edge set/wait pairs with allocated
393+
// event IDs.
394+
for (auto &pair : preLoopInitWaits_) {
395+
for (auto *waitSync : pair.second) {
396+
if (waitSync->uselessSync || waitSync->eventIds.empty())
397+
continue;
398+
addUnique(waitSync->GetActualSrcPipe(), waitSync->GetActualDstPipe(),
399+
waitSync->eventIds[0]);
400+
}
401+
}
402+
403+
// Also scan the last element's pipeAfter for any set_flag ops that might
404+
// leave events pending (these are the "syncEnd" phantom pairs from
405+
// UpdateBackwardMatchSync that sink to the function tail).
406+
if (!syncIR_.empty()) {
407+
for (auto *sync : syncIR_.back()->pipeAfter) {
408+
if (sync->uselessSync || sync->eventIds.empty())
409+
continue;
410+
if (sync->isSyncWaitType()) {
411+
addUnique(sync->GetActualSrcPipe(), sync->GetActualDstPipe(),
412+
sync->eventIds[0]);
413+
}
414+
}
415+
}
416+
417+
if (drainEvents.empty())
418+
return;
419+
420+
// Sort for deterministic output.
421+
llvm::sort(drainEvents);
422+
423+
LLVM_DEBUG(llvm::dbgs() << "[Fix #428] Emitting " << drainEvents.size()
424+
<< " tail event drain wait_flag ops\n");
425+
426+
for (auto &ev : drainEvents) {
427+
auto srcPipe = getPipeAttr(rewriter, ev.src);
428+
auto dstPipe = getPipeAttr(rewriter, ev.dst);
429+
auto eventId = getEventAttr(rewriter, ev.eventId);
430+
rewriter.create<pto::WaitFlagOp>(ret.getLoc(), srcPipe, dstPipe, eventId);
431+
}
432+
}
335433

336434
void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter,
337435
Operation *op,
@@ -399,3 +497,64 @@ Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op,
399497
SyncIndex2SelectBuffer[sync->GetSyncIndex()] = selected;
400498
return selected;
401499
}
500+
501+
// ==============================================================================
502+
// [Fix #428] Pre-loop event initialization for back-edge sync
503+
// ==============================================================================
504+
505+
void SyncCodegen::CollectBackEdgeLoopHeadWaits() {
506+
for (auto &nowElement : syncIR_) {
507+
auto *loopElement = dyn_cast<LoopInstanceElement>(nowElement.get());
508+
if (!loopElement || loopElement->getLoopKind() != KindOfLoop::LOOP_END)
509+
continue;
510+
511+
// Look at the LOOP_BEGIN node's pipeBefore — these are waits that
512+
// MoveSyncState hoisted to the loop head.
513+
auto *loopBegin =
514+
dyn_cast<LoopInstanceElement>(syncIR_[loopElement->beginId].get());
515+
if (!loopBegin)
516+
continue;
517+
518+
for (auto *sync : loopBegin->pipeBefore) {
519+
if (sync->uselessSync)
520+
continue;
521+
if (!sync->isSyncWaitType())
522+
continue;
523+
if (sync->eventIds.empty())
524+
continue;
525+
// This is a wait at loop head with an allocated event ID.
526+
// It needs a pre-loop set_flag to initialize the event register.
527+
// Record {Operation* forOp -> SyncOperation* wait} for later emission.
528+
if (loopElement->elementOp) {
529+
preLoopInitWaits_[loopElement->elementOp].push_back(sync);
530+
}
531+
}
532+
}
533+
}
534+
535+
void SyncCodegen::EmitPreLoopEventInit(IRRewriter &rewriter, Operation *op) {
536+
auto it = preLoopInitWaits_.find(op);
537+
if (it == preLoopInitWaits_.end())
538+
return;
539+
540+
// For each back-edge wait at the loop head, emit a set_flag before the
541+
// for loop to initialize the event register. This ensures that on iteration
542+
// 0, the wait_flag finds a valid (already-set) event instead of hanging.
543+
rewriter.setInsertionPoint(op);
544+
for (auto *waitSync : it->second) {
545+
if (waitSync->uselessSync || waitSync->eventIds.empty())
546+
continue;
547+
548+
auto srcPipe = getPipeAttr(rewriter, waitSync->GetActualSrcPipe());
549+
auto dstPipe = getPipeAttr(rewriter, waitSync->GetActualDstPipe());
550+
auto eventId = getEventAttr(rewriter, waitSync->eventIds[0]);
551+
552+
LLVM_DEBUG(llvm::dbgs()
553+
<< "[Fix #428] Emitting pre-loop set_flag("
554+
<< static_cast<unsigned>(waitSync->GetActualSrcPipe()) << ", "
555+
<< static_cast<unsigned>(waitSync->GetActualDstPipe()) << ", "
556+
<< waitSync->eventIds[0] << ") before loop\n");
557+
558+
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
559+
}
560+
}

0 commit comments

Comments
 (0)