@@ -124,34 +124,29 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re
124124 const float d = xb->d ;
125125 const float neg_d = -d;
126126
127- // Process 16 bits starting at offset il*16
128- // Optimization: process 2 bytes (16 bits) at once for better memory access
129127 const int byte_offset = il * 2 ; // il*16 bits = il*2 bytes
130128 const uint8_t b0 = qs[byte_offset];
131129 const uint8_t b1 = qs[byte_offset + 1 ];
132130
133131 float4x4 reg_f;
134132
135- // Unroll completely for better ILP
136- // First byte (bits 0-7)
137- reg_f[0 ][0 ] = (b0 & 0x01 ) ? d : neg_d;
138- reg_f[0 ][1 ] = (b0 & 0x02 ) ? d : neg_d;
139- reg_f[0 ][2 ] = (b0 & 0x04 ) ? d : neg_d;
140- reg_f[0 ][3 ] = (b0 & 0x08 ) ? d : neg_d;
141- reg_f[1 ][0 ] = (b0 & 0x10 ) ? d : neg_d;
142- reg_f[1 ][1 ] = (b0 & 0x20 ) ? d : neg_d;
143- reg_f[1 ][2 ] = (b0 & 0x40 ) ? d : neg_d;
144- reg_f[1 ][3 ] = (b0 & 0x80 ) ? d : neg_d;
145-
146- // Second byte (bits 8-15)
147- reg_f[2 ][0 ] = (b1 & 0x01 ) ? d : neg_d;
148- reg_f[2 ][1 ] = (b1 & 0x02 ) ? d : neg_d;
149- reg_f[2 ][2 ] = (b1 & 0x04 ) ? d : neg_d;
150- reg_f[2 ][3 ] = (b1 & 0x08 ) ? d : neg_d;
151- reg_f[3 ][0 ] = (b1 & 0x10 ) ? d : neg_d;
152- reg_f[3 ][1 ] = (b1 & 0x20 ) ? d : neg_d;
153- reg_f[3 ][2 ] = (b1 & 0x40 ) ? d : neg_d;
154- reg_f[3 ][3 ] = (b1 & 0x80 ) ? d : neg_d;
133+ reg_f[0 ][0 ] = select (neg_d, d, bool (b0 & 0x01 ));
134+ reg_f[0 ][1 ] = select (neg_d, d, bool (b0 & 0x02 ));
135+ reg_f[0 ][2 ] = select (neg_d, d, bool (b0 & 0x04 ));
136+ reg_f[0 ][3 ] = select (neg_d, d, bool (b0 & 0x08 ));
137+ reg_f[1 ][0 ] = select (neg_d, d, bool (b0 & 0x10 ));
138+ reg_f[1 ][1 ] = select (neg_d, d, bool (b0 & 0x20 ));
139+ reg_f[1 ][2 ] = select (neg_d, d, bool (b0 & 0x40 ));
140+ reg_f[1 ][3 ] = select (neg_d, d, bool (b0 & 0x80 ));
141+
142+ reg_f[2 ][0 ] = select (neg_d, d, bool (b1 & 0x01 ));
143+ reg_f[2 ][1 ] = select (neg_d, d, bool (b1 & 0x02 ));
144+ reg_f[2 ][2 ] = select (neg_d, d, bool (b1 & 0x04 ));
145+ reg_f[2 ][3 ] = select (neg_d, d, bool (b1 & 0x08 ));
146+ reg_f[3 ][0 ] = select (neg_d, d, bool (b1 & 0x10 ));
147+ reg_f[3 ][1 ] = select (neg_d, d, bool (b1 & 0x20 ));
148+ reg_f[3 ][2 ] = select (neg_d, d, bool (b1 & 0x40 ));
149+ reg_f[3 ][3 ] = select (neg_d, d, bool (b1 & 0x80 ));
155150
156151 reg = (type4x4) reg_f;
157152}
@@ -160,20 +155,18 @@ template <typename type4>
160155void dequantize_q1_0_t4 (device const block_q1_0 * xb, short il, thread type4 & reg) {
161156 device const uint8_t * qs = xb->qs ;
162157 const float d = xb->d ;
158+ const float neg_d = -d;
163159
164- float4 reg_f;
165-
166- // Process 4 bits for each call
167- const int offset = il * 4 ;
168-
169- for (int i = 0 ; i < 4 ; i++) {
170- const int bit_idx = offset + i;
171- const int byte_idx = bit_idx / 8 ;
172- const int bit_offset = bit_idx % 8 ;
160+ // 4 consecutive bits starting at il*4
161+ const int base = il * 4 ;
162+ const uint8_t byte = qs[base / 8 ];
163+ const int bit_base = base % 8 ;
173164
174- const bool bit_val = (qs[byte_idx] >> bit_offset) & 1 ;
175- reg_f[i] = bit_val ? d : -d;
176- }
165+ float4 reg_f;
166+ reg_f[0 ] = select (neg_d, d, bool ((byte >> (bit_base )) & 1 ));
167+ reg_f[1 ] = select (neg_d, d, bool ((byte >> (bit_base + 1 )) & 1 ));
168+ reg_f[2 ] = select (neg_d, d, bool ((byte >> (bit_base + 2 )) & 1 ));
169+ reg_f[3 ] = select (neg_d, d, bool ((byte >> (bit_base + 3 )) & 1 ));
177170
178171 reg = (type4) reg_f;
179172}
@@ -3178,25 +3171,36 @@ kernel void kernel_group_norm_f32(
31783171
31793172// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
31803173// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181- // we assume that the yl's have been multiplied with the appropriate scale factor
3174+ // Q1_0 encodes weights as {+1, -1}: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
31823175inline float block_q_n_dot_y (device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183- float d = qb_curr->d ;
3176+ // 16 elements = 2 bytes starting at byte offset il/8
3177+ device const uint8_t * qs = qb_curr->qs + il / 8 ;
3178+ const uint8_t b0 = qs[0 ];
3179+ const uint8_t b1 = qs[1 ];
31843180
3181+ // Accumulate yl[i] only where bit is 1, using select (branchless)
31853182 float acc = 0 .0f ;
31863183
3187- // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188- // 16 weights = 16 bits = 2 bytes
3189- const int byte_offset = il / 8 ;
3190- device const uint8_t * qs = qb_curr->qs + byte_offset;
3184+ acc += select (0 .0f , yl[ 0 ], bool (b0 & 0x01 ));
3185+ acc += select (0 .0f , yl[ 1 ], bool (b0 & 0x02 ));
3186+ acc += select (0 .0f , yl[ 2 ], bool (b0 & 0x04 ));
3187+ acc += select (0 .0f , yl[ 3 ], bool (b0 & 0x08 ));
3188+ acc += select (0 .0f , yl[ 4 ], bool (b0 & 0x10 ));
3189+ acc += select (0 .0f , yl[ 5 ], bool (b0 & 0x20 ));
3190+ acc += select (0 .0f , yl[ 6 ], bool (b0 & 0x40 ));
3191+ acc += select (0 .0f , yl[ 7 ], bool (b0 & 0x80 ));
31913192
3192- for (int i = 0 ; i < 16 ; i++) {
3193- const uint8_t byte_idx = i / 8 ;
3194- const uint8_t bit_idx = i % 8 ;
3195- const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1 ) ? 1 : -1 ;
3196- acc += yl[i] * qval;
3197- }
3193+ acc += select (0 .0f , yl[ 8 ], bool (b1 & 0x01 ));
3194+ acc += select (0 .0f , yl[ 9 ], bool (b1 & 0x02 ));
3195+ acc += select (0 .0f , yl[10 ], bool (b1 & 0x04 ));
3196+ acc += select (0 .0f , yl[11 ], bool (b1 & 0x08 ));
3197+ acc += select (0 .0f , yl[12 ], bool (b1 & 0x10 ));
3198+ acc += select (0 .0f , yl[13 ], bool (b1 & 0x20 ));
3199+ acc += select (0 .0f , yl[14 ], bool (b1 & 0x40 ));
3200+ acc += select (0 .0f , yl[15 ], bool (b1 & 0x80 ));
31983201
3199- return d * acc;
3202+ // dot = d * (2 * sum_where_bit1 - sumy)
3203+ return qb_curr->d * (2 .0f * acc - sumy);
32003204}
32013205
32023206// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -10022,6 +10026,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
1002210026
1002310027template [[host_name(" kernel_mul_mm_f32_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float4x4, half, half2x4>;
1002410028template [[host_name(" kernel_mul_mm_f16_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half4x4, half, half2x4>;
10029+ template [[host_name(" kernel_mul_mm_q1_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, half, half2x4>;
1002510030template [[host_name(" kernel_mul_mm_q4_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, half, half2x4>;
1002610031template [[host_name(" kernel_mul_mm_q4_1_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, half, half2x4>;
1002710032template [[host_name(" kernel_mul_mm_q5_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, half, half2x4>;
0 commit comments