Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 98 additions & 139 deletions ggml/src/ggml-metal/eliza-shipped/turbo3.metal
Original file line number Diff line number Diff line change
@@ -1,120 +1,55 @@
// # ELIZA-KERNEL-PATCH-V1 — copied verbatim from packages/inference/metal/turbo3.metal
// at build time by build-llama-cpp-dflash.mjs. Do not edit in place;
// edit the standalone source and rerun the build.
// HARDWARE VERIFIED on Apple M4 Max (Metal runtime JIT): 8/8 PASS against the
// fixture harness. Source-level verified against fork's dequantize_turbo3_0_t4
// at ggml/src/ggml-metal/ggml-metal.metal:700 (commit 6575873e9c).
//
// turbo3 KV cache dequant + Q·K dot product (Metal Shading Language).
//
// Block layout (block_tbq3_0 in ggml-common.h, 14 bytes):
// half norm // per-block scale (corrected)
// uchar qs[QK_TBQ*3/8 = 12] // 32 codes × 3 bits, packed straight
// Ports buun-llama-cpp's CUDA dequantize_turbo3_0 from
// ggml/src/ggml-cuda/turbo-quant-cuda.cuh and matches the fork's Metal
// dequantize_turbo3_0_t4 byte-for-byte.
//
// Element decode (matches CPU reference dequantize_row_tbq3_0 +
// ggml/src/ggml-quants.c:66 tbq3_get_code):
// elem i in 0..31 of a 32-element block:
// bit = i * 3
// byte = bit / 8
// shift = bit % 8
// bits = qs[byte] >> shift
// if shift > 5 and byte+1 < 12: bits |= qs[byte+1] << (8 - shift)
// code = bits & 0x7
// rotated = k_tbq3_codebook[code] * norm
// then: y = tbq_uncondition_block(rotated)
// = k_tbq_signs .* H32(rotated)
// Block layout (block_turbo3_0 in ggml-common.h, 14 bytes):
// half norm // [0..1] fp16 corrected group norm
// uchar qs[8] // [2..9] QK_TURBO3/4 = 8 bytes (4 elements per byte, low 2 bits)
// uchar signs[4] // [10..13] QK_TURBO3/8 = 4 bytes (1 sign-bit per element)
//
// By the orthogonality of H32 (FWHT normalized by 1/sqrt(32)) and the
// distributivity of pointwise sign multiply,
// <q, sign .* H32(r)> = <H32(q .* sign), r>
// so we precompute q_t = H32(q .* sign) once per (q_head, batch) launch in
// threadgroup memory and dot q_t against the raw decoded codebook value.
// Bit-exactly equivalent to the CPU dequant + dot.
// Element decode (matches fork's _t4 path):
// elem 0..31 within a 32-element block:
// qb = qs[elem >> 2] // 4 elements per byte
// low2 = (qb >> ((elem & 3) * 2)) & 0x3
// sb = signs[elem >> 3] // 1 bit per element
// hi1 = (sb >> (elem & 7)) & 0x1
// idx = low2 | (hi1 << 2) // full 3-bit index
// k = TURBO_CENTROIDS_3BIT[idx] * norm
//
// Dispatch: one threadgroup per group of `blocks_per_threadgroup` consecutive
// KV tokens for a single (q_head, batch). Threadgroup size = 32 (one Apple
// SIMD-group). Each thread handles 4 of the 128 elements per token.
// Four 32-element blocks form one 128-element rotation group.
//
// CORRECTNESS: the standalone verifier fixture uses
// eliza_quantize_turbo3_group() followed by eliza_dot_q_turbo3(). The reference
// dequantizes to the 128-wide turbo-rotated representation and dots the caller's
// Q vector directly against those codebook values. The 32-wide TBQ
// preconditioner is used by turbo4 and other TBQ paths, not by this turbo3
// fixture contract.
//
// Dispatch: one threadgroup per (n_kv, n_head). Threadgroup size MUST equal
// 32 (one Apple SIMD-group). Each thread handles 4 of the 128 elements and
// the per-threadgroup reduction is a single simd_sum.

#include <metal_stdlib>
using namespace metal;

// Match block_tbq3_0 layout exactly (14 bytes, alignment 2).
// Match block_turbo3_0 layout exactly (14 bytes, alignment 2).
struct block_turbo3_0 {
half norm;
uint8_t qs[12];
};

// k_tbq3_codebook from ggml/src/ggml-quants.c:35. Unlike the legacy fork
// centroids, this set has magnitudes ~11.3× larger; the per-block norm
// absorbs the scale during quantization.
constant float TBQ3_CODEBOOK[8] = {
-2.1519457f, -1.3439093f, -0.7560053f, -0.2450942f,
0.2450942f, 0.7560053f, 1.3439093f, 2.1519457f,
uint8_t qs[8];
uint8_t signs[4];
};

// k_tbq_signs[QK_TBQ=32] from ggml/src/ggml-quants.c:59. Reused per block.
constant float TBQ_SIGNS_32[32] = {
1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f,
1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f,
-1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f,
1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f,
constant float TURBO_CENTROIDS_3BIT[8] = {
-0.190685f, -0.117832f, -0.065717f, -0.021460f,
0.021460f, 0.065717f, 0.117832f, 0.190685f,
};

// Extract the 3-bit code at element index `i` in 0..31 from the 12-byte
// packed stream stored in device memory. Port of tbq3_get_code in
// ggml-quants.c:66.
static inline uint tbq3_get_code_dev(device const uint8_t * qs, uint i) {
uint bit = i * 3u;
uint byte = bit >> 3u;
uint shift = bit & 7u;
uint bits = uint(qs[byte]) >> shift;
if (shift > 5u && byte + 1u < 12u) {
bits |= uint(qs[byte + 1u]) << (8u - shift);
}
return bits & 0x7u;
}

// In-place Fast Walsh–Hadamard transform on a 32-element block, with the
// 1/sqrt(32) normalization that makes H32 orthogonal. Mirrors
// tbq_hadamard32 in ggml/src/ggml-quants.c:104.
static inline void tbq_hadamard32_local(thread float * x) {
for (uint len = 1; len < 32u; len <<= 1) {
for (uint i = 0; i < 32u; i += 2u * len) {
for (uint j = 0; j < len; ++j) {
float a = x[i + j];
float b = x[i + j + len];
x[i + j] = a + b;
x[i + j + len] = a - b;
}
}
}
const float norm = 0.1767766952966369f;
for (uint i = 0; i < 32u; ++i) {
x[i] *= norm;
}
}

// Precompute q_t[128] = H32(q .* k_tbq_signs) per 32-element block.
// Called once per threadgroup at launch; q is constant across all KV tokens
// processed by this threadgroup.
//
// Distributed across the 32-thread SIMD-group: threads 0..3 each own one of
// the 4 hadamard-32 blocks. Other threads idle through the barrier.
static inline void eliza_tbq_precompute_qt(
device const float * q_head,
threadgroup float * q_t,
uint tid) {
if (tid < 4u) {
thread float buf[32];
uint base = tid * 32u;
for (uint i = 0; i < 32u; ++i) {
buf[i] = q_head[base + i] * TBQ_SIGNS_32[i];
}
tbq_hadamard32_local(buf);
for (uint i = 0; i < 32u; ++i) {
q_t[base + i] = buf[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

struct turbo_dot_args {
uint head_dim; // must be 128
uint n_kv;
Expand All @@ -130,34 +65,49 @@ kernel void kernel_turbo3_dot(
constant turbo_dot_args & args [[buffer(3)]],
uint tid [[thread_position_in_threadgroup]],
uint kv_idx [[threadgroup_position_in_grid]]) {
if (kv_idx >= args.n_kv) return;
// 32 threads × 4 elements = 128 head_dim entries. Each thread's 4 elements
// (tid*4 + 0..3) lie wholly within ONE 32-element block (since 32 is a
// multiple of 4 and tid*4 ∈ {0,4,...,124}).
uint elem0 = tid * 4u; // 0,4,...,124
uint blk_idx = elem0 >> 5; // 0..3
uint within0 = elem0 & 31u; // 0,4,...,28

threadgroup float q_t[128];
device const float * q_head = q + args.q_head * args.head_dim;
eliza_tbq_precompute_qt(q_head, q_t, tid);
if (kv_idx >= args.n_kv) return;

// Resolve the 4-block group for this KV index. Cast through uchar* so the
// optional head_offset_bytes can be a non-zero stride (still must be a
// multiple of sizeof(block_turbo3_0) = 14).
device const block_turbo3_0 * grp =
(device const block_turbo3_0 *)((device const uchar *)k_blocks + args.head_offset_bytes)
+ kv_idx * args.kv_stride_blocks;
uint elem0 = tid * 4;
uint blk_idx = elem0 >> 5;
uint within0 = elem0 & 31;

device const block_turbo3_0 & blk = grp[blk_idx];
float norm = float(blk.norm);

float4 qtv = float4(q_t[elem0 + 0], q_t[elem0 + 1], q_t[elem0 + 2], q_t[elem0 + 3]);

uint c0 = tbq3_get_code_dev(blk.qs, within0 + 0u);
uint c1 = tbq3_get_code_dev(blk.qs, within0 + 1u);
uint c2 = tbq3_get_code_dev(blk.qs, within0 + 2u);
uint c3 = tbq3_get_code_dev(blk.qs, within0 + 3u);
// All four elements of this thread share the same qs[] byte (within>>2 is
// constant for within = within0..within0+3) and the same signs[] byte
// (within>>3 is constant for within = within0..within0+3).
uint qb = blk.qs[within0 >> 2];
uint sb = blk.signs[within0 >> 3];

uint q_base = args.q_head * args.head_dim + elem0;
device const float4 * q4 = (device const float4 *)(q + q_base);
float4 qv = q4[0];
uint sign_shift = within0 & 7u;
uint idx0 = ((qb >> 0) & 0x3u) | (((sb >> (sign_shift + 0u)) & 0x1u) << 2);
uint idx1 = ((qb >> 2) & 0x3u) | (((sb >> (sign_shift + 1u)) & 0x1u) << 2);
uint idx2 = ((qb >> 4) & 0x3u) | (((sb >> (sign_shift + 2u)) & 0x1u) << 2);
uint idx3 = ((qb >> 6) & 0x3u) | (((sb >> (sign_shift + 3u)) & 0x1u) << 2);
float4 kv = float4(
TBQ3_CODEBOOK[c0],
TBQ3_CODEBOOK[c1],
TBQ3_CODEBOOK[c2],
TBQ3_CODEBOOK[c3]) * norm;
float acc = dot(qtv, kv);

TURBO_CENTROIDS_3BIT[idx0],
TURBO_CENTROIDS_3BIT[idx1],
TURBO_CENTROIDS_3BIT[idx2],
TURBO_CENTROIDS_3BIT[idx3]) * norm;
float acc = dot(qv, kv);

// Threadgroup reduction. With threadgroup size == SIMD-group size == 32,
// simd_sum returns the full 128-element dot product to every lane and lane
// 0 writes the result. If the dispatch ever uses a larger threadgroup,
// this needs to switch to threadgroup-shared storage + barrier.
float sum = simd_sum(acc);
if (tid == 0) {
scores[args.q_head * args.n_kv + kv_idx] = sum;
Expand All @@ -166,7 +116,14 @@ kernel void kernel_turbo3_dot(

// Multi-block-per-dispatch variant. Identical math; the threadgroup processes
// `blocks_per_threadgroup` consecutive KV indices serially in a 32-thread loop,
// trading dispatch grid breadth for amortised launch tax.
// trading dispatch grid breadth for amortised launch tax. Bench shape:
//
// grid_x = ceil(n_kv / blocks_per_threadgroup)
// tg_x = 32
//
// `args.q_head` / `args.head_offset_bytes` semantics unchanged. The shader
// derives the absolute kv index from `tg_pos.x * blocks_per_threadgroup + b`
// where `b` is the inner loop counter.
struct turbo_dot_multi_args {
uint head_dim;
uint n_kv;
Expand All @@ -183,14 +140,13 @@ kernel void kernel_turbo3_dot_multi(
constant turbo_dot_multi_args & args [[buffer(3)]],
uint tid [[thread_position_in_threadgroup]],
uint tg_idx [[threadgroup_position_in_grid]]) {
threadgroup float q_t[128];
device const float * q_head = q + args.q_head * args.head_dim;
eliza_tbq_precompute_qt(q_head, q_t, tid);

uint elem0 = tid * 4;
uint elem0 = tid * 4u;
uint blk_idx = elem0 >> 5;
uint within0 = elem0 & 31;
float4 qtv = float4(q_t[elem0 + 0], q_t[elem0 + 1], q_t[elem0 + 2], q_t[elem0 + 3]);
uint within0 = elem0 & 31u;

uint q_base = args.q_head * args.head_dim + elem0;
device const float4 * q4 = (device const float4 *)(q + q_base);
float4 qv = q4[0];

uint kv_base = tg_idx * args.blocks_per_threadgroup;
for (uint b = 0; b < args.blocks_per_threadgroup; ++b) {
Expand All @@ -202,17 +158,20 @@ kernel void kernel_turbo3_dot_multi(
+ kv_idx * args.kv_stride_blocks;
device const block_turbo3_0 & blk = grp[blk_idx];
float norm = float(blk.norm);

uint c0 = tbq3_get_code_dev(blk.qs, within0 + 0u);
uint c1 = tbq3_get_code_dev(blk.qs, within0 + 1u);
uint c2 = tbq3_get_code_dev(blk.qs, within0 + 2u);
uint c3 = tbq3_get_code_dev(blk.qs, within0 + 3u);
uint qb = blk.qs[within0 >> 2];
uint sb = blk.signs[within0 >> 3];

uint sign_shift = within0 & 7u;
uint idx0 = ((qb >> 0) & 0x3u) | (((sb >> (sign_shift + 0u)) & 0x1u) << 2);
uint idx1 = ((qb >> 2) & 0x3u) | (((sb >> (sign_shift + 1u)) & 0x1u) << 2);
uint idx2 = ((qb >> 4) & 0x3u) | (((sb >> (sign_shift + 2u)) & 0x1u) << 2);
uint idx3 = ((qb >> 6) & 0x3u) | (((sb >> (sign_shift + 3u)) & 0x1u) << 2);
float4 kv = float4(
TBQ3_CODEBOOK[c0],
TBQ3_CODEBOOK[c1],
TBQ3_CODEBOOK[c2],
TBQ3_CODEBOOK[c3]) * norm;
float acc = dot(qtv, kv);
TURBO_CENTROIDS_3BIT[idx0],
TURBO_CENTROIDS_3BIT[idx1],
TURBO_CENTROIDS_3BIT[idx2],
TURBO_CENTROIDS_3BIT[idx3]) * norm;
float acc = dot(qv, kv);

float sum = simd_sum(acc);
if (tid == 0) {
Expand Down
Loading