diff --git a/include/xsimd/types/xsimd_batch.hpp b/include/xsimd/types/xsimd_batch.hpp index b3b704666..38acbe339 100644 --- a/include/xsimd/types/xsimd_batch.hpp +++ b/include/xsimd/types/xsimd_batch.hpp @@ -296,6 +296,7 @@ namespace xsimd static constexpr std::size_t size = sizeof(types::simd_register) / sizeof(T); ///< Number of scalar elements in this batch. using value_type = bool; ///< Type of the scalar elements within this batch. + using operand_type = T; using arch_type = A; ///< SIMD Architecture abstracted by this batch. using register_type = typename base_type::register_type; ///< SIMD register type abstracted by this batch. using batch_type = batch; ///< Associated batch type this batch represents logical operations for. diff --git a/include/xsimd/types/xsimd_batch_constant.hpp b/include/xsimd/types/xsimd_batch_constant.hpp index 60e27493d..878283ec0 100644 --- a/include/xsimd/types/xsimd_batch_constant.hpp +++ b/include/xsimd/types/xsimd_batch_constant.hpp @@ -12,6 +12,8 @@ #ifndef XSIMD_BATCH_CONSTANT_HPP #define XSIMD_BATCH_CONSTANT_HPP +#include + #include "./xsimd_batch.hpp" #include "./xsimd_utils.hpp" @@ -31,6 +33,7 @@ namespace xsimd using batch_type = batch_bool; static constexpr std::size_t size = sizeof...(Values); using value_type = bool; + using operand_type = T; static_assert(sizeof...(Values) == batch_type::size, "consistent batch size"); public: @@ -44,7 +47,7 @@ namespace xsimd */ constexpr operator batch_type() const noexcept { return as_batch_bool(); } - constexpr bool get(size_t i) const noexcept + constexpr bool get(std::size_t i) const noexcept { return std::array { { Values... } }[i]; } @@ -76,7 +79,7 @@ namespace xsimd constexpr bool operator()(bool x, bool y) const { return x ^ y; } }; - template + template static constexpr batch_bool_constant::type::value, std::tuple_element::type::value)...> apply(detail::index_sequence) { @@ -88,7 +91,7 @@ namespace xsimd -> decltype(apply...>, std::tuple...>>(detail::make_index_sequence())) { static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches"); - return apply...>, std::tuple...>>(detail::make_index_sequence()); + return {}; } public: @@ -148,13 +151,13 @@ namespace xsimd /** * @brief Get the @p i th element of this @p batch_constant */ - constexpr T get(size_t i) const noexcept + constexpr T get(std::size_t i) const noexcept { return get(i, std::array { Values... }); } private: - constexpr T get(size_t i, std::array const& values) const noexcept + constexpr T get(std::size_t i, std::array const& values) const noexcept { return values[i]; } @@ -191,8 +194,16 @@ namespace xsimd { constexpr T operator()(T x, T y) const { return x ^ y; } }; + struct binary_rshift + { + constexpr T operator()(T x, T y) const { return x >> y; } + }; + struct binary_lshift + { + constexpr T operator()(T x, T y) const { return x << y; } + }; - template + template static constexpr batch_constant::type::value, std::tuple_element::type::value)...> apply(detail::index_sequence) { @@ -204,7 +215,7 @@ namespace xsimd -> decltype(apply...>, std::tuple...>>(detail::make_index_sequence())) { static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches"); - return apply...>, std::tuple...>>(detail::make_index_sequence()); + return {}; } public: @@ -224,9 +235,68 @@ namespace xsimd MAKE_BINARY_OP(&, binary_and) MAKE_BINARY_OP(|, binary_or) MAKE_BINARY_OP(^, binary_xor) + MAKE_BINARY_OP(<<, binary_lshift) + MAKE_BINARY_OP(>>, binary_rshift) #undef MAKE_BINARY_OP + struct boolean_eq + { + constexpr bool operator()(T x, T y) const { return x == y; } + }; + struct boolean_ne + { + constexpr bool operator()(T x, T y) const { return x != y; } + }; + struct boolean_gt + { + constexpr bool operator()(T x, T y) const { return x > y; } + }; + struct boolean_ge + { + constexpr bool operator()(T x, T y) const { return x >= y; } + }; + struct boolean_lt + { + constexpr bool operator()(T x, T y) const { return x < y; } + }; + struct boolean_le + { + constexpr bool operator()(T x, T y) const { return x <= y; } + }; + + template + static constexpr batch_bool_constant::type::value, std::tuple_element::type::value)...> + apply_bool(detail::index_sequence) + { + return {}; + } + + template + static constexpr auto apply_bool(batch_constant, batch_constant) + -> decltype(apply_bool...>, std::tuple...>>(detail::make_index_sequence())) + { + static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches"); + return {}; + } + +#define MAKE_BINARY_BOOL_OP(OP, NAME) \ + template \ + constexpr auto operator OP(batch_constant other) const \ + -> decltype(apply_bool(*this, other)) \ + { \ + return {}; \ + } + + MAKE_BINARY_BOOL_OP(==, boolean_eq) + MAKE_BINARY_BOOL_OP(!=, boolean_ne) + MAKE_BINARY_BOOL_OP(<, boolean_lt) + MAKE_BINARY_BOOL_OP(<=, boolean_le) + MAKE_BINARY_BOOL_OP(>, boolean_gt) + MAKE_BINARY_BOOL_OP(>=, boolean_ge) + +#undef MAKE_BINARY_BOOL_OP + constexpr batch_constant operator-() const { return {}; diff --git a/test/test_batch_constant.cpp b/test/test_batch_constant.cpp index f5e1a09b4..5b0e23028 100644 --- a/test/test_batch_constant.cpp +++ b/test/test_batch_constant.cpp @@ -136,6 +136,14 @@ struct constant_batch_test constexpr auto n12_lxor_n3 = n12 ^ n3; static_assert(std::is_same::value, "n12 ^ n3 == n15"); + constexpr auto n96 = xsimd::make_batch_constant, arch_type>(); + constexpr auto n12_lshift_n3 = n12 << n3; + static_assert(std::is_same::value, "n12 << n3 == n96"); + + constexpr auto n1 = xsimd::make_batch_constant, arch_type>(); + constexpr auto n12_rshift_n3 = n12 >> n3; + static_assert(std::is_same::value, "n12 >> n3 == n1"); + constexpr auto n12_uadd = +n12; static_assert(std::is_same::value, "+n12 == n12"); @@ -146,6 +154,30 @@ struct constant_batch_test constexpr auto n12_usub = -n12; constexpr auto n12_usub_ = xsimd::make_batch_constant, arch_type>(); static_assert(std::is_same::value, "-n12 == n12_usub"); + + // comparison operators + using true_batch_type = decltype(xsimd::make_batch_bool_constant()); + using false_batch_type = decltype(xsimd::make_batch_bool_constant()); + + static_assert(std::is_same::value, "same type"); + + static_assert(std::is_same::value, "n12 == n12"); + static_assert(std::is_same::value, "n12 == n3"); + + static_assert(std::is_same::value, "n12 != n12"); + static_assert(std::is_same::value, "n12 != n3"); + + static_assert(std::is_same::value, "n12 < n12"); + static_assert(std::is_same::value, "n12 < n3"); + + static_assert(std::is_same n12), false_batch_type>::value, "n12 > n12"); + static_assert(std::is_same n3), true_batch_type>::value, "n12 > n3"); + + static_assert(std::is_same::value, "n12 <= n12"); + static_assert(std::is_same::value, "n12 <= n3"); + + static_assert(std::is_same= n12), true_batch_type>::value, "n12 >= n12"); + static_assert(std::is_same= n3), true_batch_type>::value, "n12 >= n3"); } };