diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 82b048bb3ae4..5ae0dabd34ef 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -672,34 +672,36 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo float32x4_t acc = vdupq_n_f32(0.0f); for (int ib = 0; ib < nb; ++ib) { + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs); const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16); - const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_0, m4b)); const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4)); const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); - const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); - const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); - const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); - const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b)); - - const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs); - const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16); - const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b)); - const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); - const int32x4_t p0 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); - const int32x4_t p1 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + const int32x4_t p0 = ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi); + const int32x4_t p1 = ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi); + const int32x4_t p2 = ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi); + const int32x4_t p3 = ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi); - const int32x4_t sums = vpaddq_s32(p0, p1); - - // Decode 4 UE4M3 scales to f32 and multiply with q8 scales const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); const float32x4_t nvsc = { @@ -710,7 +712,13 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo }; const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); - acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + const float32x4_t sums = (float32x4_t){ + (float)vaddvq_s32(p0), + (float)vaddvq_s32(p1), + (float)vaddvq_s32(p2), + (float)vaddvq_s32(p3) + }; + acc = vfmaq_f32(acc, sums, scales); } sumf = vaddvq_f32(acc); #else diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec0572..a3353e787ced 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -319,6 +319,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__