Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Paged Attention Unroll Golden - tensormap_and_ringbuffer test (production scale, bfloat16).

Input shapes use 4D format: (batch, seq_len, num_heads, head_dim) for query and out.
"""

from paged_attention_golden import (
generate_inputs as _generate_inputs,
compute_golden as _compute_golden,
run_golden_test,
)

__outputs__ = ["out"]

RTOL = 1e-3
ATOL = 1e-3

ALL_CASES = {
"Case1": {
"batch": 256,
"num_heads": 16,
"kv_head_num": 1,
"head_dim": 128,
"block_size": 128,
"context_len": 8192,
"max_model_len": 32768,
"dtype": "bfloat16",
},
"Case2": {
"batch": 64,
"num_heads": 64,
"kv_head_num": 1,
"head_dim": 128,
"block_size": 64,
"context_len": 8192,
"max_model_len": 32768,
"dtype": "bfloat16",
},
"Case3": {
"batch": 64,
"num_heads": 64,
"kv_head_num": 1,
"head_dim": 256,
"block_size": 64,
"context_len": 8192,
"max_model_len": 32768,
"dtype": "bfloat16",
},
}

DEFAULT_CASE = "Case1"


def generate_inputs(params: dict) -> list:
result = _generate_inputs(params)
batch = params["batch"]
num_heads = params["num_heads"]
head_dim = params["head_dim"]
reshaped = []
for name, val in result:
if name in ("query", "out"):
val = val.reshape(batch, 1, num_heads, head_dim)
reshaped.append((name, val))
return reshaped


def compute_golden(tensors: dict, params: dict) -> None:
batch = params["batch"]
num_heads = params["num_heads"]
head_dim = params["head_dim"]
out_4d = tensors["out"]
tensors["out"] = out_4d.reshape(batch, num_heads, head_dim)
_compute_golden(tensors, params)
tensors["out"] = out_4d


if __name__ == "__main__":
run_golden_test(ALL_CASES, DEFAULT_CASE, generate_inputs, label="Paged Attention Unroll")
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <cstdint>
#include <pto/pto-inst.hpp>

using namespace pto;

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

constexpr int M = 16;
constexpr int K = 16;
constexpr int N = 16;

extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// SplitK PV Matmul Kernel: Accumulated P @ V across n_blocks
//
// Processes n_blocks blocks using SplitK accumulation pattern:
// Block 0: TMATMUL(C, A, B) — initialize accumulator
// Block i: TMATMUL_ACC(C, C, A, B) — accumulate into same C
//
// Per-block pij addresses: contiguous slices of pij_buf (n_blocks * M * K)
// Per-block vj addresses: value_cache base + block_indices lookup
// Single output: oi_new (M, N) fp32 = sum of P_i @ V_i across all blocks
//
// Optimizations:
// - Double-buffered L1 tiles (ping/pong for A and B via MTE2)
// - Double-buffered L0 tiles (ping/pong for L0A and L0B via MTE1)
// - TLOAD(next) overlaps with TMATMUL(current) via MTE2/M-pipe parallelism
// - Canonical 3-stage pipeline: TLOAD(MTE2) → TMOV(MTE1) → TMATMUL(M)
// - Reverse-dependency events ensure buffer safety across iterations
//
// Supports two tile configurations via runtime dispatch:
// Case1: (16, 128) @ (128, 128) -> (16, 128)
// Case2: (64, 64) @ ( 64, 128) -> (64, 128)
//
// pij is bfloat16 (from softmax_prepare TCVT).
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.

#include <cstdint>
#include <pto/pto-inst.hpp>

#include "tensor.h"

using namespace pto;

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

template <int M, int K, int N>
static __aicore__ void pv_matmul_n_impl(
__gm__ bfloat16_t* pij_base,
__gm__ bfloat16_t* val_base,
__gm__ float* oi_base,
uint64_t n_blocks,
__gm__ int32_t* block_table) {

using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;

using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

// L1 memory layout: double-buffered A and B tiles (tightly packed)
constexpr int kATileBytes = M * K * static_cast<int>(sizeof(bfloat16_t));
constexpr int kBTileBytes = K * N * static_cast<int>(sizeof(bfloat16_t));

TileMatA aMatTile[2];
TileMatB bMatTile[2];
TASSIGN(aMatTile[0], 0x0);
TASSIGN(aMatTile[1], kATileBytes);
TASSIGN(bMatTile[0], 2 * kATileBytes);
TASSIGN(bMatTile[1], 2 * kATileBytes + kBTileBytes);

// L0 memory layout: double-buffered L0A and L0B, single accumulator L0C
LeftTile aTile[2];
RightTile bTile[2];
AccTile cTile;
TASSIGN(aTile[0], 0x0);
TASSIGN(aTile[1], kATileBytes);
TASSIGN(bTile[0], 0x0);
TASSIGN(bTile[1], kBTileBytes);
TASSIGN(cTile, 0x0);

GlobalOut oiGlobal(oi_base);

// Seed reverse-dependency flags: all ping/pong buffers initially free
// PIPE_MTE1 → PIPE_MTE2: L1 buffer [0/1] safe for TLOAD to overwrite
// PIPE_M → PIPE_MTE1: L0 buffer [0/1] safe for TMOV to overwrite
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);

for (uint64_t i = 0; i < n_blocks; i++) {
int cur = static_cast<int>(i % 2);
GlobalA pijGlobal(pij_base + i * M * K);
GlobalB vjGlobal(val_base + block_table[i] * K * N);

// Stage 1: TLOAD (MTE2: GM → L1[cur])
// Wait for MTE1 to release L1[cur] (reverse dep from previous iteration)
wait_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur);
TLOAD(aMatTile[cur], pijGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: A in L1 ready
TLOAD(bMatTile[cur], vjGlobal);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: B in L1 ready

// Stage 2: TMOV (MTE1: L1[cur] → L0[cur])
// Wait for M-pipe to release L0[cur] (reverse dep from previous iteration)
wait_flag(PIPE_M, PIPE_MTE1, (event_t)cur);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); // forward: wait A loaded
TMOV(aTile[cur], aMatTile[cur]);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); // forward: wait B loaded
TMOV(bTile[cur], bMatTile[cur]);
set_flag(PIPE_MTE1, PIPE_MTE2, (event_t)cur); // reverse: release L1[cur]

// Stage 3: TMATMUL (M-pipe: L0A[cur] × L0B[cur] → L0C)
set_flag(PIPE_MTE1, PIPE_M, (event_t)cur); // forward: L0[cur] ready
wait_flag(PIPE_MTE1, PIPE_M, (event_t)cur);
if (i == 0) {
TMATMUL(cTile, aTile[cur], bTile[cur]);
} else {
TMATMUL_ACC(cTile, cTile, aTile[cur], bTile[cur]);
}
set_flag(PIPE_M, PIPE_MTE1, (event_t)cur); // reverse: release L0[cur]
}

// Drain outstanding reverse-dependency flags
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);

set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
TSTORE(oiGlobal, cTile);

set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
}

extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
__gm__ TensorData* pij_buf = reinterpret_cast<__gm__ TensorData*>(args[0]);
__gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
__gm__ TensorData* oi_new = reinterpret_cast<__gm__ TensorData*>(args[2]);
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);

__gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_buf->buffer.addr) + pij_buf->start_offset;
__gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr);
__gm__ float* oi_base = reinterpret_cast<__gm__ float*>(oi_new->buffer.addr) + oi_new->start_offset;

uint64_t q_tile_size = static_cast<uint64_t>(pij_buf->shapes[0]);

if (q_tile_size == 16) {
pv_matmul_n_impl<16, 128, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
} else {
pv_matmul_n_impl<64, 64, 128>(pij_base, val_base, oi_base, n_blocks, block_table);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Multi-block QK Matmul Kernel: qi(M, K) @ kj.T(K, N) -> sij(M, N) for each block
//
// Processes n_blocks blocks in a single kernel invocation.
// Per-block kj addresses computed from key_cache base + block_indices lookup.
// qi is shared across all blocks (same query head against different key blocks).
//
// Output layout: n_blocks contiguous (M, N) tiles stacked vertically.
// Block i occupies sij[i*M : (i+1)*M, 0:N].
//
// Optimizations:
// - qi TLOAD hoisted before the loop (constant across all iterations)
//
// Supports two tile configurations via runtime dispatch:
// Case1: (16, 128) @ (128, 128).T -> (16, 128)
// Case2: (64, 128) @ (128, 64).T -> (64, 64)
//
// Template: M=q_tile, K=head_dim, N=block_size

#include <cstdint>
#include <pto/pto-inst.hpp>

#include "tensor.h"

using namespace pto;

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

template <int M, int K, int N>
static __aicore__ void qk_matmul_n_impl(
__gm__ bfloat16_t* qi_base,
__gm__ bfloat16_t* key_base,
__gm__ float* sij_base,
uint64_t n_blocks,
__gm__ int32_t* block_table) {

using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;

using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

TileMatA aMatTile;
TileMatB bMatTile;
TASSIGN(aMatTile, 0x0);
TASSIGN(bMatTile, 0x20000);

LeftTile aTile;
RightTile bTile;
AccTile cTile;
TASSIGN(aTile, 0x0);
TASSIGN(bTile, 0x0);
TASSIGN(cTile, 0x0);

// Hoist qi TLOAD before the loop (qi is constant across all blocks)
GlobalA qiGlobal(qi_base);
TLOAD(aMatTile, qiGlobal);

for (uint64_t i = 0; i < n_blocks; i++) {
GlobalB kjGlobal(key_base + block_table[i] * N * K);
GlobalOut sijGlobal(sij_base + i * M * N);

// Load only B each iteration (qi already in L1 from hoist)
TLOAD(bMatTile, kjGlobal);

set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);

// TMOV qi from L1→L0A (re-copy since TMATMUL consumed L0A) and kj from L1→L0B
TMOV(aTile, aMatTile);
TMOV(bTile, bMatTile);

set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);

TMATMUL(cTile, aTile, bTile);

set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);

TSTORE(sijGlobal, cTile);

if (i + 1 < n_blocks) {
pipe_barrier(PIPE_ALL);
}
}
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
}

extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {
__gm__ TensorData* qi = reinterpret_cast<__gm__ TensorData*>(args[0]);
__gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]);
__gm__ TensorData* sij_buf = reinterpret_cast<__gm__ TensorData*>(args[2]);
uint64_t n_blocks = static_cast<uint64_t>(args[3]);
__gm__ int32_t* block_table = reinterpret_cast<__gm__ int32_t*>(args[4]);

__gm__ bfloat16_t* qi_base = reinterpret_cast<__gm__ bfloat16_t*>(qi->buffer.addr) + qi->start_offset;
__gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr);
__gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_buf->buffer.addr) + sij_buf->start_offset;

// qi is a 4D view: (1, 1, q_tile, head_dim)
uint64_t q_tile_size = static_cast<uint64_t>(qi->shapes[2]);

if (q_tile_size == 16) {
qk_matmul_n_impl<16, 128, 128>(qi_base, key_base, sij_base, n_blocks, block_table);
} else {
qk_matmul_n_impl<64, 128, 64>(qi_base, key_base, sij_base, n_blocks, block_table);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <cstdint>
#include <pto/pto-inst.hpp>

using namespace pto;

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#define __aicore__ [aicore]
#endif

constexpr int M = 16;
constexpr int K = 16;
constexpr int N = 16;

extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {}
Loading
Loading