From e5fc888ac2c0cd6ccfaf57c7f5cb45bf5afc54a8 Mon Sep 17 00:00:00 2001 From: Frank Barchard Date: Wed, 11 Mar 2026 17:58:39 -0700 Subject: [PATCH] [WIP] NEONDOT qd8_bf16_qc8w GEMM microkernels - add BF16 variant of qd8 gemm PiperOrigin-RevId: 882289963 --- bench/qd8-f16-qc2w-gemm.cc | 135 ++++ cmake/gen/neondotfp16arith_microkernels.cmake | 12 + gen/neondotfp16arith_microkernels.bzl | 12 + scripts/generate-qs8-gemm.sh | 16 + src/configs/gemm-config.c | 49 +- ...-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c | 145 ++++ ...qb4w-gemm-1x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c | 2 +- ...qb4w-gemm-2x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c | 2 +- ...qb4w-gemm-3x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c | 2 +- ...qb4w-gemm-4x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-4x8c4-minmax-neondotfp16arith.c | 2 +- ...qb4w-gemm-5x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-5x8c4-minmax-neondotfp16arith.c | 2 +- ...qb4w-gemm-6x16c4-minmax-neondotfp16arith.c | 2 +- ...-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c | 2 +- ...qc2w-gemm-1x16c4-minmax-neondotfp16arith.c | 283 +++++++ ...-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c | 209 +++++ ...qc2w-gemm-2x16c4-minmax-neondotfp16arith.c | 377 +++++++++ ...-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c | 267 +++++++ ...qc2w-gemm-3x16c4-minmax-neondotfp16arith.c | 472 +++++++++++ ...-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c | 326 ++++++++ ...qc2w-gemm-4x16c4-minmax-neondotfp16arith.c | 566 +++++++++++++ ...-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c | 384 +++++++++ ...qc2w-gemm-5x16c4-minmax-neondotfp16arith.c | 661 +++++++++++++++ ...-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c | 443 ++++++++++ ...qc2w-gemm-6x16c4-minmax-neondotfp16arith.c | 755 ++++++++++++++++++ ...-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c | 501 ++++++++++++ ...qc4w-gemm-1x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c | 2 +- ...qc4w-gemm-2x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c | 2 +- ...qc4w-gemm-3x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-3x8c4-minmax-neondotfp16arith.c | 2 +- ...qc4w-gemm-4x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-4x8c4-minmax-neondotfp16arith.c | 2 +- ...qc4w-gemm-5x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-5x8c4-minmax-neondotfp16arith.c | 2 +- ...qc4w-gemm-6x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc4w-gemm-6x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-1x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-2x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-3x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-4x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-4x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-5x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-5x8c4-minmax-neondotfp16arith.c | 2 +- ...qc8w-gemm-6x16c4-minmax-neondotfp16arith.c | 2 +- ...-qc8w-gemm-6x8c4-minmax-neondotfp16arith.c | 2 +- .../qd8-f32-qb4w-gemm-1x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-1x8c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-2x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-2x8c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-3x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-3x8c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-4x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-4x8c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-5x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-5x8c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-6x16c4-minmax-neondot.c | 1 + .../qd8-f32-qb4w-gemm-6x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-1x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-1x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-2x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-2x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-3x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-3x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-4x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-4x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-5x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-5x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-6x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-6x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-7x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-7x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-8x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc2w-gemm-8x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-1x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-2x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-2x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-3x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-4x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-4x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-6x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc4w-gemm-6x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-1x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-1x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-2x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-2x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-3x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-3x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-4x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-4x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-5x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-5x8c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-6x16c4-minmax-neondot.c | 1 + .../qd8-f32-qc8w-gemm-6x8c4-minmax-neondot.c | 1 + src/qs8-gemm/c4-neondot.c.in | 71 +- ...qs8-qc2w-gemm-1x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc2w-gemm-1x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc2w-gemm-4x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc2w-gemm-4x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc2w-gemm-6x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc2w-gemm-6x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc2w-gemm-8x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc2w-gemm-8x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc8w-gemm-1x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc8w-gemm-1x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc8w-gemm-4x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc8w-gemm-4x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc8w-gemm-6x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc8w-gemm-6x8c4-minmax-fp32-neondot.c | 1 + ...qs8-qc8w-gemm-8x16c4-minmax-fp32-neondot.c | 1 + .../qs8-qc8w-gemm-8x8c4-minmax-fp32-neondot.c | 1 + src/xnnpack/gemm.h | 14 + test/qd8-f16-qc2w-gemm-minmax.cc | 231 ++++++ test/qd8-f16-qc2w-gemm-minmax.yaml | 63 ++ 123 files changed, 6050 insertions(+), 79 deletions(-) create mode 100644 src/qd8-bf16-qc8w-gemm/gen/qd8-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c create mode 100644 src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c diff --git a/bench/qd8-f16-qc2w-gemm.cc b/bench/qd8-f16-qc2w-gemm.cc index 356d0e3247d..2fb71c69909 100644 --- a/bench/qd8-f16-qc2w-gemm.cc +++ b/bench/qd8-f16-qc2w-gemm.cc @@ -296,6 +296,141 @@ static void qd8_f16_qc2w_gemm_minmax_ukernel_4x4__scalar(benchmark::State& state BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_4x4__scalar) +#if XNN_ENABLE_ARM_DOTPROD && XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + static void qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith) + + static void qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith(benchmark::State& state) { + GEMMBenchmark(state, + xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*arch_flags=*/xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith); + } + + BENCHMARK_GEMM(qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith) +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + } // namespace #ifndef XNNPACK_BENCHMARK_NO_MAIN diff --git a/cmake/gen/neondotfp16arith_microkernels.cmake b/cmake/gen/neondotfp16arith_microkernels.cmake index 632ca91abee..bfba3f28a1c 100644 --- a/cmake/gen/neondotfp16arith_microkernels.cmake +++ b/cmake/gen/neondotfp16arith_microkernels.cmake @@ -12,6 +12,9 @@ SET(PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c @@ -34,6 +37,15 @@ SET(NON_PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c + src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c diff --git a/gen/neondotfp16arith_microkernels.bzl b/gen/neondotfp16arith_microkernels.bzl index efd2a2306d4..0fae7fd268b 100644 --- a/gen/neondotfp16arith_microkernels.bzl +++ b/gen/neondotfp16arith_microkernels.bzl @@ -8,6 +8,9 @@ PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c", @@ -31,6 +34,15 @@ NON_PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c", + "src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c", "src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c", diff --git a/scripts/generate-qs8-gemm.sh b/scripts/generate-qs8-gemm.sh index aabf1cb3082..0174a49420f 100755 --- a/scripts/generate-qs8-gemm.sh +++ b/scripts/generate-qs8-gemm.sh @@ -610,6 +610,8 @@ tools/xngen src/qs8-gemm/c8-neon-mull.c.in -D MR=1 -D NR=8 -D MLA=1 -D REQUANTI tools/xngen src/qs8-gemm/c8-neon-mull.c.in -D MR=2 -D NR=8 -D MLA=1 -D REQUANTIZATION=FP32 -D DATATYPE=QC8 -D ARMV8=1 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-neonv8-mlal.c & ### C4 micro-kernels +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=1 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QD8_BF16 -o src/qd8-bf16-qc8w-gemm/gen/qd8-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c & + tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=1 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QD8_F16 -o src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c & tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=2 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QD8_F16 -o src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c & tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=3 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QD8_F16 -o src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c & @@ -694,6 +696,20 @@ tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=4 -D NR=16 -D REQUANTIZATION= -D tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=5 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC4_F16 -o src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x16c4-minmax-neondotfp16arith.c & tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=6 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC4_F16 -o src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=1 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=2 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=3 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=4 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=5 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=6 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c & + +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=1 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=2 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=3 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=4 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=5 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c & +tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=6 -D NR=16 -D REQUANTIZATION= -D DATATYPE=QC2_F16 -o src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c & + tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=1 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QB4_F16 -o src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c & tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=2 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QB4_F16 -o src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c & tools/xngen src/qs8-gemm/c4-neondot.c.in -D MR=3 -D NR=8 -D REQUANTIZATION= -D DATATYPE=QB4_F16 -o src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c & diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 1c9fada56b7..52aa2f43331 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -2319,10 +2319,36 @@ static void init_qd8_f16_qc2w_gemm_config(void) { qd8_f16_qc2w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qd8_qc2w_gemm_goi_w; // Ignored - qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = - XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x2__scalar); - qd8_f16_qc2w_gemm_config.mr = 1; - qd8_f16_qc2w_gemm_config.nr = 2; + #if XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_ARM_DOTPROD + if (hardware_config->arch_flags & xnn_arch_arm_neon_dot) { + qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith); + #if XNN_ARCH_ARM64 + qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(6)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith); + qd8_f16_qc2w_gemm_config.mr = 6; + #else + qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(2)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith); + qd8_f16_qc2w_gemm_config.mr = 2; + #endif + qd8_f16_qc2w_gemm_config.nr = 8; + qd8_f16_qc2w_gemm_config.log2_kr = 2; + } else + #endif // XNN_ENABLE_ARM_DOTPROD + { + qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = + XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x2__scalar); + qd8_f16_qc2w_gemm_config.mr = 1; + qd8_f16_qc2w_gemm_config.nr = 2; + } + #else + qd8_f16_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = + XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x2__scalar); + qd8_f16_qc2w_gemm_config.mr = 1; + qd8_f16_qc2w_gemm_config.nr = 2; + #endif assert(qd8_f16_qc2w_gemm_config.mr <= XNN_MAX_MR); assert(qd8_f16_qc2w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); @@ -2354,7 +2380,7 @@ static void init_qdu8_f16_qc2w_gemm_config(void) { qdu8_f16_qc2w_gemm_config.nr = 8; qdu8_f16_qc2w_gemm_config.log2_kr = 3; qdu8_f16_qc2w_gemm_config.planes = 4; - } + } else #endif #if XNN_ENABLE_AVX2 if (hardware_config->arch_flags & xnn_arch_x86_avx2) { @@ -2370,8 +2396,10 @@ static void init_qdu8_f16_qc2w_gemm_config(void) { qdu8_f16_qc2w_gemm_config.nr = 8; qdu8_f16_qc2w_gemm_config.log2_kr = 3; qdu8_f16_qc2w_gemm_config.planes = 4; - } + } else #endif + { + } #endif //XNN_ARCH_X86 || XNN_ARCH_X86_64 assert(qdu8_f16_qc2w_gemm_config.mr <= XNN_MAX_MR); assert(qdu8_f16_qc2w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); @@ -2482,11 +2510,11 @@ static void init_qd8_f32_qc2w_gemm_config(void) { qd8_f32_qc2w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qd8_qc2w_gemm_goi_w; // Ignored #if XNN_ARCH_ARM || XNN_ARCH_ARM64 - #if XNN_ENABLE_ARM_DOTPROD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); (void) hardware_config; // May be unused. + #if XNN_ENABLE_ARM_DOTPROD if (hardware_config->arch_flags & xnn_arch_arm_neon_dot) { qd8_f32_qc2w_gemm_config.arch = xnn_arch_arm_neon_dot; qd8_f32_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = @@ -2496,6 +2524,7 @@ static void init_qd8_f32_qc2w_gemm_config(void) { XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_8x8c4__neondot); qd8_f32_qc2w_gemm_config.mr = 8; #else + // TODO: fix sdot lane in clang qd8_f32_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(2)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qd8_f32_qc2w_gemm_minmax_ukernel_2x8c4__neondot); qd8_f32_qc2w_gemm_config.mr = 2; @@ -2546,7 +2575,7 @@ static void init_qdu8_f32_qc2w_gemm_config(void) { qdu8_f32_qc2w_gemm_config.nr = 8; qdu8_f32_qc2w_gemm_config.log2_kr = 3; qdu8_f32_qc2w_gemm_config.planes = 4; - } + } else #endif #if XNN_ENABLE_AVX2 if (hardware_config->arch_flags & xnn_arch_x86_avx2) { @@ -2562,8 +2591,10 @@ static void init_qdu8_f32_qc2w_gemm_config(void) { qdu8_f32_qc2w_gemm_config.nr = 8; qdu8_f32_qc2w_gemm_config.log2_kr = 3; qdu8_f32_qc2w_gemm_config.planes = 4; - } + } else #endif + { + } #endif //XNN_ARCH_X86 || XNN_ARCH_X86_64 assert(qdu8_f32_qc2w_gemm_config.mr <= XNN_MAX_MR); assert(qdu8_f32_qc2w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); diff --git a/src/qd8-bf16-qc8w-gemm/gen/qd8-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c b/src/qd8-bf16-qc8w-gemm/gen/qd8-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c new file mode 100644 index 00000000000..325e3046393 --- /dev/null +++ b/src/qd8-bf16-qc8w-gemm/gen/qd8-bf16-qc8w-gemm-1x8c4-minmax-neondotbf16.c @@ -0,0 +1,145 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_bf16_qc8w_gemm_minmax_ukernel_1x8c4__neondotbf16( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_bfloat16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_bf16_minmax_params* restrict params, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + + // Loop over groups of 8 columns. + do { + // Initialize accumulators with bias. 8 bias values are loaded from the + // weight matrix, at the start of the group of 8 columns. + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 1x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb0123x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + + // Multiply-accumulate: 1x8 * 8x8 --> 1x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + + k -= 8 * sizeof(int8_t); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 1x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + + // Load a 4x8 block of weights. + const int8x16_t vb0123x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb0123x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + + // Multiply-accumulate: 1x4 * 4x8 --> 1x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + const float32x4_t vinput_scale0 = vld1q_dup_f32(&quantization_params[0].inv_scale); + vout0x0123 = vmulq_f32(vout0x0123, vinput_scale0); + vout0x4567 = vmulq_f32(vout0x4567, vinput_scale0); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_bf16_f32(vout0x0123), vcvt_bf16_f32(vout0x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_bf16(vfp16out0x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_bf16(vfp16out0x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c index 3771e197a7c..fcd05de7733 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -152,7 +153,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c index b1d1c7a1e63..65110398574 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -119,7 +120,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( vout0x4567 = vaddq_f32(vbias4567, vout0x4567); float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x16c4-minmax-neondotfp16arith.c index 3f963e3a29b..59376303025 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -198,7 +199,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c index 90181ff08fc..69eb0a5bb9b 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-2x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -146,7 +147,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x16c4-minmax-neondotfp16arith.c index f968e3dcb84..df6b2169d83 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -246,7 +247,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c index 4481f299106..c0e6d12cf6a 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -175,7 +176,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c index 354bc08c708..2591f54a667 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -292,7 +293,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x8c4-minmax-neondotfp16arith.c index 5c2ea91f093..ead90e617b3 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-4x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -202,7 +203,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c index 668021735f1..41cb091a4ef 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -340,7 +341,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x8c4-minmax-neondotfp16arith.c index a744fec8cf3..ac53b9c4ac6 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-5x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -231,7 +232,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c index 83ccf589abb..a9ece0f66f8 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -386,7 +387,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); float16x8_t vfp16out5x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout5x89AB), vcvt_f16_f32(vout5xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c index e80fa03746e..3074b08af06 100644 --- a/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-6x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -258,7 +259,6 @@ void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..5086115d926 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,283 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 1x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 1x16 * 16x16 --> 1x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 1x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 1x8 * 8x16 --> 1x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 1x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 1x4 * 4x16 --> 1x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + const float32x4_t vinput_scale0 = vld1q_dup_f32(&quantization_params[0].inv_scale); + vout0x0123 = vmulq_f32(vout0x0123, vinput_scale0); + vout0x4567 = vmulq_f32(vout0x4567, vinput_scale0); + vout0x89AB = vmulq_f32(vout0x89AB, vinput_scale0); + vout0xCDEF = vmulq_f32(vout0xCDEF, vinput_scale0); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..526a44fd644 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-1x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,209 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 1x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 1x16 * 16x8 --> 1x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 1x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 1x8 * 8x8 --> 1x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 1x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 1x4 * 4x8 --> 1x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + const float32x4_t vinput_scale0 = vld1q_dup_f32(&quantization_params[0].inv_scale); + vout0x0123 = vmulq_f32(vout0x0123, vinput_scale0); + vout0x4567 = vmulq_f32(vout0x4567, vinput_scale0); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..e04b94c1a36 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,377 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + int32x4_t vacc1x89AB = vmulq_s32(vksum89AB, vinput_zero_point1); + int32x4_t vacc1xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point1); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 2x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 2x16 * 16x16 --> 2x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx0123, vget_low_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx4567, vget_low_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx89AB, vget_high_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABxCDEF, vget_high_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFxCDEF, vget_high_s8(va_1x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 2x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 2x8 * 8x16 --> 2x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x01234567, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb4567x89AB, va1x01234567, 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb4567xCDEF, va1x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 2x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 2x4 * 4x16 --> 2x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x0123, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout1x89AB = vcvtq_f32_s32(vacc1x89AB); + float32x4_t vout1xCDEF = vcvtq_f32_s32(vacc1xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + vout1x89AB = vfmsq_f32(vout1x89AB, biased_kernel_zero_points_89AB, lh_row_sum_1); + vout1xCDEF = vfmsq_f32(vout1xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_1); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout1x89AB = + vmlaq_f32(vout1x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_1); + vout1xCDEF = + vmlaq_f32(vout1xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_1); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + vout0x89AB = vmulq_lane_f32(vout0x89AB, vget_low_f32(vinput_scale01), 1); + vout1x89AB = vmulq_lane_f32(vout1x89AB, vget_high_f32(vinput_scale01), 1); + vout0xCDEF = vmulq_lane_f32(vout0xCDEF, vget_low_f32(vinput_scale01), 1); + vout1xCDEF = vmulq_lane_f32(vout1xCDEF, vget_high_f32(vinput_scale01), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vfmaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vmlaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vfmaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vmlaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out1x89ABCDEF = vmaxq_f16(vfp16out1x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out1x89ABCDEF = vminq_f16(vfp16out1x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c1 + 8, vreinterpretq_u16_f16(vfp16out1x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); c1 += 8; + vfp16out1x01234567 = vfp16out1x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..5bcbe410bcf --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-2x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,267 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 2x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 2x16 * 16x8 --> 2x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 2x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 2x8 * 8x8 --> 2x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 2x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 2x4 * 4x8 --> 2x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..22fb1c1e43d --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,472 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + int32x4_t vacc1x89AB = vmulq_s32(vksum89AB, vinput_zero_point1); + int32x4_t vacc1xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + int32x4_t vacc2x89AB = vmulq_s32(vksum89AB, vinput_zero_point2); + int32x4_t vacc2xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point2); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 3x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 3x16 * 16x16 --> 3x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx0123, vget_low_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx4567, vget_low_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx89AB, vget_high_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABxCDEF, vget_high_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFxCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx0123, vget_low_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx4567, vget_low_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx89AB, vget_high_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABxCDEF, vget_high_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFxCDEF, vget_high_s8(va_2x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 3x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 3x8 * 8x16 --> 3x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x01234567, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x01234567, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb4567x89AB, va1x01234567, 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb4567xCDEF, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb4567x89AB, va2x01234567, 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb4567xCDEF, va2x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 3x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 3x4 * 4x16 --> 3x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x0123, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x0123, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout1x89AB = vcvtq_f32_s32(vacc1x89AB); + float32x4_t vout1xCDEF = vcvtq_f32_s32(vacc1xCDEF); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout2x89AB = vcvtq_f32_s32(vacc2x89AB); + float32x4_t vout2xCDEF = vcvtq_f32_s32(vacc2xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + vout1x89AB = vfmsq_f32(vout1x89AB, biased_kernel_zero_points_89AB, lh_row_sum_1); + vout1xCDEF = vfmsq_f32(vout1xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + vout2x89AB = vfmsq_f32(vout2x89AB, biased_kernel_zero_points_89AB, lh_row_sum_2); + vout2xCDEF = vfmsq_f32(vout2xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_2); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout1x89AB = + vmlaq_f32(vout1x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_1); + vout1xCDEF = + vmlaq_f32(vout1xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout2x89AB = + vmlaq_f32(vout2x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_2); + vout2xCDEF = + vmlaq_f32(vout2xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_2); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + vout0x89AB = vmulq_lane_f32(vout0x89AB, vget_low_f32(vinput_scale01), 1); + vout1x89AB = vmulq_lane_f32(vout1x89AB, vget_high_f32(vinput_scale01), 1); + vout0xCDEF = vmulq_lane_f32(vout0xCDEF, vget_low_f32(vinput_scale01), 1); + vout1xCDEF = vmulq_lane_f32(vout1xCDEF, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale2 = vld1q_dup_f32(&quantization_params[2].inv_scale); + vout2x0123 = vmulq_f32(vout2x0123, vinput_scale2); + vout2x4567 = vmulq_f32(vout2x4567, vinput_scale2); + vout2x89AB = vmulq_f32(vout2x89AB, vinput_scale2); + vout2xCDEF = vmulq_f32(vout2xCDEF, vinput_scale2); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vfmaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vfmaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vmlaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vmlaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vfmaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vfmaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vmlaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vmlaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out1x89ABCDEF = vmaxq_f16(vfp16out1x89ABCDEF, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out2x89ABCDEF = vmaxq_f16(vfp16out2x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out1x89ABCDEF = vminq_f16(vfp16out1x89ABCDEF, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out2x89ABCDEF = vminq_f16(vfp16out2x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c1 + 8, vreinterpretq_u16_f16(vfp16out1x89ABCDEF)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c2 + 8, vreinterpretq_u16_f16(vfp16out2x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); c1 += 8; + vfp16out1x01234567 = vfp16out1x89ABCDEF; + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); c2 += 8; + vfp16out2x01234567 = vfp16out2x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..9124c0a2680 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-3x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,326 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 3x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 3x16 * 16x8 --> 3x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 3x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 3x8 * 8x8 --> 3x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 3x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 3x4 * 4x8 --> 3x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale2 = vld1q_dup_f32(&quantization_params[2].inv_scale); + vout2x0123 = vmulq_f32(vout2x0123, vinput_scale2); + vout2x4567 = vmulq_f32(vout2x4567, vinput_scale2); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..ef1a1bf17bb --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,566 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + int32x4_t vacc1x89AB = vmulq_s32(vksum89AB, vinput_zero_point1); + int32x4_t vacc1xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + int32x4_t vacc2x89AB = vmulq_s32(vksum89AB, vinput_zero_point2); + int32x4_t vacc2xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + int32x4_t vacc3x89AB = vmulq_s32(vksum89AB, vinput_zero_point3); + int32x4_t vacc3xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point3); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 4x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 4x16 * 16x16 --> 4x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx0123, vget_low_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx4567, vget_low_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx89AB, vget_high_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABxCDEF, vget_high_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFxCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx0123, vget_low_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx4567, vget_low_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx89AB, vget_high_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABxCDEF, vget_high_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFxCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx0123, vget_low_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx4567, vget_low_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx89AB, vget_high_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABxCDEF, vget_high_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFxCDEF, vget_high_s8(va_3x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 4x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 4x8 * 8x16 --> 4x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x01234567, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x01234567, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x01234567, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb4567x89AB, va1x01234567, 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb4567xCDEF, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb4567x89AB, va2x01234567, 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb4567xCDEF, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb4567x89AB, va3x01234567, 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb4567xCDEF, va3x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 4x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 4x4 * 4x16 --> 4x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x0123, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x0123, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x0123, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout1x89AB = vcvtq_f32_s32(vacc1x89AB); + float32x4_t vout1xCDEF = vcvtq_f32_s32(vacc1xCDEF); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout2x89AB = vcvtq_f32_s32(vacc2x89AB); + float32x4_t vout2xCDEF = vcvtq_f32_s32(vacc2xCDEF); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + float32x4_t vout3x89AB = vcvtq_f32_s32(vacc3x89AB); + float32x4_t vout3xCDEF = vcvtq_f32_s32(vacc3xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + vout1x89AB = vfmsq_f32(vout1x89AB, biased_kernel_zero_points_89AB, lh_row_sum_1); + vout1xCDEF = vfmsq_f32(vout1xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + vout2x89AB = vfmsq_f32(vout2x89AB, biased_kernel_zero_points_89AB, lh_row_sum_2); + vout2xCDEF = vfmsq_f32(vout2xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + vout3x89AB = vfmsq_f32(vout3x89AB, biased_kernel_zero_points_89AB, lh_row_sum_3); + vout3xCDEF = vfmsq_f32(vout3xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_3); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout1x89AB = + vmlaq_f32(vout1x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_1); + vout1xCDEF = + vmlaq_f32(vout1xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout2x89AB = + vmlaq_f32(vout2x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_2); + vout2xCDEF = + vmlaq_f32(vout2xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + vout3x89AB = + vmlaq_f32(vout3x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_3); + vout3xCDEF = + vmlaq_f32(vout3xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_3); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + vout0x89AB = vmulq_lane_f32(vout0x89AB, vget_low_f32(vinput_scale01), 1); + vout1x89AB = vmulq_lane_f32(vout1x89AB, vget_high_f32(vinput_scale01), 1); + vout0xCDEF = vmulq_lane_f32(vout0xCDEF, vget_low_f32(vinput_scale01), 1); + vout1xCDEF = vmulq_lane_f32(vout1xCDEF, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + vout2x89AB = vmulq_lane_f32(vout2x89AB, vget_low_f32(vinput_scale23), 1); + vout3x89AB = vmulq_lane_f32(vout3x89AB, vget_high_f32(vinput_scale23), 1); + vout2xCDEF = vmulq_lane_f32(vout2xCDEF, vget_low_f32(vinput_scale23), 1); + vout3xCDEF = vmulq_lane_f32(vout3xCDEF, vget_high_f32(vinput_scale23), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vfmaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vfmaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vfmaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vmlaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vmlaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vmlaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vfmaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vfmaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vfmaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vmlaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vmlaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vmlaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out1x89ABCDEF = vmaxq_f16(vfp16out1x89ABCDEF, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out2x89ABCDEF = vmaxq_f16(vfp16out2x89ABCDEF, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + vfp16out3x89ABCDEF = vmaxq_f16(vfp16out3x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out1x89ABCDEF = vminq_f16(vfp16out1x89ABCDEF, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out2x89ABCDEF = vminq_f16(vfp16out2x89ABCDEF, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + vfp16out3x89ABCDEF = vminq_f16(vfp16out3x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c1 + 8, vreinterpretq_u16_f16(vfp16out1x89ABCDEF)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c2 + 8, vreinterpretq_u16_f16(vfp16out2x89ABCDEF)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + vst1q_u16(c3 + 8, vreinterpretq_u16_f16(vfp16out3x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); c1 += 8; + vfp16out1x01234567 = vfp16out1x89ABCDEF; + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); c2 += 8; + vfp16out2x01234567 = vfp16out2x89ABCDEF; + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); c3 += 8; + vfp16out3x01234567 = vfp16out3x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..a322ae261d6 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-4x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,384 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 4x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 4x16 * 16x8 --> 4x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 4x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 4x8 * 8x8 --> 4x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 4x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 4x4 * 4x8 --> 4x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..e66e7bcd4f7 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,661 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const int8_t* a4 = (const int8_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + int32x4_t vacc1x89AB = vmulq_s32(vksum89AB, vinput_zero_point1); + int32x4_t vacc1xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + int32x4_t vacc2x89AB = vmulq_s32(vksum89AB, vinput_zero_point2); + int32x4_t vacc2xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + int32x4_t vacc3x89AB = vmulq_s32(vksum89AB, vinput_zero_point3); + int32x4_t vacc3xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point3); + const int32x4_t vinput_zero_point4 = vld1q_dup_s32(&quantization_params[4].zero_point); + int32x4_t vacc4x0123 = vmulq_s32(vksum0123, vinput_zero_point4); + int32x4_t vacc4x4567 = vmulq_s32(vksum4567, vinput_zero_point4); + int32x4_t vacc4x89AB = vmulq_s32(vksum89AB, vinput_zero_point4); + int32x4_t vacc4xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point4); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 5x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + const int8x16_t va_4x16 = vld1q_s8(a4); a4 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 5x16 * 16x16 --> 5x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx0123, vget_low_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx4567, vget_low_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx89AB, vget_high_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABxCDEF, vget_high_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFxCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx0123, vget_low_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx4567, vget_low_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx89AB, vget_high_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABxCDEF, vget_high_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFxCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx0123, vget_low_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx4567, vget_low_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx89AB, vget_high_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABxCDEF, vget_high_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFxCDEF, vget_high_s8(va_3x16), 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, vget_low_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x0123, vget_low_s8(va_4x16), 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx0123, vget_low_s8(va_4x16), 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx0123, vget_low_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x4567, vget_low_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, vget_low_s8(va_4x16), 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx4567, vget_low_s8(va_4x16), 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx4567, vget_low_s8(va_4x16), 1); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x89AB, vget_high_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x89AB, vget_high_s8(va_4x16), 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx89AB, vget_high_s8(va_4x16), 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx89AB, vget_high_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123xCDEF, vget_high_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567xCDEF, vget_high_s8(va_4x16), 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABxCDEF, vget_high_s8(va_4x16), 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFxCDEF, vget_high_s8(va_4x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 5x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + const int8x8_t va4x01234567 = vld1_s8(a4); a4 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 5x8 * 8x16 --> 5x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x01234567, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x01234567, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x01234567, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x01234567, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x01234567, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x01234567, 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb0123x89AB, va4x01234567, 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb0123xCDEF, va4x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb4567x89AB, va1x01234567, 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb4567xCDEF, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb4567x89AB, va2x01234567, 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb4567xCDEF, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb4567x89AB, va3x01234567, 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb4567xCDEF, va3x01234567, 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb4567x0123, va4x01234567, 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, va4x01234567, 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb4567x89AB, va4x01234567, 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb4567xCDEF, va4x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 5x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + const int8x8_t va4x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a4)); a4 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 5x4 * 4x16 --> 5x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x0123, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x0123, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x0123, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x0123, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x0123, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x0123, 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb0123x89AB, va4x0123, 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb0123xCDEF, va4x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout1x89AB = vcvtq_f32_s32(vacc1x89AB); + float32x4_t vout1xCDEF = vcvtq_f32_s32(vacc1xCDEF); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout2x89AB = vcvtq_f32_s32(vacc2x89AB); + float32x4_t vout2xCDEF = vcvtq_f32_s32(vacc2xCDEF); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + float32x4_t vout3x89AB = vcvtq_f32_s32(vacc3x89AB); + float32x4_t vout3xCDEF = vcvtq_f32_s32(vacc3xCDEF); + float32x4_t vout4x0123 = vcvtq_f32_s32(vacc4x0123); + float32x4_t vout4x4567 = vcvtq_f32_s32(vacc4x4567); + float32x4_t vout4x89AB = vcvtq_f32_s32(vacc4x89AB); + float32x4_t vout4xCDEF = vcvtq_f32_s32(vacc4xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + vout1x89AB = vfmsq_f32(vout1x89AB, biased_kernel_zero_points_89AB, lh_row_sum_1); + vout1xCDEF = vfmsq_f32(vout1xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + vout2x89AB = vfmsq_f32(vout2x89AB, biased_kernel_zero_points_89AB, lh_row_sum_2); + vout2xCDEF = vfmsq_f32(vout2xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + vout3x89AB = vfmsq_f32(vout3x89AB, biased_kernel_zero_points_89AB, lh_row_sum_3); + vout3xCDEF = vfmsq_f32(vout3xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_3); + const float32x4_t lh_row_sum_4 = vld1q_dup_f32(&row_sum[4]); + vout4x0123 = vfmsq_f32(vout4x0123, biased_kernel_zero_points_0123, lh_row_sum_4); + vout4x4567 = vfmsq_f32(vout4x4567, biased_kernel_zero_points_4567, lh_row_sum_4); + vout4x89AB = vfmsq_f32(vout4x89AB, biased_kernel_zero_points_89AB, lh_row_sum_4); + vout4xCDEF = vfmsq_f32(vout4xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_4); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + const float32x4_t vscaled_input_zero_point_4 = + vdupq_n_f32((float)kc * quantization_params[4].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout1x89AB = + vmlaq_f32(vout1x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_1); + vout1xCDEF = + vmlaq_f32(vout1xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout2x89AB = + vmlaq_f32(vout2x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_2); + vout2xCDEF = + vmlaq_f32(vout2xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + vout3x89AB = + vmlaq_f32(vout3x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_3); + vout3xCDEF = + vmlaq_f32(vout3xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_3); + vout4x0123 = + vmlaq_f32(vout4x0123, kernel_zero_points_0123, vscaled_input_zero_point_4); + vout4x4567 = + vmlaq_f32(vout4x4567, kernel_zero_points_4567, vscaled_input_zero_point_4); + vout4x89AB = + vmlaq_f32(vout4x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_4); + vout4xCDEF = + vmlaq_f32(vout4xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_4); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + vout0x89AB = vmulq_lane_f32(vout0x89AB, vget_low_f32(vinput_scale01), 1); + vout1x89AB = vmulq_lane_f32(vout1x89AB, vget_high_f32(vinput_scale01), 1); + vout0xCDEF = vmulq_lane_f32(vout0xCDEF, vget_low_f32(vinput_scale01), 1); + vout1xCDEF = vmulq_lane_f32(vout1xCDEF, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + vout2x89AB = vmulq_lane_f32(vout2x89AB, vget_low_f32(vinput_scale23), 1); + vout3x89AB = vmulq_lane_f32(vout3x89AB, vget_high_f32(vinput_scale23), 1); + vout2xCDEF = vmulq_lane_f32(vout2xCDEF, vget_low_f32(vinput_scale23), 1); + vout3xCDEF = vmulq_lane_f32(vout3xCDEF, vget_high_f32(vinput_scale23), 1); + const float32x4_t vinput_scale4 = vld1q_dup_f32(&quantization_params[4].inv_scale); + vout4x0123 = vmulq_f32(vout4x0123, vinput_scale4); + vout4x4567 = vmulq_f32(vout4x4567, vinput_scale4); + vout4x89AB = vmulq_f32(vout4x89AB, vinput_scale4); + vout4xCDEF = vmulq_f32(vout4xCDEF, vinput_scale4); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vfmaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vmlaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vfmaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vmlaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vfmaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vfmaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vfmaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + vout4x89AB = vfmaq_f32(vbias89AB, vout4x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vmlaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vmlaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vmlaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + vout4x89AB = vmlaq_f32(vbias89AB, vout4x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vfmaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vfmaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vfmaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + vout4xCDEF = vfmaq_f32(vbiasCDEF, vout4xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vmlaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vmlaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vmlaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + vout4xCDEF = vmlaq_f32(vbiasCDEF, vout4xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); + float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); + float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out1x89ABCDEF = vmaxq_f16(vfp16out1x89ABCDEF, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out2x89ABCDEF = vmaxq_f16(vfp16out2x89ABCDEF, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + vfp16out3x89ABCDEF = vmaxq_f16(vfp16out3x89ABCDEF, voutput_min); + vfp16out4x01234567 = vmaxq_f16(vfp16out4x01234567, voutput_min); + vfp16out4x89ABCDEF = vmaxq_f16(vfp16out4x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out1x89ABCDEF = vminq_f16(vfp16out1x89ABCDEF, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out2x89ABCDEF = vminq_f16(vfp16out2x89ABCDEF, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + vfp16out3x89ABCDEF = vminq_f16(vfp16out3x89ABCDEF, voutput_max); + vfp16out4x01234567 = vminq_f16(vfp16out4x01234567, voutput_max); + vfp16out4x89ABCDEF = vminq_f16(vfp16out4x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c1 + 8, vreinterpretq_u16_f16(vfp16out1x89ABCDEF)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c2 + 8, vreinterpretq_u16_f16(vfp16out2x89ABCDEF)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + vst1q_u16(c3 + 8, vreinterpretq_u16_f16(vfp16out3x89ABCDEF)); + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); + vst1q_u16(c4 + 8, vreinterpretq_u16_f16(vfp16out4x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); c1 += 8; + vfp16out1x01234567 = vfp16out1x89ABCDEF; + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); c2 += 8; + vfp16out2x01234567 = vfp16out2x89ABCDEF; + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); c3 += 8; + vfp16out3x01234567 = vfp16out3x89ABCDEF; + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); c4 += 8; + vfp16out4x01234567 = vfp16out4x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + float16x4_t vfp16out4x0123 = vget_low_f16(vfp16out4x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vst1_u16(c4, vreinterpret_u16_f16(vfp16out4x0123)); c4 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + vfp16out4x0123 = vget_high_f16(vfp16out4x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_f16(vfp16out4x0123), 0); c4 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + vfp16out4x0123 = vext_f16(vfp16out4x0123, vfp16out4x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + vst1_lane_u16(c4, vreinterpret_u16_f16(vfp16out4x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..64da42e814b --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-5x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,443 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const int8_t* a4 = (const int8_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + const int32x4_t vinput_zero_point4 = vld1q_dup_s32(&quantization_params[4].zero_point); + int32x4_t vacc4x0123 = vmulq_s32(vksum0123, vinput_zero_point4); + int32x4_t vacc4x4567 = vmulq_s32(vksum4567, vinput_zero_point4); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 5x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + const int8x16_t va_4x16 = vld1q_s8(a4); a4 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 5x16 * 16x8 --> 5x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, vget_low_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x0123, vget_low_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x4567, vget_low_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, vget_low_s8(va_4x16), 1); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x89AB, vget_high_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x89AB, vget_high_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123xCDEF, vget_high_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567xCDEF, vget_high_s8(va_4x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 5x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + const int8x8_t va4x01234567 = vld1_s8(a4); a4 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 5x8 * 8x8 --> 5x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x01234567, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb4567x0123, va4x01234567, 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, va4x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 5x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + const int8x8_t va4x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a4)); a4 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 5x4 * 4x8 --> 5x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x0123, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + float32x4_t vout4x0123 = vcvtq_f32_s32(vacc4x0123); + float32x4_t vout4x4567 = vcvtq_f32_s32(vacc4x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + const float32x4_t lh_row_sum_4 = vld1q_dup_f32(&row_sum[4]); + vout4x0123 = vfmsq_f32(vout4x0123, biased_kernel_zero_points_0123, lh_row_sum_4); + vout4x4567 = vfmsq_f32(vout4x4567, biased_kernel_zero_points_4567, lh_row_sum_4); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + const float32x4_t vscaled_input_zero_point_4 = + vdupq_n_f32((float)kc * quantization_params[4].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + vout4x0123 = + vmlaq_f32(vout4x0123, kernel_zero_points_0123, vscaled_input_zero_point_4); + vout4x4567 = + vmlaq_f32(vout4x4567, kernel_zero_points_4567, vscaled_input_zero_point_4); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + const float32x4_t vinput_scale4 = vld1q_dup_f32(&quantization_params[4].inv_scale); + vout4x0123 = vmulq_f32(vout4x0123, vinput_scale4); + vout4x4567 = vmulq_f32(vout4x4567, vinput_scale4); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vfmaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vmlaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vfmaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vmlaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + vfp16out4x01234567 = vmaxq_f16(vfp16out4x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + vfp16out4x01234567 = vminq_f16(vfp16out4x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + float16x4_t vfp16out4x0123 = vget_low_f16(vfp16out4x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vst1_u16(c4, vreinterpret_u16_f16(vfp16out4x0123)); c4 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + vfp16out4x0123 = vget_high_f16(vfp16out4x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_f16(vfp16out4x0123), 0); c4 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + vfp16out4x0123 = vext_f16(vfp16out4x0123, vfp16out4x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + vst1_lane_u16(c4, vreinterpret_u16_f16(vfp16out4x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..d6b664e54f7 --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x16c4-minmax-neondotfp16arith.c @@ -0,0 +1,755 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const int8_t* a4 = (const int8_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const int8_t* a5 = (const int8_t*) ((uintptr_t) a4 + a_stride); + uint16_t* c5 = (uint16_t*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + a5 = a4; + c5 = c4; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 16 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum89AB = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksumCDEF = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + int32x4_t vacc0x89AB = vmulq_s32(vksum89AB, vinput_zero_point0); + int32x4_t vacc0xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + int32x4_t vacc1x89AB = vmulq_s32(vksum89AB, vinput_zero_point1); + int32x4_t vacc1xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + int32x4_t vacc2x89AB = vmulq_s32(vksum89AB, vinput_zero_point2); + int32x4_t vacc2xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + int32x4_t vacc3x89AB = vmulq_s32(vksum89AB, vinput_zero_point3); + int32x4_t vacc3xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point3); + const int32x4_t vinput_zero_point4 = vld1q_dup_s32(&quantization_params[4].zero_point); + int32x4_t vacc4x0123 = vmulq_s32(vksum0123, vinput_zero_point4); + int32x4_t vacc4x4567 = vmulq_s32(vksum4567, vinput_zero_point4); + int32x4_t vacc4x89AB = vmulq_s32(vksum89AB, vinput_zero_point4); + int32x4_t vacc4xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point4); + const int32x4_t vinput_zero_point5 = vld1q_dup_s32(&quantization_params[5].zero_point); + int32x4_t vacc5x0123 = vmulq_s32(vksum0123, vinput_zero_point5); + int32x4_t vacc5x4567 = vmulq_s32(vksum4567, vinput_zero_point5); + int32x4_t vacc5x89AB = vmulq_s32(vksum89AB, vinput_zero_point5); + int32x4_t vacc5xCDEF = vmulq_s32(vksumCDEF, vinput_zero_point5); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 16; + + // Inner accumulation loop along the 16 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 6x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + const int8x16_t va_4x16 = vld1q_s8(a4); a4 += 16; + const int8x16_t va_5x16 = vld1q_s8(a5); a5 += 16; + + // Load a 16x16 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb89ABx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vbCDEFx16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + const int8x16_t vb89ABx0123 = vandq_s8(vb89ABx16, vmask); + const int8x16_t vbCDEFx0123 = vandq_s8(vbCDEFx16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + const int8x16_t vb89ABx4567 = vandq_s8(vshrq_n_s8(vb89ABx16, 2), vmask); + const int8x16_t vbCDEFx4567 = vandq_s8(vshrq_n_s8(vbCDEFx16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + const int8x16_t vb89ABx89AB = vandq_s8(vshrq_n_s8(vb89ABx16, 4), vmask); + const int8x16_t vbCDEFx89AB = vandq_s8(vshrq_n_s8(vbCDEFx16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + const int8x16_t vb89ABxCDEF = vandq_s8(vshrq_n_s8(vb89ABx16, 6), vmask); + const int8x16_t vbCDEFxCDEF = vandq_s8(vshrq_n_s8(vbCDEFx16, 6), vmask); + + // Multiply-accumulate: 6x16 * 16x16 --> 6x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx0123, vget_low_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx4567, vget_low_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABx89AB, vget_high_s8(va_0x16), 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFx89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb89ABxCDEF, vget_high_s8(va_0x16), 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vbCDEFxCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx0123, vget_low_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx4567, vget_low_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABx89AB, vget_high_s8(va_1x16), 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFx89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb89ABxCDEF, vget_high_s8(va_1x16), 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vbCDEFxCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx0123, vget_low_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx4567, vget_low_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABx89AB, vget_high_s8(va_2x16), 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFx89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb89ABxCDEF, vget_high_s8(va_2x16), 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vbCDEFxCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx0123, vget_low_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx4567, vget_low_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABx89AB, vget_high_s8(va_3x16), 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFx89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb89ABxCDEF, vget_high_s8(va_3x16), 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vbCDEFxCDEF, vget_high_s8(va_3x16), 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, vget_low_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x0123, vget_low_s8(va_4x16), 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx0123, vget_low_s8(va_4x16), 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx0123, vget_low_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x4567, vget_low_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, vget_low_s8(va_4x16), 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx4567, vget_low_s8(va_4x16), 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx4567, vget_low_s8(va_4x16), 1); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x89AB, vget_high_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x89AB, vget_high_s8(va_4x16), 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABx89AB, vget_high_s8(va_4x16), 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFx89AB, vget_high_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123xCDEF, vget_high_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567xCDEF, vget_high_s8(va_4x16), 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb89ABxCDEF, vget_high_s8(va_4x16), 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vbCDEFxCDEF, vget_high_s8(va_4x16), 1); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, vget_low_s8(va_5x16), 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x0123, vget_low_s8(va_5x16), 0); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb89ABx0123, vget_low_s8(va_5x16), 0); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vbCDEFx0123, vget_low_s8(va_5x16), 0); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x4567, vget_low_s8(va_5x16), 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, vget_low_s8(va_5x16), 1); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb89ABx4567, vget_low_s8(va_5x16), 1); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vbCDEFx4567, vget_low_s8(va_5x16), 1); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x89AB, vget_high_s8(va_5x16), 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x89AB, vget_high_s8(va_5x16), 0); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb89ABx89AB, vget_high_s8(va_5x16), 0); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vbCDEFx89AB, vget_high_s8(va_5x16), 0); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123xCDEF, vget_high_s8(va_5x16), 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567xCDEF, vget_high_s8(va_5x16), 1); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb89ABxCDEF, vget_high_s8(va_5x16), 1); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vbCDEFxCDEF, vget_high_s8(va_5x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x89AB = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567xCDEF = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 6x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + const int8x8_t va4x01234567 = vld1_s8(a4); a4 += 8; + const int8x8_t va5x01234567 = vld1_s8(a5); a5 += 8; + + // Load a 8x16 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb01234567x89AB, 2), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb01234567xCDEF, 2), vmask); + + // Multiply-accumulate: 6x8 * 8x16 --> 6x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x01234567, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x01234567, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x01234567, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x01234567, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x01234567, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x01234567, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x01234567, 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb0123x89AB, va4x01234567, 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb0123xCDEF, va4x01234567, 0); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, va5x01234567, 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb0123x4567, va5x01234567, 0); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb0123x89AB, va5x01234567, 0); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vb0123xCDEF, va5x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb4567x89AB, va0x01234567, 1); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb4567xCDEF, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb4567x89AB, va1x01234567, 1); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb4567xCDEF, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb4567x89AB, va2x01234567, 1); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb4567xCDEF, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb4567x89AB, va3x01234567, 1); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb4567xCDEF, va3x01234567, 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb4567x0123, va4x01234567, 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, va4x01234567, 1); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb4567x89AB, va4x01234567, 1); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb4567xCDEF, va4x01234567, 1); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb4567x0123, va5x01234567, 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, va5x01234567, 1); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb4567x89AB, va5x01234567, 1); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vb4567xCDEF, va5x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + vb01234567x89AB = vshrq_n_s8(vb01234567x89AB, 4); + vb01234567xCDEF = vshrq_n_s8(vb01234567xCDEF, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 6x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + const int8x8_t va4x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a4)); a4 += 4; + const int8x8_t va5x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a5)); a5 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb0123x89AB = vandq_s8(vb01234567x89AB, vmask); + const int8x16_t vb0123xCDEF = vandq_s8(vb01234567xCDEF, vmask); + + // Multiply-accumulate: 6x4 * 4x16 --> 6x16. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc0x89AB = vdotq_lane_s32(vacc0x89AB, vb0123x89AB, va0x0123, 0); + vacc0xCDEF = vdotq_lane_s32(vacc0xCDEF, vb0123xCDEF, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc1x89AB = vdotq_lane_s32(vacc1x89AB, vb0123x89AB, va1x0123, 0); + vacc1xCDEF = vdotq_lane_s32(vacc1xCDEF, vb0123xCDEF, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc2x89AB = vdotq_lane_s32(vacc2x89AB, vb0123x89AB, va2x0123, 0); + vacc2xCDEF = vdotq_lane_s32(vacc2xCDEF, vb0123xCDEF, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + vacc3x89AB = vdotq_lane_s32(vacc3x89AB, vb0123x89AB, va3x0123, 0); + vacc3xCDEF = vdotq_lane_s32(vacc3xCDEF, vb0123xCDEF, va3x0123, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x0123, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x0123, 0); + vacc4x89AB = vdotq_lane_s32(vacc4x89AB, vb0123x89AB, va4x0123, 0); + vacc4xCDEF = vdotq_lane_s32(vacc4xCDEF, vb0123xCDEF, va4x0123, 0); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, va5x0123, 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb0123x4567, va5x0123, 0); + vacc5x89AB = vdotq_lane_s32(vacc5x89AB, vb0123x89AB, va5x0123, 0); + vacc5xCDEF = vdotq_lane_s32(vacc5xCDEF, vb0123xCDEF, va5x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout0x89AB = vcvtq_f32_s32(vacc0x89AB); + float32x4_t vout0xCDEF = vcvtq_f32_s32(vacc0xCDEF); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout1x89AB = vcvtq_f32_s32(vacc1x89AB); + float32x4_t vout1xCDEF = vcvtq_f32_s32(vacc1xCDEF); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout2x89AB = vcvtq_f32_s32(vacc2x89AB); + float32x4_t vout2xCDEF = vcvtq_f32_s32(vacc2xCDEF); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + float32x4_t vout3x89AB = vcvtq_f32_s32(vacc3x89AB); + float32x4_t vout3xCDEF = vcvtq_f32_s32(vacc3xCDEF); + float32x4_t vout4x0123 = vcvtq_f32_s32(vacc4x0123); + float32x4_t vout4x4567 = vcvtq_f32_s32(vacc4x4567); + float32x4_t vout4x89AB = vcvtq_f32_s32(vacc4x89AB); + float32x4_t vout4xCDEF = vcvtq_f32_s32(vacc4xCDEF); + float32x4_t vout5x0123 = vcvtq_f32_s32(vacc5x0123); + float32x4_t vout5x4567 = vcvtq_f32_s32(vacc5x4567); + float32x4_t vout5x89AB = vcvtq_f32_s32(vacc5x89AB); + float32x4_t vout5xCDEF = vcvtq_f32_s32(vacc5xCDEF); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + const float32x4_t kernel_zero_points_89AB = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_89AB = vaddq_f32(kernel_zero_points_89AB, vtwo); + const float32x4_t kernel_zero_points_CDEF = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_CDEF = vaddq_f32(kernel_zero_points_CDEF, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + vout0x89AB = vfmsq_f32(vout0x89AB, biased_kernel_zero_points_89AB, lh_row_sum_0); + vout0xCDEF = vfmsq_f32(vout0xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + vout1x89AB = vfmsq_f32(vout1x89AB, biased_kernel_zero_points_89AB, lh_row_sum_1); + vout1xCDEF = vfmsq_f32(vout1xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + vout2x89AB = vfmsq_f32(vout2x89AB, biased_kernel_zero_points_89AB, lh_row_sum_2); + vout2xCDEF = vfmsq_f32(vout2xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + vout3x89AB = vfmsq_f32(vout3x89AB, biased_kernel_zero_points_89AB, lh_row_sum_3); + vout3xCDEF = vfmsq_f32(vout3xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_3); + const float32x4_t lh_row_sum_4 = vld1q_dup_f32(&row_sum[4]); + vout4x0123 = vfmsq_f32(vout4x0123, biased_kernel_zero_points_0123, lh_row_sum_4); + vout4x4567 = vfmsq_f32(vout4x4567, biased_kernel_zero_points_4567, lh_row_sum_4); + vout4x89AB = vfmsq_f32(vout4x89AB, biased_kernel_zero_points_89AB, lh_row_sum_4); + vout4xCDEF = vfmsq_f32(vout4xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_4); + const float32x4_t lh_row_sum_5 = vld1q_dup_f32(&row_sum[5]); + vout5x0123 = vfmsq_f32(vout5x0123, biased_kernel_zero_points_0123, lh_row_sum_5); + vout5x4567 = vfmsq_f32(vout5x4567, biased_kernel_zero_points_4567, lh_row_sum_5); + vout5x89AB = vfmsq_f32(vout5x89AB, biased_kernel_zero_points_89AB, lh_row_sum_5); + vout5xCDEF = vfmsq_f32(vout5xCDEF, biased_kernel_zero_points_CDEF, lh_row_sum_5); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + const float32x4_t vscaled_input_zero_point_4 = + vdupq_n_f32((float)kc * quantization_params[4].zero_point); + const float32x4_t vscaled_input_zero_point_5 = + vdupq_n_f32((float)kc * quantization_params[5].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout0x89AB = + vmlaq_f32(vout0x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_0); + vout0xCDEF = + vmlaq_f32(vout0xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout1x89AB = + vmlaq_f32(vout1x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_1); + vout1xCDEF = + vmlaq_f32(vout1xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout2x89AB = + vmlaq_f32(vout2x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_2); + vout2xCDEF = + vmlaq_f32(vout2xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + vout3x89AB = + vmlaq_f32(vout3x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_3); + vout3xCDEF = + vmlaq_f32(vout3xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_3); + vout4x0123 = + vmlaq_f32(vout4x0123, kernel_zero_points_0123, vscaled_input_zero_point_4); + vout4x4567 = + vmlaq_f32(vout4x4567, kernel_zero_points_4567, vscaled_input_zero_point_4); + vout4x89AB = + vmlaq_f32(vout4x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_4); + vout4xCDEF = + vmlaq_f32(vout4xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_4); + vout5x0123 = + vmlaq_f32(vout5x0123, kernel_zero_points_0123, vscaled_input_zero_point_5); + vout5x4567 = + vmlaq_f32(vout5x4567, kernel_zero_points_4567, vscaled_input_zero_point_5); + vout5x89AB = + vmlaq_f32(vout5x89AB, kernel_zero_points_89AB, vscaled_input_zero_point_5); + vout5xCDEF = + vmlaq_f32(vout5xCDEF, kernel_zero_points_CDEF, vscaled_input_zero_point_5); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + vout0x89AB = vmulq_lane_f32(vout0x89AB, vget_low_f32(vinput_scale01), 1); + vout1x89AB = vmulq_lane_f32(vout1x89AB, vget_high_f32(vinput_scale01), 1); + vout0xCDEF = vmulq_lane_f32(vout0xCDEF, vget_low_f32(vinput_scale01), 1); + vout1xCDEF = vmulq_lane_f32(vout1xCDEF, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + vout2x89AB = vmulq_lane_f32(vout2x89AB, vget_low_f32(vinput_scale23), 1); + vout3x89AB = vmulq_lane_f32(vout3x89AB, vget_high_f32(vinput_scale23), 1); + vout2xCDEF = vmulq_lane_f32(vout2xCDEF, vget_low_f32(vinput_scale23), 1); + vout3xCDEF = vmulq_lane_f32(vout3xCDEF, vget_high_f32(vinput_scale23), 1); + const float32x4_t vinput_scale45 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[4].zero_point)); + vout4x0123 = vmulq_lane_f32(vout4x0123, vget_low_f32(vinput_scale45), 1); + vout5x0123 = vmulq_lane_f32(vout5x0123, vget_high_f32(vinput_scale45), 1); + vout4x4567 = vmulq_lane_f32(vout4x4567, vget_low_f32(vinput_scale45), 1); + vout5x4567 = vmulq_lane_f32(vout5x4567, vget_high_f32(vinput_scale45), 1); + vout4x89AB = vmulq_lane_f32(vout4x89AB, vget_low_f32(vinput_scale45), 1); + vout5x89AB = vmulq_lane_f32(vout5x89AB, vget_high_f32(vinput_scale45), 1); + vout4xCDEF = vmulq_lane_f32(vout4xCDEF, vget_low_f32(vinput_scale45), 1); + vout5xCDEF = vmulq_lane_f32(vout5xCDEF, vget_high_f32(vinput_scale45), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale89AB = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scaleCDEF = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vfmaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + vout5x0123 = vfmaq_f32(vbias0123, vout5x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vmlaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + vout5x0123 = vmlaq_f32(vbias0123, vout5x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vfmaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + vout5x4567 = vfmaq_f32(vbias4567, vout5x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vmlaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + vout5x4567 = vmlaq_f32(vbias4567, vout5x4567, vfilter_output_scale4567); + #endif + const float32x4_t vbias89AB = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x89AB = vfmaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vfmaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vfmaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vfmaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + vout4x89AB = vfmaq_f32(vbias89AB, vout4x89AB, vfilter_output_scale89AB); + vout5x89AB = vfmaq_f32(vbias89AB, vout5x89AB, vfilter_output_scale89AB); + #else + vout0x89AB = vmlaq_f32(vbias89AB, vout0x89AB, vfilter_output_scale89AB); + vout1x89AB = vmlaq_f32(vbias89AB, vout1x89AB, vfilter_output_scale89AB); + vout2x89AB = vmlaq_f32(vbias89AB, vout2x89AB, vfilter_output_scale89AB); + vout3x89AB = vmlaq_f32(vbias89AB, vout3x89AB, vfilter_output_scale89AB); + vout4x89AB = vmlaq_f32(vbias89AB, vout4x89AB, vfilter_output_scale89AB); + vout5x89AB = vmlaq_f32(vbias89AB, vout5x89AB, vfilter_output_scale89AB); + #endif + const float32x4_t vbiasCDEF = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0xCDEF = vfmaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vfmaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vfmaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vfmaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + vout4xCDEF = vfmaq_f32(vbiasCDEF, vout4xCDEF, vfilter_output_scaleCDEF); + vout5xCDEF = vfmaq_f32(vbiasCDEF, vout5xCDEF, vfilter_output_scaleCDEF); + #else + vout0xCDEF = vmlaq_f32(vbiasCDEF, vout0xCDEF, vfilter_output_scaleCDEF); + vout1xCDEF = vmlaq_f32(vbiasCDEF, vout1xCDEF, vfilter_output_scaleCDEF); + vout2xCDEF = vmlaq_f32(vbiasCDEF, vout2xCDEF, vfilter_output_scaleCDEF); + vout3xCDEF = vmlaq_f32(vbiasCDEF, vout3xCDEF, vfilter_output_scaleCDEF); + vout4xCDEF = vmlaq_f32(vbiasCDEF, vout4xCDEF, vfilter_output_scaleCDEF); + vout5xCDEF = vmlaq_f32(vbiasCDEF, vout5xCDEF, vfilter_output_scaleCDEF); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); + float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); + float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); + float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); + float16x8_t vfp16out5x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout5x89AB), vcvt_f16_f32(vout5xCDEF)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out1x89ABCDEF = vmaxq_f16(vfp16out1x89ABCDEF, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out2x89ABCDEF = vmaxq_f16(vfp16out2x89ABCDEF, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + vfp16out3x89ABCDEF = vmaxq_f16(vfp16out3x89ABCDEF, voutput_min); + vfp16out4x01234567 = vmaxq_f16(vfp16out4x01234567, voutput_min); + vfp16out4x89ABCDEF = vmaxq_f16(vfp16out4x89ABCDEF, voutput_min); + vfp16out5x01234567 = vmaxq_f16(vfp16out5x01234567, voutput_min); + vfp16out5x89ABCDEF = vmaxq_f16(vfp16out5x89ABCDEF, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out0x89ABCDEF = vminq_f16(vfp16out0x89ABCDEF, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out1x89ABCDEF = vminq_f16(vfp16out1x89ABCDEF, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out2x89ABCDEF = vminq_f16(vfp16out2x89ABCDEF, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + vfp16out3x89ABCDEF = vminq_f16(vfp16out3x89ABCDEF, voutput_max); + vfp16out4x01234567 = vminq_f16(vfp16out4x01234567, voutput_max); + vfp16out4x89ABCDEF = vminq_f16(vfp16out4x89ABCDEF, voutput_max); + vfp16out5x01234567 = vminq_f16(vfp16out5x01234567, voutput_max); + vfp16out5x89ABCDEF = vminq_f16(vfp16out5x89ABCDEF, voutput_max); + if XNN_LIKELY(nc >= 16) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c0 + 8, vreinterpretq_u16_f16(vfp16out0x89ABCDEF)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c1 + 8, vreinterpretq_u16_f16(vfp16out1x89ABCDEF)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c2 + 8, vreinterpretq_u16_f16(vfp16out2x89ABCDEF)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + vst1q_u16(c3 + 8, vreinterpretq_u16_f16(vfp16out3x89ABCDEF)); + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); + vst1q_u16(c4 + 8, vreinterpretq_u16_f16(vfp16out4x89ABCDEF)); + vst1q_u16(c5, vreinterpretq_u16_f16(vfp16out5x01234567)); + vst1q_u16(c5 + 8, vreinterpretq_u16_f16(vfp16out5x89ABCDEF)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride); + + nc -= 16; + } else { + if (nc & 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); c0 += 8; + vfp16out0x01234567 = vfp16out0x89ABCDEF; + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); c1 += 8; + vfp16out1x01234567 = vfp16out1x89ABCDEF; + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); c2 += 8; + vfp16out2x01234567 = vfp16out2x89ABCDEF; + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); c3 += 8; + vfp16out3x01234567 = vfp16out3x89ABCDEF; + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); c4 += 8; + vfp16out4x01234567 = vfp16out4x89ABCDEF; + vst1q_u16(c5, vreinterpretq_u16_f16(vfp16out5x01234567)); c5 += 8; + vfp16out5x01234567 = vfp16out5x89ABCDEF; + } + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + float16x4_t vfp16out4x0123 = vget_low_f16(vfp16out4x01234567); + float16x4_t vfp16out5x0123 = vget_low_f16(vfp16out5x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vst1_u16(c4, vreinterpret_u16_f16(vfp16out4x0123)); c4 += 4; + vst1_u16(c5, vreinterpret_u16_f16(vfp16out5x0123)); c5 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + vfp16out4x0123 = vget_high_f16(vfp16out4x01234567); + vfp16out5x0123 = vget_high_f16(vfp16out5x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_f16(vfp16out4x0123), 0); c4 += 2; + vst1_lane_u32((void*) c5, vreinterpret_u32_f16(vfp16out5x0123), 0); c5 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + vfp16out4x0123 = vext_f16(vfp16out4x0123, vfp16out4x0123, 2); + vfp16out5x0123 = vext_f16(vfp16out5x0123, vfp16out5x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + vst1_lane_u16(c4, vreinterpret_u16_f16(vfp16out4x0123), 0); + vst1_lane_u16(c5, vreinterpret_u16_f16(vfp16out5x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c new file mode 100644 index 00000000000..91d94f2a2fc --- /dev/null +++ b/src/qd8-f16-qc2w-gemm/gen/qd8-f16-qc2w-gemm-6x8c4-minmax-neondotfp16arith.c @@ -0,0 +1,501 @@ +// clang-format off +// Auto-generated file. Do not edit! +// Template: src/qs8-gemm/c4-neondot.c.in +// Generator: tools/xngen +// +// Copyright 2020 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "src/xnnpack/common.h" +#include "src/xnnpack/gemm.h" +#include "src/xnnpack/math.h" +#include "src/xnnpack/microparams.h" + + + +void xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + xnn_float16* restrict c, + size_t cm_stride, + size_t cn_stride, + const struct xnn_f16_minmax_params* restrict params, + const float* row_sum, + const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + uint16_t* c0 = (uint16_t*) c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const int8_t* a4 = (const int8_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const int8_t* a5 = (const int8_t*) ((uintptr_t) a4 + a_stride); + uint16_t* c5 = (uint16_t*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + a5 = a4; + c5 = c4; + } + + const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); + // Loop over groups of 8 columns. + do { + // Initialize the bias with the scaled left-hand weight sums. + const int32x4_t vksum0123 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vksum4567 = vld1q_s32(w); w = (const int32_t*) w + 4; + const int32x4_t vinput_zero_point0 = vld1q_dup_s32(&quantization_params[0].zero_point); + int32x4_t vacc0x0123 = vmulq_s32(vksum0123, vinput_zero_point0); + int32x4_t vacc0x4567 = vmulq_s32(vksum4567, vinput_zero_point0); + const int32x4_t vinput_zero_point1 = vld1q_dup_s32(&quantization_params[1].zero_point); + int32x4_t vacc1x0123 = vmulq_s32(vksum0123, vinput_zero_point1); + int32x4_t vacc1x4567 = vmulq_s32(vksum4567, vinput_zero_point1); + const int32x4_t vinput_zero_point2 = vld1q_dup_s32(&quantization_params[2].zero_point); + int32x4_t vacc2x0123 = vmulq_s32(vksum0123, vinput_zero_point2); + int32x4_t vacc2x4567 = vmulq_s32(vksum4567, vinput_zero_point2); + const int32x4_t vinput_zero_point3 = vld1q_dup_s32(&quantization_params[3].zero_point); + int32x4_t vacc3x0123 = vmulq_s32(vksum0123, vinput_zero_point3); + int32x4_t vacc3x4567 = vmulq_s32(vksum4567, vinput_zero_point3); + const int32x4_t vinput_zero_point4 = vld1q_dup_s32(&quantization_params[4].zero_point); + int32x4_t vacc4x0123 = vmulq_s32(vksum0123, vinput_zero_point4); + int32x4_t vacc4x4567 = vmulq_s32(vksum4567, vinput_zero_point4); + const int32x4_t vinput_zero_point5 = vld1q_dup_s32(&quantization_params[5].zero_point); + int32x4_t vacc5x0123 = vmulq_s32(vksum0123, vinput_zero_point5); + int32x4_t vacc5x4567 = vmulq_s32(vksum4567, vinput_zero_point5); + // TODO: move kernel zero point after weights + const void* kzp = w; + w = (const float*)w + 8; + + // Inner accumulation loop along the 8 columns. + size_t k = kc; + // 4x partial unrolled loop to load 16 bytes at a time. + while (k >= 16 * sizeof(int8_t)) { + // Load a 6x16 block of activations. + const int8x16_t va_0x16 = vld1q_s8(a0); a0 += 16; + const int8x16_t va_1x16 = vld1q_s8(a1); a1 += 16; + const int8x16_t va_2x16 = vld1q_s8(a2); a2 += 16; + const int8x16_t va_3x16 = vld1q_s8(a3); a3 += 16; + const int8x16_t va_4x16 = vld1q_s8(a4); a4 += 16; + const int8x16_t va_5x16 = vld1q_s8(a5); a5 += 16; + + // Load a 16x8 block of weights. + const int8x16_t vb0123x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + const int8x16_t vb4567x16 = vld1q_s8(w); w = (const int8_t*) w + 16; + // First crumb. + const int8x16_t vb0123x0123 = vandq_s8(vb0123x16, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vb4567x16, vmask); + // Second crumb. + const int8x16_t vb0123x4567 = vandq_s8(vshrq_n_s8(vb0123x16, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb4567x16, 2), vmask); + // Third crumb. + const int8x16_t vb0123x89AB = vandq_s8(vshrq_n_s8(vb0123x16, 4), vmask); + const int8x16_t vb4567x89AB = vandq_s8(vshrq_n_s8(vb4567x16, 4), vmask); + // Fourth crumb. + const int8x16_t vb0123xCDEF = vandq_s8(vshrq_n_s8(vb0123x16, 6), vmask); + const int8x16_t vb4567xCDEF = vandq_s8(vshrq_n_s8(vb4567x16, 6), vmask); + + // Multiply-accumulate: 6x16 * 16x8 --> 6x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, vget_low_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x0123, vget_low_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x4567, vget_low_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, vget_low_s8(va_0x16), 1); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x89AB, vget_high_s8(va_0x16), 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x89AB, vget_high_s8(va_0x16), 0); + + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123xCDEF, vget_high_s8(va_0x16), 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567xCDEF, vget_high_s8(va_0x16), 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, vget_low_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x0123, vget_low_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x4567, vget_low_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, vget_low_s8(va_1x16), 1); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x89AB, vget_high_s8(va_1x16), 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x89AB, vget_high_s8(va_1x16), 0); + + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123xCDEF, vget_high_s8(va_1x16), 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567xCDEF, vget_high_s8(va_1x16), 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, vget_low_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x0123, vget_low_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x4567, vget_low_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, vget_low_s8(va_2x16), 1); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x89AB, vget_high_s8(va_2x16), 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x89AB, vget_high_s8(va_2x16), 0); + + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123xCDEF, vget_high_s8(va_2x16), 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567xCDEF, vget_high_s8(va_2x16), 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, vget_low_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x0123, vget_low_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x4567, vget_low_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, vget_low_s8(va_3x16), 1); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x89AB, vget_high_s8(va_3x16), 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x89AB, vget_high_s8(va_3x16), 0); + + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123xCDEF, vget_high_s8(va_3x16), 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567xCDEF, vget_high_s8(va_3x16), 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, vget_low_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x0123, vget_low_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x4567, vget_low_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, vget_low_s8(va_4x16), 1); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x89AB, vget_high_s8(va_4x16), 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x89AB, vget_high_s8(va_4x16), 0); + + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123xCDEF, vget_high_s8(va_4x16), 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567xCDEF, vget_high_s8(va_4x16), 1); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, vget_low_s8(va_5x16), 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x0123, vget_low_s8(va_5x16), 0); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x4567, vget_low_s8(va_5x16), 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, vget_low_s8(va_5x16), 1); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x89AB, vget_high_s8(va_5x16), 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x89AB, vget_high_s8(va_5x16), 0); + + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123xCDEF, vget_high_s8(va_5x16), 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567xCDEF, vget_high_s8(va_5x16), 1); + + k -= 16 * sizeof(int8_t); + } + // Handle up to 8 final positions of `k`. + if XNN_UNLIKELY(k > 0) { + int8x16_t vb01234567x0123 = vld1q_s8(w); w = (const int8_t*) w + 16; + int8x16_t vb01234567x4567 = vld1q_s8(w); w = (const int8_t*) w + 16; + // 2x partial unrolled loop to load 8 bytes at a time. + while (k >= 8 * sizeof(int8_t)) { + // Load a 6x8 block of activations. + const int8x8_t va0x01234567 = vld1_s8(a0); a0 += 8; + const int8x8_t va1x01234567 = vld1_s8(a1); a1 += 8; + const int8x8_t va2x01234567 = vld1_s8(a2); a2 += 8; + const int8x8_t va3x01234567 = vld1_s8(a3); a3 += 8; + const int8x8_t va4x01234567 = vld1_s8(a4); a4 += 8; + const int8x8_t va5x01234567 = vld1_s8(a5); a5 += 8; + + // Load a 8x8 block of weights. + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + const int8x16_t vb4567x0123 = vandq_s8(vshrq_n_s8(vb01234567x0123, 2), vmask); + const int8x16_t vb4567x4567 = vandq_s8(vshrq_n_s8(vb01234567x4567, 2), vmask); + + // Multiply-accumulate: 6x8 * 8x8 --> 6x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x01234567, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x01234567, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x01234567, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x01234567, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x01234567, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x01234567, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x01234567, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x01234567, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x01234567, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x01234567, 0); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, va5x01234567, 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb0123x4567, va5x01234567, 0); + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb4567x0123, va0x01234567, 1); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb4567x4567, va0x01234567, 1); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb4567x0123, va1x01234567, 1); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb4567x4567, va1x01234567, 1); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb4567x0123, va2x01234567, 1); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb4567x4567, va2x01234567, 1); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb4567x0123, va3x01234567, 1); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb4567x4567, va3x01234567, 1); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb4567x0123, va4x01234567, 1); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb4567x4567, va4x01234567, 1); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb4567x0123, va5x01234567, 1); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb4567x4567, va5x01234567, 1); + + k -= 8 * sizeof(int8_t); + vb01234567x0123 = vshrq_n_s8(vb01234567x0123, 4); + vb01234567x4567 = vshrq_n_s8(vb01234567x4567, 4); + } + // Handle up to 4 final positions of `k` + if XNN_UNLIKELY(k != 0) { + // Load a 6x4 block of activations. + const int8x8_t va0x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a0)); a0 += 4; + const int8x8_t va1x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a1)); a1 += 4; + const int8x8_t va2x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a2)); a2 += 4; + const int8x8_t va3x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a3)); a3 += 4; + const int8x8_t va4x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a4)); a4 += 4; + const int8x8_t va5x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a5)); a5 += 4; + + const int8x16_t vb0123x0123 = vandq_s8(vb01234567x0123, vmask); + const int8x16_t vb0123x4567 = vandq_s8(vb01234567x4567, vmask); + + // Multiply-accumulate: 6x4 * 4x8 --> 6x8. + vacc0x0123 = vdotq_lane_s32(vacc0x0123, vb0123x0123, va0x0123, 0); + vacc0x4567 = vdotq_lane_s32(vacc0x4567, vb0123x4567, va0x0123, 0); + vacc1x0123 = vdotq_lane_s32(vacc1x0123, vb0123x0123, va1x0123, 0); + vacc1x4567 = vdotq_lane_s32(vacc1x4567, vb0123x4567, va1x0123, 0); + vacc2x0123 = vdotq_lane_s32(vacc2x0123, vb0123x0123, va2x0123, 0); + vacc2x4567 = vdotq_lane_s32(vacc2x4567, vb0123x4567, va2x0123, 0); + vacc3x0123 = vdotq_lane_s32(vacc3x0123, vb0123x0123, va3x0123, 0); + vacc3x4567 = vdotq_lane_s32(vacc3x4567, vb0123x4567, va3x0123, 0); + vacc4x0123 = vdotq_lane_s32(vacc4x0123, vb0123x0123, va4x0123, 0); + vacc4x4567 = vdotq_lane_s32(vacc4x4567, vb0123x4567, va4x0123, 0); + vacc5x0123 = vdotq_lane_s32(vacc5x0123, vb0123x0123, va5x0123, 0); + vacc5x4567 = vdotq_lane_s32(vacc5x4567, vb0123x4567, va5x0123, 0); + } + } + + float32x4_t vout0x0123 = vcvtq_f32_s32(vacc0x0123); + float32x4_t vout0x4567 = vcvtq_f32_s32(vacc0x4567); + float32x4_t vout1x0123 = vcvtq_f32_s32(vacc1x0123); + float32x4_t vout1x4567 = vcvtq_f32_s32(vacc1x4567); + float32x4_t vout2x0123 = vcvtq_f32_s32(vacc2x0123); + float32x4_t vout2x4567 = vcvtq_f32_s32(vacc2x4567); + float32x4_t vout3x0123 = vcvtq_f32_s32(vacc3x0123); + float32x4_t vout3x4567 = vcvtq_f32_s32(vacc3x4567); + float32x4_t vout4x0123 = vcvtq_f32_s32(vacc4x0123); + float32x4_t vout4x4567 = vcvtq_f32_s32(vacc4x4567); + float32x4_t vout5x0123 = vcvtq_f32_s32(vacc5x0123); + float32x4_t vout5x4567 = vcvtq_f32_s32(vacc5x4567); + const float32x4_t vtwo = vdupq_n_f32(2.0f); + const float32x4_t kernel_zero_points_0123 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_0123 = vaddq_f32(kernel_zero_points_0123, vtwo); + const float32x4_t kernel_zero_points_4567 = vld1q_f32(kzp); kzp = (const float*)kzp + 4; + const float32x4_t biased_kernel_zero_points_4567 = vaddq_f32(kernel_zero_points_4567, vtwo); + + // Subtract out the scaled left-hand row sums. + const float32x4_t lh_row_sum_0 = vld1q_dup_f32(&row_sum[0]); + vout0x0123 = vfmsq_f32(vout0x0123, biased_kernel_zero_points_0123, lh_row_sum_0); + vout0x4567 = vfmsq_f32(vout0x4567, biased_kernel_zero_points_4567, lh_row_sum_0); + const float32x4_t lh_row_sum_1 = vld1q_dup_f32(&row_sum[1]); + vout1x0123 = vfmsq_f32(vout1x0123, biased_kernel_zero_points_0123, lh_row_sum_1); + vout1x4567 = vfmsq_f32(vout1x4567, biased_kernel_zero_points_4567, lh_row_sum_1); + const float32x4_t lh_row_sum_2 = vld1q_dup_f32(&row_sum[2]); + vout2x0123 = vfmsq_f32(vout2x0123, biased_kernel_zero_points_0123, lh_row_sum_2); + vout2x4567 = vfmsq_f32(vout2x4567, biased_kernel_zero_points_4567, lh_row_sum_2); + const float32x4_t lh_row_sum_3 = vld1q_dup_f32(&row_sum[3]); + vout3x0123 = vfmsq_f32(vout3x0123, biased_kernel_zero_points_0123, lh_row_sum_3); + vout3x4567 = vfmsq_f32(vout3x4567, biased_kernel_zero_points_4567, lh_row_sum_3); + const float32x4_t lh_row_sum_4 = vld1q_dup_f32(&row_sum[4]); + vout4x0123 = vfmsq_f32(vout4x0123, biased_kernel_zero_points_0123, lh_row_sum_4); + vout4x4567 = vfmsq_f32(vout4x4567, biased_kernel_zero_points_4567, lh_row_sum_4); + const float32x4_t lh_row_sum_5 = vld1q_dup_f32(&row_sum[5]); + vout5x0123 = vfmsq_f32(vout5x0123, biased_kernel_zero_points_0123, lh_row_sum_5); + vout5x4567 = vfmsq_f32(vout5x4567, biased_kernel_zero_points_4567, lh_row_sum_5); + + // Add the product of left/right-hand zero points and `kc`. + const float32x4_t vscaled_input_zero_point_0 = + vdupq_n_f32((float)kc * quantization_params[0].zero_point); + const float32x4_t vscaled_input_zero_point_1 = + vdupq_n_f32((float)kc * quantization_params[1].zero_point); + const float32x4_t vscaled_input_zero_point_2 = + vdupq_n_f32((float)kc * quantization_params[2].zero_point); + const float32x4_t vscaled_input_zero_point_3 = + vdupq_n_f32((float)kc * quantization_params[3].zero_point); + const float32x4_t vscaled_input_zero_point_4 = + vdupq_n_f32((float)kc * quantization_params[4].zero_point); + const float32x4_t vscaled_input_zero_point_5 = + vdupq_n_f32((float)kc * quantization_params[5].zero_point); + vout0x0123 = + vmlaq_f32(vout0x0123, kernel_zero_points_0123, vscaled_input_zero_point_0); + vout0x4567 = + vmlaq_f32(vout0x4567, kernel_zero_points_4567, vscaled_input_zero_point_0); + vout1x0123 = + vmlaq_f32(vout1x0123, kernel_zero_points_0123, vscaled_input_zero_point_1); + vout1x4567 = + vmlaq_f32(vout1x4567, kernel_zero_points_4567, vscaled_input_zero_point_1); + vout2x0123 = + vmlaq_f32(vout2x0123, kernel_zero_points_0123, vscaled_input_zero_point_2); + vout2x4567 = + vmlaq_f32(vout2x4567, kernel_zero_points_4567, vscaled_input_zero_point_2); + vout3x0123 = + vmlaq_f32(vout3x0123, kernel_zero_points_0123, vscaled_input_zero_point_3); + vout3x4567 = + vmlaq_f32(vout3x4567, kernel_zero_points_4567, vscaled_input_zero_point_3); + vout4x0123 = + vmlaq_f32(vout4x0123, kernel_zero_points_0123, vscaled_input_zero_point_4); + vout4x4567 = + vmlaq_f32(vout4x4567, kernel_zero_points_4567, vscaled_input_zero_point_4); + vout5x0123 = + vmlaq_f32(vout5x0123, kernel_zero_points_0123, vscaled_input_zero_point_5); + vout5x4567 = + vmlaq_f32(vout5x4567, kernel_zero_points_4567, vscaled_input_zero_point_5); + const float32x4_t vinput_scale01 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[0].zero_point)); + vout0x0123 = vmulq_lane_f32(vout0x0123, vget_low_f32(vinput_scale01), 1); + vout1x0123 = vmulq_lane_f32(vout1x0123, vget_high_f32(vinput_scale01), 1); + vout0x4567 = vmulq_lane_f32(vout0x4567, vget_low_f32(vinput_scale01), 1); + vout1x4567 = vmulq_lane_f32(vout1x4567, vget_high_f32(vinput_scale01), 1); + const float32x4_t vinput_scale23 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[2].zero_point)); + vout2x0123 = vmulq_lane_f32(vout2x0123, vget_low_f32(vinput_scale23), 1); + vout3x0123 = vmulq_lane_f32(vout3x0123, vget_high_f32(vinput_scale23), 1); + vout2x4567 = vmulq_lane_f32(vout2x4567, vget_low_f32(vinput_scale23), 1); + vout3x4567 = vmulq_lane_f32(vout3x4567, vget_high_f32(vinput_scale23), 1); + const float32x4_t vinput_scale45 = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[4].zero_point)); + vout4x0123 = vmulq_lane_f32(vout4x0123, vget_low_f32(vinput_scale45), 1); + vout5x0123 = vmulq_lane_f32(vout5x0123, vget_high_f32(vinput_scale45), 1); + vout4x4567 = vmulq_lane_f32(vout4x4567, vget_low_f32(vinput_scale45), 1); + vout5x4567 = vmulq_lane_f32(vout5x4567, vget_high_f32(vinput_scale45), 1); + + const float32x4_t vfilter_output_scale0123 = vld1q_f32(w); w = (const float*) w + 4; + const float32x4_t vfilter_output_scale4567 = vld1q_f32(w); w = (const float*) w + 4; + + const float32x4_t vbias0123 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x0123 = vfmaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vfmaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vfmaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vfmaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vfmaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + vout5x0123 = vfmaq_f32(vbias0123, vout5x0123, vfilter_output_scale0123); + #else + vout0x0123 = vmlaq_f32(vbias0123, vout0x0123, vfilter_output_scale0123); + vout1x0123 = vmlaq_f32(vbias0123, vout1x0123, vfilter_output_scale0123); + vout2x0123 = vmlaq_f32(vbias0123, vout2x0123, vfilter_output_scale0123); + vout3x0123 = vmlaq_f32(vbias0123, vout3x0123, vfilter_output_scale0123); + vout4x0123 = vmlaq_f32(vbias0123, vout4x0123, vfilter_output_scale0123); + vout5x0123 = vmlaq_f32(vbias0123, vout5x0123, vfilter_output_scale0123); + #endif + const float32x4_t vbias4567 = vld1q_f32(w); w = (const float*) w + 4; + #if XNN_ARCH_ARM64 + vout0x4567 = vfmaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vfmaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vfmaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vfmaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vfmaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + vout5x4567 = vfmaq_f32(vbias4567, vout5x4567, vfilter_output_scale4567); + #else + vout0x4567 = vmlaq_f32(vbias4567, vout0x4567, vfilter_output_scale4567); + vout1x4567 = vmlaq_f32(vbias4567, vout1x4567, vfilter_output_scale4567); + vout2x4567 = vmlaq_f32(vbias4567, vout2x4567, vfilter_output_scale4567); + vout3x4567 = vmlaq_f32(vbias4567, vout3x4567, vfilter_output_scale4567); + vout4x4567 = vmlaq_f32(vbias4567, vout4x4567, vfilter_output_scale4567); + vout5x4567 = vmlaq_f32(vbias4567, vout5x4567, vfilter_output_scale4567); + #endif + + float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); + float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); + float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); + float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); + float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); + float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); + const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); + vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); + vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); + vfp16out2x01234567 = vmaxq_f16(vfp16out2x01234567, voutput_min); + vfp16out3x01234567 = vmaxq_f16(vfp16out3x01234567, voutput_min); + vfp16out4x01234567 = vmaxq_f16(vfp16out4x01234567, voutput_min); + vfp16out5x01234567 = vmaxq_f16(vfp16out5x01234567, voutput_min); + const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); + vfp16out0x01234567 = vminq_f16(vfp16out0x01234567, voutput_max); + vfp16out1x01234567 = vminq_f16(vfp16out1x01234567, voutput_max); + vfp16out2x01234567 = vminq_f16(vfp16out2x01234567, voutput_max); + vfp16out3x01234567 = vminq_f16(vfp16out3x01234567, voutput_max); + vfp16out4x01234567 = vminq_f16(vfp16out4x01234567, voutput_max); + vfp16out5x01234567 = vminq_f16(vfp16out5x01234567, voutput_max); + if XNN_LIKELY(nc >= 8) { + vst1q_u16(c0, vreinterpretq_u16_f16(vfp16out0x01234567)); + vst1q_u16(c1, vreinterpretq_u16_f16(vfp16out1x01234567)); + vst1q_u16(c2, vreinterpretq_u16_f16(vfp16out2x01234567)); + vst1q_u16(c3, vreinterpretq_u16_f16(vfp16out3x01234567)); + vst1q_u16(c4, vreinterpretq_u16_f16(vfp16out4x01234567)); + vst1q_u16(c5, vreinterpretq_u16_f16(vfp16out5x01234567)); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride); + + nc -= 8; + } else { + float16x4_t vfp16out0x0123 = vget_low_f16(vfp16out0x01234567); + float16x4_t vfp16out1x0123 = vget_low_f16(vfp16out1x01234567); + float16x4_t vfp16out2x0123 = vget_low_f16(vfp16out2x01234567); + float16x4_t vfp16out3x0123 = vget_low_f16(vfp16out3x01234567); + float16x4_t vfp16out4x0123 = vget_low_f16(vfp16out4x01234567); + float16x4_t vfp16out5x0123 = vget_low_f16(vfp16out5x01234567); + if (nc & 4) { + vst1_u16(c0, vreinterpret_u16_f16(vfp16out0x0123)); c0 += 4; + vst1_u16(c1, vreinterpret_u16_f16(vfp16out1x0123)); c1 += 4; + vst1_u16(c2, vreinterpret_u16_f16(vfp16out2x0123)); c2 += 4; + vst1_u16(c3, vreinterpret_u16_f16(vfp16out3x0123)); c3 += 4; + vst1_u16(c4, vreinterpret_u16_f16(vfp16out4x0123)); c4 += 4; + vst1_u16(c5, vreinterpret_u16_f16(vfp16out5x0123)); c5 += 4; + vfp16out0x0123 = vget_high_f16(vfp16out0x01234567); + vfp16out1x0123 = vget_high_f16(vfp16out1x01234567); + vfp16out2x0123 = vget_high_f16(vfp16out2x01234567); + vfp16out3x0123 = vget_high_f16(vfp16out3x01234567); + vfp16out4x0123 = vget_high_f16(vfp16out4x01234567); + vfp16out5x0123 = vget_high_f16(vfp16out5x01234567); + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vfp16out0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vfp16out1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vfp16out2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vfp16out3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_f16(vfp16out4x0123), 0); c4 += 2; + vst1_lane_u32((void*) c5, vreinterpret_u32_f16(vfp16out5x0123), 0); c5 += 2; + vfp16out0x0123 = vext_f16(vfp16out0x0123, vfp16out0x0123, 2); + vfp16out1x0123 = vext_f16(vfp16out1x0123, vfp16out1x0123, 2); + vfp16out2x0123 = vext_f16(vfp16out2x0123, vfp16out2x0123, 2); + vfp16out3x0123 = vext_f16(vfp16out3x0123, vfp16out3x0123, 2); + vfp16out4x0123 = vext_f16(vfp16out4x0123, vfp16out4x0123, 2); + vfp16out5x0123 = vext_f16(vfp16out5x0123, vfp16out5x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vreinterpret_u16_f16(vfp16out0x0123), 0); + vst1_lane_u16(c1, vreinterpret_u16_f16(vfp16out1x0123), 0); + vst1_lane_u16(c2, vreinterpret_u16_f16(vfp16out2x0123), 0); + vst1_lane_u16(c3, vreinterpret_u16_f16(vfp16out3x0123), 0); + vst1_lane_u16(c4, vreinterpret_u16_f16(vfp16out4x0123), 0); + vst1_lane_u16(c5, vreinterpret_u16_f16(vfp16out5x0123), 0); + } + nc = 0; + } + } while (nc != 0); +} diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c index 6e1bf870ebb..e42e957697e 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -157,7 +158,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c index fd53ed58d4d..22d47172e8d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -118,7 +119,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( #endif float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c index 073eecf2b92..c9b9493b8e6 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -199,7 +200,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c index 55db2004f12..dd42d46cecb 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -143,7 +144,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x16c4-minmax-neondotfp16arith.c index 3f12d664f7f..c78f428c002 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -243,7 +244,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c4-minmax-neondotfp16arith.c index 3f3ded600a1..cba415f805b 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -170,7 +171,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c index 6f62b72524f..5662891d22a 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -285,7 +286,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c4-minmax-neondotfp16arith.c index 392edcfe01a..69a68da641d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -195,7 +196,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x16c4-minmax-neondotfp16arith.c index 94fc5547482..576d4dbbbaa 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -329,7 +330,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c4-minmax-neondotfp16arith.c index f1f8796d82a..8d54efa52ba 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -222,7 +223,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x16c4-minmax-neondotfp16arith.c index 3249f6006ae..fd6af146b85 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -371,7 +372,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); float16x8_t vfp16out5x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout5x89AB), vcvt_f16_f32(vout5xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c4-minmax-neondotfp16arith.c index 442c8572c12..e8a1416ac06 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -247,7 +248,6 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x16c4-minmax-neondotfp16arith.c index a7eccf2cbd0..24cdfbe6293 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -148,7 +149,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x16c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c index 6917af5902f..8831c62a1d9 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -113,7 +114,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c4__neondotfp16arith( #endif float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x16c4-minmax-neondotfp16arith.c index 9061e829ba5..81626428297 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -190,7 +191,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x16c4__neondotfp16arith( float16x8_t vfp16out0x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout0x89AB), vcvt_f16_f32(vout0xCDEF)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c index c4b0afa7183..e4289711c2e 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -138,7 +139,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x16c4-minmax-neondotfp16arith.c index 7d3a72989cc..86268a1c4de 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -234,7 +235,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x16c4__neondotfp16arith( float16x8_t vfp16out1x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout1x89AB), vcvt_f16_f32(vout1xCDEF)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c index 30802eda4fb..1415514c6eb 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -165,7 +166,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c4__neondotfp16arith( float16x8_t vfp16out0x01234567 = vcombine_f16(vcvt_f16_f32(vout0x0123), vcvt_f16_f32(vout0x4567)); float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-neondotfp16arith.c index 9bc97fc9fc7..3dbdc98ee93 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -276,7 +277,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c4__neondotfp16arith( float16x8_t vfp16out2x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout2x89AB), vcvt_f16_f32(vout2xCDEF)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-neondotfp16arith.c index 26f38d21eb8..7bc7b2f2f71 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -190,7 +191,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c4__neondotfp16arith( float16x8_t vfp16out1x01234567 = vcombine_f16(vcvt_f16_f32(vout1x0123), vcvt_f16_f32(vout1x4567)); float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x16c4-minmax-neondotfp16arith.c index ae8e83297e0..002bdc41fe7 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -320,7 +321,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x16c4__neondotfp16arith( float16x8_t vfp16out3x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout3x89AB), vcvt_f16_f32(vout3xCDEF)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c4-minmax-neondotfp16arith.c index 0c175150ea3..623c712807e 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -217,7 +218,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c4__neondotfp16arith( float16x8_t vfp16out2x01234567 = vcombine_f16(vcvt_f16_f32(vout2x0123), vcvt_f16_f32(vout2x4567)); float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x16c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x16c4-minmax-neondotfp16arith.c index 0da6565f40d..b7a488e6b84 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x16c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x16c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( size_t mr, size_t nc, @@ -362,7 +363,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x16c4__neondotfp16arith( float16x8_t vfp16out4x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout4x89AB), vcvt_f16_f32(vout4xCDEF)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); float16x8_t vfp16out5x89ABCDEF = vcombine_f16(vcvt_f16_f32(vout5x89AB), vcvt_f16_f32(vout5xCDEF)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out0x89ABCDEF = vmaxq_f16(vfp16out0x89ABCDEF, voutput_min); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c4-minmax-neondotfp16arith.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c4-minmax-neondotfp16arith.c index 10c038b8ae9..e93f0b47431 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c4-minmax-neondotfp16arith.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c4-minmax-neondotfp16arith.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( size_t mr, size_t nc, @@ -242,7 +243,6 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c4__neondotfp16arith( float16x8_t vfp16out3x01234567 = vcombine_f16(vcvt_f16_f32(vout3x0123), vcvt_f16_f32(vout3x4567)); float16x8_t vfp16out4x01234567 = vcombine_f16(vcvt_f16_f32(vout4x0123), vcvt_f16_f32(vout4x4567)); float16x8_t vfp16out5x01234567 = vcombine_f16(vcvt_f16_f32(vout5x0123), vcvt_f16_f32(vout5x4567)); - const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); vfp16out0x01234567 = vmaxq_f16(vfp16out0x01234567, voutput_min); vfp16out1x01234567 = vmaxq_f16(vfp16out1x01234567, voutput_min); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c4-minmax-neondot.c index f08a3819119..3997983dc2d 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x8c4-minmax-neondot.c index 95ab2ed57d4..cff85ed8854 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x16c4-minmax-neondot.c index 88c32e74c77..9fd4239d631 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x8c4-minmax-neondot.c index a2c165ab2d8..70217e5b5a5 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-2x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x16c4-minmax-neondot.c index b1e9becf336..d4c5c1d95f0 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x8c4-minmax-neondot.c index 3e235dfea28..7c942fff119 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-3x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x16c4-minmax-neondot.c index 19c3f33c487..51100703dbf 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x8c4-minmax-neondot.c index f2e6a20a6dc..6f5988459af 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c4-minmax-neondot.c index 2159fbf68f2..f8de4c6bbcd 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x8c4-minmax-neondot.c index 3bdc8b26bbc..965016a76fe 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x16c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x16c4-minmax-neondot.c index b55326f2294..37802198ee2 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x16c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x8c4-minmax-neondot.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x8c4-minmax-neondot.c index a1897699df4..37423dc7577 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x8c4-minmax-neondot.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-6x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x16c4-minmax-neondot.c index 730b2502350..e83e799aff0 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x8c4-minmax-neondot.c index fa4f9cfdded..aec490fdb48 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x16c4-minmax-neondot.c index 4343d4c550f..858f698dcf9 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_2x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x8c4-minmax-neondot.c index c5f12fdca37..c7d4b81d66f 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-2x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_2x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x16c4-minmax-neondot.c index 39d0500c10d..d4e55cdfcd4 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_3x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x8c4-minmax-neondot.c index bde4b0d77ce..776c7a26396 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-3x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_3x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x16c4-minmax-neondot.c index b2cb4108958..4093b2f7435 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x8c4-minmax-neondot.c index 21224c59e90..5bb33a42beb 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-4x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x16c4-minmax-neondot.c index 95b22afb9a2..539d7ae8c9c 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_5x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x8c4-minmax-neondot.c index 379a099860c..b8d3a2f2515 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-5x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_5x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x16c4-minmax-neondot.c index 4a7d27f027c..dc8782aca2a 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x8c4-minmax-neondot.c index f6685ada854..7db6f9b5568 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-6x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x16c4-minmax-neondot.c index ca27bce14cc..a73884e49a5 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_7x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x8c4-minmax-neondot.c index 11a1a108d9f..8340d40addb 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-7x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_7x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x16c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x16c4-minmax-neondot.c index 063930fe17a..3ead575cac5 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_8x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x8c4-minmax-neondot.c b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x8c4-minmax-neondot.c index 1cffd56f536..a2ea893dc8c 100644 --- a/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-8x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_8x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c index 11d59e0b59e..656181e2afb 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c4-minmax-neondot.c index d21bba71c39..a7b0936370e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-neondot.c index 0ac71c89d81..4c69d44fd2a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c4-minmax-neondot.c index 650423db3e2..70d8f779551 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-neondot.c index 56960989f73..5c88163454b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-neondot.c index 26857057f9f..da2d1590e35 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c4-minmax-neondot.c index 33e921d9517..b77ddace803 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-neondot.c index 4f5a76072d3..2134bfdde19 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c4-minmax-neondot.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c4-minmax-neondot.c index d0f6a380980..786ed0a1dcb 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-neondot.c index 893ddce2814..592c989f045 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-neondot.c index f6266354b20..f732e642813 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-neondot.c index 1e76acc5c85..720ca0c5901 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-neondot.c index 987f8dac4dd..d6e103c140a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-neondot.c index a4f125c45fd..72f75c34695 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-neondot.c index f344988c4b5..119d0a6f647 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-neondot.c index 8a7d9729b32..5d1c78da505 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-neondot.c index d08a6021dbe..b3d6773fcb3 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-neondot.c index 9ab30b68ffb..93670d16bc5 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-neondot.c index cbc7f52e5e1..25f4c39684d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-neondot.c index 7a4a0131640..c075aba3534 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-neondot.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-neondot.c index d80cfe88a7c..bb2ce4c73c0 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-neondot.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-neondot.c @@ -19,6 +19,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-gemm/c4-neondot.c.in b/src/qs8-gemm/c4-neondot.c.in index 8952a7c2c58..67fb3f5503f 100644 --- a/src/qs8-gemm/c4-neondot.c.in +++ b/src/qs8-gemm/c4-neondot.c.in @@ -7,9 +7,9 @@ $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" $assert NR % 8 == 0 $assert 8 <= NR <= 16 $assert REQUANTIZATION in ["FP32", "RNDNU"] or not REQUANTIZATION -$assert DATATYPE in ["QC8", "QS8", "QS8_QC2", "QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QB4_F16", "QB4_F32"] +$assert DATATYPE in ["QC8", "QS8", "QS8_QC2", "QD8_F16", "QD8_BF16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16", "QB4_F16", "QB4_F32"] $assert DATATYPE not in ["QC8", "QS8_QC2"] or REQUANTIZATION == "FP32" -$assert not DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QB4_F16", "QB4_F32"] or not REQUANTIZATION +$assert not DATATYPE in ["QD8_F16", "QD8_BF16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16", "QB4_F16", "QB4_F32"] or not REQUANTIZATION #include #include #include @@ -41,12 +41,16 @@ $def END(): $SET_INDENT(INDENT - 1) $return _ + '}' $# -$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8": "qs8", "QS8_QC2": "qs8_qc2w", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w", "QC2_F32": "qd8_f32_qc2w", "QB4_F16": "qd8_f16_qb4w", "QB4_F32": "qd8_f32_qb4w"}[DATATYPE] +$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8": "qs8", "QS8_QC2": "qs8_qc2w", "QD8_F16" : "qd8_f16_qc8w", "QD8_BF16" : "qd8_bf16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w", "QC2_F32": "qd8_f32_qc2w", "QC2_F16": "qd8_f16_qc2w", "QB4_F16": "qd8_f16_qb4w", "QB4_F32": "qd8_f32_qb4w"}[DATATYPE] $REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else "" $PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" else "neon") -$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8": "union xnn_qs8_conv_minmax_params", "QS8_QC2": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params", "QC2_F32": "struct xnn_f32_minmax_params", "QB4_F16": "struct xnn_f16_qb4w_minmax_params", "QB4_F32": "struct xnn_f32_qb4w_minmax_params"}[DATATYPE] -$OUT_T = {"QC8": "int8_t", "QS8_QC2": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QC2_F32": "float", "QB4_F16": "xnn_float16", "QB4_F32": "float", "QS8": "int8_t"}[DATATYPE] -$ISA = "fp16arith" if DATATYPE in ["QC4_F16", "QD8_F16", "QB4_F16"] else "" +$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8": "union xnn_qs8_conv_minmax_params", "QS8_QC2": "union xnn_qs8_qc8w_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_BF16": "struct xnn_bf16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params", "QC2_F32": "struct xnn_f32_minmax_params", "QC2_F16": "struct xnn_f16_minmax_params", "QB4_F16": "struct xnn_f16_qb4w_minmax_params", "QB4_F32": "struct xnn_f32_qb4w_minmax_params"}[DATATYPE] +$OUT_T = {"QC8": "int8_t", "QS8_QC2": "int8_t", "QD8_F16": "xnn_float16", "QD8_BF16": "xnn_bfloat16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QC2_F32": "float", "QC2_F16": "xnn_float16", "QB4_F16": "xnn_float16", "QB4_F32": "float", "QS8": "int8_t"}[DATATYPE] +$VCVT_X16_F32 = "vcvt_bf16_f32" if DATATYPE in ["QD8_BF16"] else "vcvt_f16_f32" +$VMAXQ_X16 = "vmaxq_bf16" if DATATYPE in ["QD8_BF16"] else "vmaxq_f16" +$VMINQ_X16 = "vminq_bf16" if DATATYPE in ["QD8_BF16"] else "vminq_f16" + +$ISA = "fp16arith" if DATATYPE in ["QC4_F16", "QC2_F16", "QD8_F16", "QB4_F16"] else "bf16" if DATATYPE in ["QD8_BF16"] else "" $BLOCKWISE = DATATYPE in ["QB4_F16", "QB4_F32"] void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c4__neondot${ISA}( size_t mr, @@ -58,9 +62,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${OUT_T}* restrict c, size_t cm_stride, size_t cn_stride, - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QB4_F16", "QB4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16", "QB4_F16", "QB4_F32"]: const ${PARAMS_TYPE}* restrict params, - $if DATATYPE in ["QC2_F32"]: + $if DATATYPE in ["QC2_F32", "QC2_F16"]: const float* row_sum, const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS $else: @@ -77,13 +81,13 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c kc = round_up_po2(kc, 4 * sizeof(int8_t)); const int8_t* a0 = a; - $if DATATYPE in ["QD8_F16", "QC4_F16", "QB4_F16"]: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QC4_F16", "QC2_F16", "QB4_F16"]: uint16_t* c0 = (uint16_t*) c; $else: ${OUT_T}* c0 = c; $for M in range(1, MR): const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride); - $if DATATYPE in ["QD8_F16", "QC4_F16", "QB4_F16"]: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QC4_F16", "QC2_F16", "QB4_F16"]: uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); $else: ${OUT_T}* c${M} = (${OUT_T}*) ((uintptr_t) c${M-1} + cm_stride); @@ -110,11 +114,11 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c assert(bl % 32 == 0); $if DATATYPE in ["QC4_F16", "QC4_F32", "QB4_F16", "QB4_F32"]: const int8x16_t vmask = vmovq_n_s8(INT8_C(0xF0)); - $if DATATYPE in ["QC2_F32"]: + $if DATATYPE in ["QC2_F32", "QC2_F16"]: const int8x16_t vmask = vmovq_n_s8(INT8_C(0x03)); // Loop over groups of ${NR} columns. do { - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QB4_F16", "QB4_F32"]: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QD8_F32", "QC4_F16", "QC4_F32", "QB4_F16", "QB4_F32"]: // Initialize accumulators with bias. ${NR} bias values are loaded from the // weight matrix, at the start of the group of ${NR} columns. $for M in range(0, MR, 2): @@ -148,7 +152,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c const int32x4_t vksum${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4; int32x4_t vacc${M}x${ABC[N:N+4]} = vmulq_lane_s32(vksum${ABC[N:N+4]}, vget_low_s32(vinput_zero_point${ABC[M:M+2]}), 0); int32x4_t vacc${M+1}x${ABC[N:N+4]} = vmulq_lane_s32(vksum${ABC[N:N+4]}, vget_high_s32(vinput_zero_point${ABC[M:M+2]}), 0); - $elif DATATYPE in ["QC2_F32"]: + $elif DATATYPE in ["QC2_F32", "QC2_F16"]: // Initialize the bias with the scaled left-hand weight sums. $for N in range(0, NR, 4): const int32x4_t vksum${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4; @@ -184,7 +188,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${_}size_t k = bl; $else: size_t k = kc; - $if DATATYPE in ["QS8_QC2", "QC2_F32"]: + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: ${_}// 4x partial unrolled loop to load 16 bytes at a time. ${_}while (k >= 16 * sizeof(int8_t)) { ${_}// Load a ${MR}x16 block of activations. @@ -198,25 +202,25 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[0:4]} = vshrq_n_s8(vshlq_n_s8(vb${ABC[N:N+4]}x16, 6), 6); - $if DATATYPE == "QC2_F32": + $if DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[0:4]} = vandq_s8(vb${ABC[N:N+4]}x16, vmask); ${_}// Second crumb. $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[4:8]} = vshrq_n_s8(vshlq_n_s8(vb${ABC[N:N+4]}x16, 4), 6); - $if DATATYPE == "QC2_F32": + $if DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[4:8]} = vandq_s8(vshrq_n_s8(vb${ABC[N:N+4]}x16, 2), vmask); ${_}// Third crumb. $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[8:12]} = vshrq_n_s8(vshlq_n_s8(vb${ABC[N:N+4]}x16, 2), 6); - $if DATATYPE == "QC2_F32": + $if DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[8:12]} = vandq_s8(vshrq_n_s8(vb${ABC[N:N+4]}x16, 4), vmask); ${_}// Fourth crumb. $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[12:16]} = vshrq_n_s8(vb${ABC[N:N+4]}x16, 6); - $if DATATYPE == "QC2_F32": + $if DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[N:N+4]}x${ABC[12:16]} = vandq_s8(vshrq_n_s8(vb${ABC[N:N+4]}x16, 6), vmask); ${_}// Multiply-accumulate: ${MR}x16 * 16x${NR} --> ${MR}x${NR}. @@ -235,7 +239,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${_}k -= 16 * sizeof(int8_t); ${_}} - $if DATATYPE in ["QC2_F32", "QS8_QC2"]: + $if DATATYPE in ["QC2_F32", "QC2_F16", "QS8_QC2"]: // Handle up to 8 final positions of `k`. ${_}if XNN_UNLIKELY(k > 0) { $for K in range(0, 8, 8): @@ -248,8 +252,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${_}const int8x8_t va${M}x01234567 = vld1_s8(a${M}); a${M} += 8; ${_}// Load a 8x${NR} block of weights. - $if DATATYPE in ["QC4_F16", "QC4_F32", "QB4_F16", "QB4_F32", "QC2_F32", "QS8_QC2"]: - $if DATATYPE not in ["QC2_F32", "QS8_QC2"]: + $if DATATYPE in ["QC4_F16", "QC4_F32", "QB4_F16", "QB4_F32", "QC2_F32", "QC2_F16", "QS8_QC2"]: + $if DATATYPE not in ["QC2_F32", "QC2_F16", "QS8_QC2"]: $for K in range(0, 8, 8): $for N in range(0, NR, 4): ${_}const int8x16_t vb${ABC[K:K+8]}x${ABC[N:N+4]} = vld1q_s8(w); w = (const int8_t*) w + 16; @@ -257,14 +261,14 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[K:K+4]}x${ABC[N:N+4]} = vshrq_n_s8(vshlq_n_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, 6), 6); - $elif DATATYPE == "QC2_F32": + $elif DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[K:K+4]}x${ABC[N:N+4]} = vandq_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, vmask); $else: ${_}const int8x16_t vb${ABC[K:K+4]}x${ABC[N:N+4]} = vshlq_n_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, 4); $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb${ABC[K+4:K+8]}x${ABC[N:N+4]} = vshrq_n_s8(vshlq_n_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, 4), 6); - $elif DATATYPE == "QC2_F32": + $elif DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb${ABC[K+4:K+8]}x${ABC[N:N+4]} = vandq_s8(vshrq_n_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, 2), vmask); $else: ${_}const int8x16_t vb${ABC[K+4:K+8]}x${ABC[N:N+4]} = vandq_s8(vb${ABC[K:K+8]}x${ABC[N:N+4]}, vmask); @@ -280,7 +284,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${_}vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb${ABC[K:K+4]}x${ABC[N:N+4]}, va${M}x01234567, ${K//4}); ${_}k -= 8 * sizeof(int8_t); - $if DATATYPE in ["QS8_QC2", "QC2_F32"]: + $if DATATYPE in ["QS8_QC2", "QC2_F32", "QC2_F16"]: $for N in range(0, NR, 4): ${_}vb01234567x${ABC[N:N+4]} = vshrq_n_s8(vb01234567x${ABC[N:N+4]}, 4); ${_}} @@ -290,11 +294,11 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c $for M in range(MR): ${_}const int8x8_t va${M}x0123 = vreinterpret_s8_s32(vld1_dup_s32((const int32_t*)a${M})); a${M} += 4; - $if DATATYPE in ["QC2_F32", "QS8_QC2"]: + $if DATATYPE in ["QC2_F32", "QC2_F16", "QS8_QC2"]: $for N in range(0, NR, 4): $if DATATYPE == "QS8_QC2": ${_}const int8x16_t vb0123x${ABC[N:N+4]} = vshrq_n_s8(vshlq_n_s8(vb01234567x${ABC[N:N+4]}, 6), 6); - $if DATATYPE == "QC2_F32": + $if DATATYPE in ["QC2_F32", "QC2_F16"]: ${_}const int8x16_t vb0123x${ABC[N:N+4]} = vandq_s8(vb01234567x${ABC[N:N+4]}, vmask); $else: ${_}// Load a 4x${NR} block of weights. @@ -312,7 +316,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c $for N in range(0, NR, 4): ${_}vacc${M}x${ABC[N:N+4]} = vdotq_lane_s32(vacc${M}x${ABC[N:N+4]}, vb0123x${ABC[N:N+4]}, va${M}x0123, 0); ${_}} - $if DATATYPE in ["QC2_F32", "QS8_QC2"]: + $if DATATYPE in ["QC2_F32", "QC2_F16", "QS8_QC2"]: ${_}} $if BLOCKWISE: $for N in range(0, NR, 4): @@ -331,14 +335,14 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c ${_}vout${M+1}x${ABC[N:N+4]} = vfmaq_f32(vout${M+1}x${ABC[N:N+4]}, vf${M+1}x${ABC[N:N+4]}, vfilter_output_scale${ABC[N:N+4]}); } - $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32"] or BLOCKWISE: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QD8_F32", "QC4_F16", "QC4_F32", "QC2_F32", "QC2_F16"] or BLOCKWISE: $for M in range(MR): $for N in range(0, NR, 4): $if DATATYPE in ["QC4_F16", "QC4_F32"]: float32x4_t vout${M}x${ABC[N:N+4]} = vcvtq_n_f32_s32(vacc${M}x${ABC[N:N+4]}, 4); $elif not BLOCKWISE: float32x4_t vout${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]}); - $if DATATYPE in ["QC2_F32"]: + $if DATATYPE in ["QC2_F32", "QC2_F16"]: const float32x4_t vtwo = vdupq_n_f32(2.0f); $for N in range(0, NR, 4): const float32x4_t kernel_zero_points_${ABC[N:N+4]} = vld1q_f32(kzp); kzp = (const float*)kzp + 4; @@ -387,19 +391,18 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c vout${M}x${ABC[N:N+4]} = vmlaq_f32(vbias${ABC[N:N+4]}, vout${M}x${ABC[N:N+4]}, vfilter_output_scale${ABC[N:N+4]}); #endif - $if DATATYPE in ["QD8_F16", "QC4_F16", "QB4_F16"]: + $if DATATYPE in ["QD8_F16", "QD8_BF16", "QC4_F16", "QC2_F16", "QB4_F16"]: $for M in range(0, MR): $for N in range(0, NR, 8): - float16x8_t vfp16out${M}x${ABC[N:N+8]} = vcombine_f16(vcvt_f16_f32(vout${M}x${ABC[N:N+4]}), vcvt_f16_f32(vout${M}x${ABC[N+4:N+8]})); - + float16x8_t vfp16out${M}x${ABC[N:N+8]} = vcombine_f16(${VCVT_X16_F32}(vout${M}x${ABC[N:N+4]}), ${VCVT_X16_F32}(vout${M}x${ABC[N+4:N+8]})); const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.min)); $for M in range(0, MR): $for N in range(0, NR, 8): - vfp16out${M}x${ABC[N:N+8]} = vmaxq_f16(vfp16out${M}x${ABC[N:N+8]}, voutput_min); + vfp16out${M}x${ABC[N:N+8]} = ${VMAXQ_X16}(vfp16out${M}x${ABC[N:N+8]}, voutput_min); const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) ¶ms->scalar.max)); $for M in range(0, MR): $for N in range(0, NR, 8): - vfp16out${M}x${ABC[N:N+8]} = vminq_f16(vfp16out${M}x${ABC[N:N+8]}, voutput_max); + vfp16out${M}x${ABC[N:N+8]} = ${VMINQ_X16}(vfp16out${M}x${ABC[N:N+8]}, voutput_max); if XNN_LIKELY(nc >= ${NR}) { $for M in range(MR): vst1q_u16(c${M}, vreinterpretq_u16_f16(vfp16out${M}x${ABC[0:8]})); diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x16c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x16c4-minmax-fp32-neondot.c index f9ad604e649..489f77af24a 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c4-minmax-fp32-neondot.c index e8778b69858..59a88d6cc94 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x16c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x16c4-minmax-fp32-neondot.c index 6f1d303a772..61619db5332 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c4-minmax-fp32-neondot.c index bcc86301b99..a4677c653dd 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x16c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x16c4-minmax-fp32-neondot.c index 3e6f204f0fa..9a32d579c88 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c4-minmax-fp32-neondot.c index 644c9043c42..a221e85675e 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x16c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x16c4-minmax-fp32-neondot.c index 52b595b3adc..184504706aa 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_8x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c4-minmax-fp32-neondot.c b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c4-minmax-fp32-neondot.c index 8cbaa5694fe..0b09654cb5e 100644 --- a/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_8x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-neondot.c index c2f8d5dfd08..fdb9dc2157c 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c4-minmax-fp32-neondot.c index 7034bd0f668..a0b6d70be50 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-neondot.c index ff1046899c1..8dc56e1548a 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-neondot.c index e7a30092bd3..dbf589ae2dd 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-neondot.c index 4a21a4a8db3..a83ade1d66d 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x8c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x8c4-minmax-fp32-neondot.c index 491bf964bdb..c98d41c252d 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-neondot.c index 423e4263a79..75c17ca1480 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__neondot( size_t mr, size_t nc, diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x8c4-minmax-fp32-neondot.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x8c4-minmax-fp32-neondot.c index 14860325313..cd43af9cba2 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x8c4-minmax-fp32-neondot.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x8c4-minmax-fp32-neondot.c @@ -20,6 +20,7 @@ #include "src/xnnpack/microparams.h" + void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c4__neondot( size_t mr, size_t nc, diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index f511abbf56e..28bcdb6e611 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -4174,6 +4174,20 @@ DECLARE_QD8_F32_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc2w_gemm_minmax_u const struct xnn_f16_minmax_params* params, const float* row_sum, \ const struct xnn_qd8_quantization_params* quantization_params); +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith) + +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith) +DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith) + DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c8__avx2_madd) DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c8__avx2_madd) DECLARE_QD8_F16_QC2W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c8__avx2_madd) diff --git a/test/qd8-f16-qc2w-gemm-minmax.cc b/test/qd8-f16-qc2w-gemm-minmax.cc index d5568abd633..628978ab08a 100644 --- a/test/qd8-f16-qc2w-gemm-minmax.cc +++ b/test/qd8-f16-qc2w-gemm-minmax.cc @@ -488,6 +488,237 @@ INSTANTIATE_TEST_SUITE_P( }); +#if XNN_ENABLE_ARM_DOTPROD && XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_1X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_2X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_3X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_4X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_5X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_6X8C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_1X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_2X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_3X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_4X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_5X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F16_QC2W_GEMM_MINMAX_6X16C4__NEONDOTFP16ARITH, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/16, + /*adj_k_block=*/16, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/4, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith, + xnn_init_f16_minmax_scalar_params, + xnn_pack_qd8_qc2w_gemm_goi_w); + }, + xnn_arch_arm_neon_dot | xnn_arch_arm_neon_fp16_arith)), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + #if XNN_ENABLE_AVX2 && (XNN_ARCH_X86 || XNN_ARCH_X86_64) INSTANTIATE_TEST_SUITE_P( QD8_F16_QC2W_GEMM_MINMAX_1X8C8__AVX2_MADD, GemmTest, diff --git a/test/qd8-f16-qc2w-gemm-minmax.yaml b/test/qd8-f16-qc2w-gemm-minmax.yaml index d410630e3b1..d0f2bac0bca 100644 --- a/test/qd8-f16-qc2w-gemm-minmax.yaml +++ b/test/qd8-f16-qc2w-gemm-minmax.yaml @@ -45,6 +45,69 @@ k-block: 4 planes: 4 +# ARM NEONDOT +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x8c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 + +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_2x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_3x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_4x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_5x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 +- name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_6x16c4__neondotfp16arith + init: xnn_init_f16_minmax_scalar_params + pack: xnn_pack_qd8_qc2w_gemm_goi_w + k-block: 16 + planes: 4 + # AVX2 MADD # TODO(fbarchard): use signed inputs - name: xnn_qd8_f16_qc2w_gemm_minmax_ukernel_1x8c8__avx2_madd