Skip to content

Commit 52fcb93

Browse files
committed
tuning q1_0 metal kernels
1 parent f2b50f9 commit 52fcb93

2 files changed

Lines changed: 79 additions & 72 deletions

File tree

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11-
#define N_R0_Q1_0 4
11+
#define N_R0_Q1_0 8
1212
#define N_SG_Q1_0 2
1313

1414
#define N_R0_Q4_0 4

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -124,56 +124,46 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re
124124
const float d = xb->d;
125125
const float neg_d = -d;
126126

127-
// Process 16 bits starting at offset il*16
128-
// Optimization: process 2 bytes (16 bits) at once for better memory access
129127
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
130128
const uint8_t b0 = qs[byte_offset];
131129
const uint8_t b1 = qs[byte_offset + 1];
132130

133131
float4x4 reg_f;
134132

135-
// Unroll completely for better ILP
136-
// First byte (bits 0-7)
137-
reg_f[0][0] = (b0 & 0x01) ? d : neg_d;
138-
reg_f[0][1] = (b0 & 0x02) ? d : neg_d;
139-
reg_f[0][2] = (b0 & 0x04) ? d : neg_d;
140-
reg_f[0][3] = (b0 & 0x08) ? d : neg_d;
141-
reg_f[1][0] = (b0 & 0x10) ? d : neg_d;
142-
reg_f[1][1] = (b0 & 0x20) ? d : neg_d;
143-
reg_f[1][2] = (b0 & 0x40) ? d : neg_d;
144-
reg_f[1][3] = (b0 & 0x80) ? d : neg_d;
145-
146-
// Second byte (bits 8-15)
147-
reg_f[2][0] = (b1 & 0x01) ? d : neg_d;
148-
reg_f[2][1] = (b1 & 0x02) ? d : neg_d;
149-
reg_f[2][2] = (b1 & 0x04) ? d : neg_d;
150-
reg_f[2][3] = (b1 & 0x08) ? d : neg_d;
151-
reg_f[3][0] = (b1 & 0x10) ? d : neg_d;
152-
reg_f[3][1] = (b1 & 0x20) ? d : neg_d;
153-
reg_f[3][2] = (b1 & 0x40) ? d : neg_d;
154-
reg_f[3][3] = (b1 & 0x80) ? d : neg_d;
133+
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
134+
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
135+
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
136+
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
137+
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
138+
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
139+
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
140+
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
141+
142+
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
143+
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
144+
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
145+
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
146+
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
147+
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
148+
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
149+
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
155150

156151
reg = (type4x4) reg_f;
157152
}
158153

159154
template <typename type4>
160155
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
161-
device const uint8_t * qs = xb->qs;
162156
const float d = xb->d;
157+
const float neg_d = -d;
158+
const int base = il * 4;
159+
const uint8_t byte = xb->qs[base / 8];
160+
const int s = base % 8;
163161

164162
float4 reg_f;
165-
166-
// Process 4 bits for each call
167-
const int offset = il * 4;
168-
169-
for (int i = 0; i < 4; i++) {
170-
const int bit_idx = offset + i;
171-
const int byte_idx = bit_idx / 8;
172-
const int bit_offset = bit_idx % 8;
173-
174-
const bool bit_val = (qs[byte_idx] >> bit_offset) & 1;
175-
reg_f[i] = bit_val ? d : -d;
176-
}
163+
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
164+
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
165+
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
166+
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
177167

178168
reg = (type4) reg_f;
179169
}
@@ -3176,27 +3166,33 @@ kernel void kernel_group_norm_f32(
31763166
}
31773167
}
31783168

3179-
// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
3180-
// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181-
// we assume that the yl's have been multiplied with the appropriate scale factor
3169+
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
31823170
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183-
float d = qb_curr->d;
3171+
device const uint8_t * qs = qb_curr->qs + il / 8;
3172+
const uint8_t b0 = qs[0];
3173+
const uint8_t b1 = qs[1];
31843174

31853175
float acc = 0.0f;
31863176

3187-
// il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188-
// 16 weights = 16 bits = 2 bytes
3189-
const int byte_offset = il / 8;
3190-
device const uint8_t * qs = qb_curr->qs + byte_offset;
3177+
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
3178+
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
3179+
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
3180+
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
3181+
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
3182+
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
3183+
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
3184+
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
31913185

3192-
for (int i = 0; i < 16; i++) {
3193-
const uint8_t byte_idx = i / 8;
3194-
const uint8_t bit_idx = i % 8;
3195-
const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1) ? 1 : -1;
3196-
acc += yl[i] * qval;
3197-
}
3186+
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
3187+
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
3188+
acc += select(0.0f, yl[10], bool(b1 & 0x04));
3189+
acc += select(0.0f, yl[11], bool(b1 & 0x08));
3190+
acc += select(0.0f, yl[12], bool(b1 & 0x10));
3191+
acc += select(0.0f, yl[13], bool(b1 & 0x20));
3192+
acc += select(0.0f, yl[14], bool(b1 & 0x40));
3193+
acc += select(0.0f, yl[15], bool(b1 & 0x80));
31983194

3199-
return d * acc;
3195+
return qb_curr->d * (2.0f * acc - sumy);
32003196
}
32013197

32023198
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -3420,22 +3416,25 @@ void mul_vec_q_n_f32_impl(
34203416
}
34213417
}
34223418

3423-
kernel void kernel_mul_mv_q1_0_f32(
3424-
constant ggml_metal_kargs_mul_mv & args,
3419+
template<int nr0, typename args_t>
3420+
void kernel_mul_mv_q1_0_f32_impl(
3421+
args_t args,
34253422
device const char * src0,
34263423
device const char * src1,
34273424
device char * dst,
3428-
uint3 tgpig[[threadgroup_position_in_grid]],
3429-
ushort tiisg[[thread_index_in_simdgroup]],
3430-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3431-
// Q1_0-specific implementation with 128-element blocks
3425+
threadgroup char * shmem,
3426+
uint3 tgpig,
3427+
ushort tiisg,
3428+
ushort sgitg) {
3429+
const short NSG = FC_mul_mv_nsg;
3430+
34323431
const int nb = args.ne00/QK1_0;
34333432

34343433
const int r0 = tgpig.x;
34353434
const int r1 = tgpig.y;
34363435
const int im = tgpig.z;
34373436

3438-
const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0;
3437+
const int first_row = (r0 * NSG + sgitg) * nr0;
34393438

34403439
const uint i12 = im%args.ne12;
34413440
const uint i13 = im/args.ne12;
@@ -3444,37 +3443,31 @@ kernel void kernel_mul_mv_q1_0_f32(
34443443

34453444
device const float * y = (device const float *) (src1 + offset1);
34463445

3447-
// pointers to src0 rows
3448-
device const block_q1_0 * ax[N_R0_Q1_0];
3449-
for (int row = 0; row < N_R0_Q1_0; ++row) {
3446+
device const block_q1_0 * ax[nr0];
3447+
for (int row = 0; row < nr0; ++row) {
34503448
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3451-
34523449
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
34533450
}
34543451

3455-
float yl[16]; // src1 vector cache
3456-
float sumf[N_R0_Q1_0] = {0.f};
3452+
float yl[16];
3453+
float sumf[nr0] = {0.f};
34573454

3458-
// For 128-element blocks, we need 8 passes of 16 elements each
3459-
// Each thread processes a different 16-element chunk
3460-
const short ix = (tiisg/8); // which block (0 to 3 for 32 threads / 8)
3461-
const short il = (tiisg%8)*16; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112)
3455+
const short ix = (tiisg/8);
3456+
const short il = (tiisg%8)*16;
34623457

34633458
device const float * yb = y + ix*QK1_0 + il;
34643459

3465-
// each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128)
34663460
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
34673461
float sumy = 0.f;
34683462

3469-
// Q1_0: simple copy
34703463
#pragma unroll
34713464
for (short i = 0; i < 16; i++) {
34723465
yl[i] = yb[i];
34733466
sumy += yb[i];
34743467
}
34753468

34763469
#pragma unroll
3477-
for (short row = 0; row < N_R0_Q1_0; row++) {
3470+
for (short row = 0; row < nr0; row++) {
34783471
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
34793472
}
34803473

@@ -3483,7 +3476,7 @@ kernel void kernel_mul_mv_q1_0_f32(
34833476

34843477
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
34853478

3486-
for (int row = 0; row < N_R0_Q1_0; ++row) {
3479+
for (int row = 0; row < nr0; ++row) {
34873480
const float tot = simd_sum(sumf[row]);
34883481

34893482
if (tiisg == 0 && first_row + row < args.ne01) {
@@ -3492,6 +3485,18 @@ kernel void kernel_mul_mv_q1_0_f32(
34923485
}
34933486
}
34943487

3488+
[[host_name("kernel_mul_mv_q1_0_f32")]]
3489+
kernel void kernel_mul_mv_q1_0_f32(
3490+
constant ggml_metal_kargs_mul_mv & args,
3491+
device const char * src0,
3492+
device const char * src1,
3493+
device char * dst,
3494+
uint3 tgpig[[threadgroup_position_in_grid]],
3495+
ushort tiisg[[thread_index_in_simdgroup]],
3496+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3497+
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
3498+
}
3499+
34953500
kernel void kernel_mul_mv_q4_0_f32(
34963501
constant ggml_metal_kargs_mul_mv & args,
34973502
device const char * src0,
@@ -10022,6 +10027,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
1002210027

1002310028
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
1002410029
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10030+
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
1002510031
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
1002610032
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
1002710033
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -10231,6 +10237,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
1023110237

1023210238
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
1023310239

10240+
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
1023410241
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
1023510242
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
1023610243
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;

0 commit comments

Comments
 (0)