diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index fe1f67c25..92894213e 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -1276,4 +1276,37 @@ KacsWalk(float* data, uint64_t len) { return sse::FHTRotate(data, len); #endif } + +void +FlipSign(const uint8_t* flip, float* data, uint64_t dim) { +#if defined(ENABLE_AVX) + constexpr uint64_t kFloatsPerChunk = 8; + uint64_t i = 0; + for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) { + uint32_t bit_idx = static_cast(i); + uint16_t flip_bits = 0; + for (int j = 0; j < 8; j++) { + flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j; + } + + alignas(32) uint32_t mask[8]; + for (int j = 0; j < 8; j++) { + mask[j] = (flip_bits & (1 << j)) ? 0x80000000 : 0; + } + + __m256i sign_mask_int = _mm256_load_si256(reinterpret_cast(mask)); + __m256 vec = _mm256_loadu_ps(&data[i]); + vec = _mm256_xor_ps(vec, _mm256_castsi256_ps(sign_mask_int)); + _mm256_storeu_ps(&data[i], vec); + } + for (; i < dim; i++) { + bool mask = (flip[i / 8] & (1 << (i % 8))) != 0; + if (mask) { + data[i] = -data[i]; + } + } +#else + return sse::FlipSign(flip, data, dim); +#endif +} } // namespace vsag::avx diff --git a/src/simd/avx2.cpp b/src/simd/avx2.cpp index ad95d9094..14dd5e7fe 100644 --- a/src/simd/avx2.cpp +++ b/src/simd/avx2.cpp @@ -1426,4 +1426,37 @@ KacsWalk(float* data, uint64_t len) { return avx::KacsWalk(data, len); #endif } + +void +FlipSign(const uint8_t* flip, float* data, uint64_t dim) { +#if defined(ENABLE_AVX2) + constexpr uint64_t kFloatsPerChunk = 8; + uint64_t i = 0; + for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) { + uint32_t bit_idx = static_cast(i); + uint16_t flip_bits = 0; + for (int j = 0; j < 8; j++) { + flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j; + } + + alignas(32) uint32_t mask[8]; + for (int j = 0; j < 8; j++) { + mask[j] = (flip_bits & (1 << j)) ? 0x80000000 : 0; + } + + __m256i sign_mask = _mm256_load_si256(reinterpret_cast(mask)); + __m256 vec = _mm256_loadu_ps(&data[i]); + vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(vec), sign_mask)); + _mm256_storeu_ps(&data[i], vec); + } + for (; i < dim; i++) { + bool mask = (flip[i / 8] & (1 << (i % 8))) != 0; + if (mask) { + data[i] = -data[i]; + } + } +#else + return avx::FlipSign(flip, data, dim); +#endif +} } // namespace vsag::avx2 diff --git a/src/simd/rabitq_simd.h b/src/simd/rabitq_simd.h index 95c49bfe7..d6fec4401 100644 --- a/src/simd/rabitq_simd.h +++ b/src/simd/rabitq_simd.h @@ -64,6 +64,10 @@ VecRescale(float* data, uint64_t dim, float val); void KacsWalk(float* data, uint64_t len); + +void +FlipSign(const uint8_t* flip, float* data, uint64_t dim); + } // namespace avx2 namespace avx { @@ -99,6 +103,9 @@ RotateOp(float* data, int idx, int dim_, int step); void VecRescale(float* data, uint64_t dim, float val); +void +FlipSign(const uint8_t* flip, float* data, uint64_t dim); + void KacsWalk(float* data, uint64_t len); } // namespace sse diff --git a/src/simd/rabitq_simd_test.cpp b/src/simd/rabitq_simd_test.cpp index 9ae0e23bc..42f4aa768 100644 --- a/src/simd/rabitq_simd_test.cpp +++ b/src/simd/rabitq_simd_test.cpp @@ -463,6 +463,27 @@ TEST_CASE("SIMD test for flip_sign", "[ut][simd]") { generic::FlipSign(flip, gt_data, dim); + if (SimdStatus::SupportSSE()) { + auto* sse_data = sse_datas.data() + i * dim; + sse::FlipSign(flip, sse_data, dim); + for (int j = 0; j < dim; j++) { + REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta); + } + } + if (SimdStatus::SupportAVX()) { + auto* avx_data = avx_datas.data() + i * dim; + avx::FlipSign(flip, avx_data, dim); + for (int j = 0; j < dim; j++) { + REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta); + } + } + if (SimdStatus::SupportAVX2()) { + auto* avx2_data = avx2_datas.data() + i * dim; + avx2::FlipSign(flip, avx2_data, dim); + for (int j = 0; j < dim; j++) { + REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta); + } + } if (SimdStatus::SupportAVX512()) { auto* avx512_data = avx512_datas.data() + i * dim; avx512::FlipSign(flip, avx512_data, dim); @@ -505,6 +526,15 @@ TEST_CASE("SIMD FlipSign Benchmark", "[ut][simd][!benchmark]") { std::vector flips = fixtures::GenerateVectors(count, flip_size); BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign); + if (SimdStatus::SupportSSE()) { + BENCHMARK_SIMD_FLIP_SIGN(sse, FlipSign); + } + if (SimdStatus::SupportAVX()) { + BENCHMARK_SIMD_FLIP_SIGN(avx, FlipSign); + } + if (SimdStatus::SupportAVX2()) { + BENCHMARK_SIMD_FLIP_SIGN(avx2, FlipSign); + } if (SimdStatus::SupportAVX512()) { BENCHMARK_SIMD_FLIP_SIGN(avx512, FlipSign); } @@ -549,6 +579,24 @@ TEST_CASE("SIMD FlipSign Correctness with Patterns", "[ut][simd]") { generic::FlipSign(test.flip_pattern.data(), gt_data.data(), dim); + if (SimdStatus::SupportSSE()) { + sse::FlipSign(test.flip_pattern.data(), sse_data.data(), dim); + for (int i = 0; i < dim; i++) { + REQUIRE(std::abs(gt_data[i] - sse_data[i]) < delta); + } + } + if (SimdStatus::SupportAVX()) { + avx::FlipSign(test.flip_pattern.data(), avx_data.data(), dim); + for (int i = 0; i < dim; i++) { + REQUIRE(std::abs(gt_data[i] - avx_data[i]) < delta); + } + } + if (SimdStatus::SupportAVX2()) { + avx2::FlipSign(test.flip_pattern.data(), avx2_data.data(), dim); + for (int i = 0; i < dim; i++) { + REQUIRE(std::abs(gt_data[i] - avx2_data[i]) < delta); + } + } if (SimdStatus::SupportAVX512()) { avx512::FlipSign(test.flip_pattern.data(), avx512_data.data(), dim); for (int i = 0; i < dim; i++) { diff --git a/src/simd/sse.cpp b/src/simd/sse.cpp index fe3830a71..c51646f49 100644 --- a/src/simd/sse.cpp +++ b/src/simd/sse.cpp @@ -1353,4 +1353,39 @@ KacsWalk(float* data, uint64_t len) { return generic::KacsWalk(data, len); #endif } + +void +FlipSign(const uint8_t* flip, float* data, uint64_t dim) { +#if defined(ENABLE_SSE) + constexpr uint64_t kFloatsPerChunk = 4; + uint64_t i = 0; + for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) { + uint32_t bit_idx = static_cast(i); + uint8_t flip_bits = 0; + flip_bits |= ((flip[bit_idx / 8] >> (bit_idx % 8)) & 1) << 0; + flip_bits |= ((flip[(bit_idx + 1) / 8] >> ((bit_idx + 1) % 8)) & 1) << 1; + flip_bits |= ((flip[(bit_idx + 2) / 8] >> ((bit_idx + 2) % 8)) & 1) << 2; + flip_bits |= ((flip[(bit_idx + 3) / 8] >> ((bit_idx + 3) % 8)) & 1) << 3; + + alignas(16) uint32_t mask[4]; + mask[0] = (flip_bits & 1) ? 0x80000000 : 0; + mask[1] = (flip_bits & 2) ? 0x80000000 : 0; + mask[2] = (flip_bits & 4) ? 0x80000000 : 0; + mask[3] = (flip_bits & 8) ? 0x80000000 : 0; + + __m128i sign_mask_int = _mm_load_si128(reinterpret_cast(mask)); + __m128 vec = _mm_loadu_ps(&data[i]); + vec = _mm_xor_ps(vec, _mm_castsi128_ps(sign_mask_int)); + _mm_storeu_ps(&data[i], vec); + } + for (; i < dim; i++) { + bool mask = (flip[i / 8] & (1 << (i % 8))) != 0; + if (mask) { + data[i] = -data[i]; + } + } +#else + return generic::FlipSign(flip, data, dim); +#endif +} } // namespace vsag::sse