From d8ccf87087b77e105fd2c88bd17a780843e563d3 Mon Sep 17 00:00:00 2001 From: chenshengxin Date: Fri, 27 Mar 2026 14:05:21 +0800 Subject: [PATCH] Add: paged attention unroll scene test with 4D input shapes - New paged_attention_unroll_4dims test under tensormap_and_ringbuffer - Query and output tensors use 4D format (batch, seq_len, num_heads, head_dim) - 6 kernels: QK/PV matmul (AIC), softmax_prepare/online_update (AIV), hub stubs - Orchestration with N_UNROLL=64, 4 tasks per group, online softmax accumulation - Golden wraps shared paged_attention_golden with 4D reshape adapter - Three test cases: varying batch/heads/head_dim at production scale (bfloat16) --- .../paged_attention_unroll_4dims/golden.py | 77 +++++ .../kernels/aic/aic_hub.cpp | 18 + .../kernels/aic/aic_pv_matmul.cpp | 155 +++++++++ .../kernels/aic/aic_qk_matmul.cpp | 120 +++++++ .../kernels/aiv/aiv_hub.cpp | 18 + .../kernels/aiv/aiv_online_update.cpp | 251 ++++++++++++++ .../kernels/aiv/aiv_softmax_prepare.cpp | 263 +++++++++++++++ .../kernels/kernel_config.py | 46 +++ .../orchestration/paged_attention_orch.cpp | 318 ++++++++++++++++++ 9 files changed, 1266 insertions(+) create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/golden.py create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_hub.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_hub.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/kernel_config.py create mode 100644 tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/golden.py b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/golden.py new file mode 100644 index 000000000..66cd11460 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/golden.py @@ -0,0 +1,77 @@ +"""Paged Attention Unroll Golden - tensormap_and_ringbuffer test (production scale, bfloat16). + +Input shapes use 4D format: (batch, seq_len, num_heads, head_dim) for query and out. +""" + +from paged_attention_golden import ( + generate_inputs as _generate_inputs, + compute_golden as _compute_golden, + run_golden_test, +) + +__outputs__ = ["out"] + +RTOL = 1e-3 +ATOL = 1e-3 + +ALL_CASES = { + "Case1": { + "batch": 256, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + "Case2": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, + "Case3": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 256, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + "dtype": "bfloat16", + }, +} + +DEFAULT_CASE = "Case1" + + +def generate_inputs(params: dict) -> list: + result = _generate_inputs(params) + batch = params["batch"] + num_heads = params["num_heads"] + head_dim = params["head_dim"] + reshaped = [] + for name, val in result: + if name in ("query", "out"): + val = val.reshape(batch, 1, num_heads, head_dim) + reshaped.append((name, val)) + return reshaped + + +def compute_golden(tensors: dict, params: dict) -> None: + batch = params["batch"] + num_heads = params["num_heads"] + head_dim = params["head_dim"] + out_4d = tensors["out"] + tensors["out"] = out_4d.reshape(batch, num_heads, head_dim) + _compute_golden(tensors, params) + tensors["out"] = out_4d + + +if __name__ == "__main__": + run_golden_test(ALL_CASES, DEFAULT_CASE, generate_inputs, label="Paged Attention Unroll") diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_hub.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_hub.cpp new file mode 100644 index 000000000..0974de371 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_hub.cpp @@ -0,0 +1,18 @@ +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +constexpr int M = 16; +constexpr int K = 16; +constexpr int N = 16; + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 000000000..cc25dcfaa --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,155 @@ +// SplitK PV Matmul Kernel: Accumulated P @ V across n_blocks +// +// Processes n_blocks blocks using SplitK accumulation pattern: +// Block 0: TMATMUL(C, A, B) — initialize accumulator +// Block i: TMATMUL_ACC(C, C, A, B) — accumulate into same C +// +// Per-block pij addresses: contiguous slices of pij_buf (n_blocks * M * K) +// Per-block vj addresses: value_cache base + block_indices lookup +// Single output: oi_new (M, N) fp32 = sum of P_i @ V_i across all blocks +// +// Optimizations: +// - Double-buffered L1 tiles (ping/pong for A and B via MTE2) +// - Double-buffered L0 tiles (ping/pong for L0A and L0B via MTE1) +// - TLOAD(next) overlaps with TMATMUL(current) via MTE2/M-pipe parallelism +// - Canonical 3-stage pipeline: TLOAD(MTE2) → TMOV(MTE1) → TMATMUL(M) +// - Reverse-dependency events ensure buffer safety across iterations +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128) -> (16, 128) +// Case2: (64, 64) @ ( 64, 128) -> (64, 128) +// +// pij is bfloat16 (from softmax_prepare TCVT). +// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout. + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void pv_matmul_n_impl( + __gm__ bfloat16_t* pij_base, + __gm__ bfloat16_t* val_base, + __gm__ float* oi_base, + uint64_t n_blocks, + __gm__ int32_t* block_table) { + + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + // L1 memory layout: double-buffered A and B tiles (tightly packed) + constexpr int kATileBytes = M * K * static_cast(sizeof(bfloat16_t)); + constexpr int kBTileBytes = K * N * static_cast(sizeof(bfloat16_t)); + + TileMatA aMatTile[2]; + TileMatB bMatTile[2]; + TASSIGN(aMatTile[0], 0x0); + TASSIGN(aMatTile[1], kATileBytes); + TASSIGN(bMatTile[0], 2 * kATileBytes); + TASSIGN(bMatTile[1], 2 * kATileBytes + kBTileBytes); + + // L0 memory layout: double-buffered L0A and L0B, single accumulator L0C + LeftTile aTile[2]; + RightTile bTile[2]; + AccTile cTile; + TASSIGN(aTile[0], 0x0); + TASSIGN(aTile[1], kATileBytes); + TASSIGN(bTile[0], 0x0); + TASSIGN(bTile[1], kBTileBytes); + TASSIGN(cTile, 0x0); + + GlobalOut oiGlobal(oi_base); + + // Seed reverse-dependency flags: all ping/pong buffers initially free + // PIPE_MTE1 → PIPE_MTE2: L1 buffer [0/1] safe for TLOAD to overwrite + // PIPE_M → PIPE_MTE1: L0 buffer [0/1] safe for TMOV to overwrite + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + for (uint64_t i = 0; i < n_blocks; i++) { + int cur = static_cast(i % 2); + GlobalA pijGlobal(pij_base + i * M * K); + GlobalB vjGlobal(val_base + block_table[i] * K * N); + + // Stage 1: TLOAD (MTE2: GM → L1[cur]) + // Wait for MTE1 to release L1[cur] (reverse dep from previous iteration) + wait_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); + TLOAD(aMatTile[cur], pijGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: A in L1 ready + TLOAD(bMatTile[cur], vjGlobal); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: B in L1 ready + + // Stage 2: TMOV (MTE1: L1[cur] → L0[cur]) + // Wait for M-pipe to release L0[cur] (reverse dep from previous iteration) + wait_flag(PIPE_M, PIPE_MTE1, (event_t)cur); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: wait A loaded + TMOV(aTile[cur], aMatTile[cur]); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: wait B loaded + TMOV(bTile[cur], bMatTile[cur]); + set_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); // reverse: release L1[cur] + + // Stage 3: TMATMUL (M-pipe: L0A[cur] × L0B[cur] → L0C) + set_flag(PIPE_MTE1, PIPE_M, (event_t)cur); // forward: L0[cur] ready + wait_flag(PIPE_MTE1, PIPE_M, (event_t)cur); + if (i == 0) { + TMATMUL(cTile, aTile[cur], bTile[cur]); + } else { + TMATMUL_ACC(cTile, cTile, aTile[cur], bTile[cur]); + } + set_flag(PIPE_M, PIPE_MTE1, (event_t)cur); // reverse: release L0[cur] + } + + // Drain outstanding reverse-dependency flags + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(oiGlobal, cTile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* pij_buf = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]); + uint64_t n_blocks = static_cast(args[3]); + __gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]); + + __gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset; + __gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr); + __gm__ float* oi_base = reinterpret_cast<__gm__ float*>(oi_new->buffer.addr) + oi_new->start_offset; + + uint64_t q_tile_size = static_cast(pij_buf->shapes[0]); + + if (q_tile_size == 16) { + pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_table); + } else { + pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_table); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 000000000..6d790b8fe --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,120 @@ +// Multi-block QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) for each block +// +// Processes n_blocks blocks in a single kernel invocation. +// Per-block kj addresses computed from key_cache base + block_indices lookup. +// qi is shared across all blocks (same query head against different key blocks). +// +// Output layout: n_blocks contiguous (M, N) tiles stacked vertically. +// Block i occupies sij[i*M : (i+1)*M, 0:N]. +// +// Optimizations: +// - qi TLOAD hoisted before the loop (constant across all iterations) +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128).T -> (16, 128) +// Case2: (64, 128) @ (128, 64).T -> (64, 64) +// +// Template: M=q_tile, K=head_dim, N=block_size + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void qk_matmul_n_impl( + __gm__ bfloat16_t* qi_base, + __gm__ bfloat16_t* key_base, + __gm__ float* sij_base, + uint64_t n_blocks, + __gm__ int32_t* block_table) { + + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + // Hoist qi TLOAD before the loop (qi is constant across all blocks) + GlobalA qiGlobal(qi_base); + TLOAD(aMatTile, qiGlobal); + + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalB kjGlobal(key_base + block_table[i] * N * K); + GlobalOut sijGlobal(sij_base + i * M * N); + + // Load only B each iteration (qi already in L1 from hoist) + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + // TMOV qi from L1→L0A (re-copy since TMATMUL consumed L0A) and kj from L1→L0B + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); + + if (i + 1 < n_blocks) { + pipe_barrier(PIPE_ALL); + } + } + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* qi = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]); + uint64_t n_blocks = static_cast(args[3]); + __gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]); + + __gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset; + __gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr); + __gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_buf->buffer.addr) + sij_buf->start_offset; + + // qi is a 4D view: (1, 1, q_tile, head_dim) + uint64_t q_tile_size = static_cast(qi->shapes[2]); + + if (q_tile_size == 16) { + qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table); + } else { + qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_hub.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_hub.cpp new file mode 100644 index 000000000..0974de371 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_hub.cpp @@ -0,0 +1,18 @@ +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +constexpr int M = 16; +constexpr int K = 16; +constexpr int N = 16; + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 000000000..5d7ef9ba9 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,251 @@ +// Online Softmax Update + Normalize Kernel (AIV) +// +// Operates on full tiles where M=q_tile_size, N=head_dim (128): +// Case1: oi/oi_new are (16, 128), mij/lij/mi/li are 16-element vectors +// Case2: oi/oi_new are (64, 128), mij/lij/mi/li are 64-element vectors +// +// Scalar layout strategy using TRESHAPE (zero-copy UB reshape): +// Scalars loaded as DN ColMajor (M, 1) for TROWEXPANDMUL/TROWEXPANDDIV. +// For element-wise ops (TMAX, TSUB, TEXP, etc.), TRESHAPE to RowMajor (1, M). +// After arithmetic, TRESHAPE back to ColMajor (M, 1) for row-broadcast ops. +// This eliminates the GM round-trip (TSTORE ND → TLOAD DN) used in the original. + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void online_update_impl(__gm__ TensorData* mij, + __gm__ TensorData* lij, + __gm__ TensorData* oi_new, + __gm__ TensorData* mi, + __gm__ TensorData* li, + __gm__ TensorData* oi, + uint64_t is_first, + uint64_t is_last, + __gm__ TensorData* dst) { + __gm__ float* mij_ptr = reinterpret_cast<__gm__ float*>(mij->buffer.addr); + __gm__ float* lij_ptr = reinterpret_cast<__gm__ float*>(lij->buffer.addr); + __gm__ float* oi_new_ptr = reinterpret_cast<__gm__ float*>(oi_new->buffer.addr); + __gm__ float* mi_ptr = reinterpret_cast<__gm__ float*>(mi->buffer.addr); + __gm__ float* li_ptr = reinterpret_cast<__gm__ float*>(li->buffer.addr); + __gm__ float* oi_ptr = reinterpret_cast<__gm__ float*>(oi->buffer.addr); + __gm__ float* dst_ptr = reinterpret_cast<__gm__ float*>(dst->buffer.addr); + + // Aligned rows for ColMajor DN tiles (32-byte alignment) + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + // --- GlobalTensor types --- + + // Data (M, N) RowMajor + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + + // Scalar DN: M contiguous floats as (kAlignedRows, 1) ColMajor for TROWEXPAND ops and loading + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + // Scalar ND: for storing mi_new and li_new back to GM + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + + // --- GlobalTensor instances --- + + GlobalDataMxN oiNewGlobal(oi_new_ptr + oi_new->start_offset); + GlobalDataMxN oiGlobal(oi_ptr + oi->start_offset); + GlobalDataMxN dstGlobal(dst_ptr + dst->start_offset); + + // DN globals for loading scalars as ColMajor + GlobalScalarDN mijGlobalDN(mij_ptr + mij->start_offset); + GlobalScalarDN lijGlobalDN(lij_ptr + lij->start_offset); + GlobalScalarDN miGlobalDN(mi_ptr + mi->start_offset); + GlobalScalarDN liGlobalDN(li_ptr + li->start_offset); + + // ND globals for storing scalar results + GlobalScalarND miGlobalND(mi_ptr + mi->start_offset); + GlobalScalarND liGlobalND(li_ptr + li->start_offset); + + // --- Tile types --- + + using TileDataMxN = Tile; + using TileScalarDN = Tile; + + // RowMajor (1, M) tiles for element-wise arithmetic via TRESHAPE + using TileScalarRow = Tile; + + // ND tile for storing back to GM + using TileScalarND = + Tile; + + // --- UB memory layout --- + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Data tiles + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + // Scalar DN tiles loaded from GM (ColMajor) + TileScalarDN mijDN, lijDN, miDN, liDN; + + // Temporary DN tiles for results + TileScalarDN miNewDN, alphaDN, betaDN, liNewDN, tmpDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijDN, 2 * kDataBytes); + TASSIGN(lijDN, 2 * kDataBytes + kScalarDNBytes); + TASSIGN(miDN, 2 * kDataBytes + 2 * kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 3 * kScalarDNBytes); + TASSIGN(miNewDN, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaDN, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewDN, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpDN, 2 * kDataBytes + 8 * kScalarDNBytes); + + if (is_first) { + // --- First block: copy inputs to accumulators --- + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Store mi = mij, li = lij, oi = oi_new + // Alias ND tiles to same UB as DN tiles for ND-format store + TileScalarND mijND, lijND; + TASSIGN(mijND, 2 * kDataBytes); // alias same UB as mijDN + TASSIGN(lijND, 2 * kDataBytes + kScalarDNBytes); // alias same UB as lijDN + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); // mi = mij + TSTORE(liGlobalND, lijND); // li = lij + TSTORE(oiGlobal, oiNewTile); // oi = oi_new + + if (is_last) { + // Single block: normalize dst = oi_new / lij + // lijDN already in ColMajor DN format, use directly for TROWEXPANDDIV + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TROWEXPANDDIV(oiNewTile, oiNewTile, lijDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + // --- Subsequent blocks: accumulate --- + + // Load all inputs as DN (ColMajor) + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijDN, mijGlobalDN); + TLOAD(lijDN, lijGlobalDN); + TLOAD(miDN, miGlobalDN); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TRESHAPE: ColMajor(M,1) → RowMajor(1,M) for element-wise arithmetic + TileScalarRow miRow, mijRow, liRow, lijRow; + TRESHAPE(miRow, miDN); + TRESHAPE(mijRow, mijDN); + TRESHAPE(liRow, liDN); + TRESHAPE(lijRow, lijDN); + + // Scalar arithmetic in RowMajor (1, M) layout + TileScalarRow miNewRow, alphaRow, betaRow, liNewRow, tmpRow; + TASSIGN(miNewRow, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(alphaRow, 2 * kDataBytes + 5 * kScalarDNBytes); + TASSIGN(betaRow, 2 * kDataBytes + 6 * kScalarDNBytes); + TASSIGN(liNewRow, 2 * kDataBytes + 7 * kScalarDNBytes); + TASSIGN(tmpRow, 2 * kDataBytes + 8 * kScalarDNBytes); + + TMAX(miNewRow, miRow, mijRow); // mi_new = max(mi, mij) + pipe_barrier(PIPE_V); + TSUB(alphaRow, miRow, miNewRow); // alpha_exp = mi - mi_new + pipe_barrier(PIPE_V); + TEXP(alphaRow, alphaRow); // alpha = exp(mi - mi_new) + pipe_barrier(PIPE_V); + TSUB(betaRow, mijRow, miNewRow); // beta_exp = mij - mi_new + pipe_barrier(PIPE_V); + TEXP(betaRow, betaRow); // beta = exp(mij - mi_new) + pipe_barrier(PIPE_V); + TMUL(tmpRow, alphaRow, liRow); // alpha * li + pipe_barrier(PIPE_V); + TMUL(liNewRow, betaRow, lijRow); // beta * lij + pipe_barrier(PIPE_V); + TADD(liNewRow, tmpRow, liNewRow); // li_new = alpha*li + beta*lij + + // TRESHAPE back: RowMajor(1,M) → ColMajor(M,1) for TROWEXPANDMUL + pipe_barrier(PIPE_V); + TRESHAPE(alphaDN, alphaRow); + TRESHAPE(betaDN, betaRow); + + // Scale data tiles using row-broadcast multiply + TROWEXPANDMUL(oiTile, oiTile, alphaDN); // oi *= alpha + pipe_barrier(PIPE_V); + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); // oi_new *= beta + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); // oi = alpha*oi + beta*oi_new + + // Store mi_new and li_new to GM (ND format) + // Alias ND tiles to the same UB locations as miNewRow and liNewRow + TileScalarND miNewND, liNewND; + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarDNBytes); + TASSIGN(liNewND, 2 * kDataBytes + 7 * kScalarDNBytes); + + if (is_last) { + // Normalize and output: dst = oi / li_new + TRESHAPE(liNewDN, liNewRow); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liNewDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new + TSTORE(dstGlobal, oiTile); + } else { + // Store updated accumulators + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); // persist mi_new + TSTORE(liGlobalND, liNewND); // persist li_new + TSTORE(oiGlobal, oiTile); + } + } + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* mij = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* lij = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]); + __gm__ TensorData* mi = reinterpret_cast<__gm__ TensorData*>(args[3]); + __gm__ TensorData* li = reinterpret_cast<__gm__ TensorData*>(args[4]); + __gm__ TensorData* oi = reinterpret_cast<__gm__ TensorData*>(args[5]); + __gm__ TensorData* dst = reinterpret_cast<__gm__ TensorData*>(args[6]); + uint64_t is_first = static_cast(args[7]); + uint64_t is_last = static_cast(args[8]); + uint64_t q_tile_size = static_cast(mij->shapes[0]); + // args[10] = head_dim (128) + + if (q_tile_size == 16) { + online_update_impl<16, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } else { + online_update_impl<64, 128>(mij, lij, oi_new, mi, li, oi, is_first, is_last, dst); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 000000000..e78c39546 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,263 @@ +// Two-Pass Softmax Kernel (AIV) for n_blocks tiles +// +// Input: sij_buf (n_blocks * M, N) fp32 — QK results stacked vertically +// Output: pij_buf (n_blocks * M, N) bf16 — attention weights per block +// mij (M,) fp32 — global row max across all blocks +// lij (M,) fp32 — total row sum across all blocks +// +// Pass 1: Iterate over n_blocks tiles, apply scale, mask last block, +// find global m = max over all blocks of rowmax(S_i * scale) +// Uses TRESHAPE for DN↔Row conversion to keep globalMax in UB +// (eliminates 63 × 4 GM round-trip operations). +// Pass 2: Iterate again, compute P_i = exp(S_i * scale - m) -> bf16, +// accumulate l = sum over all blocks of rowsum(P_i) +// Uses double-buffered sij tiles to overlap TLOAD with computation. +// +// Two-pass ensures all P_i tiles share the same scale (global max), +// enabling direct TMATMUL_ACC accumulation in the PV kernel. +// +// Supports two tile configurations via runtime dispatch: +// Case1: M=16, N=128 (q_tile=16, block_size=128) +// Case2: M=64, N=64 (q_tile=64, block_size=64) + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void softmax_prepare_n_impl( + __gm__ float* sij_base, + float scale_value, + __gm__ bfloat16_t* pij_base, + __gm__ float* mij_addr, + __gm__ float* lij_addr, + uint64_t n_blocks, + uint64_t valid_len_last) { + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + + // --- GlobalTensor types --- + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_bf16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + + // --- Tile types --- + using TileSijDyn = Tile; + using TileSijPad = Tile; + using TileVecMxN = Tile; + using TileVecMxN_bf16 = Tile; + using TileScalarDN = Tile; + using TileScalarND = + Tile; + // RowMajor (1, M) tile for element-wise arithmetic via TRESHAPE + using TileScalarRow = Tile; + + // --- UB memory layout (double-buffered sij) --- + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + // Double-buffered sij tiles + TileVecMxN sijTile_A; + TileSijPad sijPadTile_A; + TileVecMxN sijTile_B; + TileSijPad sijPadTile_B; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileVecMxN sumAccTile; + TileScalarDN localMaxDN; + TileScalarDN globalMaxDN; + TileScalarDN sumDN; + TileVecMxN_bf16 pijBf16Tile; + + // TRESHAPE aliases (same UB address as their DN counterparts) + TileScalarRow localMaxRow; + TileScalarRow globalMaxRow; + + // ND alias for storing globalMax to GM + TileScalarND globalMaxND; + + TASSIGN(sijTile_A, 0x0); + TASSIGN(sijPadTile_A, 0x0); + TASSIGN(sijTile_B, kDataBytes); + TASSIGN(sijPadTile_B, kDataBytes); + TASSIGN(pijTile, 2 * kDataBytes); + TASSIGN(tmpTile, 3 * kDataBytes); + TASSIGN(sumAccTile, 4 * kDataBytes); + int scalarBase = 5 * kDataBytes; + TASSIGN(localMaxDN, scalarBase); + TASSIGN(localMaxRow, scalarBase); // alias: same UB as localMaxDN + TASSIGN(globalMaxDN, scalarBase + kScalarDNBytes); + TASSIGN(globalMaxRow, scalarBase + kScalarDNBytes); // alias: same UB as globalMaxDN + TASSIGN(globalMaxND, scalarBase + kScalarDNBytes); // alias: same UB as globalMaxDN + TASSIGN(sumDN, scalarBase + 2 * kScalarDNBytes); + TASSIGN(pijBf16Tile, scalarBase + 3 * kScalarDNBytes); + + // GM aliases (mij/lij output buffers) + GlobalScalarND mijGlobalND(mij_addr); + GlobalScalarDN lijGlobalDN(lij_addr); + + // ======== Pass 1: Find global row max via TRESHAPE (no GM round-trip) ======== + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalDataMxN sijGlobal(sij_base + i * M * N); + TLOAD(sijTile_A, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + if (i == n_blocks - 1 && valid_len_last < static_cast(N)) { + TileSijDyn sijDynTile(static_cast(valid_len_last)); + TASSIGN(sijDynTile, 0x0); + TFILLPAD_INPLACE(sijPadTile_A, sijDynTile); + pipe_barrier(PIPE_V); + } + + TMULS(sijTile_A, sijTile_A, scale_value); + pipe_barrier(PIPE_V); + TROWMAX(localMaxDN, sijTile_A, tmpTile); + pipe_barrier(PIPE_V); + + // TRESHAPE: ColMajor(M,1) → RowMajor(1,M) for element-wise TMAX + TRESHAPE(localMaxRow, localMaxDN); + if (i == 0) { + TMAX(globalMaxRow, localMaxRow, localMaxRow); + } else { + TMAX(globalMaxRow, globalMaxRow, localMaxRow); + } + pipe_barrier(PIPE_V); + } + + // TRESHAPE back: RowMajor(1,M) → ColMajor(M,1) for Pass 2's TROWEXPANDSUB + TRESHAPE(globalMaxDN, globalMaxRow); + + // Store final global max to mij for online_update to consume + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobalND, globalMaxND); + + // ======== Pass 2: Compute softmax with double-buffered sij ======== + // globalMaxDN is already in UB from TRESHAPE — no reload needed. + // Sync MTE3→MTE2 to ensure the mij TSTORE completed before first sij TLOAD. + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + // Pre-load first sij tile into buffer A + GlobalDataMxN sijGlobal_0(sij_base); + TLOAD(sijTile_A, sijGlobal_0); + + for (uint64_t i = 0; i < n_blocks; i++) { + GlobalDataMxN_bf16 pijGlobal(pij_base + i * M * N); + + // Wait for current tile's TLOAD to complete + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TFILLPAD on current buffer if last block with partial valid length + if (i == n_blocks - 1 && valid_len_last < static_cast(N)) { + TileSijDyn curSijDyn(static_cast(valid_len_last)); + if (i % 2 == 0) { + TASSIGN(curSijDyn, 0x0); + TFILLPAD_INPLACE(sijPadTile_A, curSijDyn); + } else { + TASSIGN(curSijDyn, static_cast(kDataBytes)); + TFILLPAD_INPLACE(sijPadTile_B, curSijDyn); + } + pipe_barrier(PIPE_V); + } + + // Compute on current buffer (select A or B based on iteration parity) + if (i % 2 == 0) { + TMULS(sijTile_A, sijTile_A, scale_value); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile_A, globalMaxDN); + } else { + TMULS(sijTile_B, sijTile_B, scale_value); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile_B, globalMaxDN); + } + pipe_barrier(PIPE_V); + TEXP(pijTile, pijTile); + pipe_barrier(PIPE_V); + TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); + pipe_barrier(PIPE_V); + TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); + + pipe_barrier(PIPE_V); + if (i == 0) { + TMULS(sumAccTile, pijTile, 1.0f); + } else { + TADD(sumAccTile, sumAccTile, pijTile); + } + + // Store pij (must complete before next iteration's TCVT overwrites pijBf16Tile) + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(pijGlobal, pijBf16Tile); + + // Prefetch next sij into alternate buffer (after TSTORE to avoid UB race) + if (i + 1 < n_blocks) { + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + GlobalDataMxN sijGlobal_next(sij_base + (i + 1) * M * N); + if (i % 2 == 0) { + TLOAD(sijTile_B, sijGlobal_next); + } else { + TLOAD(sijTile_A, sijGlobal_next); + } + } + } + + // Compute final row sum from accumulated pij values + pipe_barrier(PIPE_V); + TROWSUM(sumDN, sumAccTile, tmpTile); + + // Store lij (total sum). mij already stored after Pass 1. + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(lijGlobalDN, sumDN); + + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* pij_buf = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* mij = reinterpret_cast<__gm__ TensorData*>(args[2]); + __gm__ TensorData* lij = reinterpret_cast<__gm__ TensorData*>(args[3]); + union { + uint64_t u; + float f; + } scale_conv; + scale_conv.u = static_cast(args[4]); + float scale_value = scale_conv.f; + uint64_t n_blocks = static_cast(args[5]); + uint64_t valid_len_last = static_cast(args[6]); + + __gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_buf->buffer.addr) + sij_buf->start_offset; + __gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset; + __gm__ float* mij_addr = reinterpret_cast<__gm__ float*>(mij->buffer.addr) + mij->start_offset; + __gm__ float* lij_addr = reinterpret_cast<__gm__ float*>(lij->buffer.addr) + lij->start_offset; + + uint64_t q_tile_size = static_cast(sij_buf->shapes[0]); + + if (q_tile_size == 16) { + softmax_prepare_n_impl<16, 128>(sij_base, scale_value, pij_base, mij_addr, lij_addr, n_blocks, valid_len_last); + } else { + softmax_prepare_n_impl<64, 64>(sij_base, scale_value, pij_base, mij_addr, lij_addr, n_blocks, valid_len_last); + } +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/kernel_config.py b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/kernel_config.py new file mode 100644 index 000000000..3ce1cd680 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/kernel_config.py @@ -0,0 +1,46 @@ +""" +Paged Attention Kernel and Orchestration Configuration + +Defines the kernels and orchestration function for paged attention +with AIC/AIV subgraph splitting: + +AIC Kernels (Matrix Multiplication): + - aic_qk_matmul: Q @ K^T computation + - aic_pv_matmul: P @ V computation + +AIV Kernels (Vector Operations): + - aiv_softmax_prepare: scale, rowmax, exp, rowsum + - aiv_online_update: online softmax accumulation + fused normalization + +Note: aiv_normalize has been merged into aiv_online_update for efficiency. +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +# Orchestration config +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_orch.cpp"), + "function_name": "build_paged_attention_graph", +} + +# Kernel configs (aiv_normalize removed - merged into aiv_online_update) +KERNELS = [ + # AIC kernels (matrix multiplication using Cube unit) + {"func_id": 0, "name": "QK", "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 2, "name": "PV", "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 4, "name": "AIC_HUB", "source": str(_KERNELS_ROOT / "aic" / "aic_hub.cpp"), "core_type": "aic"}, + # AIV kernels (vector operations) + {"func_id": 1, "name": "SF", "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "UP", "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "AIV_HUB", "source": str(_KERNELS_ROOT / "aiv" / "aiv_hub.cpp"), "core_type": "aiv"}, +] + +# Runtime configuration +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 24, +} diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp new file mode 100644 index 000000000..03b5ba20e --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/paged_attention_unroll_4dims/kernels/orchestration/paged_attention_orch.cpp @@ -0,0 +1,318 @@ +/** + * Paged Attention Orchestration Function V2 - N_UNROLL=8, 4 Tasks Per Group + * + * Batches up to N_UNROLL blocks per group. Each group submits exactly 4 tasks: + * 1. QK matmul: qi @ K^T for n_blocks → sij_buf (q_tile, n_blocks * block_size) + * 2. Softmax: two-pass over sij_buf → pij_buf, mi, li + * 3. PV matmul: SplitK accumulated P @ V → oi_new (q_tile, head_dim) + * 4. Update: online softmax accumulation with group-level mi, li, oi_new + * + * Memory Layout: + * Query: (batch, seq_len, num_heads, head_dim) bf16 + * Key: (total_blocks, block_size, kv_head_num, head_dim) bf16 + * Value: (total_blocks, block_size, kv_head_num, head_dim) bf16 + */ + +#include +#include + +#include "pto_orchestration_api.h" + +#define N_UNROLL 64 + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_AIC_HUB 4 +#define FUNC_AIV_HUB 5 + +constexpr uint64_t PLATFORM_PROF_SYS_CNT_FREQ = 50000000; // 50 MHz + +inline double cycles_to_us(uint64_t cycles) { + return (static_cast(cycles) / PLATFORM_PROF_SYS_CNT_FREQ) * 1000000.0; +}; + +inline uint64_t get_sys_cnt_aicpu() { + uint64_t ticks; + asm volatile("mrs %0, cntvct_el0" : "=r"(ticks)); + return ticks; +} + +#ifdef ENABLE_PROFILING +#define CYCLE_COUNT_START() uint64_t _t0 = get_sys_cnt_aicpu(), _t1 +#define CYCLE_COUNT_LAP(acc) \ + do { \ + _t1 = get_sys_cnt_aicpu(); \ + acc += (_t1 - _t0); \ + _t0 = _t1; \ + } while (0) +#else +#define CYCLE_COUNT_START() (void)0 +#define CYCLE_COUNT_LAP(acc) (void)0 +#endif + +extern "C" { +/** + * Orchestration config — the executor reads these values to set up + * shared memory and runtime before calling aicpu_orchestration_entry. + */ +__attribute__((visibility("default"))) PTO2OrchestrationConfig aicpu_orchestration_config( + TaskArg* orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 7, + }; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(TaskArg* orch_args, int orch_thread_num, int orch_thread_index) { + (void)orch_thread_num; + (void)orch_thread_index; +#ifdef ENABLE_PROFILING + uint64_t prof_param_extract = 0; + uint64_t prof_ext_tensor = 0; + uint64_t prof_make_tensor = 0; + uint64_t prof_tensor_view = 0; + uint64_t prof_param_setup = 0; + uint64_t prof_submit_task = 0; + uint64_t prof_scope_and_loop = 0; + int prof_submit_count = 0; + int prof_make_count = 0; + int prof_view_count = 0; +#endif + + CYCLE_COUNT_START(); + + // Read dimensions from TaskArg tensor metadata + // query: shape=[batch, seq_len, num_heads, head_dim] + uint64_t batch = orch_args[0].tensor.shapes[0]; + uint64_t seq_len = orch_args[0].tensor.shapes[1]; + uint64_t num_heads = orch_args[0].tensor.shapes[2]; + uint64_t head_dim = orch_args[0].tensor.shapes[3]; + DataType data_type = orch_args[0].tensor.dtype; + + // key_cache: shape=[total_blocks, block_size, kv_head_num, head_dim] + uint64_t block_size = orch_args[1].tensor.shapes[1]; + uint64_t kv_head_num = orch_args[1].tensor.shapes[2]; + + // block_table: shape=[batch, max_num_blocks_per_req] + uint64_t block_num = orch_args[3].tensor.shapes[1]; + + // scale from scalar arg + uint64_t scale_value = orch_args[6].scalar; + uint64_t q_head_num = num_heads; + uint64_t q_tile = std::min(num_heads, 128UL); + uint64_t q_loop = (q_head_num + q_tile - 1) / q_tile; + CYCLE_COUNT_LAP(prof_param_extract); + + void* query_ptr = orch_args[0].data(); + void* kc_ptr = orch_args[1].data(); + void* vc_ptr = orch_args[2].data(); + void* out_ptr = orch_args[5].data(); + + uint64_t total_blocks_count = orch_args[1].tensor.shapes[0]; + + // 4D external tensors matching input data shapes + uint32_t query_shapes[4] = {(uint32_t)batch, (uint32_t)seq_len, (uint32_t)num_heads, (uint32_t)head_dim}; + uint32_t key_cache_shapes[4] = {(uint32_t)total_blocks_count, (uint32_t)block_size, (uint32_t)kv_head_num, (uint32_t)head_dim}; + uint32_t value_cache_shapes[4] = {(uint32_t)total_blocks_count, (uint32_t)block_size, (uint32_t)kv_head_num, (uint32_t)head_dim}; + uint32_t out_shapes[4] = {(uint32_t)batch, (uint32_t)seq_len, (uint32_t)num_heads, (uint32_t)head_dim}; + + Tensor query = make_tensor_external(query_ptr, query_shapes, 4, data_type, false); + Tensor key_cache = make_tensor_external(kc_ptr, key_cache_shapes, 4, data_type, false); + Tensor value_cache = make_tensor_external(vc_ptr, value_cache_shapes, 4, data_type, false); + Tensor out = make_tensor_external(out_ptr, out_shapes, 4, DataType::FLOAT32); + + int* host_block_table = orch_args[3].data(); + int* host_context_lens = orch_args[4].data(); + +#ifdef ENABLE_PROFILING + CYCLE_COUNT_LAP(prof_ext_tensor); +#endif + + // Prefetch first block host_context_lens data into cache + __builtin_prefetch(&host_context_lens[0], 0, 3); + + for (uint64_t b_idx = 0; b_idx < batch; b_idx++) { + uint64_t cur_seq = host_context_lens[b_idx]; + uint64_t bn_this_batch = (cur_seq + block_size - 1) / block_size; + // Pre-compute block table base pointer for this batch + int* bt_base = host_block_table + b_idx * block_num; + + // Prefetch next block host_context_lens data while processing current batch + if (b_idx + 1 < batch) { + __builtin_prefetch(&host_context_lens[b_idx + 1], 0, 3); + } + for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { + CYCLE_COUNT_LAP(prof_scope_and_loop); + PTO2_SCOPE() { + uint32_t oi_shapes[2] = {(uint32_t)q_tile, (uint32_t)head_dim}; + uint32_t li_shapes[1] = {(uint32_t)q_tile}; + uint32_t mi_shapes[1] = {(uint32_t)q_tile}; + Tensor oi = make_tensor(oi_shapes, 2, DataType::FLOAT32); + Tensor li_update = make_tensor(li_shapes, 1, DataType::FLOAT32, false); + Tensor mi_update = make_tensor(mi_shapes, 1, DataType::FLOAT32, false); +#ifdef ENABLE_PROFILING + prof_make_count += 3; + CYCLE_COUNT_LAP(prof_make_tensor); +#endif + + uint32_t qi_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)head_dim}; + uint32_t qi_offsets[4] = {(uint32_t)b_idx, 0, (uint32_t)(q_idx * q_tile), 0}; + Tensor qi = query.view(qi_shapes, qi_offsets); + uint32_t out_view_shapes[4] = {1, 1, (uint32_t)q_tile, (uint32_t)head_dim}; + uint32_t out_view_offsets[4] = {(uint32_t)b_idx, 0, (uint32_t)(q_idx * q_tile), 0}; + Tensor out_view = out.view(out_view_shapes, out_view_offsets); +#ifdef ENABLE_PROFILING + prof_view_count += 2; + CYCLE_COUNT_LAP(prof_tensor_view); +#endif + PTOParam params_inplace; + params_inplace.add_output(oi); + params_inplace.add_output(li_update); + params_inplace.add_output(mi_update); + CYCLE_COUNT_LAP(prof_param_setup); + pto2_rt_submit_aiv_task(FUNC_AIV_HUB, params_inplace); +#ifdef ENABLE_PROFILING + prof_submit_count++; + CYCLE_COUNT_LAP(prof_submit_task); +#endif + + // Reusable PTOParam objects — reset() before each use avoids + // repeated stack-frame construction in the inner loop. + PTOParam params_qk, params_sf, params_pv, params_up; + + for (uint64_t bn = 0; bn < bn_this_batch; bn += N_UNROLL) { + uint64_t n_blocks = std::min((uint64_t)N_UNROLL, bn_this_batch - bn); + + // Valid length for last block in this group + uint64_t last_block_seq_start = (bn + n_blocks - 1) * block_size; + uint64_t valid_len_last = std::min(block_size, cur_seq - last_block_seq_start); + CYCLE_COUNT_LAP(prof_param_extract); + + // === Task 1: Batched QK matmul === + uint32_t sij_buf_shapes[2] = {(uint32_t)q_tile, (uint32_t)(n_blocks * block_size)}; + Tensor sij_buf = make_tensor(sij_buf_shapes, 2, DataType::FLOAT32); +#ifdef ENABLE_PROFILING + prof_make_count += 1; + CYCLE_COUNT_LAP(prof_make_tensor); +#endif + + params_qk.reset(); + params_qk.add_input(qi); + params_qk.add_input(key_cache); + params_qk.add_output(sij_buf); + params_qk.add_scalar(n_blocks); + params_qk.add_scalar(reinterpret_cast(bt_base + bn)); + CYCLE_COUNT_LAP(prof_param_setup); + pto2_rt_submit_aic_task(FUNC_QK_MATMUL, params_qk); +#ifdef ENABLE_PROFILING + prof_submit_count++; + CYCLE_COUNT_LAP(prof_submit_task); +#endif + + // === Task 2: Two-pass softmax over all blocks in group === + uint32_t pij_buf_shapes[2] = {(uint32_t)q_tile, (uint32_t)(n_blocks * block_size)}; + Tensor pij_buf = make_tensor(pij_buf_shapes, 2, data_type); + Tensor mi = make_tensor(mi_shapes, 1, DataType::FLOAT32); + Tensor li = make_tensor(li_shapes, 1, DataType::FLOAT32); +#ifdef ENABLE_PROFILING + prof_make_count += 3; + CYCLE_COUNT_LAP(prof_make_tensor); +#endif + + params_sf.reset(); + params_sf.add_input(sij_buf); + params_sf.add_output(pij_buf); + params_sf.add_output(mi); + params_sf.add_output(li); + params_sf.add_scalar(scale_value); + params_sf.add_scalar(n_blocks); + params_sf.add_scalar(valid_len_last); + CYCLE_COUNT_LAP(prof_param_setup); + pto2_rt_submit_aiv_task(FUNC_SOFTMAX_PREPARE, params_sf); +#ifdef ENABLE_PROFILING + prof_submit_count++; + CYCLE_COUNT_LAP(prof_submit_task); +#endif + + // === Task 3: SplitK PV matmul (accumulated P @ V) === + uint32_t oi_new_shapes[2] = {(uint32_t)q_tile, (uint32_t)head_dim}; + Tensor oi_new = make_tensor(oi_new_shapes, 2, DataType::FLOAT32); +#ifdef ENABLE_PROFILING + prof_make_count += 1; + CYCLE_COUNT_LAP(prof_make_tensor); +#endif + + params_pv.reset(); + params_pv.add_input(pij_buf); + params_pv.add_input(value_cache); + params_pv.add_output(oi_new); + params_pv.add_scalar(n_blocks); + params_pv.add_scalar(reinterpret_cast(bt_base + bn)); + CYCLE_COUNT_LAP(prof_param_setup); + pto2_rt_submit_aic_task(FUNC_PV_MATMUL, params_pv); +#ifdef ENABLE_PROFILING + prof_submit_count++; + CYCLE_COUNT_LAP(prof_submit_task); +#endif + + // === Task 4: Online update (per-group) === + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn + n_blocks >= bn_this_batch) ? 1 : 0; + + params_up.reset(); + params_up.add_input(mi); + params_up.add_input(li); + params_up.add_input(oi_new); + params_up.add_inout(mi_update); + params_up.add_inout(li_update); + params_up.add_inout(oi); + params_up.add_output(out_view); + params_up.add_scalar(is_first); + params_up.add_scalar(is_last); + CYCLE_COUNT_LAP(prof_param_setup); + pto2_rt_submit_aiv_task(FUNC_ONLINE_UPDATE, params_up); +#ifdef ENABLE_PROFILING + prof_submit_count++; + CYCLE_COUNT_LAP(prof_submit_task); +#endif + } + } + CYCLE_COUNT_LAP(prof_scope_and_loop); + } + } + CYCLE_COUNT_LAP(prof_scope_and_loop); + +#ifdef ENABLE_PROFILING + uint64_t total = prof_param_extract + prof_ext_tensor + prof_make_tensor + + prof_tensor_view + prof_param_setup + prof_submit_task + + prof_scope_and_loop; + LOG_ALWAYS("=== PagedAttn Orch Profiling: %d submits, %d makes, %d views, total=%.3fus ===", + prof_submit_count, prof_make_count, prof_view_count, cycles_to_us(total)); + if (total > 0) { + LOG_ALWAYS(" param_extract : %7.3fus (%5.1f%%)", + cycles_to_us(prof_param_extract), prof_param_extract * 100.0 / total); + LOG_ALWAYS(" ext_tensor(x4) : %7.3fus (%5.1f%%)", + cycles_to_us(prof_ext_tensor), prof_ext_tensor * 100.0 / total); + LOG_ALWAYS(" make_tensor(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", + prof_make_count, cycles_to_us(prof_make_tensor), prof_make_tensor * 100.0 / total, + prof_make_count > 0 ? cycles_to_us(prof_make_tensor) / prof_make_count : 0.0); + LOG_ALWAYS(" tensor_view(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", + prof_view_count, cycles_to_us(prof_tensor_view), prof_tensor_view * 100.0 / total, + prof_view_count > 0 ? cycles_to_us(prof_tensor_view) / prof_view_count : 0.0); + LOG_ALWAYS(" param_setup : %7.3fus (%5.1f%%)", + cycles_to_us(prof_param_setup), prof_param_setup * 100.0 / total); + LOG_ALWAYS(" submit_task(x%d) : %7.3fus (%5.1f%%) avg=%.3fus", + prof_submit_count, cycles_to_us(prof_submit_task), prof_submit_task * 100.0 / total, + prof_submit_count > 0 ? cycles_to_us(prof_submit_task) / prof_submit_count : 0.0); + LOG_ALWAYS(" scope_and_loop : %7.3fus (%5.1f%%)", + cycles_to_us(prof_scope_and_loop), prof_scope_and_loop * 100.0 / total); + } +#endif + +#undef CYCLE_COUNT_START +#undef CYCLE_COUNT_LAP +} + +} // extern "C"