From b0c8c8975350980c82d18701c273fb916c8e4e19 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 07:35:26 -0700 Subject: [PATCH 1/5] initial Q1_0 Metal backend --- ggml/src/ggml-metal/ggml-metal-device.cpp | 10 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 161 ++++++++++++++++++++++ 4 files changed, 175 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 89539bd7615..e8548b053e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index eb2253e029a..3f5c58dad4e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,6 +8,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 4 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3cda21be43e..846225d9077 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2074211594c..e92857ce2ba 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -118,6 +118,66 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + // Process 16 bits starting at offset il*16 + // Optimization: process 2 bytes (16 bits) at once for better memory access + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + // Unroll completely for better ILP + // First byte (bits 0-7) + reg_f[0][0] = (b0 & 0x01) ? d : neg_d; + reg_f[0][1] = (b0 & 0x02) ? d : neg_d; + reg_f[0][2] = (b0 & 0x04) ? d : neg_d; + reg_f[0][3] = (b0 & 0x08) ? d : neg_d; + reg_f[1][0] = (b0 & 0x10) ? d : neg_d; + reg_f[1][1] = (b0 & 0x20) ? d : neg_d; + reg_f[1][2] = (b0 & 0x40) ? d : neg_d; + reg_f[1][3] = (b0 & 0x80) ? d : neg_d; + + // Second byte (bits 8-15) + reg_f[2][0] = (b1 & 0x01) ? d : neg_d; + reg_f[2][1] = (b1 & 0x02) ? d : neg_d; + reg_f[2][2] = (b1 & 0x04) ? d : neg_d; + reg_f[2][3] = (b1 & 0x08) ? d : neg_d; + reg_f[3][0] = (b1 & 0x10) ? d : neg_d; + reg_f[3][1] = (b1 & 0x20) ? d : neg_d; + reg_f[3][2] = (b1 & 0x40) ? d : neg_d; + reg_f[3][3] = (b1 & 0x80) ? d : neg_d; + + reg = (type4x4) reg_f; +} + +template +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + + float4 reg_f; + + // Process 4 bits for each call + const int offset = il * 4; + + for (int i = 0; i < 4; i++) { + const int bit_idx = offset + i; + const int byte_idx = bit_idx / 8; + const int bit_offset = bit_idx % 8; + + const bool bit_val = (qs[byte_idx] >> bit_offset) & 1; + reg_f[i] = bit_val ? d : -d; + } + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -3116,6 +3176,29 @@ kernel void kernel_group_norm_f32( } } +// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block) +// we assume that the yl's have been multiplied with the appropriate scale factor +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc = 0.0f; + + // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112) + // 16 weights = 16 bits = 2 bytes + const int byte_offset = il / 8; + device const uint8_t * qs = qb_curr->qs + byte_offset; + + for (int i = 0; i < 16; i++) { + const uint8_t byte_idx = i / 8; + const uint8_t bit_idx = i % 8; + const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1) ? 1 : -1; + acc += yl[i] * qval; + } + + return d * acc; +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3337,6 +3420,78 @@ void mul_vec_q_n_f32_impl( } } +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + // Q1_0-specific implementation with 128-element blocks + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q1_0 * ax[N_R0_Q1_0]; + for (int row = 0; row < N_R0_Q1_0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[N_R0_Q1_0] = {0.f}; + + // For 128-element blocks, we need 8 passes of 16 elements each + // Each thread processes a different 16-element chunk + const short ix = (tiisg/8); // which block (0 to 3 for 32 threads / 8) + const short il = (tiisg%8)*16; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112) + + device const float * yb = y + ix*QK1_0 + il; + + // each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128) + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; + + // Q1_0: simple copy +#pragma unroll + for (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + +#pragma unroll + for (short row = 0; row < N_R0_Q1_0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK1_0 * (N_SIMDWIDTH/8); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_R0_Q1_0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3729,6 +3884,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; #endif +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -9838,6 +9998,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; From 0776cd206a84b76d860ad38aa748b71fdb92bfe5 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 08:13:59 -0700 Subject: [PATCH 2/5] tuning q1_0 metal kernels --- ggml/src/ggml-metal/ggml-metal-impl.h | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 149 ++++++++++++++------------ 2 files changed, 79 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 3f5c58dad4e..62b028f4a4a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,7 +8,7 @@ // // TODO: for optimal performance, become function of the device and work size -#define N_R0_Q1_0 4 +#define N_R0_Q1_0 8 #define N_SG_Q1_0 2 #define N_R0_Q4_0 4 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e92857ce2ba..437b9b0d63d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -124,56 +124,46 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re const float d = xb->d; const float neg_d = -d; - // Process 16 bits starting at offset il*16 - // Optimization: process 2 bytes (16 bits) at once for better memory access const int byte_offset = il * 2; // il*16 bits = il*2 bytes const uint8_t b0 = qs[byte_offset]; const uint8_t b1 = qs[byte_offset + 1]; float4x4 reg_f; - // Unroll completely for better ILP - // First byte (bits 0-7) - reg_f[0][0] = (b0 & 0x01) ? d : neg_d; - reg_f[0][1] = (b0 & 0x02) ? d : neg_d; - reg_f[0][2] = (b0 & 0x04) ? d : neg_d; - reg_f[0][3] = (b0 & 0x08) ? d : neg_d; - reg_f[1][0] = (b0 & 0x10) ? d : neg_d; - reg_f[1][1] = (b0 & 0x20) ? d : neg_d; - reg_f[1][2] = (b0 & 0x40) ? d : neg_d; - reg_f[1][3] = (b0 & 0x80) ? d : neg_d; - - // Second byte (bits 8-15) - reg_f[2][0] = (b1 & 0x01) ? d : neg_d; - reg_f[2][1] = (b1 & 0x02) ? d : neg_d; - reg_f[2][2] = (b1 & 0x04) ? d : neg_d; - reg_f[2][3] = (b1 & 0x08) ? d : neg_d; - reg_f[3][0] = (b1 & 0x10) ? d : neg_d; - reg_f[3][1] = (b1 & 0x20) ? d : neg_d; - reg_f[3][2] = (b1 & 0x40) ? d : neg_d; - reg_f[3][3] = (b1 & 0x80) ? d : neg_d; + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); reg = (type4x4) reg_f; } template void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { - device const uint8_t * qs = xb->qs; const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; float4 reg_f; - - // Process 4 bits for each call - const int offset = il * 4; - - for (int i = 0; i < 4; i++) { - const int bit_idx = offset + i; - const int byte_idx = bit_idx / 8; - const int bit_offset = bit_idx % 8; - - const bool bit_val = (qs[byte_idx] >> bit_offset) & 1; - reg_f[i] = bit_val ? d : -d; - } + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); reg = (type4) reg_f; } @@ -3176,27 +3166,33 @@ kernel void kernel_group_norm_f32( } } -// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block) -// we assume that the yl's have been multiplied with the appropriate scale factor +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; float acc = 0.0f; - // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112) - // 16 weights = 16 bits = 2 bytes - const int byte_offset = il / 8; - device const uint8_t * qs = qb_curr->qs + byte_offset; + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); - for (int i = 0; i < 16; i++) { - const uint8_t byte_idx = i / 8; - const uint8_t bit_idx = i % 8; - const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1) ? 1 : -1; - acc += yl[i] * qval; - } + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); - return d * acc; + return qb_curr->d * (2.0f * acc - sumy); } // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -3420,22 +3416,25 @@ void mul_vec_q_n_f32_impl( } } -kernel void kernel_mul_mv_q1_0_f32( - constant ggml_metal_kargs_mul_mv & args, +template +void kernel_mul_mv_q1_0_f32_impl( + args_t args, device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - // Q1_0-specific implementation with 128-element blocks + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + const int nb = args.ne00/QK1_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -3444,29 +3443,23 @@ kernel void kernel_mul_mv_q1_0_f32( device const float * y = (device const float *) (src1 + offset1); - // pointers to src0 rows - device const block_q1_0 * ax[N_R0_Q1_0]; - for (int row = 0; row < N_R0_Q1_0; ++row) { + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } - float yl[16]; // src1 vector cache - float sumf[N_R0_Q1_0] = {0.f}; + float yl[16]; + float sumf[nr0] = {0.f}; - // For 128-element blocks, we need 8 passes of 16 elements each - // Each thread processes a different 16-element chunk - const short ix = (tiisg/8); // which block (0 to 3 for 32 threads / 8) - const short il = (tiisg%8)*16; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112) + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; device const float * yb = y + ix*QK1_0 + il; - // each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128) for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { float sumy = 0.f; - // Q1_0: simple copy #pragma unroll for (short i = 0; i < 16; i++) { yl[i] = yb[i]; @@ -3474,7 +3467,7 @@ kernel void kernel_mul_mv_q1_0_f32( } #pragma unroll - for (short row = 0; row < N_R0_Q1_0; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); } @@ -3483,7 +3476,7 @@ kernel void kernel_mul_mv_q1_0_f32( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_R0_Q1_0; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -3492,6 +3485,18 @@ kernel void kernel_mul_mv_q1_0_f32( } } +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -10022,6 +10027,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10231,6 +10237,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; From 6a55d70311ee7839442b11d464c845a799e5c661 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 14:11:24 -0700 Subject: [PATCH 3/5] add Q1_0 to test-backend-ops --- ggml/src/ggml-metal/ggml-metal.metal | 3 +++ tests/test-backend-ops.cpp | 2 ++ 2 files changed, 5 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 437b9b0d63d..2687a50f59f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7338,12 +7338,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -9941,6 +9943,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 781c621d930..c813d57cc0b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7284,6 +7284,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, + GGML_TYPE_Q1_0, GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, @@ -7308,6 +7309,7 @@ static const ggml_type other_types[] = { GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, + GGML_TYPE_Q1_0, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, From 74c9bdd078ef475db8c6e46a91ff6f5e93511741 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 15:10:55 -0700 Subject: [PATCH 4/5] add Q1_0<->F32 copy test --- ggml/src/ggml-metal/ggml-metal-device.m | 2 ++ ggml/src/ggml-metal/ggml-metal.metal | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 17d51b11b6e..40cacb46520 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2687a50f59f..6da79a59a72 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -202,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -7298,6 +7315,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; From a1517c29f03a77b4e628da48ee51270b97da6bdb Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Tue, 7 Apr 2026 10:27:08 -0700 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal.metal | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6da79a59a72..f28bfa0b95b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3477,14 +3477,12 @@ void kernel_mul_mv_q1_0_f32_impl( for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { float sumy = 0.f; -#pragma unroll - for (short i = 0; i < 16; i++) { + FOR_UNROLL (short i = 0; i < 16; i++) { yl[i] = yb[i]; sumy += yb[i]; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); }