diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index f0da2f4ef..f58fcdb0c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1343,6 +1343,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: + case GGML_TYPE_Q4_POLAR: case GGML_TYPE_I32: return true; default: @@ -1370,6 +1373,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: + case GGML_TYPE_Q4_POLAR: switch (op->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -1441,7 +1447,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_Q5_1: case GGML_TYPE_IQ4_NL: case GGML_TYPE_QJL1_256: - // // ELIZA-QJL-SET-ROWS-V1 + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: + case GGML_TYPE_Q4_POLAR: + // ELIZA-CUSTOM-KV-SET-ROWS-V1 return true; default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 001c1125a..dbf850de8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -509,6 +509,310 @@ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) { dst.d = sumq2 > 0 ? sumqx/sumq2 : d; } +constant float k_tbq3_codebook_set_rows[8] = { + -2.1519457f, -1.3439093f, -0.7560053f, -0.2450942f, + 0.2450942f, 0.7560053f, 1.3439093f, 2.1519457f, +}; + +constant float k_tbq4_codebook_set_rows[16] = { + -2.7321365f, -2.0685055f, -1.6175243f, -1.2557391f, + -0.9419147f, -0.6564307f, -0.3878412f, -0.1283243f, + 0.1283243f, 0.3878412f, 0.6564307f, 0.9419147f, + 1.2557391f, 1.6175243f, 2.0685055f, 2.7321365f, +}; + +constant float k_tbq_signs_set_rows[QK_TBQ] = { + 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, +}; + +static inline void tbq_hadamard32_set_rows(thread float * x) { + for (int len = 1; len < QK_TBQ; len <<= 1) { + for (int i = 0; i < QK_TBQ; i += 2 * len) { + for (int j = 0; j < len; ++j) { + const float a = x[i + j]; + const float b = x[i + j + len]; + x[i + j] = a + b; + x[i + j + len] = a - b; + } + } + } + + constexpr float norm = 0.1767766952966369f; + for (int i = 0; i < QK_TBQ; ++i) { + x[i] *= norm; + } +} + +static inline uint8_t tbq_best_index_set_rows(constant const float * codebook, int n, float x) { + uint8_t best = 0; + float best_dist = fabs(x - codebook[0]); + + for (int i = 1; i < n; ++i) { + const float dist = fabs(x - codebook[i]); + if (dist < best_dist) { + best = (uint8_t) i; + best_dist = dist; + } + } + + return best; +} + +static inline void tbq3_set_code_set_rows(device uint8_t * qs, int idx, uint8_t code) { + const int bit = idx * 3; + const int byte = bit >> 3; + const int shift = bit & 7; + + qs[byte] = (uint8_t) (qs[byte] | ((code & 0x7u) << shift)); + if (shift > 5 && byte + 1 < (int) (QK_TBQ * 3 / 8)) { + qs[byte + 1] = (uint8_t) (qs[byte + 1] | ((code & 0x7u) >> (8 - shift))); + } +} + +static inline void tbq4_set_code_set_rows(device uint8_t * qs, int idx, uint8_t code) { + const int j = idx % (QK_TBQ / 2); + if (idx < QK_TBQ / 2) { + qs[j] = (uint8_t) ((qs[j] & 0xF0) | (code & 0x0F)); + } else { + qs[j] = (uint8_t) ((qs[j] & 0x0F) | ((code & 0x0F) << 4)); + } +} + +static inline uint8_t tbq3_get_code_set_rows(device const uint8_t * qs, int idx) { + const int bit = idx * 3; + const int byte = bit >> 3; + const int shift = bit & 7; + uint16_t raw = qs[byte]; + if (byte + 1 < (int) (QK_TBQ * 3 / 8)) { + raw = (uint16_t) (raw | ((uint16_t) qs[byte + 1] << 8)); + } + return (uint8_t) ((raw >> shift) & 0x7u); +} + +static inline uint8_t tbq4_get_code_set_rows(device const uint8_t * qs, int idx) { + const int j = idx % (QK_TBQ / 2); + return idx < QK_TBQ / 2 ? (qs[j] & 0x0Fu) : (qs[j] >> 4); +} + +static inline void tbq_precondition_block_set_rows(device const float * src, thread float * rotated) { + for (int i = 0; i < QK_TBQ; ++i) { + rotated[i] = src[i] * k_tbq_signs_set_rows[i]; + } + tbq_hadamard32_set_rows(rotated); +} + +static inline void tbq_uncondition_block_set_rows(thread float * x) { + tbq_hadamard32_set_rows(x); + for (int i = 0; i < QK_TBQ; ++i) { + x[i] *= k_tbq_signs_set_rows[i]; + } +} + +void quantize_tbq3_0(device const float * src, device block_tbq3_0 & dst) { +#pragma METAL fp math_mode(safe) + thread float rotated[QK_TBQ]; + tbq_precondition_block_set_rows(src, rotated); + + float sumsq = 0.0f; + for (int i = 0; i < QK_TBQ; ++i) { + sumsq = fma(rotated[i], rotated[i], sumsq); + } + + const float d = sqrt(sumsq / float(QK_TBQ)); + dst.d = (half) d; + for (int i = 0; i < QK_TBQ * 3 / 8; ++i) { + dst.qs[i] = 0; + } + + if (d == 0.0f) { + return; + } + + const float id = 1.0f / d; + for (int i = 0; i < QK_TBQ; ++i) { + const uint8_t code = tbq_best_index_set_rows(k_tbq3_codebook_set_rows, 8, rotated[i] * id); + tbq3_set_code_set_rows(dst.qs, i, code); + } +} + +void quantize_tbq4_0(device const float * src, device block_tbq4_0 & dst) { +#pragma METAL fp math_mode(safe) + thread float rotated[QK_TBQ]; + tbq_precondition_block_set_rows(src, rotated); + + float sumsq = 0.0f; + for (int i = 0; i < QK_TBQ; ++i) { + sumsq = fma(rotated[i], rotated[i], sumsq); + } + + const float d = sqrt(sumsq / float(QK_TBQ)); + dst.d = (half) d; + for (int i = 0; i < QK_TBQ / 2; ++i) { + dst.qs[i] = 0; + } + + if (d == 0.0f) { + return; + } + + const float id = 1.0f / d; + for (int i = 0; i < QK_TBQ; ++i) { + const uint8_t code = tbq_best_index_set_rows(k_tbq4_codebook_set_rows, 16, rotated[i] * id); + tbq4_set_code_set_rows(dst.qs, i, code); + } +} + +constant float k_polar_q4_centroids_set_rows[16] = { + -2.754354807f, -2.093562707f, -1.643041510f, -1.279739752f, + -0.962640978f, -0.672392117f, -0.397897103f, -0.131757782f, + 0.131757782f, 0.397897103f, 0.672392117f, 0.962640978f, + 1.279739752f, 1.643041510f, 2.093562707f, 2.754354807f, +}; + +constant float k_polar_q4_boundaries_set_rows[15] = { + -2.423958757f, -1.868302108f, -1.461390631f, -1.121190365f, + -0.817516548f, -0.535144610f, -0.264827443f, 0.0f, + 0.264827443f, 0.535144610f, 0.817516548f, 1.121190365f, + 1.461390631f, 1.868302108f, 2.423958757f, +}; + +static inline void polar_hadamard128_set_rows(thread float * x) { + for (int h = 1; h < QK_POLAR; h <<= 1) { + for (int i = 0; i < QK_POLAR; i += (h << 1)) { + for (int j = i; j < i + h; ++j) { + const float a = x[j]; + const float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } +} + +static inline uint8_t polar_q4_bucketize_set_rows(float v) { + uint8_t code = 0; + for (int i = 0; i < 15; ++i) { + if (v > k_polar_q4_boundaries_set_rows[i]) { + code = (uint8_t) (i + 1); + } + } + return code; +} + +void quantize_q4_polar(device const float * src, device block_q4_polar & dst) { +#pragma METAL fp math_mode(safe) + float sumsq = 0.0f; + for (int i = 0; i < QK_POLAR; ++i) { + sumsq = fma(src[i], src[i], sumsq); + } + + const float l2 = sqrt(sumsq); + const float inv_l2 = l2 > 1e-10f ? 1.0f / l2 : 0.0f; + dst.d = (half) l2; + + thread float buf[QK_POLAR]; + for (int i = 0; i < QK_POLAR; ++i) { + buf[i] = src[i] * inv_l2; + } + + polar_hadamard128_set_rows(buf); + + thread uint8_t codes[QK_POLAR]; + for (int i = 0; i < QK_POLAR; ++i) { + codes[i] = polar_q4_bucketize_set_rows(buf[i]); + } + + for (int i = 0; i < QK_POLAR / 2; ++i) { + const uint8_t lo = codes[2 * i]; + const uint8_t hi = codes[2 * i + 1]; + dst.qs[i] = (uint8_t) ((hi << 4) | (lo & 0x0F)); + } + + for (int i = 0; i < QJL_RESIDUAL_BYTES; ++i) { + dst.qjl[i] = 0; + } + + float proj = 0.0f; + uint state = 42u; + for (int i = 0; i < QK_POLAR; ++i) { + state ^= state << 13; + state ^= state >> 17; + state ^= state << 5; + const float sign = (state & 1u) ? 1.0f : -1.0f; + const float c = k_polar_q4_centroids_set_rows[codes[i]]; + proj = fma(buf[i] - c, sign, proj); + } + dst.qjl[0] = proj >= 0.0f ? 1u : 0u; +} + +template +static inline void store_dequant_chunk_set_rows(thread float * decoded, short il, thread type4x4 & reg) { + float4x4 reg_f; + const int base = il * 16; + for (int i = 0; i < 16; ++i) { + reg_f[i / 4][i % 4] = decoded[base + i]; + } + reg = (type4x4) reg_f; +} + +template +void dequantize_tbq3_0(device const block_tbq3_0 * xb, short il, thread type4x4 & reg) { + thread float decoded[QK_TBQ]; + const float d = xb->d; + if (d == 0.0f) { + for (int i = 0; i < QK_TBQ; ++i) { + decoded[i] = 0.0f; + } + } else { + for (int i = 0; i < QK_TBQ; ++i) { + decoded[i] = d * k_tbq3_codebook_set_rows[tbq3_get_code_set_rows(xb->qs, i)]; + } + tbq_uncondition_block_set_rows(decoded); + } + store_dequant_chunk_set_rows(decoded, il, reg); +} + +template +void dequantize_tbq4_0(device const block_tbq4_0 * xb, short il, thread type4x4 & reg) { + thread float decoded[QK_TBQ]; + const float d = xb->d; + if (d == 0.0f) { + for (int i = 0; i < QK_TBQ; ++i) { + decoded[i] = 0.0f; + } + } else { + for (int i = 0; i < QK_TBQ; ++i) { + decoded[i] = d * k_tbq4_codebook_set_rows[tbq4_get_code_set_rows(xb->qs, i)]; + } + tbq_uncondition_block_set_rows(decoded); + } + store_dequant_chunk_set_rows(decoded, il, reg); +} + +template +void dequantize_q4_polar(device const block_q4_polar * xb, short il, thread type4x4 & reg) { + thread float decoded[QK_POLAR]; + for (int i = 0; i < QK_POLAR / 2; ++i) { + const uint8_t byte = xb->qs[i]; + decoded[2 * i] = k_polar_q4_centroids_set_rows[byte & 0x0Fu]; + decoded[2 * i + 1] = k_polar_q4_centroids_set_rows[(byte >> 4) & 0x0Fu]; + } + + // Match the CPU default: QJL residual correction is opt-in and disabled + // unless runtime metadata enables it. + polar_hadamard128_set_rows(decoded); + + const float scale = ((float) xb->d) / float(QK_POLAR); + for (int i = 0; i < QK_POLAR; ++i) { + decoded[i] *= scale; + } + + store_dequant_chunk_set_rows(decoded, il, reg); +} + template void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); @@ -7863,6 +8167,9 @@ template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_ template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_tbq3_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_tbq4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_polar")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -7904,6 +8211,9 @@ template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32< template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_tbq3_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_tbq4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_polar_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -7911,6 +8221,9 @@ template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32< template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_tbq3_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_tbq4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_polar_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; kernel void kernel_concat( constant ggml_metal_kargs_concat & args, @@ -9852,6 +10165,37 @@ kernel void kernel_set_rows_q32( } } +template +kernel void kernel_set_rows_q128( + constant ggml_metal_kargs_set_rows & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + + const int32_t i12 = i03%args.ne12; + const int32_t i11 = i02%args.ne11; + + const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; + if (i01 >= args.ne01) { + return; + } + + const int32_t i10 = i01; + const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + + device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + + for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { + quantize_func(src_row + QK_POLAR*ind, dst_row[ind]); + } +} + template kernel void kernel_set_rows_f( constant ggml_metal_kargs_set_rows & args, @@ -10694,6 +11038,14 @@ template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kerne template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_tbq3_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_tbq3_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_tbq4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_tbq4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; + +typedef decltype(kernel_set_rows_q128) set_rows_q128_t; +template [[host_name("kernel_set_rows_q4_polar_i64")]] kernel set_rows_q128_t kernel_set_rows_q128; +template [[host_name("kernel_set_rows_q4_polar_i32")]] kernel set_rows_q128_t kernel_set_rows_q128; // // matrix-matrix multiplication diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c1356ec30..76fa3ff49 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1987,20 +1987,28 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_cast(ctx0, v, GGML_TYPE_F16); } - // Eliza fused-attn K/V cache types (QJL1_256, Q4_POLAR, TBQ3_TCQ) have - // no vec_dot in the CPU type-traits — they are stored cache types that - // require either the fused custom op (GGML_OP_FUSED_ATTN_QJL_TBQ) or a - // dequantize hop before ggml_flash_attn_ext. The graph builder does - // not yet route to the fused op, so dequantize via F32 -> F16 here. - // This is bit-exact w.r.t. the type's to_float (dequantize_row_*). - if (k->type == GGML_TYPE_QJL1_256 || k->type == GGML_TYPE_TBQ3_TCQ) { - ggml_tensor * k_f32 = ggml_cast(ctx0, k, GGML_TYPE_F32); - k = ggml_cast(ctx0, k_f32, GGML_TYPE_F16); - } - - if (v->type == GGML_TYPE_Q4_POLAR) { - ggml_tensor * v_f32 = ggml_cast(ctx0, v, GGML_TYPE_F32); - v = ggml_cast(ctx0, v_f32, GGML_TYPE_F16); + // Eliza custom cache types need either their fused attention op or a + // dequantize hop before stock flash attention. Keep this fallback + // explicit so manual cache-type overrides do not reach FA with a type + // its backend does not accept. + const auto needs_cache_dequant = [](ggml_type type) { + return type == GGML_TYPE_QJL1_256 || + type == GGML_TYPE_TBQ3_TCQ || + type == GGML_TYPE_TBQ3_0 || + type == GGML_TYPE_TBQ4_0 || + type == GGML_TYPE_Q4_POLAR; + }; + const auto dequant_to_f16 = [&](ggml_tensor * t) { + ggml_tensor * t_f32 = ggml_cast(ctx0, t, GGML_TYPE_F32); + return ggml_cast(ctx0, t_f32, GGML_TYPE_F16); + }; + + if (needs_cache_dequant(k->type)) { + k = dequant_to_f16(k); + } + + if (needs_cache_dequant(v->type)) { + v = dequant_to_f16(v); } cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 48672110c..12405c3df 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2410,6 +2410,12 @@ struct test_set_rows : public test_case { } double max_nmse_err() override { + if (type == GGML_TYPE_TBQ3_0 || type == GGML_TYPE_TBQ4_0) { + return 1e-5; + } + if (type == GGML_TYPE_Q4_POLAR) { + return 5e-4; + } if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) { // estimate what the max nmse error would be if one quantized value is @@ -7962,6 +7968,12 @@ static const ggml_type eliza_custom_quant_types_cpy[] = { GGML_TYPE_TBQ3_TCQ, }; +static const ggml_type eliza_custom_quant_types_set_rows[] = { + GGML_TYPE_TBQ3_0, + GGML_TYPE_TBQ4_0, + GGML_TYPE_Q4_POLAR, +}; + // Mul-mat: only the types that register a vec_dot in the CPU traits table. // QJL1_256 and TBQ3_TCQ are intentionally absent (no vec_dot - they are // scored via dedicated GGML_OP_ATTN_SCORE_QJL / fused attention paths). @@ -8087,6 +8099,12 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I64, { 1, 8, 1, 3 }, { 1, 1 }, 2, false)); test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false)); test_cases.emplace_back(new test_set_rows(GGML_TYPE_Q8_0, GGML_TYPE_I32, { 256, 5, 1, 3 }, { 1, 1, }, 1, false)); + for (ggml_type type : eliza_custom_quant_types_set_rows) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 128, 5, 2, 1 }, { 1, 1 }, 2, v)); + test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I32, { 128, 7, 1, 2 }, { 1, 1 }, 2, v)); + } + } for (ggml_type type : all_types) { for (int b : {1, 7}) { for (bool v : {false, true}) { @@ -8489,12 +8507,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type_dst, {256, 4, 4, 4})); } for (ggml_type type_src : eliza_custom_quant_types_cpy) { - // Skip QJL/POLAR/TCQ for the dequant-to-F32 path: their to_float - // is defined but only via cache-attention codepaths, not the - // generic ggml_compute_forward_dup_q path. The standard quants - // (g32/g128/TBQ3_0/TBQ4_0/K-variants) are full participants. + // Skip QJL/TCQ for the dequant-to-F32 path: their to_float is + // defined but only via cache-attention codepaths, not the generic + // ggml_compute_forward_dup_q path. Q4_POLAR now participates in the + // generic path because flash-attention fallback dequantizes it. if (type_src == GGML_TYPE_QJL1_256 || - type_src == GGML_TYPE_Q4_POLAR || type_src == GGML_TYPE_TBQ3_TCQ) { continue; }