@@ -124,56 +124,46 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re
124124 const float d = xb->d ;
125125 const float neg_d = -d;
126126
127- // Process 16 bits starting at offset il*16
128- // Optimization: process 2 bytes (16 bits) at once for better memory access
129127 const int byte_offset = il * 2 ; // il*16 bits = il*2 bytes
130128 const uint8_t b0 = qs[byte_offset];
131129 const uint8_t b1 = qs[byte_offset + 1 ];
132130
133131 float4x4 reg_f;
134132
135- // Unroll completely for better ILP
136- // First byte (bits 0-7)
137- reg_f[0 ][0 ] = (b0 & 0x01 ) ? d : neg_d;
138- reg_f[0 ][1 ] = (b0 & 0x02 ) ? d : neg_d;
139- reg_f[0 ][2 ] = (b0 & 0x04 ) ? d : neg_d;
140- reg_f[0 ][3 ] = (b0 & 0x08 ) ? d : neg_d;
141- reg_f[1 ][0 ] = (b0 & 0x10 ) ? d : neg_d;
142- reg_f[1 ][1 ] = (b0 & 0x20 ) ? d : neg_d;
143- reg_f[1 ][2 ] = (b0 & 0x40 ) ? d : neg_d;
144- reg_f[1 ][3 ] = (b0 & 0x80 ) ? d : neg_d;
145-
146- // Second byte (bits 8-15)
147- reg_f[2 ][0 ] = (b1 & 0x01 ) ? d : neg_d;
148- reg_f[2 ][1 ] = (b1 & 0x02 ) ? d : neg_d;
149- reg_f[2 ][2 ] = (b1 & 0x04 ) ? d : neg_d;
150- reg_f[2 ][3 ] = (b1 & 0x08 ) ? d : neg_d;
151- reg_f[3 ][0 ] = (b1 & 0x10 ) ? d : neg_d;
152- reg_f[3 ][1 ] = (b1 & 0x20 ) ? d : neg_d;
153- reg_f[3 ][2 ] = (b1 & 0x40 ) ? d : neg_d;
154- reg_f[3 ][3 ] = (b1 & 0x80 ) ? d : neg_d;
133+ reg_f[0 ][0 ] = select (neg_d, d, bool (b0 & 0x01 ));
134+ reg_f[0 ][1 ] = select (neg_d, d, bool (b0 & 0x02 ));
135+ reg_f[0 ][2 ] = select (neg_d, d, bool (b0 & 0x04 ));
136+ reg_f[0 ][3 ] = select (neg_d, d, bool (b0 & 0x08 ));
137+ reg_f[1 ][0 ] = select (neg_d, d, bool (b0 & 0x10 ));
138+ reg_f[1 ][1 ] = select (neg_d, d, bool (b0 & 0x20 ));
139+ reg_f[1 ][2 ] = select (neg_d, d, bool (b0 & 0x40 ));
140+ reg_f[1 ][3 ] = select (neg_d, d, bool (b0 & 0x80 ));
141+
142+ reg_f[2 ][0 ] = select (neg_d, d, bool (b1 & 0x01 ));
143+ reg_f[2 ][1 ] = select (neg_d, d, bool (b1 & 0x02 ));
144+ reg_f[2 ][2 ] = select (neg_d, d, bool (b1 & 0x04 ));
145+ reg_f[2 ][3 ] = select (neg_d, d, bool (b1 & 0x08 ));
146+ reg_f[3 ][0 ] = select (neg_d, d, bool (b1 & 0x10 ));
147+ reg_f[3 ][1 ] = select (neg_d, d, bool (b1 & 0x20 ));
148+ reg_f[3 ][2 ] = select (neg_d, d, bool (b1 & 0x40 ));
149+ reg_f[3 ][3 ] = select (neg_d, d, bool (b1 & 0x80 ));
155150
156151 reg = (type4x4) reg_f;
157152}
158153
159154template <typename type4>
160155void dequantize_q1_0_t4 (device const block_q1_0 * xb, short il, thread type4 & reg) {
161- device const uint8_t * qs = xb->qs ;
162156 const float d = xb->d ;
157+ const float neg_d = -d;
158+ const int base = il * 4 ;
159+ const uint8_t byte = xb->qs [base / 8 ];
160+ const int s = base % 8 ;
163161
164162 float4 reg_f;
165-
166- // Process 4 bits for each call
167- const int offset = il * 4 ;
168-
169- for (int i = 0 ; i < 4 ; i++) {
170- const int bit_idx = offset + i;
171- const int byte_idx = bit_idx / 8 ;
172- const int bit_offset = bit_idx % 8 ;
173-
174- const bool bit_val = (qs[byte_idx] >> bit_offset) & 1 ;
175- reg_f[i] = bit_val ? d : -d;
176- }
163+ reg_f[0 ] = select (neg_d, d, bool ((byte >> (s )) & 1 ));
164+ reg_f[1 ] = select (neg_d, d, bool ((byte >> (s + 1 )) & 1 ));
165+ reg_f[2 ] = select (neg_d, d, bool ((byte >> (s + 2 )) & 1 ));
166+ reg_f[3 ] = select (neg_d, d, bool ((byte >> (s + 3 )) & 1 ));
177167
178168 reg = (type4) reg_f;
179169}
@@ -3176,27 +3166,33 @@ kernel void kernel_group_norm_f32(
31763166 }
31773167}
31783168
3179- // function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
3180- // il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181- // we assume that the yl's have been multiplied with the appropriate scale factor
3169+ // Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
31823170inline float block_q_n_dot_y (device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183- float d = qb_curr->d ;
3171+ device const uint8_t * qs = qb_curr->qs + il / 8 ;
3172+ const uint8_t b0 = qs[0 ];
3173+ const uint8_t b1 = qs[1 ];
31843174
31853175 float acc = 0 .0f ;
31863176
3187- // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188- // 16 weights = 16 bits = 2 bytes
3189- const int byte_offset = il / 8 ;
3190- device const uint8_t * qs = qb_curr->qs + byte_offset;
3177+ acc += select (0 .0f , yl[ 0 ], bool (b0 & 0x01 ));
3178+ acc += select (0 .0f , yl[ 1 ], bool (b0 & 0x02 ));
3179+ acc += select (0 .0f , yl[ 2 ], bool (b0 & 0x04 ));
3180+ acc += select (0 .0f , yl[ 3 ], bool (b0 & 0x08 ));
3181+ acc += select (0 .0f , yl[ 4 ], bool (b0 & 0x10 ));
3182+ acc += select (0 .0f , yl[ 5 ], bool (b0 & 0x20 ));
3183+ acc += select (0 .0f , yl[ 6 ], bool (b0 & 0x40 ));
3184+ acc += select (0 .0f , yl[ 7 ], bool (b0 & 0x80 ));
31913185
3192- for (int i = 0 ; i < 16 ; i++) {
3193- const uint8_t byte_idx = i / 8 ;
3194- const uint8_t bit_idx = i % 8 ;
3195- const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1 ) ? 1 : -1 ;
3196- acc += yl[i] * qval;
3197- }
3186+ acc += select (0 .0f , yl[ 8 ], bool (b1 & 0x01 ));
3187+ acc += select (0 .0f , yl[ 9 ], bool (b1 & 0x02 ));
3188+ acc += select (0 .0f , yl[10 ], bool (b1 & 0x04 ));
3189+ acc += select (0 .0f , yl[11 ], bool (b1 & 0x08 ));
3190+ acc += select (0 .0f , yl[12 ], bool (b1 & 0x10 ));
3191+ acc += select (0 .0f , yl[13 ], bool (b1 & 0x20 ));
3192+ acc += select (0 .0f , yl[14 ], bool (b1 & 0x40 ));
3193+ acc += select (0 .0f , yl[15 ], bool (b1 & 0x80 ));
31983194
3199- return d * acc;
3195+ return qb_curr-> d * ( 2 . 0f * acc - sumy) ;
32003196}
32013197
32023198// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -3420,22 +3416,25 @@ void mul_vec_q_n_f32_impl(
34203416 }
34213417}
34223418
3423- kernel void kernel_mul_mv_q1_0_f32 (
3424- constant ggml_metal_kargs_mul_mv & args,
3419+ template <int nr0, typename args_t >
3420+ void kernel_mul_mv_q1_0_f32_impl (
3421+ args_t args,
34253422 device const char * src0,
34263423 device const char * src1,
34273424 device char * dst,
3428- uint3 tgpig[[threadgroup_position_in_grid]],
3429- ushort tiisg[[thread_index_in_simdgroup]],
3430- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3431- // Q1_0-specific implementation with 128-element blocks
3425+ threadgroup char * shmem,
3426+ uint3 tgpig,
3427+ ushort tiisg,
3428+ ushort sgitg) {
3429+ const short NSG = FC_mul_mv_nsg;
3430+
34323431 const int nb = args.ne00 /QK1_0;
34333432
34343433 const int r0 = tgpig.x ;
34353434 const int r1 = tgpig.y ;
34363435 const int im = tgpig.z ;
34373436
3438- const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0 ;
3437+ const int first_row = (r0 * NSG + sgitg) * nr0 ;
34393438
34403439 const uint i12 = im%args.ne12 ;
34413440 const uint i13 = im/args.ne12 ;
@@ -3444,37 +3443,31 @@ kernel void kernel_mul_mv_q1_0_f32(
34443443
34453444 device const float * y = (device const float *) (src1 + offset1);
34463445
3447- // pointers to src0 rows
3448- device const block_q1_0 * ax[N_R0_Q1_0];
3449- for (int row = 0 ; row < N_R0_Q1_0; ++row) {
3446+ device const block_q1_0 * ax[nr0];
3447+ for (int row = 0 ; row < nr0; ++row) {
34503448 const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
3451-
34523449 ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
34533450 }
34543451
3455- float yl[16 ]; // src1 vector cache
3456- float sumf[N_R0_Q1_0 ] = {0 .f };
3452+ float yl[16 ];
3453+ float sumf[nr0 ] = {0 .f };
34573454
3458- // For 128-element blocks, we need 8 passes of 16 elements each
3459- // Each thread processes a different 16-element chunk
3460- const short ix = (tiisg/8 ); // which block (0 to 3 for 32 threads / 8)
3461- const short il = (tiisg%8 )*16 ; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112)
3455+ const short ix = (tiisg/8 );
3456+ const short il = (tiisg%8 )*16 ;
34623457
34633458 device const float * yb = y + ix*QK1_0 + il;
34643459
3465- // each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128)
34663460 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8 ) {
34673461 float sumy = 0 .f ;
34683462
3469- // Q1_0: simple copy
34703463#pragma unroll
34713464 for (short i = 0 ; i < 16 ; i++) {
34723465 yl[i] = yb[i];
34733466 sumy += yb[i];
34743467 }
34753468
34763469#pragma unroll
3477- for (short row = 0 ; row < N_R0_Q1_0 ; row++) {
3470+ for (short row = 0 ; row < nr0 ; row++) {
34783471 sumf[row] += block_q_n_dot_y (ax[row] + ib, sumy, yl, il);
34793472 }
34803473
@@ -3483,7 +3476,7 @@ kernel void kernel_mul_mv_q1_0_f32(
34833476
34843477 device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
34853478
3486- for (int row = 0 ; row < N_R0_Q1_0 ; ++row) {
3479+ for (int row = 0 ; row < nr0 ; ++row) {
34873480 const float tot = simd_sum (sumf[row]);
34883481
34893482 if (tiisg == 0 && first_row + row < args.ne01 ) {
@@ -3492,6 +3485,18 @@ kernel void kernel_mul_mv_q1_0_f32(
34923485 }
34933486}
34943487
3488+ [[host_name(" kernel_mul_mv_q1_0_f32" )]]
3489+ kernel void kernel_mul_mv_q1_0_f32 (
3490+ constant ggml_metal_kargs_mul_mv & args,
3491+ device const char * src0,
3492+ device const char * src1,
3493+ device char * dst,
3494+ uint3 tgpig[[threadgroup_position_in_grid]],
3495+ ushort tiisg[[thread_index_in_simdgroup]],
3496+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3497+ kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr , tgpig, tiisg, sgitg);
3498+ }
3499+
34953500kernel void kernel_mul_mv_q4_0_f32 (
34963501 constant ggml_metal_kargs_mul_mv & args,
34973502 device const char * src0,
@@ -10022,6 +10027,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
1002210027
1002310028template [[host_name(" kernel_mul_mm_f32_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float4x4, half, half2x4>;
1002410029template [[host_name(" kernel_mul_mm_f16_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half4x4, half, half2x4>;
10030+ template [[host_name(" kernel_mul_mm_q1_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, half, half2x4>;
1002510031template [[host_name(" kernel_mul_mm_q4_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, half, half2x4>;
1002610032template [[host_name(" kernel_mul_mm_q4_1_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, half, half2x4>;
1002710033template [[host_name(" kernel_mul_mm_q5_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, half, half2x4>;
@@ -10231,6 +10237,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
1023110237
1023210238template [[host_name(" kernel_mul_mv_id_q8_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
1023310239
10240+ template [[host_name(" kernel_mul_mv_id_q1_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
1023410241template [[host_name(" kernel_mul_mv_id_q4_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
1023510242template [[host_name(" kernel_mul_mv_id_q4_1_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
1023610243template [[host_name(" kernel_mul_mv_id_q5_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
0 commit comments