From 54de6134670236b9b93865ed334f237af28d4f45 Mon Sep 17 00:00:00 2001 From: Frank Barchard Date: Sun, 22 Mar 2026 20:23:29 -0700 Subject: [PATCH] Add 2 bit SSE GEMM microkernels These updates enable 2-bit quantization support for both QS8-QC2W and QD8-F32-QC2W using SSSE3 instructions with MADD optimization. 1. src/qs8-gemm/MRx4c8-ssevnni.c.in: * Added support for QS8_QC2, QC2_F32, and QC2_F16 datatypes. * Introduced the _MM_SET1_EPI8 macro for consistent constant generation. * Updated the ISA and instruction selection logic to support MADD variants (specifically _mm_dpbusd_epi32_madd_kzp2 for 2-bit variants). * Updated the function signature to include the row_sum parameter for QD8 variants. 2. scripts/generate-qs8-gemm.sh: * Added generation rules for SSSE3 with MADD=1 for both QS8_QC2 and QC2_F32 variants. 3. src/xnnpack/gemm.h: * Added microkernel declarations for xnn_qd8_f32_qc2w_gemm_minmax_ukernel_*x4c8__ssse3_madd. * Added microkernel declarations for xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_*x4c8__ssse3_madd. PiperOrigin-RevId: 887841831 --- scripts/generate-qs8-gemm.sh | 14 ++ src/qs8-gemm/MRx4c8-ssevnni.c.in | 246 ++++++++++++++++++++++--------- src/xnnpack/gemm.h | 14 ++ 3 files changed, 205 insertions(+), 69 deletions(-) diff --git a/scripts/generate-qs8-gemm.sh b/scripts/generate-qs8-gemm.sh index afbc55df27f..e4f0a52aa7a 100755 --- a/scripts/generate-qs8-gemm.sh +++ b/scripts/generate-qs8-gemm.sh @@ -1721,6 +1721,20 @@ tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC4 -D SSE tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=5 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=1 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-5x4c8-minmax-avx-madd-prfm.c & tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=6 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=1 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-6x4c8-minmax-avx-madd-prfm.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x4c8-minmax-fp32-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x4c8-minmax-fp32-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x4c8-minmax-fp32-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x4c8-minmax-fp32-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=5 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-5x4c8-minmax-fp32-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=6 -D DATATYPE=QS8_QC2 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x4c8-minmax-fp32-ssse3-madd.c & + +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x4c8-minmax-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x4c8-minmax-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x4c8-minmax-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x4c8-minmax-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=5 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x4c8-minmax-ssse3-madd.c & +tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=6 -D DATATYPE=QC2_F32 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION= -o src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x4c8-minmax-ssse3-madd.c & + ################################## x86 AVX256 VNNI EVEX ################################# ### C8 micro-kernels diff --git a/src/qs8-gemm/MRx4c8-ssevnni.c.in b/src/qs8-gemm/MRx4c8-ssevnni.c.in index 7eb45cc91df..06f966f506b 100644 --- a/src/qs8-gemm/MRx4c8-ssevnni.c.in +++ b/src/qs8-gemm/MRx4c8-ssevnni.c.in @@ -4,7 +4,7 @@ // LICENSE file in the root directory of this source tree. $assert REQUANTIZATION == "FP32" or not REQUANTIZATION -$assert DATATYPE in ["QC4_F32", "QS8_QC4"] +$assert DATATYPE in ["QC8", "QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16", "QS8_QC4", "QS8_QC2"] $assert SSE in [3, 4] #include #include @@ -22,13 +22,18 @@ $if PREFETCH: #include "src/xnnpack/prefetch.h" #include "src/xnnpack/unaligned.h" +$def _MM_SET1_EPI8(VALUE): +$ return f"_mm_set1_epi8({VALUE})" -$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8_QC4": "qs8_qc4w", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE] +$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8_QC4": "qs8_qc4w", "QS8_QC2": "qs8_qc2w", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC2_F16": "qd8_f16_qc2w", "QC4_F32": "qd8_f32_qc4w", "QC2_F32": "qd8_f32_qc2w"}[DATATYPE] $REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else "" $PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" if REQUANTIZATION else "scalar" -$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC4": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE] -$OUT_T = {"QC8": "int8_t", "QS8_QC4": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float"}[DATATYPE] -$_MM_DPBUSD_EPI32 = "_mm_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm_dpbusd_avx_epi32" if AVX == 2 else "_mm_dpbusd_epi32" +$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC4": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC2": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params", "QC2_F32": "struct xnn_f32_minmax_params", "QC2_F16": "struct xnn_f16_minmax_params",}[DATATYPE] +$OUT_T = {"QC8": "int8_t", "QS8_QC4": "int8_t", "QS8_QC2": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC2_F16": "xnn_float16", "QC4_F32": "float", "QC2_F32": "float"}[DATATYPE] +$if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + $_MM_DPBUSD_EPI32 = "_mm_dpbusd_epi32_madd_kzp2" if VARIANT == "MADD" else "_mm_dpbusd_avx_epi32" if AVX == 2 else "_mm_dpbusd_epi32" +$else: + $_MM_DPBUSD_EPI32 = "_mm_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm_dpbusd_avx_epi32" if AVX == 2 else "_mm_dpbusd_epi32" $ISA = "avx" if AVX else {3: "ssse3", 4: "sse41"}[SSE] void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}( size_t mr, @@ -40,8 +45,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ ${OUT_T}* restrict c, size_t cm_stride, size_t cn_stride, - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: const ${PARAMS_TYPE}* restrict params, + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + const float* row_sum, const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS $else: const ${PARAMS_TYPE}* restrict params) XNN_OOB_READS @@ -55,15 +62,17 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ assert(w != NULL); assert(c != NULL); + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + const size_t original_kc = kc; kc = round_up_po2(kc, 8 * sizeof(int8_t)); const int8_t* a0 = a; - $if DATATYPE in ["QD8_F16", "QC4_F16"]: + $if DATATYPE in ["QD8_F16", "QC4_F16", "QC2_F16"]: uint16_t* c0 = (uint16_t*) c; $else: ${OUT_T}* c0 = c; $for M in range(1, MR): const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride); - $if DATATYPE in ["QD8_F16", "QC4_F16"]: + $if DATATYPE in ["QD8_F16", "QC4_F16", "QC2_F16"]: uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); $else: ${OUT_T}* c${M} = (${OUT_T}*) ((uintptr_t) c${M-1} + cm_stride); @@ -83,7 +92,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ c${M} = c${M-1}; } - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: $for M in range(MR): const __m128i vinput_zero_point${M} = _mm_set1_epi32((int) quantization_params[${M}].zero_point); $if "F16" in DATATYPE: @@ -93,12 +102,16 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); $else: - const __m128i vsign_mask = _mm_set1_epi8(0x80); + const __m128i vsign_mask = ${_MM_SET1_EPI8(0x80)}; XNN_FORCE_REALIZATION(vsign_mask); const __m128 voutput_max_less_zero_point = _mm_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_zero_point = _mm_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_set1_epi16(params->${PARAMS_STRUCT}.output_min); - $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + const __m128i vmask = _mm_set1_epi8(0x03); + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"] and VARIANT != "MADD": + const __m128i vtwo = _mm_set1_epi8(2); // Subtract 2 (kernel zero point) from unsigned 2 bit to sign extend + $elif DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: $if VARIANT == "MADD": const __m128i vmask = _mm_set1_epi8(0x0F); $else: @@ -108,13 +121,17 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ const __m128i vshl4 = _mm_set1_epi64x(0x01020408); XNN_FORCE_REALIZATION(vshl4); do { - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: const __m128i vksum0123 = _mm_load_si128(w); $if SSE == 4: $for M in range(MR): const __m128i vsum${M}x0123 = _mm_mullo_epi32(vksum0123, vinput_zero_point${M}); - __m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x0123, _mm_setzero_si128()); - __m128i vacc${M}x23 = _mm_unpackhi_epi32(vsum${M}x0123, _mm_setzero_si128()); + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + __m128i vacc${M}x01 = _mm_setzero_si128(); + __m128i vacc${M}x23 = _mm_setzero_si128(); + $else: + __m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x0123, _mm_setzero_si128()); + __m128i vacc${M}x23 = _mm_unpackhi_epi32(vsum${M}x0123, _mm_setzero_si128()); $else: const __m128i vksum13 = _mm_shuffle_epi32(vksum0123, 0xF5); $for M in range(MR): @@ -122,8 +139,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ const __m128i vsum${M}x13 = _mm_mul_epu32(vksum13, vinput_zero_point${M}); const __m128i vsum${M}x01 = _mm_unpacklo_epi32(vsum${M}x02, vsum${M}x13); const __m128i vsum${M}x23 = _mm_unpackhi_epi32(vsum${M}x02, vsum${M}x13); - __m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x01, _mm_setzero_si128()); - __m128i vacc${M}x23 = _mm_unpacklo_epi32(vsum${M}x23, _mm_setzero_si128()); + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + __m128i vacc${M}x01 = _mm_setzero_si128(); + __m128i vacc${M}x23 = _mm_setzero_si128(); + $else: + __m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x01, _mm_setzero_si128()); + __m128i vacc${M}x23 = _mm_unpacklo_epi32(vsum${M}x23, _mm_setzero_si128()); $else: $if SSE == 4: __m128i vacc0x01 = _mm_cvtepu32_epi64(_mm_loadu_si64(w)); @@ -140,11 +161,61 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ __m128i vacc1x${M}x01 = _mm_setzero_si128(); __m128i vacc1x${M}x23 = _mm_setzero_si128(); w = (const int32_t*) w + 4; + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 4; size_t k = kc; + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + while (k >= 32 * sizeof(int8_t)) { + $for M in range(MR): + const __m128i va${M}x0 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); + const __m128i va${M}x1 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)); + const __m128i va${M}x2 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 16)); + const __m128i va${M}x3 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 24)); + $if DATATYPE == "QS8_QC2": + $for i in range(4): + va${M}x${i} = _mm_xor_si128(va${M}x${i}, vsign_mask); + a${M} += 32; + + const __m128i vbb0123x0123456789ABCDEF = _mm_load_si128(w); + $if GFNI: + const __m128i vb0123x0123 = _mm_gf2p8affine_epi64_epi8(vbb0123x0123456789ABCDEF, vshr0, 0xFE); + const __m128i vb0123x4567 = _mm_gf2p8affine_epi64_epi8(vbb0123x0123456789ABCDEF, vshr2, 0xFE); + const __m128i vb0123x89AB = _mm_gf2p8affine_epi64_epi8(vbb0123x0123456789ABCDEF, vshr4, 0xFE); + const __m128i vb0123xCDEF = _mm_gf2p8affine_epi64_epi8(vbb0123x0123456789ABCDEF, vshr6, 0xFE); + $else: + __m128i vb0123x0123 = _mm_and_si128(vbb0123x0123456789ABCDEF, vmask); + __m128i vbs0123x4567 = _mm_srli_epi32(vbb0123x0123456789ABCDEF, 2); + __m128i vb0123x4567 = _mm_and_si128(vbs0123x4567, vmask); + __m128i vbs0123x89AB = _mm_srli_epi32(vbb0123x0123456789ABCDEF, 4); + __m128i vb0123x89AB = _mm_and_si128(vbs0123x89AB, vmask); + __m128i vbs0123xCDEF = _mm_srli_epi32(vbb0123x0123456789ABCDEF, 6); + __m128i vb0123xCDEF = _mm_and_si128(vbs0123xCDEF, vmask); + $if VARIANT != "MADD": + vb0123x0123 = _mm_sub_epi8(vb0123x0123, vtwo); + vb0123x4567 = _mm_sub_epi8(vb0123x4567, vtwo); + vb0123x89AB = _mm_sub_epi8(vb0123x89AB, vtwo); + vb0123xCDEF = _mm_sub_epi8(vb0123xCDEF, vtwo); + + $for M in range(MR): + vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x0, vb0123x0123); + $if MR < 3: + vacc1x${M}x01 = ${_MM_DPBUSD_EPI32}(vacc1x${M}x01, va${M}x1, vb0123x4567); + $else: + vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x1, vb0123x4567); + vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x2, vb0123x89AB); + $if MR < 3: + vacc1x${M}x23 = ${_MM_DPBUSD_EPI32}(vacc1x${M}x23, va${M}x3, vb0123xCDEF); + $else: + vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x3, vb0123xCDEF); + w = (const int8_t*) w + 16; + k -= 32 * sizeof(int8_t); + } while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); const __m128i va${M}x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)); $else: @@ -152,89 +223,112 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); a${M} += 16; - $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: - const __m128i vbb01234567x0123 = _mm_load_si128(w); - const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + // 2 planes of 2 bit. potentially 3rd plane handled later + const __m128i vbb0123x01234567 = _mm_load_si128(w); $if GFNI: - const __m128i vb01234567x01 = _mm_gf2p8affine_epi64_epi8(vbb01234567x0123, vshl4, 0); - const __m128i vb89ABCDEFx01 = _mm_gf2p8affine_epi64_epi8(vbb89ABCDEFx0123, vshl4, 0); - const __m128i vb01234567x23 = _mm_and_si128(vbb01234567x0123, vmask); - const __m128i vb89ABCDEFx23 = _mm_and_si128(vbb89ABCDEFx0123, vmask); + const __m128i vb0123x01 = _mm_gf2p8affine_epi64_epi8(vbb0123x01234567, vshr0, 0xFE); + const __m128i vb0123x23 = _mm_gf2p8affine_epi64_epi8(vbb0123x01234567, vshr2, 0xFE); + $else: + __m128i vb0123x01 = _mm_and_si128(vbb0123x01234567, vmask); + __m128i vbs0123x23 = _mm_srli_epi32(vbb0123x01234567, 2); + __m128i vb0123x23 = _mm_and_si128(vbs0123x23, vmask); + $if VARIANT != "MADD": + vb0123x01 = _mm_sub_epi8(vb0123x01, vtwo); + vb0123x23 = _mm_sub_epi8(vb0123x23, vtwo); + $elif DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: + const __m128i vbb0123x0123 = _mm_load_si128(w); + $if GFNI: + const __m128i vb0123x01 = _mm_gf2p8affine_epi64_epi8(vbb0123x0123, vshl4, 0); + const __m128i vb0123x23 = _mm_and_si128(vbb0123x0123, vmask); $elif VARIANT == "MADD": - const __m128i vbs01234567x23 = _mm_srli_epi32(vbb01234567x0123, 4); - const __m128i vbs89ABCDEFx23 = _mm_srli_epi32(vbb89ABCDEFx0123, 4); - const __m128i vb01234567x01 = _mm_and_si128(vbb01234567x0123, vmask); - const __m128i vb89ABCDEFx01 = _mm_and_si128(vbb89ABCDEFx0123, vmask); - const __m128i vb01234567x23 = _mm_and_si128(vbs01234567x23, vmask); - const __m128i vb89ABCDEFx23 = _mm_and_si128(vbs89ABCDEFx23, vmask); + const __m128i vbs0123x23 = _mm_srli_epi32(vbb0123x0123, 4); + const __m128i vb0123x01 = _mm_and_si128(vbb0123x0123, vmask); + const __m128i vb0123x23 = _mm_and_si128(vbs0123x23, vmask); $else: - const __m128i vbs01234567x01 = _mm_slli_epi32(vbb01234567x0123, 4); - const __m128i vbs89ABCDEFx01 = _mm_slli_epi32(vbb89ABCDEFx0123, 4); - const __m128i vb01234567x23 = _mm_and_si128(vbb01234567x0123, vmask); - const __m128i vb89ABCDEFx23 = _mm_and_si128(vbb89ABCDEFx0123, vmask); - const __m128i vb01234567x01 = _mm_and_si128(vbs01234567x01, vmask); - const __m128i vb89ABCDEFx01 = _mm_and_si128(vbs89ABCDEFx01, vmask); + const __m128i vbs0123x01 = _mm_slli_epi32(vbb0123x0123, 4); + const __m128i vb0123x23 = _mm_and_si128(vbb0123x0123, vmask); + const __m128i vb0123x01 = _mm_and_si128(vbs0123x01, vmask); $else: - const __m128i vb01234567x01 = _mm_load_si128(w); - const __m128i vb89ABCDEFx01 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); - const __m128i vb01234567x23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32)); - const __m128i vb89ABCDEFx23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48)); + const __m128i vb0123x01 = _mm_load_si128(w); + const __m128i vb0123x23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); $if PREFETCH: xnn_prefetch_to_l1((const int8_t*) w + 896); $for M in range(MR): - vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x01234567, vb01234567x01); - vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x01234567, vb89ABCDEFx01); - $if PREFETCH: - xnn_prefetch_to_l1((const int8_t*) w + 960); - $for M in range(MR): + vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x01234567, vb0123x01); $if MR < 3: - vacc1x${M}x01 = ${_MM_DPBUSD_EPI32}(vacc1x${M}x01, va${M}x89ABCDEF, vb01234567x23); - vacc1x${M}x23 = ${_MM_DPBUSD_EPI32}(vacc1x${M}x23, va${M}x89ABCDEF, vb89ABCDEFx23); + vacc1x${M}x23 = ${_MM_DPBUSD_EPI32}(vacc1x${M}x23, va${M}x89ABCDEF, vb0123x23); $else: - vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x89ABCDEF, vb01234567x23); - vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x89ABCDEF, vb89ABCDEFx23); + vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x89ABCDEF, vb0123x23); - $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: - w = (const int8_t*) w + 32; + $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4", "QS8_QC2", "QC2_F32", "QC2_F16"]: + w = (const int8_t*) w + 16; $else: - w = (const int8_t*) w + 64; + w = (const int8_t*) w + 32; k -= 16 * sizeof(int8_t); + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + // 3rd plane for 2 bit + if (k != 0) { + $for M in range(MR): + const __m128i va${M}x3 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); + $if DATATYPE == "QS8_QC2": + va${M}x3 = _mm_xor_si128(va${M}x3, vsign_mask); + a${M} += 8; + + // mask 3rd plane of 2 bit. + $if GFNI: + const __m128i vb0123x89AB = _mm_gf2p8affine_epi64_epi8(vbb0123x01234567, vshr4, 0xFE); + $else: + __m128i vbs0123x89AB = _mm_srli_epi32(vbb0123x01234567, 4); + __m128i vb0123x89AB = _mm_and_si128(vbs0123x89AB, vmask); + $if VARIANT != "MADD": + vb0123x89AB = _mm_sub_epi8(vb0123x89AB, vtwo); + $for M in range(MR): + vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x3, vb0123x89AB); + k -= 8 * sizeof(int8_t); + } } if (k != 0) { $for M in range(MR): - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); $else: const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); a${M} += 8; - $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: - const __m128i vbb01234567x0123 = _mm_load_si128(w); - const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: + // 1 plane of 2 bit. + const __m128i vbb0123x01234567 = _mm_load_si128(w); + $if GFNI: + const __m128i vb0123x01 = _mm_gf2p8affine_epi64_epi8(vbb0123x01234567, vshr0, 0xFE); + $else: + __m128i vb0123x01 = _mm_and_si128(vbb0123x01234567, vmask); + $if VARIANT != "MADD": + vb0123x01 = _mm_sub_epi8(vb0123x01, vtwo); + $elif DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]: + const __m128i vbb0123x0123 = _mm_load_si128(w); $if GFNI: - const __m128i vb01234567x01 = _mm_gf2p8affine_epi64_epi8(vbb01234567x0123, vshl4, 0); - const __m128i vb89ABCDEFx01 = _mm_gf2p8affine_epi64_epi8(vbb89ABCDEFx0123, vshl4, 0); + const __m128i vb0123x01 = _mm_gf2p8affine_epi64_epi8(vbb0123x0123, vshl4, 0); $elif VARIANT == "MADD": - const __m128i vb01234567x01 = _mm_and_si128(vbb01234567x0123, vmask); - const __m128i vb89ABCDEFx01 = _mm_and_si128(vbb89ABCDEFx0123, vmask); + const __m128i vb0123x01 = _mm_and_si128(vbb0123x0123, vmask); $else: - const __m128i vb01234567x01 = _mm_slli_epi32(vbb01234567x0123, 4); - const __m128i vb89ABCDEFx01 = _mm_slli_epi32(vbb89ABCDEFx0123, 4); + const __m128i vb0123x01 = _mm_slli_epi32(vbb0123x0123, 4); $else: - const __m128i vb01234567x01 = _mm_load_si128(w); - const __m128i vb89ABCDEFx01 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); + const __m128i vb0123x01 = _mm_load_si128(w); $for M in range(MR): - vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x01234567, vb01234567x01); - vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x01234567, vb89ABCDEFx01); + vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x01234567, vb0123x01); $if PREFETCH: xnn_prefetch_to_l1((const int8_t*) w + 960); - w = (const int8_t*) w + 32; + w = (const int8_t*) w + 16; k -= 8 * sizeof(int8_t); } + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + // Make sure there were no leftovers. + assert(k == 0); $if MR < 3: $for M in range(MR): vacc${M}x01 = _mm_add_epi32(vacc${M}x01, vacc1x${M}x01); @@ -250,7 +344,21 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ $for M in range(MR): __m128 vout${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123); - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"]: + $if DATATYPE in ["QC2_F32", "QC2_F16"]: + const __m128 rh_zero_points_0123 = _mm_load_ps((const float*) kzp); + kzp = (const float*)kzp + 4; + + // Subtract out the scaled left-hand row sums. + $for M in range(MR): + const __m128 lh_row_sum_${M} = _mm_set1_ps(row_sum[${M}]); + vout${M}x0123 = _mm_add_ps(_mm_mul_ps(rh_zero_points_0123, lh_row_sum_${M}), vout${M}x0123); + // Add the product of left/right-hand zero points and `kc`. + // TODO: use kc + $for M in range(MR): + const __m128 vscaled_lh_zero_point_${M} = _mm_set1_ps((float)original_kc * quantization_params[${M}].zero_point); + $for M in range(MR): + vout${M}x0123 = _mm_add_ps(_mm_mul_ps(rh_zero_points_0123, vscaled_lh_zero_point_${M}), vout${M}x0123); $for M in range(MR): vout${M}x0123 = _mm_mul_ps(vout${M}x0123, _mm_set1_ps(quantization_params[${M}].inv_scale)); @@ -267,7 +375,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ $for M in range(MR): vout${M}x0123 = _mm_min_ps(vout${M}x0123, voutput_max); - $if DATATYPE in ["QC4_F16", "QD8_F16"]: + $if DATATYPE in ["QC4_F16", "QC2_F16", "QD8_F16"]: $for M in range(MR): __m128i vfp16out${M}x0123 = _mm_cvtps_ph(vout${M}x0123, _MM_FROUND_TO_NEAREST_INT); if XNN_LIKELY(nc >= 4) { diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 99e14811c0e..20e047c0c72 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -4174,6 +4174,13 @@ DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_u DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_7x8c8__avxvnni) DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_8x8c8__avxvnni) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x4c8__ssse3_madd) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_2x4c8__ssse3_madd) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_3x4c8__ssse3_madd) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_4x4c8__ssse3_madd) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_5x4c8__ssse3_madd) +DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_6x4c8__ssse3_madd) + DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x1__scalar) DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x2__scalar) DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x4__scalar) @@ -5250,6 +5257,13 @@ DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_uker DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_3x4__scalar_fmagic) DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x4__scalar_fmagic) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x4c8__ssse3_madd) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_2x4c8__ssse3_madd) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_3x4c8__ssse3_madd) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x4c8__ssse3_madd) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_5x4c8__ssse3_madd) +DECLARE_QS8_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_6x4c8__ssse3_madd) + #define DECLARE_QS8_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t mr, size_t nc, size_t kc, const int8_t* a, size_t a_stride, \