feat: add FlipSign SIMD implementation for SSE/AVX/AVX2#10
Hidden character warning
Conversation
- Add FlipSign function to sse namespace using _mm_xor_ps - Add FlipSign function to avx namespace using _mm256_xor_ps - Add FlipSign function to avx2 namespace using _mm256_xor_si256 - Update rabitq_simd.h with declarations for sse and avx2 namespaces - Add unit tests for SSE/AVX/AVX2 FlipSign implementations
Reviewer's GuideAdds SIMD-optimized FlipSign implementations for SSE, AVX, and AVX2, wires them into the shared SIMD header, and extends unit tests and benchmarks to validate and compare the new paths against the generic implementation. Class diagram for SIMD FlipSign functions in SSE/AVX/AVX2 namespacesclassDiagram
namespace_generic <|.. namespace_sse : fallback
namespace_sse <|.. namespace_avx : fallback
namespace_avx <|.. namespace_avx2 : fallback
class namespace_generic {
<<namespace>>
+FlipSign(flip uint8_t*, data float*, dim uint64_t)
}
class namespace_sse {
<<namespace>>
+FlipSign(flip uint8_t*, data float*, dim uint64_t)
+KacsWalk(data float*, len uint64_t)
+VecRescale(data float*, dim uint64_t, val float)
+RotateOp(data float*, idx int, dim_ int, step int)
+FHTRotate(data float*, len uint64_t)
}
class namespace_avx {
<<namespace>>
+FlipSign(flip uint8_t*, data float*, dim uint64_t)
+KacsWalk(data float*, len uint64_t)
+VecRescale(data float*, dim uint64_t, val float)
+RotateOp(data float*, idx int, dim_ int, step int)
}
class namespace_avx2 {
<<namespace>>
+FlipSign(flip uint8_t*, data float*, dim uint64_t)
+KacsWalk(data float*, len uint64_t)
+VecRescale(data float*, dim uint64_t, val float)
}
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 4 issues, and left some high level feedback:
- The AVX and AVX2
FlipSignimplementations are nearly identical (bit extraction + mask construction); consider factoring this into a shared helper to avoid duplication and keep future changes in sync. - In
avx::FlipSignthe#elsebranch callssse::FlipSign, which assumes SSE is available whenever AVX is not; if that's not guaranteed by build flags, it may be safer to fall back directly togeneric::FlipSignfor consistency with other paths. - The bit extraction in
FlipSign(SSE/AVX/AVX2) repeatedly indexesflip[(bit_idx + j) / 8]within the inner loop; you could precompute the base byte and intra-byte offsets to reduce repeated division/modulo and simplify the logic.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The AVX and AVX2 `FlipSign` implementations are nearly identical (bit extraction + mask construction); consider factoring this into a shared helper to avoid duplication and keep future changes in sync.
- In `avx::FlipSign` the `#else` branch calls `sse::FlipSign`, which assumes SSE is available whenever AVX is not; if that's not guaranteed by build flags, it may be safer to fall back directly to `generic::FlipSign` for consistency with other paths.
- The bit extraction in `FlipSign` (SSE/AVX/AVX2) repeatedly indexes `flip[(bit_idx + j) / 8]` within the inner loop; you could precompute the base byte and intra-byte offsets to reduce repeated division/modulo and simplify the logic.
## Individual Comments
### Comment 1
<location path="src/simd/sse.cpp" line_range="1363-1364" />
<code_context>
+ constexpr uint64_t kFloatsPerChunk = 8;
+ uint64_t i = 0;
+ for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
+ uint32_t bit_idx = static_cast<uint32_t>(i);
+ uint16_t flip_bits = 0;
+ for (int j = 0; j < 8; j++) {
</code_context>
<issue_to_address>
**issue (bug_risk):** Avoid narrowing `i` to 32 bits for bit index computations to prevent incorrect behavior on very large `dim`.
Casting `i` (uint64_t) to `uint32_t` will wrap when `dim >= 2^32`, causing incorrect bit selection in `flip`. Unless `dim < 2^32` is guaranteed, keep `bit_idx` as `uint64_t`/`size_t` and remove the cast, and apply the same treatment to all `bit_idx` uses in the SIMD loops.
</issue_to_address>
### Comment 2
<location path="src/simd/avx.cpp" line_range="1286-1287" />
<code_context>
+ constexpr uint64_t kFloatsPerChunk = 8;
+ uint64_t i = 0;
+ for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
+ uint32_t bit_idx = static_cast<uint32_t>(i);
+ uint16_t flip_bits = 0;
+ for (int j = 0; j < 8; j++) {
+ flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j;
</code_context>
<issue_to_address>
**issue (bug_risk):** Using a 32-bit `bit_idx` for indexing can overflow for large dimensions.
Casting `i` to `uint32_t` here means bit selection breaks once `dim` exceeds `2^32 - 1`, as in the SSE path. Keeping `bit_idx` (and the related indexing math) as `uint64_t`/`size_t` avoids this overflow while preserving performance on typical platforms.
</issue_to_address>
### Comment 3
<location path="src/simd/avx2.cpp" line_range="1436-1437" />
<code_context>
+ constexpr uint64_t kFloatsPerChunk = 8;
+ uint64_t i = 0;
+ for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
+ uint32_t bit_idx = static_cast<uint32_t>(i);
+ uint16_t flip_bits = 0;
+ for (int j = 0; j < 8; j++) {
+ flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j;
</code_context>
<issue_to_address>
**issue (bug_risk):** Same potential overflow of `bit_idx` in AVX2 path as in SSE/AVX implementations.
Keep `bit_idx` as a 64-bit value instead of truncating to `uint32_t` so very large `dim` values don’t silently wrap and behavior stays consistent with the other implementations.
</issue_to_address>
### Comment 4
<location path="src/simd/rabitq_simd_test.cpp" line_range="466-475" />
<code_context>
generic::FlipSign(flip, gt_data, dim);
+ if (SimdStatus::SupportSSE()) {
+ auto* sse_data = sse_datas.data() + i * dim;
+ sse::FlipSign(flip, sse_data, dim);
+ for (int j = 0; j < dim; j++) {
+ REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
+ }
+ }
+ if (SimdStatus::SupportAVX()) {
+ auto* avx_data = avx_datas.data() + i * dim;
+ avx::FlipSign(flip, avx_data, dim);
+ for (int j = 0; j < dim; j++) {
+ REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
+ }
+ }
+ if (SimdStatus::SupportAVX2()) {
+ auto* avx2_data = avx2_datas.data() + i * dim;
+ avx2::FlipSign(flip, avx2_data, dim);
</code_context>
<issue_to_address>
**suggestion (testing):** Add explicit tests that exercise byte-boundary and tail cases for SSE/AVX/AVX2 FlipSign
The current checks validate SSE/AVX/AVX2 against the generic path, but they don’t guarantee we hit the edge cases most prone to SIMD bugs: (a) flips whose bit indices cross byte boundaries (e.g., dims 7, 8, 9, 15, 16, …) and (b) dims that trigger the scalar cleanup (non‑multiples of the SIMD width). Please add or extend a test that (1) uses hand-crafted flip masks around byte boundaries (e.g., flipping only indices 7–9, 15–17, etc.) and (2) explicitly exercises non-multiple dims for each ISA (e.g., 1, 3, 5, 7, 9 for SSE; 5, 9, 13 for AVX/AVX2) so we cover those tail and bit-indexing regions deterministically.
Suggested implementation:
```cpp
TEST_CASE("FlipSign SIMD byte-boundary and tail cases", "[simd][flip_sign][boundary]") {
constexpr float delta = 1e-5f;
// Indices chosen to cross byte boundaries: bits 7–9, 15–17, ...
const std::array<int, 6> boundary_indices = {7, 8, 9, 15, 16, 17};
auto make_flip_mask = [&](int dim) {
const std::size_t flip_size = (dim + 7) / 8;
std::vector<uint8_t> flip(flip_size, uint8_t{0});
for (int idx : boundary_indices) {
if (idx >= dim) {
continue;
}
const int byte = idx / 8;
const int bit = idx % 8;
flip[byte] |= static_cast<uint8_t>(1u << bit);
}
return flip;
};
auto run_dim = [&](int dim) {
auto flip = make_flip_mask(dim);
std::vector<float> base(dim);
for (int i = 0; i < dim; ++i) {
base[i] = static_cast<float>(i + 1);
}
std::vector<float> gt_data = base;
generic::FlipSign(flip.data(), gt_data.data(), dim);
if (SimdStatus::SupportSSE()) {
std::vector<float> sse_data = base;
sse::FlipSign(flip.data(), sse_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX()) {
std::vector<float> avx_data = base;
avx::FlipSign(flip.data(), avx_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX2()) {
std::vector<float> avx2_data = base;
avx2::FlipSign(flip.data(), avx2_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta);
}
}
};
// Dims that are not multiples of typical SIMD widths, to force scalar tails.
const std::array<int, 5> sse_dims = {1, 3, 5, 7, 9};
const std::array<int, 3> avx_dims = {5, 9, 13};
for (int dim : sse_dims) {
run_dim(dim);
}
for (int dim : avx_dims) {
run_dim(dim);
}
}
std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size);
BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign);
```
1. Ensure `#include <array>` is present near the top of `src/simd/rabitq_simd_test.cpp`, since the new test uses `std::array`.
2. If the file wraps tests in a namespace (e.g., `namespace milvus::simd { ... }`), verify that the new `TEST_CASE` is placed inside the same namespace scope (the replacement above assumes it is already within the correct scope).
3. If the existing tests use a specific float type alias (e.g., `using value_t = float;`), you may want to replace `float` in the new test with that alias for consistency.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| uint32_t bit_idx = static_cast<uint32_t>(i); | ||
| uint8_t flip_bits = 0; |
There was a problem hiding this comment.
issue (bug_risk): Avoid narrowing i to 32 bits for bit index computations to prevent incorrect behavior on very large dim.
Casting i (uint64_t) to uint32_t will wrap when dim >= 2^32, causing incorrect bit selection in flip. Unless dim < 2^32 is guaranteed, keep bit_idx as uint64_t/size_t and remove the cast, and apply the same treatment to all bit_idx uses in the SIMD loops.
| uint32_t bit_idx = static_cast<uint32_t>(i); | ||
| uint16_t flip_bits = 0; |
There was a problem hiding this comment.
issue (bug_risk): Using a 32-bit bit_idx for indexing can overflow for large dimensions.
Casting i to uint32_t here means bit selection breaks once dim exceeds 2^32 - 1, as in the SSE path. Keeping bit_idx (and the related indexing math) as uint64_t/size_t avoids this overflow while preserving performance on typical platforms.
| uint32_t bit_idx = static_cast<uint32_t>(i); | ||
| uint16_t flip_bits = 0; |
There was a problem hiding this comment.
issue (bug_risk): Same potential overflow of bit_idx in AVX2 path as in SSE/AVX implementations.
Keep bit_idx as a 64-bit value instead of truncating to uint32_t so very large dim values don’t silently wrap and behavior stays consistent with the other implementations.
| if (SimdStatus::SupportSSE()) { | ||
| auto* sse_data = sse_datas.data() + i * dim; | ||
| sse::FlipSign(flip, sse_data, dim); | ||
| for (int j = 0; j < dim; j++) { | ||
| REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX()) { | ||
| auto* avx_data = avx_datas.data() + i * dim; | ||
| avx::FlipSign(flip, avx_data, dim); |
There was a problem hiding this comment.
suggestion (testing): Add explicit tests that exercise byte-boundary and tail cases for SSE/AVX/AVX2 FlipSign
The current checks validate SSE/AVX/AVX2 against the generic path, but they don’t guarantee we hit the edge cases most prone to SIMD bugs: (a) flips whose bit indices cross byte boundaries (e.g., dims 7, 8, 9, 15, 16, …) and (b) dims that trigger the scalar cleanup (non‑multiples of the SIMD width). Please add or extend a test that (1) uses hand-crafted flip masks around byte boundaries (e.g., flipping only indices 7–9, 15–17, etc.) and (2) explicitly exercises non-multiple dims for each ISA (e.g., 1, 3, 5, 7, 9 for SSE; 5, 9, 13 for AVX/AVX2) so we cover those tail and bit-indexing regions deterministically.
Suggested implementation:
TEST_CASE("FlipSign SIMD byte-boundary and tail cases", "[simd][flip_sign][boundary]") {
constexpr float delta = 1e-5f;
// Indices chosen to cross byte boundaries: bits 7–9, 15–17, ...
const std::array<int, 6> boundary_indices = {7, 8, 9, 15, 16, 17};
auto make_flip_mask = [&](int dim) {
const std::size_t flip_size = (dim + 7) / 8;
std::vector<uint8_t> flip(flip_size, uint8_t{0});
for (int idx : boundary_indices) {
if (idx >= dim) {
continue;
}
const int byte = idx / 8;
const int bit = idx % 8;
flip[byte] |= static_cast<uint8_t>(1u << bit);
}
return flip;
};
auto run_dim = [&](int dim) {
auto flip = make_flip_mask(dim);
std::vector<float> base(dim);
for (int i = 0; i < dim; ++i) {
base[i] = static_cast<float>(i + 1);
}
std::vector<float> gt_data = base;
generic::FlipSign(flip.data(), gt_data.data(), dim);
if (SimdStatus::SupportSSE()) {
std::vector<float> sse_data = base;
sse::FlipSign(flip.data(), sse_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX()) {
std::vector<float> avx_data = base;
avx::FlipSign(flip.data(), avx_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX2()) {
std::vector<float> avx2_data = base;
avx2::FlipSign(flip.data(), avx2_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta);
}
}
};
// Dims that are not multiples of typical SIMD widths, to force scalar tails.
const std::array<int, 5> sse_dims = {1, 3, 5, 7, 9};
const std::array<int, 3> avx_dims = {5, 9, 13};
for (int dim : sse_dims) {
run_dim(dim);
}
for (int dim : avx_dims) {
run_dim(dim);
}
}
std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size);
BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign);
- Ensure
#include <array>is present near the top ofsrc/simd/rabitq_simd_test.cpp, since the new test usesstd::array. - If the file wraps tests in a namespace (e.g.,
namespace milvus::simd { ... }), verify that the newTEST_CASEis placed inside the same namespace scope (the replacement above assumes it is already within the correct scope). - If the existing tests use a specific float type alias (e.g.,
using value_t = float;), you may want to replacefloatin the new test with that alias for consistency.
Summary
Test Results
Summary by Sourcery
Add SIMD-optimized FlipSign implementations for SSE, AVX, and AVX2 and integrate them into the SIMD interface, tests, and benchmarks.
New Features:
Tests: