Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ggml/src/ggml-cpu/arch/x86/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
__m256i i3 = _mm256_cvtps_epi32( v3 );

#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
// Store sum as bf16 in the fp16 slot to preserve full fp32 range
// (avoids overflow when used in Q4_1/Q5_1/Q4_K/Q5_K dot products with large activations)
y[i].s = GGML_FP32_TO_BF16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))).bits;

// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
Expand All @@ -476,10 +477,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
__m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1);

// Compute the sum of the quants and set y[i].s
// Store sum as bf16 in the fp16 slot to preserve full fp32 range
// (avoids overflow when used in Q4_1/Q5_1/Q4_K/Q5_K dot products with large activations)
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
y[i].s = GGML_FP32_TO_BF16(d * hsum_i32_4(_mm_add_epi32(s0, s1))).bits;

// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
Expand Down Expand Up @@ -883,7 +885,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);

summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
// y[ib].s holds bf16 bits (see quantize_row_q8_1); decode as bf16 to preserve full fp32 range.
summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_BF16_TO_FP32((ggml_bf16_t){ .bits = y[ib].s });

const __m256 d0v = _mm256_set1_ps( d0 );
const __m256 d1v = _mm256_set1_ps( d1 );
Expand Down Expand Up @@ -1108,7 +1111,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
for (; ib < nb; ++ib) {
const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));

summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
// y[ib].s holds bf16 bits (see quantize_row_q8_1); decode as bf16 to preserve full fp32 range.
summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_BF16_TO_FP32((ggml_bf16_t){ .bits = y[ib].s });

__m256i qx = bytes_from_nibbles_32(x[ib].qs);
__m256i bxhi = bytes_from_bits_32(x[ib].qh);
Expand All @@ -1135,7 +1139,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
for (; ib < nb; ++ib) {
const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));

summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
// y[ib].s holds bf16 bits (see quantize_row_q8_1); decode as bf16 to preserve full fp32 range.
summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_BF16_TO_FP32((ggml_bf16_t){ .bits = y[ib].s });

__m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-cpu/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
}

int sumi = sumi0 + sumi1;
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
// y[ib].s holds bf16 bits (see quantize_row_q8_1_ref); decode as bf16 to preserve full fp32 range.
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_BF16_TO_FP32((ggml_bf16_t){ .bits = y[ib].s });
}

*s = sumf;
Expand Down Expand Up @@ -391,7 +392,8 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
}

int sumi = sumi0 + sumi1;
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
// y[ib].s holds bf16 bits (see quantize_row_q8_1_ref); decode as bf16 to preserve full fp32 range.
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_BF16_TO_FP32((ggml_bf16_t){ .bits = y[ib].s });
}

*s = sumf;
Expand Down
15 changes: 15 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,21 @@ static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3
return make_uint2(div_val, mod_val);
}

// CUDA-only bf16 sibling of block_q8_1. Identical byte layout, but the (d, s) pair
// is stored as bf16 instead of fp16 so that s = d * sum(qs) cannot overflow the
// 16-bit exponent range when activations contain large outliers (Q4_1, Q5_1,
// Q4_K and Q5_K dot products multiply by s and would otherwise produce NaN).
// CUDA quantizes Q8_1 activations on-device, so this struct never crosses the
// CPU/GPU boundary; the host-side block_q8_1 in ggml-common.h is unaffected.
// (No d/s union view: nv_bfloat16 has a non-trivial constructor, which C++
// disallows in anonymous structs/unions. All call sites use ds directly.)
struct block_q8_1_bf16 {
nv_bfloat162 ds;
int8_t qs[QK8_1];
};

static_assert(sizeof(block_q8_1_bf16) == sizeof(block_q8_1), "block_q8_1_bf16 must match block_q8_1 byte layout");

typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);

static __device__ __forceinline__ float get_alibi_slope(
Expand Down
121 changes: 69 additions & 52 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,26 @@ struct block_fp4_mmq {
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
};

static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
// CUDA-only bf16 sibling of block_q8_1_mmq. Identical byte layout; the DS4
// layout (1 scale + 1 partial sum per 32 values) holds bf16 pairs instead of
// fp16 pairs to keep the partial sum within range for Q4_1/Q5_1/Q4_K/Q5_K
// dot products. The D4 (fp32) and D2S6 (fp16, Q2_K) layouts are unchanged.
// The union is named (.u) because nv_bfloat16 carries a non-trivial constructor
// in cuda_bf16.h, which C++ disallows in anonymous aggregates.
struct block_q8_1_mmq_bf16 {
union {
float d4[4];
nv_bfloat162 ds4[4];
half d2s6[8];
} u;

int8_t qs[4 * QK8_1];
};

static_assert(sizeof(block_q8_1_mmq) == 4 * QK8_1 + 4 * sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4 * sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq_bf16) == sizeof(block_q8_1_mmq), "Unexpected block_q8_1_mmq_bf16 size");
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");

static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
Expand Down Expand Up @@ -463,12 +480,12 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

// #pragma unroll
// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
const int k0 = k00 + k01;

Expand Down Expand Up @@ -574,12 +591,12 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

// #pragma unroll
// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
const int k0 = k00 + k01;

Expand Down Expand Up @@ -1170,11 +1187,11 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(

y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);

const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2 * MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

const int i0 = (threadIdx.y / ntx) * rows_per_warp;

Expand All @@ -1197,7 +1214,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
dB = ggml_cuda_cast<float2>(y_ds[j * MMQ_TILE_Y_K + k01 / QI8_1]).x;
}

#pragma unroll
Expand Down Expand Up @@ -1225,11 +1242,11 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(

y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);

const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2 * MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
Expand Down Expand Up @@ -1272,9 +1289,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int j = j0 + tile_C::get_j(l);

if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
dB[l] = ggml_cuda_cast<float2>(y_ds[j * MMQ_TILE_Y_K + k01 / QI8_1]).x;
}
}

Expand All @@ -1301,12 +1318,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

// #pragma unroll
// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
const int k0 = k00 + k01;

Expand Down Expand Up @@ -1341,10 +1358,10 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(

y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);

const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2 * MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_dm = (const nv_bfloat162 *) y;

const int i0 = (threadIdx.y / ntx) * rows_per_warp;

Expand All @@ -1363,7 +1380,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);

const int j = j0 + tile_C::get_j(0);
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
const float2 dsB = ggml_cuda_cast<float2>(y_dm[j * MMQ_TILE_Y_K + k01 / QI8_1]);

#pragma unroll
for (int n = 0; n < ntx; ++n) {
Expand Down Expand Up @@ -1391,10 +1408,10 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(

y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);

const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2 * MMQ_TILE_NE_K;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_dm = (const nv_bfloat162 *) y;

tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
Expand Down Expand Up @@ -1436,7 +1453,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);

dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
dsB[l] = ggml_cuda_cast<float2>(y_dm[j * MMQ_TILE_Y_K + k01 / QI8_1]);
}

#pragma unroll
Expand Down Expand Up @@ -2206,13 +2223,13 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

// #pragma unroll
// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
const int k0 = k00 + k01;

Expand Down Expand Up @@ -2363,11 +2380,11 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
const int * y_qs = (const int *) y + 4;
const nv_bfloat162 * y_ds = (const nv_bfloat162 *) y;

// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
Expand Down
11 changes: 8 additions & 3 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

#include <cstdint>

typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq,
const block_q8_1_bf16 * __restrict__ bq8_1,
const int & kbx,
const int & iqs);

static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
switch (type) {
Expand Down Expand Up @@ -480,7 +483,8 @@ static __global__ void mul_mat_vec_q(
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};

const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const block_q8_1_bf16 * y =
((const block_q8_1_bf16 *) vy) + sample_y * stride_sample_y + channel_y * stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;

for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
Expand Down Expand Up @@ -628,7 +632,8 @@ static __global__ void mul_mat_vec_q_moe(
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);

const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
const block_q8_1_bf16 * y =
((const block_q8_1_bf16 *) vy) + channel_y * stride_channel_y + token_idx * stride_col_y;
const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;

// partial sum for each thread
Expand Down
Loading