diff --git a/docs/source/api/type_traits.rst b/docs/source/api/type_traits.rst index f0b082bc3..31a5f2585 100644 --- a/docs/source/api/type_traits.rst +++ b/docs/source/api/type_traits.rst @@ -32,9 +32,15 @@ Type Traits =========== `xsimd` provides a few type traits to interact with scalar and batch types in an -uniformeous manner. +uniform manner. +Combined traits: + ++---------------------------------------+----------------------------------------------------+ +| :cpp:class:`batch_traits` | batch types and proprties | ++---------------------------------------+----------------------------------------------------+ + Type check: +---------------------------------------+----------------------------------------------------+ diff --git a/include/xsimd/arch/common/xsimd_common_logical.hpp b/include/xsimd/arch/common/xsimd_common_logical.hpp index b8128158d..69ab75305 100644 --- a/include/xsimd/arch/common/xsimd_common_logical.hpp +++ b/include/xsimd/arch/common/xsimd_common_logical.hpp @@ -212,6 +212,35 @@ namespace xsimd res |= 1ul << i; return res; } + + // select + namespace detail + { + template + using is_batch_bool_register_same = std::is_same::register_type, typename batch::register_type>; + } + + template ::value, int>::type = 3> + XSIMD_INLINE batch_bool select(batch_bool const& cond, batch_bool const& true_br, batch_bool const& false_br, requires_arch) + { + using register_type = typename batch_bool::register_type; + // Do not cast, but rather reinterpret the masks as batches. + const auto true_v = batch { static_cast(true_br) }; + const auto false_v = batch { static_cast(false_br) }; + return batch_bool { select(cond, true_v, false_v) }; + } + + template ::value, int>::type = 3> + XSIMD_INLINE batch_bool select(batch_bool const& cond, batch_bool const& true_br, batch_bool const& false_br, requires_arch) + { + return (true_br & cond) | (bitwise_andnot(false_br, cond)); + } + + template + XSIMD_INLINE batch_bool select(batch_bool_constant const& cond, batch_bool const& true_br, batch_bool const& false_br, requires_arch) + { + return (true_br & cond) | (false_br & ~cond); + } } } diff --git a/include/xsimd/arch/xsimd_scalar.hpp b/include/xsimd/arch/xsimd_scalar.hpp index 9d24ed01f..f4ffb8c80 100644 --- a/include/xsimd/arch/xsimd_scalar.hpp +++ b/include/xsimd/arch/xsimd_scalar.hpp @@ -83,53 +83,16 @@ namespace xsimd using std::tgamma; using std::trunc; - XSIMD_INLINE signed char abs(signed char v) + template + XSIMD_INLINE constexpr typename std::enable_if::value && std::is_signed::value, T>::type + abs(T v) noexcept { return v < 0 ? -v : v; } - namespace detail - { - // Use templated type here to prevent automatic instantiation that may - // ends up in a warning - template - XSIMD_INLINE char abs(char_type v, std::true_type) - { - return v; - } - template - XSIMD_INLINE char abs(char_type v, std::false_type) - { - return v < 0 ? -v : v; - } - } - - XSIMD_INLINE char abs(char v) - { - return detail::abs(v, std::is_unsigned::type {}); - } - - XSIMD_INLINE short abs(short v) - { - return v < 0 ? -v : v; - } - XSIMD_INLINE unsigned char abs(unsigned char v) - { - return v; - } - XSIMD_INLINE unsigned short abs(unsigned short v) - { - return v; - } - XSIMD_INLINE unsigned int abs(unsigned int v) - { - return v; - } - XSIMD_INLINE unsigned long abs(unsigned long v) - { - return v; - } - XSIMD_INLINE unsigned long long abs(unsigned long long v) + template + XSIMD_INLINE constexpr typename std::enable_if::value && std::is_unsigned::value, T>::type + abs(T v) noexcept { return v; } @@ -1235,6 +1198,19 @@ namespace xsimd { return cond ? true_br : false_br; } + + template + XSIMD_INLINE constexpr bool batch_bool_cast(bool b) noexcept + { + return b; + } + + template + XSIMD_INLINE constexpr T_out batch_cast(T_in const& val) noexcept + { + static_assert(!std::is_same::value, "cannot convert to bool, use !x or x != 0"); + return static_cast(val); + } } #endif diff --git a/include/xsimd/types/xsimd_api.hpp b/include/xsimd/types/xsimd_api.hpp index 12bd9d95e..13db0a4d7 100644 --- a/include/xsimd/types/xsimd_api.hpp +++ b/include/xsimd/types/xsimd_api.hpp @@ -2099,6 +2099,27 @@ namespace xsimd return kernel::select(cond, true_br, false_br, A {}); } + /** + * @ingroup batch_bool_logical + * + * Ternary operator for conditions: selects values from the batches \c true_br or \c false_br + * depending on the boolean values in the constant batch \c cond. Equivalent to + * \code{.cpp} + * for(std::size_t i = 0; i < N; ++i) + * res[i] = cond[i] ? true_br[i] : false_br[i]; + * \endcode + * @param cond batch condition. + * @param true_br batch values for truthy condition. + * @param false_br batch value for falsy condition. + * @return the result of the selection. + */ + template + XSIMD_INLINE batch_bool select(batch_bool const& cond, batch_bool const& true_br, batch_bool const& false_br) noexcept + { + detail::static_check_supported_config(); + return kernel::select(cond, true_br, false_br, A {}); + } + /** * @ingroup batch_cond * @@ -2141,6 +2162,27 @@ namespace xsimd return kernel::select(cond, true_br, false_br, A {}); } + /** + * @ingroup batch_cond + * + * Ternary operator for mask batches: selects values from the masks \c true_br or \c false_br + * depending on the boolean values in the constant batch \c cond. Equivalent to + * \code{.cpp} + * for(std::size_t i = 0; i < N; ++i) + * res[i] = cond[i] ? true_br[i] : false_br[i]; + * \endcode + * @param cond constant batch condition. + * @param true_br batch values for truthy condition. + * @param false_br batch value for falsy condition. + * @return the result of the selection. + */ + template + XSIMD_INLINE batch_bool select(batch_bool_constant const& cond, batch_bool const& true_br, batch_bool const& false_br) noexcept + { + detail::static_check_supported_config(); + return kernel::select(cond, true_br, false_br, A {}); + } + /** * @ingroup batch_data_transfer * diff --git a/include/xsimd/types/xsimd_traits.hpp b/include/xsimd/types/xsimd_traits.hpp index 20c97f89f..f11fb39e1 100644 --- a/include/xsimd/types/xsimd_traits.hpp +++ b/include/xsimd/types/xsimd_traits.hpp @@ -233,59 +233,133 @@ namespace xsimd /** * @ingroup batch_traits * - * type traits that inherits from @c std::true_type for @c batch<...> types and from - * @c std::false_type otherwise. + * type traits that provide information about a batch or scalar type. * * @tparam T type to analyze. */ + template - struct is_batch; + struct batch_traits + { + using scalar_type = T; ///< T if scalar, or type of the scalar element for the batch T. + using mask_type = bool; ///< Mask type for T: bool for scalars, or batch_bool for batch types. + static constexpr bool is_batch = false; ///< True if T is @c batch<...>. + static constexpr bool is_batch_bool = false; ///< True if T is @c batch_bool<...>. + static constexpr bool is_any_batch = false; ///< True if T is @c batch<...> or @c batch_bool<...>. + static constexpr bool is_complex = detail::is_complex::value; ///< True if T is complex or a batch of complex values. + }; + +#if __cplusplus < 201703L + template + constexpr bool batch_traits::is_batch; template - struct is_batch : std::false_type + constexpr bool batch_traits::is_batch_bool; + template + constexpr bool batch_traits::is_any_batch; + template + constexpr bool batch_traits::is_complex; +#endif + + template + struct batch_traits> { + using scalar_type = T; + using mask_type = typename batch::batch_bool_type; + + static constexpr bool is_batch = true; + static constexpr bool is_batch_bool = false; + static constexpr bool is_any_batch = true; + static constexpr bool is_complex = detail::is_complex::value; }; +#if __cplusplus < 201703L + template + constexpr bool batch_traits>::is_batch; + template + constexpr bool batch_traits>::is_batch_bool; + template + constexpr bool batch_traits>::is_any_batch; + template + constexpr bool batch_traits>::is_complex; +#endif + template - struct is_batch> : std::true_type + struct batch_traits> { + using scalar_type = bool; + using mask_type = batch_bool; + + static constexpr bool is_batch = false; + static constexpr bool is_batch_bool = true; + static constexpr bool is_any_batch = true; + static constexpr bool is_complex = false; }; +#if __cplusplus < 201703L + template + constexpr bool batch_traits>::is_batch; + template + constexpr bool batch_traits>::is_batch_bool; + template + constexpr bool batch_traits>::is_any_batch; + template + constexpr bool batch_traits>::is_complex; +#endif + /** * @ingroup batch_traits * - * type traits that inherits from @c std::true_type for @c batch_bool<...> types and from + * type traits that inherits from @c std::true_type for @c batch<...> types and from * @c std::false_type otherwise. * * @tparam T type to analyze. */ template - struct is_batch_bool : std::false_type + struct is_batch : std::integral_constant::is_batch> { }; - template - struct is_batch_bool> : std::true_type + /** + * @ingroup batch_traits + * + * type traits that inherits from @c std::true_type for @c batch_bool<...> types and from + * @c std::false_type otherwise. + * + * @tparam T type to analyze. + */ + + template + struct is_batch_bool : std::integral_constant::is_batch_bool> { }; /** * @ingroup batch_traits * - * type traits that inherits from @c std::true_type for @c batch> + * type traits that inherits from @c std::true_type for @c batch<...> or batch_bool<...> * types and from @c std::false_type otherwise. * * @tparam T type to analyze. */ template - struct is_batch_complex : std::false_type + struct is_any_batch : std::integral_constant::is_any_batch> { }; - template - struct is_batch_complex, A>> : std::true_type + /** + * @ingroup batch_traits + * + * type traits that inherits from @c std::true_type for @c batch> + * types and from @c std::false_type otherwise. + * + * @tparam T type to analyze. + */ + + template + struct is_batch_complex : std::integral_constant::is_batch && batch_traits::is_complex> { }; @@ -300,12 +374,7 @@ namespace xsimd template struct scalar_type { - using type = T; - }; - template - struct scalar_type> - { - using type = T; + using type = typename batch_traits::scalar_type; }; template @@ -322,12 +391,7 @@ namespace xsimd template struct mask_type { - using type = bool; - }; - template - struct mask_type> - { - using type = typename batch::batch_bool_type; + using type = typename batch_traits::mask_type; }; template @@ -364,7 +428,6 @@ namespace xsimd } template using widen_t = typename detail::widen::type; - } #endif diff --git a/test/test_batch_cast.cpp b/test/test_batch_cast.cpp index 9605fe5b1..d5c0dc376 100644 --- a/test/test_batch_cast.cpp +++ b/test/test_batch_cast.cpp @@ -341,6 +341,7 @@ struct batch_cast_test T_out scalar_ref = static_cast(in_test_value); T_out scalar_res = res.get(0); CHECK_SCALAR_EQ(scalar_ref, scalar_res); + CHECK_SCALAR_EQ(scalar_ref, xsimd::batch_cast(in_test_value)); } } @@ -356,11 +357,13 @@ struct batch_cast_test B_common_out all_true_res = xsimd::batch_bool_cast(all_true_in); INFO(name); CHECK_SCALAR_EQ(all_true_res.get(0), true); + CHECK_SCALAR_EQ(xsimd::batch_bool_cast(true), true); B_common_in all_false_in(false); B_common_out all_false_res = xsimd::batch_bool_cast(all_false_in); INFO(name); CHECK_SCALAR_EQ(all_false_res.get(0), false); + CHECK_SCALAR_EQ(xsimd::batch_bool_cast(false), false); } }; diff --git a/test/test_select.cpp b/test/test_select.cpp index 6b450afd6..837fca314 100644 --- a/test/test_select.cpp +++ b/test/test_select.cpp @@ -18,22 +18,26 @@ template struct select_test { using batch_type = B; + using batch_bool_type = typename B::batch_bool_type; using value_type = typename B::value_type; using arch_type = typename B::arch_type; static constexpr size_t size = B::size; - using vector_type = std::vector; + static constexpr size_t nb_input = size * 10000; + using vector_type = std::array; + using vector_bool_type = std::array; - size_t nb_input; vector_type lhs_input; vector_type rhs_input; vector_type expected; vector_type res; + vector_bool_type lhs_input_b; + vector_bool_type rhs_input_b; + vector_bool_type expected_b; + vector_bool_type res_b; + select_test() { - nb_input = size * 10000; - lhs_input.resize(nb_input); - rhs_input.resize(nb_input); auto clamp = [](double v) { return static_cast(std::min(v, static_cast(std::numeric_limits::max()))); @@ -42,9 +46,9 @@ struct select_test { lhs_input[i] = clamp(i / 4 + 1.2 * std::sqrt(i + 0.25)); rhs_input[i] = clamp(10.2 / (i + 2) + 0.25); + lhs_input_b[i] = (int)lhs_input[i] % 2; + rhs_input_b[i] = (int)rhs_input[i] % 2; } - expected.resize(nb_input); - res.resize(nb_input); } void test_select_dynamic() @@ -52,18 +56,27 @@ struct select_test for (size_t i = 0; i < nb_input; ++i) { expected[i] = lhs_input[i] > value_type(3) ? lhs_input[i] : rhs_input[i]; + expected_b[i] = lhs_input[i] > value_type(3) ? lhs_input_b[i] : rhs_input_b[i]; } - batch_type lhs_in, rhs_in, out; + batch_type lhs_in, rhs_in; + batch_bool_type lhs_in_b, rhs_in_b; for (size_t i = 0; i < nb_input; i += size) { detail::load_batch(lhs_in, lhs_input, i); detail::load_batch(rhs_in, rhs_input, i); - out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in); + const auto out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in); detail::store_batch(out, res, i); + + detail::load_batch(lhs_in_b, lhs_input_b, i); + detail::load_batch(rhs_in_b, rhs_input_b, i); + const auto out_b = xsimd::select(lhs_in > value_type(3), lhs_in_b, rhs_in_b); + detail::store_batch(out_b, res_b, i); } size_t diff = detail::get_nb_diff(res, expected); + size_t diff_b = detail::get_nb_diff(res_b, expected_b); CHECK_EQ(diff, 0); + CHECK_EQ(diff_b, 0); } struct pattern { @@ -77,25 +90,35 @@ struct select_test for (size_t i = 0; i < nb_input; ++i) { expected[i] = mask.get(i % size) ? lhs_input[i] : rhs_input[i]; + expected_b[i] = mask.get(i % size) ? lhs_input_b[i] : rhs_input_b[i]; } - batch_type lhs_in, rhs_in, out; + batch_type lhs_in, rhs_in; + batch_bool_type lhs_in_b, rhs_in_b; for (size_t i = 0; i < nb_input; i += size) { detail::load_batch(lhs_in, lhs_input, i); detail::load_batch(rhs_in, rhs_input, i); - out = xsimd::select(mask, lhs_in, rhs_in); + const auto out = xsimd::select(mask, lhs_in, rhs_in); detail::store_batch(out, res, i); + + detail::load_batch(lhs_in_b, lhs_input_b, i); + detail::load_batch(rhs_in_b, rhs_input_b, i); + const auto out_b = xsimd::select(mask, lhs_in_b, rhs_in_b); + detail::store_batch(out_b, res_b, i); } size_t diff = detail::get_nb_diff(res, expected); + size_t diff_b = detail::get_nb_diff(res_b, expected_b); CHECK_EQ(diff, 0); + CHECK_EQ(diff_b, 0); } }; TEST_CASE_TEMPLATE("[select]", B, BATCH_TYPES) { - select_test Test; - SUBCASE("select_dynamic") { Test.test_select_dynamic(); } - SUBCASE("select_static") { Test.test_select_static(); } + // Allocate on heap to avoid stack overflow from excessively large object. + std::unique_ptr> Test { new select_test }; + SUBCASE("select_dynamic") { Test->test_select_dynamic(); } + SUBCASE("select_static") { Test->test_select_static(); } } #endif