Skip to content

Commit b9b3f68

Browse files
committed
tuning q1_0 metal kernels
1 parent f2b50f9 commit b9b3f68

2 files changed

Lines changed: 53 additions & 48 deletions

File tree

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11-
#define N_R0_Q1_0 4
11+
#define N_R0_Q1_0 8
1212
#define N_SG_Q1_0 2
1313

1414
#define N_R0_Q4_0 4

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -124,34 +124,29 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re
124124
const float d = xb->d;
125125
const float neg_d = -d;
126126

127-
// Process 16 bits starting at offset il*16
128-
// Optimization: process 2 bytes (16 bits) at once for better memory access
129127
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
130128
const uint8_t b0 = qs[byte_offset];
131129
const uint8_t b1 = qs[byte_offset + 1];
132130

133131
float4x4 reg_f;
134132

135-
// Unroll completely for better ILP
136-
// First byte (bits 0-7)
137-
reg_f[0][0] = (b0 & 0x01) ? d : neg_d;
138-
reg_f[0][1] = (b0 & 0x02) ? d : neg_d;
139-
reg_f[0][2] = (b0 & 0x04) ? d : neg_d;
140-
reg_f[0][3] = (b0 & 0x08) ? d : neg_d;
141-
reg_f[1][0] = (b0 & 0x10) ? d : neg_d;
142-
reg_f[1][1] = (b0 & 0x20) ? d : neg_d;
143-
reg_f[1][2] = (b0 & 0x40) ? d : neg_d;
144-
reg_f[1][3] = (b0 & 0x80) ? d : neg_d;
145-
146-
// Second byte (bits 8-15)
147-
reg_f[2][0] = (b1 & 0x01) ? d : neg_d;
148-
reg_f[2][1] = (b1 & 0x02) ? d : neg_d;
149-
reg_f[2][2] = (b1 & 0x04) ? d : neg_d;
150-
reg_f[2][3] = (b1 & 0x08) ? d : neg_d;
151-
reg_f[3][0] = (b1 & 0x10) ? d : neg_d;
152-
reg_f[3][1] = (b1 & 0x20) ? d : neg_d;
153-
reg_f[3][2] = (b1 & 0x40) ? d : neg_d;
154-
reg_f[3][3] = (b1 & 0x80) ? d : neg_d;
133+
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
134+
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
135+
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
136+
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
137+
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
138+
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
139+
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
140+
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
141+
142+
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
143+
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
144+
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
145+
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
146+
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
147+
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
148+
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
149+
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
155150

156151
reg = (type4x4) reg_f;
157152
}
@@ -160,20 +155,18 @@ template <typename type4>
160155
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
161156
device const uint8_t * qs = xb->qs;
162157
const float d = xb->d;
158+
const float neg_d = -d;
163159

164-
float4 reg_f;
165-
166-
// Process 4 bits for each call
167-
const int offset = il * 4;
168-
169-
for (int i = 0; i < 4; i++) {
170-
const int bit_idx = offset + i;
171-
const int byte_idx = bit_idx / 8;
172-
const int bit_offset = bit_idx % 8;
160+
// 4 consecutive bits starting at il*4
161+
const int base = il * 4;
162+
const uint8_t byte = qs[base / 8];
163+
const int bit_base = base % 8;
173164

174-
const bool bit_val = (qs[byte_idx] >> bit_offset) & 1;
175-
reg_f[i] = bit_val ? d : -d;
176-
}
165+
float4 reg_f;
166+
reg_f[0] = select(neg_d, d, bool((byte >> (bit_base )) & 1));
167+
reg_f[1] = select(neg_d, d, bool((byte >> (bit_base + 1)) & 1));
168+
reg_f[2] = select(neg_d, d, bool((byte >> (bit_base + 2)) & 1));
169+
reg_f[3] = select(neg_d, d, bool((byte >> (bit_base + 3)) & 1));
177170

178171
reg = (type4) reg_f;
179172
}
@@ -3178,25 +3171,36 @@ kernel void kernel_group_norm_f32(
31783171

31793172
// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
31803173
// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181-
// we assume that the yl's have been multiplied with the appropriate scale factor
3174+
// Q1_0 encodes weights as {+1, -1}: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
31823175
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183-
float d = qb_curr->d;
3176+
// 16 elements = 2 bytes starting at byte offset il/8
3177+
device const uint8_t * qs = qb_curr->qs + il / 8;
3178+
const uint8_t b0 = qs[0];
3179+
const uint8_t b1 = qs[1];
31843180

3181+
// Accumulate yl[i] only where bit is 1, using select (branchless)
31853182
float acc = 0.0f;
31863183

3187-
// il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188-
// 16 weights = 16 bits = 2 bytes
3189-
const int byte_offset = il / 8;
3190-
device const uint8_t * qs = qb_curr->qs + byte_offset;
3184+
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
3185+
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
3186+
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
3187+
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
3188+
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
3189+
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
3190+
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
3191+
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
31913192

3192-
for (int i = 0; i < 16; i++) {
3193-
const uint8_t byte_idx = i / 8;
3194-
const uint8_t bit_idx = i % 8;
3195-
const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1) ? 1 : -1;
3196-
acc += yl[i] * qval;
3197-
}
3193+
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
3194+
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
3195+
acc += select(0.0f, yl[10], bool(b1 & 0x04));
3196+
acc += select(0.0f, yl[11], bool(b1 & 0x08));
3197+
acc += select(0.0f, yl[12], bool(b1 & 0x10));
3198+
acc += select(0.0f, yl[13], bool(b1 & 0x20));
3199+
acc += select(0.0f, yl[14], bool(b1 & 0x40));
3200+
acc += select(0.0f, yl[15], bool(b1 & 0x80));
31983201

3199-
return d * acc;
3202+
// dot = d * (2 * sum_where_bit1 - sumy)
3203+
return qb_curr->d * (2.0f * acc - sumy);
32003204
}
32013205

32023206
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -10022,6 +10026,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
1002210026

1002310027
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
1002410028
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10029+
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
1002510030
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
1002610031
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
1002710032
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;

0 commit comments

Comments
 (0)