From eb7d3c0336e8fb251e07a3860e2a0a0f2ede398a Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Fri, 20 Mar 2026 12:44:29 -0700 Subject: [PATCH] Add SIMD mask and select operations. This change introduces a `mask` template to represent boolean masks for SIMD vectors. It includes: - Definition of the `mask` template in `vec.h`. - Overloaded comparison operators (`==`, `!=`, `<`, `<=`, `>`, `>=`) for `vec` types, returning `mask` types. - Overloaded bitwise operators (`&`, `|`, `^`, `~`) for `mask` types. - A `select` function to conditionally choose elements from two vectors based on a `mask`. - Specialized implementations for ARM NEON, x86 SSE2, AVX, and AVX512. - Generic implementations for multi_vec. - New test cases for compare and select operations. PiperOrigin-RevId: 886929409 --- ynnpack/base/simd/arm_neon_base.h | 97 ++++++++++++++++++++ ynnpack/base/simd/generic.inc | 47 ++++++++++ ynnpack/base/simd/test/arm_neon.cc | 7 ++ ynnpack/base/simd/test/generic.h | 87 ++++++++++++++++++ ynnpack/base/simd/test/multi_vec.cc | 7 ++ ynnpack/base/simd/test/x86_avx.cc | 7 ++ ynnpack/base/simd/test/x86_avx512.cc | 7 ++ ynnpack/base/simd/test/x86_sse2.cc | 7 ++ ynnpack/base/simd/vec.h | 103 ++++++++++++++++++++++ ynnpack/base/simd/x86_avx512.h | 127 +++++++++++++++++++++++++++ ynnpack/base/simd/x86_avx_base.h | 110 +++++++++++++++++++++++ ynnpack/base/simd/x86_sse2_base.h | 107 ++++++++++++++++++++++ 12 files changed, 713 insertions(+) diff --git a/ynnpack/base/simd/arm_neon_base.h b/ynnpack/base/simd/arm_neon_base.h index b1bfb1af541..eed9efee12b 100644 --- a/ynnpack/base/simd/arm_neon_base.h +++ b/ynnpack/base/simd/arm_neon_base.h @@ -173,6 +173,103 @@ using s16x8 = vec; using u8x16 = vec; using s8x16 = vec; +using mf32x4 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + uint32x4_t m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(uint32x4_t m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) + : m(vdupq_n_u32(x ? 0xFFFFFFFF : 0)) {} +}; + +YNN_ALWAYS_INLINE mf32x4 operator==(f32x4 a, f32x4 b) { + return mf32x4{vceqq_f32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator!=(f32x4 a, f32x4 b) { + return mf32x4{vmvnq_u32(vceqq_f32(a.v, b.v))}; +} +YNN_ALWAYS_INLINE mf32x4 operator<(f32x4 a, f32x4 b) { + return mf32x4{vcltq_f32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator<=(f32x4 a, f32x4 b) { + return mf32x4{vcleq_f32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator>(f32x4 a, f32x4 b) { + return mf32x4{vcgtq_f32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator>=(f32x4 a, f32x4 b) { + return mf32x4{vcgeq_f32(a.v, b.v)}; +} + +YNN_ALWAYS_INLINE mf32x4 operator&(mf32x4 a, mf32x4 b) { + return mf32x4{vandq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator|(mf32x4 a, mf32x4 b) { + return mf32x4{vorrq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator^(mf32x4 a, mf32x4 b) { + return mf32x4{veorq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator~(mf32x4 a) { return mf32x4{vmvnq_u32(a.m)}; } + +YNN_ALWAYS_INLINE f32x4 select(mf32x4 m, f32x4 a, f32x4 b) { + return f32x4{vbslq_f32(m.m, a.v, b.v)}; +} + +using ms32x4 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + uint32x4_t m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(uint32x4_t m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) + : m(vdupq_n_u32(x ? 0xFFFFFFFF : 0)) {} +}; + +YNN_ALWAYS_INLINE ms32x4 operator==(s32x4 a, s32x4 b) { + return ms32x4{vceqq_s32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator!=(s32x4 a, s32x4 b) { + return ms32x4{vmvnq_u32(vceqq_s32(a.v, b.v))}; +} +YNN_ALWAYS_INLINE ms32x4 operator<(s32x4 a, s32x4 b) { + return ms32x4{vcltq_s32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator<=(s32x4 a, s32x4 b) { + return ms32x4{vcleq_s32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator>(s32x4 a, s32x4 b) { + return ms32x4{vcgtq_s32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator>=(s32x4 a, s32x4 b) { + return ms32x4{vcgeq_s32(a.v, b.v)}; +} + +YNN_ALWAYS_INLINE ms32x4 operator&(ms32x4 a, ms32x4 b) { + return ms32x4{vandq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator|(ms32x4 a, ms32x4 b) { + return ms32x4{vorrq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator^(ms32x4 a, ms32x4 b) { + return ms32x4{veorq_u32(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator~(ms32x4 a) { return ms32x4{vmvnq_u32(a.m)}; } + +YNN_ALWAYS_INLINE s32x4 select(ms32x4 m, s32x4 a, s32x4 b) { + return s32x4{vbslq_s32(m.m, a.v, b.v)}; +} + +YNN_ALWAYS_INLINE ms32x4 cast(mf32x4 from, int32_t) { return ms32x4{from.m}; } +YNN_ALWAYS_INLINE mf32x4 cast(ms32x4 from, float) { return mf32x4{from.m}; } + namespace internal { YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) { diff --git a/ynnpack/base/simd/generic.inc b/ynnpack/base/simd/generic.inc index a6195dc8d12..25d08439ed3 100644 --- a/ynnpack/base/simd/generic.inc +++ b/ynnpack/base/simd/generic.inc @@ -202,6 +202,53 @@ template YNN_ALWAYS_INLINE vec sub_sat(vec a, vec b) { return {sub_sat(a.lo(), b.lo()), sub_sat(a.hi(), b.hi())}; } + +template +YNN_ALWAYS_INLINE mask operator==(vec a, vec b) { + return {a.lo() == b.lo(), a.hi() == b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator!=(vec a, vec b) { + return {a.lo() != b.lo(), a.hi() != b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator<(vec a, vec b) { + return {a.lo() < b.lo(), a.hi() < b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator<=(vec a, vec b) { + return {a.lo() <= b.lo(), a.hi() <= b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator>(vec a, vec b) { + return {a.lo() > b.lo(), a.hi() > b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator>=(vec a, vec b) { + return {a.lo() >= b.lo(), a.hi() >= b.hi()}; +} + +template +YNN_ALWAYS_INLINE mask operator&(mask a, mask b) { + return {a.lo() & b.lo(), a.hi() & b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator|(mask a, mask b) { + return {a.lo() | b.lo(), a.hi() | b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator^(mask a, mask b) { + return {a.lo() ^ b.lo(), a.hi() ^ b.hi()}; +} +template +YNN_ALWAYS_INLINE mask operator~(mask a) { + return {~a.lo(), ~a.hi()}; +} + +template +YNN_ALWAYS_INLINE vec select(mask m, vec a, vec b) { + return {select(m.lo(), a.lo(), b.lo()), select(m.hi(), a.hi(), b.hi())}; +} template YNN_ALWAYS_INLINE vec floor(vec a) { return {floor(a.lo()), floor(a.hi())}; diff --git a/ynnpack/base/simd/test/arm_neon.cc b/ynnpack/base/simd/test/arm_neon.cc index 53bfe1368e9..440f22d0558 100644 --- a/ynnpack/base/simd/test/arm_neon.cc +++ b/ynnpack/base/simd/test/arm_neon.cc @@ -60,6 +60,13 @@ TEST_PARTIAL_LOAD_STORE(arm_neon, bf16, 8); TEST_PARTIAL_LOAD_STORE(arm_neon, f32, 4); TEST_PARTIAL_LOAD_STORE(arm_neon, s32, 4); +TEST_COMPARE_EQ(arm_neon, f32, 4); +TEST_COMPARE_LT(arm_neon, f32, 4); +TEST_COMPARE_GT(arm_neon, f32, 4); +TEST_COMPARE_EQ(arm_neon, s32, 4); +TEST_COMPARE_LT(arm_neon, s32, 4); +TEST_COMPARE_GT(arm_neon, s32, 4); + TEST_PARTIAL_LOAD_STORE(arm_neon, u8, 8); TEST_ADD(arm_neon, u8, 16); diff --git a/ynnpack/base/simd/test/generic.h b/ynnpack/base/simd/test/generic.h index 122578d9e3d..b8cbda89d0f 100644 --- a/ynnpack/base/simd/test/generic.h +++ b/ynnpack/base/simd/test/generic.h @@ -215,6 +215,93 @@ void test_partial_store() { test_partial_store(); \ } +struct cmp_eq_op { + template + bool operator()(T a, T b) { + return a == b; + } + template + mask operator()(vec a, vec b) { + return a == b; + } +}; + +struct cmp_lt_op { + template + bool operator()(T a, T b) { + return a < b; + } + template + mask operator()(vec a, vec b) { + return a < b; + } +}; + +struct cmp_gt_op { + template + bool operator()(T a, T b) { + return a > b; + } + template + mask operator()(vec a, vec b) { + return a > b; + } +}; + +template +void test_compare_select_op() { + using vector = vec; + using mask_t = mask; + Op op; + + ReplicableRandomDevice rng; + for (auto _ : FuzzTest(std::chrono::milliseconds(100))) { + scalar a[N]; + scalar b[N]; + scalar c[N]; + scalar d[N]; + fill_random(a, N, rng); + fill_random(b, N, rng); + fill_random(c, N, rng); + fill_random(d, N, rng); + + // Make some elements equal to test == and <= / >= properly + for (size_t i = 0; i < N; ++i) { + if (rng() % 3 == 0) a[i] = b[i]; + } + + mask_t m = op(load(a, vector::N), load(b, vector::N)); + vector sel = select(m, load(c, vector::N), load(d, vector::N)); + + scalar res[N]; + store(res, sel); + + for (size_t i = 0; i < N; ++i) { + bool is_true = op(a[i], b[i]); + if constexpr (std::is_floating_point_v) { + if (std::isnan(a[i]) || std::isnan(b[i])) continue; + } + + ASSERT_EQ(res[i], is_true ? c[i] : d[i]); + } + } +} + +#define TEST_COMPARE_EQ(test_class, type, N) \ + TEST_F(test_class, compare_eq_##type##x##N) { \ + test_compare_select_op(); \ + } + +#define TEST_COMPARE_LT(test_class, type, N) \ + TEST_F(test_class, compare_lt_##type##x##N) { \ + test_compare_select_op(); \ + } + +#define TEST_COMPARE_GT(test_class, type, N) \ + TEST_F(test_class, compare_gt_##type##x##N) { \ + test_compare_select_op(); \ + } + template void test_op() { using vector = vec; diff --git a/ynnpack/base/simd/test/multi_vec.cc b/ynnpack/base/simd/test/multi_vec.cc index c411c66255c..6719784f52a 100644 --- a/ynnpack/base/simd/test/multi_vec.cc +++ b/ynnpack/base/simd/test/multi_vec.cc @@ -36,6 +36,13 @@ TEST_PARTIAL_LOAD_STORE(multi_vec, bf16, 4); TEST_PARTIAL_LOAD_STORE(multi_vec, f32, 2); TEST_PARTIAL_LOAD_STORE(multi_vec, s32, 2); +TEST_COMPARE_EQ(multi_vec, f32, 2); +TEST_COMPARE_LT(multi_vec, f32, 2); +TEST_COMPARE_GT(multi_vec, f32, 2); +TEST_COMPARE_EQ(multi_vec, s32, 2); +TEST_COMPARE_LT(multi_vec, s32, 2); +TEST_COMPARE_GT(multi_vec, s32, 2); + TEST_ADD(multi_vec, u8, 8); TEST_ADD(multi_vec, s8, 8); TEST_ADD(multi_vec, s16, 4); diff --git a/ynnpack/base/simd/test/x86_avx.cc b/ynnpack/base/simd/test/x86_avx.cc index 31e7e283f89..3b50f2fcf1b 100644 --- a/ynnpack/base/simd/test/x86_avx.cc +++ b/ynnpack/base/simd/test/x86_avx.cc @@ -52,6 +52,13 @@ TEST_PARTIAL_LOAD_STORE(x86_avx, bf16, 16); TEST_PARTIAL_LOAD_STORE(x86_avx, f32, 8); TEST_PARTIAL_LOAD_STORE(x86_avx, s32, 8); +TEST_COMPARE_EQ(x86_avx, f32, 8); +TEST_COMPARE_LT(x86_avx, f32, 8); +TEST_COMPARE_GT(x86_avx, f32, 8); +TEST_COMPARE_EQ(x86_avx, s32, 8); +TEST_COMPARE_LT(x86_avx, s32, 8); +TEST_COMPARE_GT(x86_avx, s32, 8); + TEST_ADD(x86_avx, f32, 8); TEST_SUBTRACT(x86_avx, f32, 8); TEST_MULTIPLY(x86_avx, f32, 8); diff --git a/ynnpack/base/simd/test/x86_avx512.cc b/ynnpack/base/simd/test/x86_avx512.cc index 82c7f1e5924..176447a8677 100644 --- a/ynnpack/base/simd/test/x86_avx512.cc +++ b/ynnpack/base/simd/test/x86_avx512.cc @@ -68,6 +68,13 @@ TEST_PARTIAL_LOAD_STORE(x86_avx512, bf16, 32); TEST_PARTIAL_LOAD_STORE(x86_avx512, f32, 16); TEST_PARTIAL_LOAD_STORE(x86_avx512, s32, 16); +TEST_COMPARE_EQ(x86_avx512, f32, 16); +TEST_COMPARE_LT(x86_avx512, f32, 16); +TEST_COMPARE_GT(x86_avx512, f32, 16); +TEST_COMPARE_EQ(x86_avx512, s32, 16); +TEST_COMPARE_LT(x86_avx512, s32, 16); +TEST_COMPARE_GT(x86_avx512, s32, 16); + TEST_ADD(x86_avx512, u8, 64); TEST_ADD(x86_avx512, s8, 64); TEST_ADD(x86_avx512, f32, 16); diff --git a/ynnpack/base/simd/test/x86_sse2.cc b/ynnpack/base/simd/test/x86_sse2.cc index 1f5fb3855b2..ffd3e5f36d5 100644 --- a/ynnpack/base/simd/test/x86_sse2.cc +++ b/ynnpack/base/simd/test/x86_sse2.cc @@ -48,6 +48,13 @@ TEST_PARTIAL_LOAD_STORE(x86_sse2, bf16, 8); TEST_PARTIAL_LOAD_STORE(x86_sse2, f32, 4); TEST_PARTIAL_LOAD_STORE(x86_sse2, s32, 4); +TEST_COMPARE_EQ(x86_sse2, f32, 4); +TEST_COMPARE_LT(x86_sse2, f32, 4); +TEST_COMPARE_GT(x86_sse2, f32, 4); +TEST_COMPARE_EQ(x86_sse2, s32, 4); +TEST_COMPARE_LT(x86_sse2, s32, 4); +TEST_COMPARE_GT(x86_sse2, s32, 4); + TEST_ADD(x86_sse2, u8, 16); TEST_ADD(x86_sse2, s8, 16); TEST_ADD(x86_sse2, s16, 8); diff --git a/ynnpack/base/simd/vec.h b/ynnpack/base/simd/vec.h index 6f4ebbba5f6..e069e5119de 100644 --- a/ynnpack/base/simd/vec.h +++ b/ynnpack/base/simd/vec.h @@ -35,6 +35,37 @@ struct undef { static constexpr std::integral_constant N = {}; }; +template +struct mask { + static constexpr std::integral_constant N = {}; + using submask = mask; + + submask m[2]; + + YNN_ALWAYS_INLINE submask& lo() { return m[0]; } + YNN_ALWAYS_INLINE const submask& lo() const { return m[0]; } + YNN_ALWAYS_INLINE submask& hi() { return m[1]; } + YNN_ALWAYS_INLINE const submask& hi() const { return m[1]; } + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(bool x) : m{submask{x}, submask{x}} {} + YNN_ALWAYS_INLINE mask(submask m0, submask m1) : m{m0, m1} {} + + YNN_ALWAYS_INLINE submask& operator[](size_t i) { return m[i]; } + YNN_ALWAYS_INLINE const submask& operator[](size_t i) const { return m[i]; } +}; + +template +struct mask { + static constexpr std::integral_constant N = {}; + using value_type = bool; + + bool v; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(bool x) : v(x) {} +}; + // The idea here is to provide the minimal wrappers around various platform // specific intrinsics that allow overloading behavior based on type and vector // length. For example, suppose you want to implement the following generic @@ -143,6 +174,31 @@ vec sub_sat(vec a, vec b); template vec operator<<(vec a, int b); +template +mask operator==(vec a, vec b); +template +mask operator!=(vec a, vec b); +template +mask operator<(vec a, vec b); +template +mask operator<=(vec a, vec b); +template +mask operator>(vec a, vec b); +template +mask operator>=(vec a, vec b); + +template +mask operator&(mask a, mask b); +template +mask operator|(mask a, mask b); +template +mask operator^(mask a, mask b); +template +mask operator~(mask a); + +template +vec select(mask m, vec a, vec b); + template std::array, 4> transpose(std::array, 4> x); template @@ -305,6 +361,53 @@ YNN_ALWAYS_INLINE vec sub_sat(vec a, vec b) { return vec{sub_sat(a.v, b.v)}; } +template +YNN_ALWAYS_INLINE mask operator==(vec a, vec b) { + return mask{a.v == b.v}; +} +template +YNN_ALWAYS_INLINE mask operator!=(vec a, vec b) { + return mask{a.v != b.v}; +} +template +YNN_ALWAYS_INLINE mask operator<(vec a, vec b) { + return mask{a.v < b.v}; +} +template +YNN_ALWAYS_INLINE mask operator<=(vec a, vec b) { + return mask{a.v <= b.v}; +} +template +YNN_ALWAYS_INLINE mask operator>(vec a, vec b) { + return mask{a.v > b.v}; +} +template +YNN_ALWAYS_INLINE mask operator>=(vec a, vec b) { + return mask{a.v >= b.v}; +} + +template +YNN_ALWAYS_INLINE mask operator&(mask a, mask b) { + return mask{a.v && b.v}; +} +template +YNN_ALWAYS_INLINE mask operator|(mask a, mask b) { + return mask{a.v || b.v}; +} +template +YNN_ALWAYS_INLINE mask operator^(mask a, mask b) { + return mask{a.v != b.v}; +} +template +YNN_ALWAYS_INLINE mask operator~(mask a) { + return mask{!a.v}; +} + +template +YNN_ALWAYS_INLINE vec select(mask m, vec a, vec b) { + return vec{m.v ? a.v : b.v}; +} + template YNN_ALWAYS_INLINE vec cast(vec from, To) { return vec{static_cast(from.v)}; diff --git a/ynnpack/base/simd/x86_avx512.h b/ynnpack/base/simd/x86_avx512.h index 4035a7d329d..3339103b08a 100644 --- a/ynnpack/base/simd/x86_avx512.h +++ b/ynnpack/base/simd/x86_avx512.h @@ -205,6 +205,133 @@ using u8x64 = vec; using s8x64 = vec; using f32x64 = vec; +using mf32x16 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __mmask16 m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__mmask16 m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) : m(x ? 0xFFFF : 0) {} + YNN_ALWAYS_INLINE mask(mask m0, mask m1) + : m(static_cast<__mmask16>(_mm256_movemask_ps(m0.m)) | + (static_cast<__mmask16>(_mm256_movemask_ps(m1.m)) << 8)) {} + + YNN_ALWAYS_INLINE mask lo() const { + return mask{_mm256_castsi256_ps( + _mm256_maskz_set1_epi32(static_cast<__mmask8>(m), -1))}; + } + YNN_ALWAYS_INLINE mask hi() const { + return mask{_mm256_castsi256_ps( + _mm256_maskz_set1_epi32(static_cast<__mmask8>(m >> 8), -1))}; + } +}; + +YNN_ALWAYS_INLINE mf32x16 operator==(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_EQ_OQ)}; +} +YNN_ALWAYS_INLINE mf32x16 operator!=(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_NEQ_OQ)}; +} +YNN_ALWAYS_INLINE mf32x16 operator<(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_LT_OQ)}; +} +YNN_ALWAYS_INLINE mf32x16 operator<=(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_LE_OQ)}; +} +YNN_ALWAYS_INLINE mf32x16 operator>(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_GT_OQ)}; +} +YNN_ALWAYS_INLINE mf32x16 operator>=(f32x16 a, f32x16 b) { + return mf32x16{_mm512_cmp_ps_mask(a.v, b.v, _CMP_GE_OQ)}; +} + +YNN_ALWAYS_INLINE mf32x16 operator&(mf32x16 a, mf32x16 b) { + return mf32x16{static_cast<__mmask16>(a.m & b.m)}; +} +YNN_ALWAYS_INLINE mf32x16 operator|(mf32x16 a, mf32x16 b) { + return mf32x16{static_cast<__mmask16>(a.m | b.m)}; +} +YNN_ALWAYS_INLINE mf32x16 operator^(mf32x16 a, mf32x16 b) { + return mf32x16{static_cast<__mmask16>(a.m ^ b.m)}; +} +YNN_ALWAYS_INLINE mf32x16 operator~(mf32x16 a) { + return mf32x16{static_cast<__mmask16>(~a.m)}; +} + +YNN_ALWAYS_INLINE f32x16 select(mf32x16 m, f32x16 a, f32x16 b) { + return f32x16{_mm512_mask_blend_ps(m.m, b.v, a.v)}; +} + +using ms32x16 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __mmask16 m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__mmask16 m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) : m(x ? 0xFFFF : 0) {} + YNN_ALWAYS_INLINE mask(mask m0, mask m1) + : m(static_cast<__mmask16>( + _mm256_movemask_ps(_mm256_castsi256_ps(m0.m))) | + (static_cast<__mmask16>(_mm256_movemask_ps(_mm256_castsi256_ps(m1.m))) + << 8)) {} + + YNN_ALWAYS_INLINE mask lo() const { + return mask{ + _mm256_maskz_set1_epi32(static_cast<__mmask8>(m), -1)}; + } + YNN_ALWAYS_INLINE mask hi() const { + return mask{ + _mm256_maskz_set1_epi32(static_cast<__mmask8>(m >> 8), -1)}; + } +}; + +YNN_ALWAYS_INLINE ms32x16 operator==(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_EQ)}; +} +YNN_ALWAYS_INLINE ms32x16 operator!=(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_NE)}; +} +YNN_ALWAYS_INLINE ms32x16 operator<(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_LT)}; +} +YNN_ALWAYS_INLINE ms32x16 operator<=(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_LE)}; +} +YNN_ALWAYS_INLINE ms32x16 operator>(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_NLE)}; +} +YNN_ALWAYS_INLINE ms32x16 operator>=(s32x16 a, s32x16 b) { + return ms32x16{_mm512_cmp_epi32_mask(a.v, b.v, _MM_CMPINT_NLT)}; +} + +YNN_ALWAYS_INLINE ms32x16 operator&(ms32x16 a, ms32x16 b) { + return ms32x16{static_cast<__mmask16>(a.m & b.m)}; +} +YNN_ALWAYS_INLINE ms32x16 operator|(ms32x16 a, ms32x16 b) { + return ms32x16{static_cast<__mmask16>(a.m | b.m)}; +} +YNN_ALWAYS_INLINE ms32x16 operator^(ms32x16 a, ms32x16 b) { + return ms32x16{static_cast<__mmask16>(a.m ^ b.m)}; +} +YNN_ALWAYS_INLINE ms32x16 operator~(ms32x16 a) { + return ms32x16{static_cast<__mmask16>(~a.m)}; +} + +YNN_ALWAYS_INLINE s32x16 select(ms32x16 m, s32x16 a, s32x16 b) { + return s32x16{_mm512_mask_blend_epi32(m.m, b.v, a.v)}; +} + +YNN_ALWAYS_INLINE ms32x16 cast(mf32x16 from, int32_t) { + return ms32x16{from.m}; +} +YNN_ALWAYS_INLINE mf32x16 cast(ms32x16 from, float) { return mf32x16{from.m}; } + YNN_ALWAYS_INLINE f32x16 load_aligned(const float* ptr, decltype(f32x16::N), f32x16 = {}) { return f32x16{_mm512_load_ps(ptr)}; diff --git a/ynnpack/base/simd/x86_avx_base.h b/ynnpack/base/simd/x86_avx_base.h index b3f26709ae6..dac51a410ed 100644 --- a/ynnpack/base/simd/x86_avx_base.h +++ b/ynnpack/base/simd/x86_avx_base.h @@ -199,6 +199,116 @@ using s16x16 = vec; using u8x32 = vec; using s8x32 = vec; +using mf32x8 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __m256 m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__m256 m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) + : m(_mm256_castsi256_ps(_mm256_set1_epi32(x ? -1 : 0))) {} + YNN_ALWAYS_INLINE mask(mask m0, mask m1) + : m(internal::concat(m0.m, m1.m)) {} + + YNN_ALWAYS_INLINE mask lo() const { + return mask{_mm256_castps256_ps128(m)}; + } + YNN_ALWAYS_INLINE mask hi() const { + return mask{_mm256_extractf128_ps(m, 1)}; + } +}; + +YNN_ALWAYS_INLINE mf32x8 operator==(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ)}; +} +YNN_ALWAYS_INLINE mf32x8 operator!=(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_NEQ_OQ)}; +} +YNN_ALWAYS_INLINE mf32x8 operator<(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_LT_OQ)}; +} +YNN_ALWAYS_INLINE mf32x8 operator<=(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_LE_OQ)}; +} +YNN_ALWAYS_INLINE mf32x8 operator>(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_GT_OQ)}; +} +YNN_ALWAYS_INLINE mf32x8 operator>=(f32x8 a, f32x8 b) { + return mf32x8{_mm256_cmp_ps(a.v, b.v, _CMP_GE_OQ)}; +} + +YNN_ALWAYS_INLINE mf32x8 operator&(mf32x8 a, mf32x8 b) { + return mf32x8{_mm256_and_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x8 operator|(mf32x8 a, mf32x8 b) { + return mf32x8{_mm256_or_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x8 operator^(mf32x8 a, mf32x8 b) { + return mf32x8{_mm256_xor_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x8 operator~(mf32x8 a) { + return mf32x8{_mm256_xor_ps(a.m, _mm256_castsi256_ps(_mm256_set1_epi32(-1)))}; +} + +YNN_ALWAYS_INLINE f32x8 select(mf32x8 m, f32x8 a, f32x8 b) { + return f32x8{_mm256_blendv_ps(b.v, a.v, m.m)}; +} + +using ms32x8 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __m256i m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__m256i m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) : m(_mm256_set1_epi32(x ? -1 : 0)) {} + YNN_ALWAYS_INLINE mask(mask m0, mask m1) + : m(internal::concat(m0.m, m1.m)) {} + + YNN_ALWAYS_INLINE mask lo() const { + return mask{_mm256_castsi256_si128(m)}; + } + YNN_ALWAYS_INLINE mask hi() const { + return mask{ + _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(m), 1))}; + } +}; + +YNN_ALWAYS_INLINE ms32x8 operator&(ms32x8 a, ms32x8 b) { + return ms32x8{_mm256_castps_si256( + _mm256_and_ps(_mm256_castsi256_ps(a.m), _mm256_castsi256_ps(b.m)))}; +} +YNN_ALWAYS_INLINE ms32x8 operator|(ms32x8 a, ms32x8 b) { + return ms32x8{_mm256_castps_si256( + _mm256_or_ps(_mm256_castsi256_ps(a.m), _mm256_castsi256_ps(b.m)))}; +} +YNN_ALWAYS_INLINE ms32x8 operator^(ms32x8 a, ms32x8 b) { + return ms32x8{_mm256_castps_si256( + _mm256_xor_ps(_mm256_castsi256_ps(a.m), _mm256_castsi256_ps(b.m)))}; +} +YNN_ALWAYS_INLINE ms32x8 operator~(ms32x8 a) { + return ms32x8{_mm256_castps_si256(_mm256_xor_ps( + _mm256_castsi256_ps(a.m), _mm256_castsi256_ps(_mm256_set1_epi32(-1))))}; +} + +YNN_ALWAYS_INLINE s32x8 select(ms32x8 m, s32x8 a, s32x8 b) { + return s32x8{_mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(b.v), + _mm256_castsi256_ps(a.v), + _mm256_castsi256_ps(m.m)))}; +} + +YNN_ALWAYS_INLINE ms32x8 cast(mf32x8 from, int32_t) { + return ms32x8{_mm256_castps_si256(from.m)}; +} +YNN_ALWAYS_INLINE mf32x8 cast(ms32x8 from, float) { + return mf32x8{_mm256_castsi256_ps(from.m)}; +} + namespace internal { // These overloads are x86-specific helpers for implementing templated diff --git a/ynnpack/base/simd/x86_sse2_base.h b/ynnpack/base/simd/x86_sse2_base.h index cad61b1e362..019b225d0b2 100644 --- a/ynnpack/base/simd/x86_sse2_base.h +++ b/ynnpack/base/simd/x86_sse2_base.h @@ -157,6 +157,113 @@ using s16x8 = vec; using u8x16 = vec; using s8x16 = vec; +using mf32x4 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __m128 m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__m128 m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) + : m(_mm_castsi128_ps(_mm_set1_epi32(x ? -1 : 0))) {} +}; + +YNN_ALWAYS_INLINE mf32x4 operator==(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmpeq_ps(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator!=(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmpneq_ps(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator<(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmplt_ps(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator<=(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmple_ps(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator>(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmpnle_ps(a.v, b.v)}; +} +YNN_ALWAYS_INLINE mf32x4 operator>=(f32x4 a, f32x4 b) { + return mf32x4{_mm_cmpnlt_ps(a.v, b.v)}; +} + +YNN_ALWAYS_INLINE mf32x4 operator&(mf32x4 a, mf32x4 b) { + return mf32x4{_mm_and_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator|(mf32x4 a, mf32x4 b) { + return mf32x4{_mm_or_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator^(mf32x4 a, mf32x4 b) { + return mf32x4{_mm_xor_ps(a.m, b.m)}; +} +YNN_ALWAYS_INLINE mf32x4 operator~(mf32x4 a) { + return mf32x4{_mm_xor_ps(a.m, _mm_castsi128_ps(_mm_set1_epi32(-1)))}; +} + +YNN_ALWAYS_INLINE f32x4 select(mf32x4 m, f32x4 a, f32x4 b) { + return f32x4{_mm_or_ps(_mm_and_ps(m.m, a.v), _mm_andnot_ps(m.m, b.v))}; +} + +using ms32x4 = mask; + +template <> +struct mask { + static constexpr std::integral_constant N = {}; + __m128i m; + + mask() = default; + YNN_ALWAYS_INLINE explicit mask(__m128i m) : m(m) {} + YNN_ALWAYS_INLINE explicit mask(bool x) : m(_mm_set1_epi32(x ? -1 : 0)) {} +}; + +YNN_ALWAYS_INLINE ms32x4 operator==(s32x4 a, s32x4 b) { + return ms32x4{_mm_cmpeq_epi32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator!=(s32x4 a, s32x4 b) { + return ms32x4{_mm_xor_si128(_mm_cmpeq_epi32(a.v, b.v), _mm_set1_epi32(-1))}; +} +YNN_ALWAYS_INLINE ms32x4 operator<(s32x4 a, s32x4 b) { + return ms32x4{_mm_cmplt_epi32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator<=(s32x4 a, s32x4 b) { + return ms32x4{ + _mm_or_si128(_mm_cmplt_epi32(a.v, b.v), _mm_cmpeq_epi32(a.v, b.v))}; +} +YNN_ALWAYS_INLINE ms32x4 operator>(s32x4 a, s32x4 b) { + return ms32x4{_mm_cmpgt_epi32(a.v, b.v)}; +} +YNN_ALWAYS_INLINE ms32x4 operator>=(s32x4 a, s32x4 b) { + return ms32x4{ + _mm_or_si128(_mm_cmpgt_epi32(a.v, b.v), _mm_cmpeq_epi32(a.v, b.v))}; +} + +YNN_ALWAYS_INLINE ms32x4 operator&(ms32x4 a, ms32x4 b) { + return ms32x4{_mm_and_si128(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator|(ms32x4 a, ms32x4 b) { + return ms32x4{_mm_or_si128(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator^(ms32x4 a, ms32x4 b) { + return ms32x4{_mm_xor_si128(a.m, b.m)}; +} +YNN_ALWAYS_INLINE ms32x4 operator~(ms32x4 a) { + return ms32x4{_mm_xor_si128(a.m, _mm_set1_epi32(-1))}; +} + +YNN_ALWAYS_INLINE s32x4 select(ms32x4 m, s32x4 a, s32x4 b) { + return s32x4{ + _mm_or_si128(_mm_and_si128(m.m, a.v), _mm_andnot_si128(m.m, b.v))}; +} + +YNN_ALWAYS_INLINE ms32x4 cast(mf32x4 from, int32_t) { + return ms32x4{_mm_castps_si128(from.m)}; +} +YNN_ALWAYS_INLINE mf32x4 cast(ms32x4 from, float) { + return mf32x4{_mm_castsi128_ps(from.m)}; +} + namespace internal { // These overloads are x86-specific helpers for implementing templated