diff --git a/include/PTO/Transforms/InsertSync/SyncEventIdAllocation.h b/include/PTO/Transforms/InsertSync/SyncEventIdAllocation.h index 860c25cbe..002ff5b48 100644 --- a/include/PTO/Transforms/InsertSync/SyncEventIdAllocation.h +++ b/include/PTO/Transforms/InsertSync/SyncEventIdAllocation.h @@ -54,7 +54,7 @@ class SyncEventIdAllocation { void SetEventId(SyncOperation *sync); SmallVector GetEventPool(const SyncOperation *sync, size_t eventIdNum); - int ScopePair(const SyncOperation *s); + int ScopePair(const SyncOperation *s) const; void FindUseEventID(unsigned int begin, unsigned int end, const SyncOperation *s, SmallVector &eventId); @@ -91,6 +91,7 @@ class SyncEventIdAllocation { SyncOperation *FindWidenSync(const SyncOperation *setSync, const SyncOperation *waitSync); void ClearEventId(const SyncOperation *sync); + bool scopePairHasLoopCarriedSync(int scopePair) const; SmallVector GetAvailableEventId(SyncOperation *sync, diff --git a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp index 4c76e773c..e877b6eab 100644 --- a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp @@ -253,7 +253,7 @@ SmallVector SyncEventIdAllocation::GetEventPool(const SyncOperation *sync, return eventIdPool; } -int SyncEventIdAllocation::ScopePair(const SyncOperation *s) { +int SyncEventIdAllocation::ScopePair(const SyncOperation *s) const { if (s->GetType() == SyncOperation::TYPE::SYNC_BLOCK_SET || s->GetType() == SyncOperation::TYPE::SYNC_BLOCK_WAIT) { return 0; @@ -480,11 +480,37 @@ void SyncEventIdAllocation::WidenEventId(SyncOps syncVector) { bool canWiden = TryWidenByOtherSync(sync); if (!canWiden) { int scopePair = ScopePair(sync); - reallocatedPipePair.insert(scopePair); + // Loop-carried syncs need a fully initialized head/tail schedule. + // Reallocating an entire scope that already contains back-edge pairs can + // rewrite those safe preheat/drain edges into mismatched waits. + if (!scopePairHasLoopCarriedSync(scopePair)) + reallocatedPipePair.insert(scopePair); } } } } + +bool SyncEventIdAllocation::scopePairHasLoopCarriedSync(int scopePair) const { + for (auto &element : syncIR_) { + for (auto *sync : element->pipeBefore) { + if (!sync || sync->uselessSync) + continue; + if (!sync->GetForEndIndex().has_value()) + continue; + if (ScopePair(sync) == scopePair) + return true; + } + for (auto *sync : element->pipeAfter) { + if (!sync || sync->uselessSync) + continue; + if (!sync->GetForEndIndex().has_value()) + continue; + if (ScopePair(sync) == scopePair) + return true; + } + } + return false; +} void SyncEventIdAllocation::clearAllocatedEventId() { // Remove generated BackwardSync diff --git a/test/basic/issue428_cube_sync_regression.pto b/test/basic/issue428_cube_sync_regression.pto new file mode 100644 index 000000000..f633c1734 --- /dev/null +++ b/test/basic/issue428_cube_sync_regression.pto @@ -0,0 +1,125 @@ +// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s +// +// Issue #428 regression guard: +// 1. The first cube loop must be primed with both PIPE_M -> PIPE_MTE1 events +// before entering the loop. +// 2. The loop head must wait on the same primed events. +// 3. The kernel tail must still drain the outstanding M/MTE1 + FIX/M events +// before the auto tail barrier helper runs. +// +// CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); +// CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); +// CHECK: for (size_t +// CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); +// CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); +// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); +// CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); +// CHECK-NEXT: wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); +// CHECK-NEXT: wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); +// CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + +module { + func.func @tri_inv_block2x2_fp16(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.subi %0, %c1 : index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = pto.get_block_num + %5 = arith.index_cast %4 : i64 to index + %6 = arith.muli %5, %c64 : index + %7 = arith.muli %3, %c64 : index + %8 = arith.addi %7, %c32 : index + %9 = pto.make_tensor_view %arg1, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %10 = pto.make_tensor_view %arg0, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %11 = pto.make_tensor_view %arg2, shape = [%c32, %c32], strides = [%c32, %c1] : !pto.tensor_view + %12 = pto.partition_view %11, offsets = [%c0, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %13 = pto.partition_view %9, offsets = [%7, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %14 = pto.partition_view %9, offsets = [%8, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %15 = pto.partition_view %9, offsets = [%8, %c32], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %16 = pto.partition_view %10, offsets = [%7, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %17 = pto.partition_view %10, offsets = [%8, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %18 = pto.partition_view %10, offsets = [%8, %c32], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + %26 = pto.alloc_tile : !pto.tile_buf + %27 = pto.alloc_tile : !pto.tile_buf + %28 = pto.alloc_tile : !pto.tile_buf + %29 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%12 : !pto.partition_tensor_view<32x32xf16>) outs(%24 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tload ins(%13 : !pto.partition_tensor_view<32x32xf16>) outs(%20 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + scf.for %arg4 = %c0 to %1 step %c1 { + pto.tmov ins(%19 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%25 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + %30 = arith.addi %arg4, %c1 : index + %31 = arith.cmpi slt, %30, %1 : index + scf.if %31 { + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + } + } + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%16 : !pto.partition_tensor_view<32x32xf32>) + pto.tload ins(%15 : !pto.partition_tensor_view<32x32xf16>) outs(%22 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + scf.for %arg4 = %c0 to %1 step %c1 { + pto.tmov ins(%21 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%25 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + %30 = arith.addi %arg4, %c1 : index + %31 = arith.cmpi slt, %30, %1 : index + scf.if %31 { + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + } + } + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%18 : !pto.partition_tensor_view<32x32xf32>) + pto.tload ins(%14 : !pto.partition_tensor_view<32x32xf16>) outs(%23 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmov ins(%26 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%19 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%26 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%17 : !pto.partition_tensor_view<32x32xf32>) + } + return + } +}