diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index 63893cdbb..009e39c23 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -1068,10 +1068,45 @@ namespace xsimd } // load_masked + template ::value>::type> + XSIMD_INLINE batch load_masked(T const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.mask() == 0x1) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return mm_loadu_si16(mem); + } + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return mm_loadu_si32(mem); + } + XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return mm_loadu_si64(mem); + } + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2 && mask.mask() == 0x3) + { + return mm_loadu_si32(mem); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4 && mask.mask() == 0x3) + { + return mm_loadu_si64(mem); + } + else + { + return load_masked(mem, mask, convert {}, Mode {}, common {}); + } + } template XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + XSIMD_IF_CONSTEXPR(mask.mask() == 0x1) + { + return _mm_load_ss(mem); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) { return _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<__m64 const*>(mem)); } @@ -1089,7 +1124,7 @@ namespace xsimd { XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) { - return _mm_move_sd(_mm_setzero_pd(), _mm_load_sd(mem)); + return _mm_load_sd(mem); } else XSIMD_IF_CONSTEXPR(mask.countl_one() == 1) { @@ -1105,7 +1140,11 @@ namespace xsimd template XSIMD_INLINE void store_masked(float* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + XSIMD_IF_CONSTEXPR(mask.mask() == 0x1) + { + _mm_store_ss(mem, src); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) { _mm_storel_pi(reinterpret_cast<__m64*>(mem), src); }