Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions ynnpack/base/simd/arm_neon_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,103 @@ using s16x8 = vec<int16_t, 8>;
using u8x16 = vec<uint8_t, 16>;
using s8x16 = vec<int8_t, 16>;

using mf32x4 = mask<float, 4>;

template <>
struct mask<float, 4> {
static constexpr std::integral_constant<size_t, 4> N = {};
uint32x4_t m;

mask() = default;
YNN_ALWAYS_INLINE explicit mask(uint32x4_t m) : m(m) {}
YNN_ALWAYS_INLINE explicit mask(bool x)
: m(vdupq_n_u32(x ? 0xFFFFFFFF : 0)) {}
};

YNN_ALWAYS_INLINE mf32x4 operator==(f32x4 a, f32x4 b) {
return mf32x4{vceqq_f32(a.v, b.v)};
}
YNN_ALWAYS_INLINE mf32x4 operator!=(f32x4 a, f32x4 b) {
return mf32x4{vmvnq_u32(vceqq_f32(a.v, b.v))};
}
YNN_ALWAYS_INLINE mf32x4 operator<(f32x4 a, f32x4 b) {
return mf32x4{vcltq_f32(a.v, b.v)};
}
YNN_ALWAYS_INLINE mf32x4 operator<=(f32x4 a, f32x4 b) {
return mf32x4{vcleq_f32(a.v, b.v)};
}
YNN_ALWAYS_INLINE mf32x4 operator>(f32x4 a, f32x4 b) {
return mf32x4{vcgtq_f32(a.v, b.v)};
}
YNN_ALWAYS_INLINE mf32x4 operator>=(f32x4 a, f32x4 b) {
return mf32x4{vcgeq_f32(a.v, b.v)};
}

YNN_ALWAYS_INLINE mf32x4 operator&(mf32x4 a, mf32x4 b) {
return mf32x4{vandq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE mf32x4 operator|(mf32x4 a, mf32x4 b) {
return mf32x4{vorrq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE mf32x4 operator^(mf32x4 a, mf32x4 b) {
return mf32x4{veorq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE mf32x4 operator~(mf32x4 a) { return mf32x4{vmvnq_u32(a.m)}; }

YNN_ALWAYS_INLINE f32x4 select(mf32x4 m, f32x4 a, f32x4 b) {
return f32x4{vbslq_f32(m.m, a.v, b.v)};
}

using ms32x4 = mask<int32_t, 4>;

template <>
struct mask<int32_t, 4> {
static constexpr std::integral_constant<size_t, 4> N = {};
uint32x4_t m;

mask() = default;
YNN_ALWAYS_INLINE explicit mask(uint32x4_t m) : m(m) {}
YNN_ALWAYS_INLINE explicit mask(bool x)
: m(vdupq_n_u32(x ? 0xFFFFFFFF : 0)) {}
};

YNN_ALWAYS_INLINE ms32x4 operator==(s32x4 a, s32x4 b) {
return ms32x4{vceqq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE ms32x4 operator!=(s32x4 a, s32x4 b) {
return ms32x4{vmvnq_u32(vceqq_s32(a.v, b.v))};
}
YNN_ALWAYS_INLINE ms32x4 operator<(s32x4 a, s32x4 b) {
return ms32x4{vcltq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE ms32x4 operator<=(s32x4 a, s32x4 b) {
return ms32x4{vcleq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE ms32x4 operator>(s32x4 a, s32x4 b) {
return ms32x4{vcgtq_s32(a.v, b.v)};
}
YNN_ALWAYS_INLINE ms32x4 operator>=(s32x4 a, s32x4 b) {
return ms32x4{vcgeq_s32(a.v, b.v)};
}

YNN_ALWAYS_INLINE ms32x4 operator&(ms32x4 a, ms32x4 b) {
return ms32x4{vandq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE ms32x4 operator|(ms32x4 a, ms32x4 b) {
return ms32x4{vorrq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE ms32x4 operator^(ms32x4 a, ms32x4 b) {
return ms32x4{veorq_u32(a.m, b.m)};
}
YNN_ALWAYS_INLINE ms32x4 operator~(ms32x4 a) { return ms32x4{vmvnq_u32(a.m)}; }

YNN_ALWAYS_INLINE s32x4 select(ms32x4 m, s32x4 a, s32x4 b) {
return s32x4{vbslq_s32(m.m, a.v, b.v)};
}

YNN_ALWAYS_INLINE ms32x4 cast(mf32x4 from, int32_t) { return ms32x4{from.m}; }
YNN_ALWAYS_INLINE mf32x4 cast(ms32x4 from, float) { return mf32x4{from.m}; }

namespace internal {

YNN_ALWAYS_INLINE int32x4x2_t vtrn(int32x4_t a, int32x4_t b) {
Expand Down
47 changes: 47 additions & 0 deletions ynnpack/base/simd/generic.inc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,53 @@ template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> sub_sat(vec<T, N> a, vec<T, N> b) {
return {sub_sat(a.lo(), b.lo()), sub_sat(a.hi(), b.hi())};
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator==(vec<T, N> a, vec<T, N> b) {
return {a.lo() == b.lo(), a.hi() == b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator!=(vec<T, N> a, vec<T, N> b) {
return {a.lo() != b.lo(), a.hi() != b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator<(vec<T, N> a, vec<T, N> b) {
return {a.lo() < b.lo(), a.hi() < b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator<=(vec<T, N> a, vec<T, N> b) {
return {a.lo() <= b.lo(), a.hi() <= b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator>(vec<T, N> a, vec<T, N> b) {
return {a.lo() > b.lo(), a.hi() > b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator>=(vec<T, N> a, vec<T, N> b) {
return {a.lo() >= b.lo(), a.hi() >= b.hi()};
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator&(mask<T, N> a, mask<T, N> b) {
return {a.lo() & b.lo(), a.hi() & b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator|(mask<T, N> a, mask<T, N> b) {
return {a.lo() | b.lo(), a.hi() | b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator^(mask<T, N> a, mask<T, N> b) {
return {a.lo() ^ b.lo(), a.hi() ^ b.hi()};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE mask<T, N> operator~(mask<T, N> a) {
return {~a.lo(), ~a.hi()};
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> select(mask<T, N> m, vec<T, N> a, vec<T, N> b) {
return {select(m.lo(), a.lo(), b.lo()), select(m.hi(), a.hi(), b.hi())};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> floor(vec<T, N> a) {
return {floor(a.lo()), floor(a.hi())};
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ TEST_PARTIAL_LOAD_STORE(arm_neon, bf16, 8);
TEST_PARTIAL_LOAD_STORE(arm_neon, f32, 4);
TEST_PARTIAL_LOAD_STORE(arm_neon, s32, 4);

TEST_COMPARE_EQ(arm_neon, f32, 4);
TEST_COMPARE_LT(arm_neon, f32, 4);
TEST_COMPARE_GT(arm_neon, f32, 4);
TEST_COMPARE_EQ(arm_neon, s32, 4);
TEST_COMPARE_LT(arm_neon, s32, 4);
TEST_COMPARE_GT(arm_neon, s32, 4);

TEST_PARTIAL_LOAD_STORE(arm_neon, u8, 8);

TEST_ADD(arm_neon, u8, 16);
Expand Down
87 changes: 87 additions & 0 deletions ynnpack/base/simd/test/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,93 @@ void test_partial_store() {
test_partial_store<type, N>(); \
}

struct cmp_eq_op {
template <typename T>
bool operator()(T a, T b) {
return a == b;
}
template <typename T, size_t N>
mask<T, N> operator()(vec<T, N> a, vec<T, N> b) {
return a == b;
}
};

struct cmp_lt_op {
template <typename T>
bool operator()(T a, T b) {
return a < b;
}
template <typename T, size_t N>
mask<T, N> operator()(vec<T, N> a, vec<T, N> b) {
return a < b;
}
};

struct cmp_gt_op {
template <typename T>
bool operator()(T a, T b) {
return a > b;
}
template <typename T, size_t N>
mask<T, N> operator()(vec<T, N> a, vec<T, N> b) {
return a > b;
}
};

template <typename scalar, size_t N, typename Op>
void test_compare_select_op() {
using vector = vec<scalar, N>;
using mask_t = mask<scalar, N>;
Op op;

ReplicableRandomDevice rng;
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
scalar a[N];
scalar b[N];
scalar c[N];
scalar d[N];
fill_random(a, N, rng);
fill_random(b, N, rng);
fill_random(c, N, rng);
fill_random(d, N, rng);

// Make some elements equal to test == and <= / >= properly
for (size_t i = 0; i < N; ++i) {
if (rng() % 3 == 0) a[i] = b[i];
}

mask_t m = op(load(a, vector::N), load(b, vector::N));
vector sel = select(m, load(c, vector::N), load(d, vector::N));

scalar res[N];
store(res, sel);

for (size_t i = 0; i < N; ++i) {
bool is_true = op(a[i], b[i]);
if constexpr (std::is_floating_point_v<scalar>) {
if (std::isnan(a[i]) || std::isnan(b[i])) continue;
}

ASSERT_EQ(res[i], is_true ? c[i] : d[i]);
}
}
}

#define TEST_COMPARE_EQ(test_class, type, N) \
TEST_F(test_class, compare_eq_##type##x##N) { \
test_compare_select_op<type, N, cmp_eq_op>(); \
}

#define TEST_COMPARE_LT(test_class, type, N) \
TEST_F(test_class, compare_lt_##type##x##N) { \
test_compare_select_op<type, N, cmp_lt_op>(); \
}

#define TEST_COMPARE_GT(test_class, type, N) \
TEST_F(test_class, compare_gt_##type##x##N) { \
test_compare_select_op<type, N, cmp_gt_op>(); \
}

template <typename scalar, size_t N, typename Op>
void test_op() {
using vector = vec<scalar, N>;
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/multi_vec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ TEST_PARTIAL_LOAD_STORE(multi_vec, bf16, 4);
TEST_PARTIAL_LOAD_STORE(multi_vec, f32, 2);
TEST_PARTIAL_LOAD_STORE(multi_vec, s32, 2);

TEST_COMPARE_EQ(multi_vec, f32, 2);
TEST_COMPARE_LT(multi_vec, f32, 2);
TEST_COMPARE_GT(multi_vec, f32, 2);
TEST_COMPARE_EQ(multi_vec, s32, 2);
TEST_COMPARE_LT(multi_vec, s32, 2);
TEST_COMPARE_GT(multi_vec, s32, 2);

TEST_ADD(multi_vec, u8, 8);
TEST_ADD(multi_vec, s8, 8);
TEST_ADD(multi_vec, s16, 4);
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ TEST_PARTIAL_LOAD_STORE(x86_avx, bf16, 16);
TEST_PARTIAL_LOAD_STORE(x86_avx, f32, 8);
TEST_PARTIAL_LOAD_STORE(x86_avx, s32, 8);

TEST_COMPARE_EQ(x86_avx, f32, 8);
TEST_COMPARE_LT(x86_avx, f32, 8);
TEST_COMPARE_GT(x86_avx, f32, 8);
TEST_COMPARE_EQ(x86_avx, s32, 8);
TEST_COMPARE_LT(x86_avx, s32, 8);
TEST_COMPARE_GT(x86_avx, s32, 8);

TEST_ADD(x86_avx, f32, 8);
TEST_SUBTRACT(x86_avx, f32, 8);
TEST_MULTIPLY(x86_avx, f32, 8);
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ TEST_PARTIAL_LOAD_STORE(x86_avx512, bf16, 32);
TEST_PARTIAL_LOAD_STORE(x86_avx512, f32, 16);
TEST_PARTIAL_LOAD_STORE(x86_avx512, s32, 16);

TEST_COMPARE_EQ(x86_avx512, f32, 16);
TEST_COMPARE_LT(x86_avx512, f32, 16);
TEST_COMPARE_GT(x86_avx512, f32, 16);
TEST_COMPARE_EQ(x86_avx512, s32, 16);
TEST_COMPARE_LT(x86_avx512, s32, 16);
TEST_COMPARE_GT(x86_avx512, s32, 16);

TEST_ADD(x86_avx512, u8, 64);
TEST_ADD(x86_avx512, s8, 64);
TEST_ADD(x86_avx512, f32, 16);
Expand Down
7 changes: 7 additions & 0 deletions ynnpack/base/simd/test/x86_sse2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ TEST_PARTIAL_LOAD_STORE(x86_sse2, bf16, 8);
TEST_PARTIAL_LOAD_STORE(x86_sse2, f32, 4);
TEST_PARTIAL_LOAD_STORE(x86_sse2, s32, 4);

TEST_COMPARE_EQ(x86_sse2, f32, 4);
TEST_COMPARE_LT(x86_sse2, f32, 4);
TEST_COMPARE_GT(x86_sse2, f32, 4);
TEST_COMPARE_EQ(x86_sse2, s32, 4);
TEST_COMPARE_LT(x86_sse2, s32, 4);
TEST_COMPARE_GT(x86_sse2, s32, 4);

TEST_ADD(x86_sse2, u8, 16);
TEST_ADD(x86_sse2, s8, 16);
TEST_ADD(x86_sse2, s16, 8);
Expand Down
Loading
Loading