Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_TURBO2_0,
GGML_TYPE_TURBO3_0,
GGML_TYPE_TURBO4_0,
};

static ggml_type kv_cache_type_from_str(const std::string & s) {
Expand Down
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
#include <string.h>
#include <fcntl.h>
#include <io.h>
#ifndef fileno
#define fileno _fileno
#endif
#ifndef isatty
#define isatty _isatty
#endif
#else
#include <sys/ioctl.h>
#include <sys/stat.h>
Expand Down
4 changes: 2 additions & 2 deletions ggml/include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
extern "C" {
#endif

#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MAJOR_VERSION 5
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0

#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC protocol version");
#endif

#define GGML_RPC_MAX_SERVERS 16
Expand Down
18 changes: 17 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,12 @@ extern "C" {
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_Q1_0 = 41,
GGML_TYPE_COUNT = 42,
GGML_TYPE_TURBO2_0 = 42, // TurboQuant 2-bit KV cache: WHT + 2-bit PolarQuant
GGML_TYPE_TURBO3_0 = 43, // TurboQuant 3-bit KV cache: WHT + 3-bit PolarQuant
GGML_TYPE_TURBO4_0 = 44, // TurboQuant 4-bit KV cache: WHT + 4-bit PolarQuant
GGML_TYPE_TQ3_1S = 45, // TurboQuant 3-bit weight: WHT-rotated 8-level Lloyd-Max, block_size=32
GGML_TYPE_TQ4_1S = 46, // TurboQuant 4-bit weight: WHT-rotated 16-level Lloyd-Max, block_size=32
GGML_TYPE_COUNT = 47,
};

// precision
Expand Down Expand Up @@ -567,6 +572,7 @@ extern "C" {
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_TURBO_WHT,

GGML_OP_UNARY,
Comment thread
romgenie marked this conversation as resolved.

Expand Down Expand Up @@ -2555,6 +2561,16 @@ extern "C" {
struct ggml_tensor * beta,
struct ggml_tensor * state);

// TurboQuant Walsh-Hadamard Transform (O(d log d) rotation for KV cache compression)
// Applies WHT rotation to 128-element groups along ne[0]: sign1 → butterfly → sign2 → normalize
// direction: 0 = forward (signs1 → WHT → signs2), 1 = inverse (signs2 → WHT → signs1)
GGML_API struct ggml_tensor * ggml_turbo_wht(
struct ggml_context * ctx,
struct ggml_tensor * a,
int direction,
int group_size, // 0 = auto (64 or 128 from ne[0])
struct ggml_tensor * scale); // NULL = no InnerQ scaling

// custom operators

typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ add_library(ggml-base
ggml-threading.h
ggml-quants.c
ggml-quants.h
ggml-turbo-quant.c
gguf.cpp)

set_target_properties(ggml-base PROPERTIES
Expand Down
91 changes: 91 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,97 @@ typedef struct {
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

// TurboQuant 3-bit MSE-only: 3-bit PolarQuant indices (no QJL)
// Storage block size = 32 (matches q4_0 for optimal GPU parallelism)
// Transform group size = 128 (head_dim, for rotation Gaussianization)
// Per block: norm(fp16) + 2-bit indices (8 bytes) + 1-bit extra (4 bytes) = 14 bytes per 32 values
// = 3.5 bits/value → 4.6× compression vs fp16
// The 3-bit index is split: lower 2 bits in qs[], upper 1 bit in signs[]
#define QK_TURBO3 128 // Block size 128: one block per rotation group, eliminates redundant norms
#define QK_TURBO3_GROUP 128 // rotation group size = head_dim
// Derived: FA template nl parameters (auto-scale with block size)
#define NL_TURBO3 (QK_TURBO3 / 16) // non-vec FA iterations per block
#define NL_TURBO3_VEC (QK_TURBO3 / 4) // vec FA iterations per block
typedef struct {
ggml_half norm; // 2 bytes: vector L2 norm (for rescaling)
uint8_t qs[QK_TURBO3 / 4]; // 8 bytes: lower 2-bit indices (4 per byte)
uint8_t signs[QK_TURBO3 / 8]; // 4 bytes: upper 1-bit of 3-bit index (8 per byte)
} block_turbo3_0; // 14 bytes total
static_assert(sizeof(block_turbo3_0) == sizeof(ggml_half) + QK_TURBO3/4 + QK_TURBO3/8, "wrong turbo3_0 block size/padding");

// TurboQuant 4-bit: 3-bit PolarQuant indices + 1-bit QJL signs
// TURBO4_USE_4BIT: switch between 4-bit PolarQuant (new) and 3-bit+QJL (legacy)
// Default: 4-bit on all backends (Metal + CUDA validated)
#ifndef TURBO4_USE_4BIT
# define TURBO4_USE_4BIT 1
#endif

#define QK_TURBO4 128

#if TURBO4_USE_4BIT
// 4-bit PolarQuant: 16 optimal centroids, nibble packed, no QJL
// Per block: norm(fp16) + rnorm(fp16, reserved) + 4-bit indices (64 bytes)
// = 68 bytes per 128 values = 4.25 bits/value → 3.8× compression vs fp16
typedef struct {
ggml_half norm; // 2 bytes
ggml_half rnorm; // 2 bytes (reserved, unused in 4-bit mode)
uint8_t qs[QK_TURBO4 / 2]; // 64 bytes: 4-bit PolarQuant indices (nibble packed)
} block_turbo4_0; // 68 bytes total
static_assert(sizeof(block_turbo4_0) == 68, "wrong turbo4_0 block size");
#else
// Legacy 3-bit PolarQuant + 1-bit QJL (original paper design)
// Per block: norm(fp16) + rnorm(fp16) + 3-bit indices (48 bytes) + 1-bit QJL signs (16 bytes)
// = 68 bytes per 128 values = 4.25 bits/value → 3.8× compression vs fp16
typedef struct {
ggml_half norm; // 2 bytes
ggml_half rnorm; // 2 bytes: residual norm for QJL scale
uint8_t qs[QK_TURBO4 * 3 / 8]; // 48 bytes: 3-bit PolarQuant indices
uint8_t signs[QK_TURBO4 / 8]; // 16 bytes: 1-bit QJL signs
} block_turbo4_0; // 68 bytes total
static_assert(sizeof(block_turbo4_0) == 2*sizeof(ggml_half) + QK_TURBO4*3/8 + QK_TURBO4/8, "wrong turbo4_0 block size");
#endif

static_assert(QK_TURBO4 == 128, "turbo4 kernels assume QK_TURBO4 == 128");

// TurboQuant 2-bit: 2-bit PolarQuant indices only (no QJL)
// Per block: norm(fp16) + 2-bit indices (8 bytes) = 10 bytes per 32 values
// = 2.5 bits/value → 6.4× compression vs fp16
// 4 centroids (Lloyd-Max for N(0, 1/128)): {-0.133462, -0.039994, 0.039994, 0.133462}
#define QK_TURBO2 128 // Block size 128: one block per rotation group
#define QK_TURBO2_GROUP 128 // rotation group size = head_dim
// Derived: FA template nl parameters (auto-scale with block size)
#define NL_TURBO2 (QK_TURBO2 / 16) // non-vec FA iterations per block
#define NL_TURBO2_VEC (QK_TURBO2 / 4) // vec FA iterations per block
typedef struct {
ggml_half norm; // 2 bytes: corrected L2 norm
uint8_t qs[QK_TURBO2 / 4]; // 8 bytes: 2-bit indices (4 per byte)
} block_turbo2_0; // 10 bytes total
static_assert(sizeof(block_turbo2_0) == sizeof(ggml_half) + QK_TURBO2/4, "wrong turbo2_0 block size/padding");

// TQ3_1S: WHT-rotated 3-bit weight quantization (8-level Lloyd-Max for N(0,1))
// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31])
// Per block: d0(fp16) + d1(fp16) + 3-bit indices packed (12 bytes) = 16 bytes per 32 values
// = 4.0 bits/value
#define QK_TQ3_0 32
typedef struct {
ggml_half d0; // 2 bytes: scale for first 16 elements
ggml_half d1; // 2 bytes: scale for last 16 elements
uint8_t qs[QK_TQ3_0 * 3 / 8]; // 12 bytes: 3-bit indices packed (4 groups of 8 in 3 bytes)
} block_tq3_1s; // 16 bytes total
static_assert(sizeof(block_tq3_1s) == 16, "wrong tq3_1s block size");

// TQ4_1S: WHT-rotated 4-bit weight quantization (16-level Lloyd-Max for N(0,1))
// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31])
// Per block: d0(fp16) + d1(fp16) + 4-bit indices packed (16 bytes) = 20 bytes per 32 values
// = 5.0 bits/value
#define QK_TQ4_1S 32
typedef struct {
ggml_half d0; // 2 bytes: scale for first 16 elements
ggml_half d1; // 2 bytes: scale for last 16 elements
uint8_t qs[QK_TQ4_1S / 2]; // 16 bytes: 4-bit indices nibble-packed
} block_tq4_1s; // 20 bytes total
static_assert(sizeof(block_tq4_1s) == 20, "wrong tq4_1s block size");

//
// Super-block quantization structures
//
Expand Down
163 changes: 163 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#include "quants.h"
#include "ggml-quants.h"
#include "ggml-threading.h"
#include "unary-ops.h"
#include "binary-ops.h"
Expand Down Expand Up @@ -208,6 +209,23 @@ typedef pthread_t ggml_thread_t;
#include <TargetConditionals.h>
#endif

// Forward declarations — defined below, after utility functions
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_tq3_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_tq4_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);

static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
Expand Down Expand Up @@ -403,6 +421,36 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
[GGML_TYPE_TURBO3_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo3_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo3_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_TURBO2_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo2_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo2_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_TURBO4_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo4_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo4_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_TQ3_1S] = {
.from_float = (ggml_from_float_t) quantize_row_tq3_1s_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_tq3_1s_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_TQ4_1S] = {
.from_float = (ggml_from_float_t) quantize_row_tq4_1s_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_tq4_1s_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
};

const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
Expand Down Expand Up @@ -2047,6 +2095,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_gated_delta_net(params, tensor);
} break;
case GGML_OP_TURBO_WHT:
{
ggml_compute_forward_turbo_wht(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
Expand Down Expand Up @@ -2227,6 +2279,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
case GGML_OP_GATED_DELTA_NET:
case GGML_OP_TURBO_WHT:
{
n_tasks = n_threads;
} break;
Expand Down Expand Up @@ -2947,6 +3000,10 @@ struct ggml_cplan ggml_graph_plan(
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
cur = per_thread * sizeof(float) * n_tasks;
} break;
case GGML_OP_TURBO_WHT:
{
cur = 0; // no extra workspace needed
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -3385,6 +3442,112 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
return ggml_graph_compute(cgraph, &cplan);
}

// TurboQuant3 vec_dot: dequantize turbo3 block to f32, then dot with f32 operand.
// Used by CPU flash attention for models with D not supported by CUDA FA (e.g. D=192).
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

// Dequantize turbo3 to f32 stack scratch, then dot.
float * tmp = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp != NULL);
ggml_get_type_traits(GGML_TYPE_TURBO3_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

// TurboQuant2 vec_dot: dequantize turbo2 block to f32, then dot with f32 operand.
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float * tmp = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp != NULL);
ggml_get_type_traits(GGML_TYPE_TURBO2_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

// TurboQuant4 vec_dot: dequantize turbo4 block to f32, then dot with f32 operand.
static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float * tmp = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp != NULL);
ggml_get_type_traits(GGML_TYPE_TURBO4_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

// TQ3_1S vec_dot: dequantize tq3_1s block to f32, then dot with q8_0.
// TODO: optimize with SIMD intrinsics for ARM NEON / AVX2
static void ggml_vec_dot_tq3_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float * tmp = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp != NULL);
ggml_get_type_traits(GGML_TYPE_TQ3_1S)->to_float(vx, tmp, n);

// Dequantize q8_0 and dot
float * tmp2 = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp2 != NULL);
ggml_get_type_traits(GGML_TYPE_Q8_0)->to_float(vy, tmp2, n);

float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * tmp2[i];
}
*s = sum;
}

// TQ4_1S vec_dot: dequantize tq4_1s block to f32, then dot with q8_0.
// TODO: optimize with SIMD intrinsics
static void ggml_vec_dot_tq4_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float * tmp = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp != NULL);
ggml_get_type_traits(GGML_TYPE_TQ4_1S)->to_float(vx, tmp, n);

float * tmp2 = (float *)alloca((size_t)n * sizeof(float));
GGML_ASSERT(tmp2 != NULL);
ggml_get_type_traits(GGML_TYPE_Q8_0)->to_float(vy, tmp2, n);

float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * tmp2[i];
}
*s = sum;
}

void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
memcpy(y, x, n * sizeof(float));
}
Expand Down
Loading
Loading