Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/simd/avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1276,4 +1276,37 @@ KacsWalk(float* data, uint64_t len) {
return sse::FHTRotate(data, len);
#endif
}

void
FlipSign(const uint8_t* flip, float* data, uint64_t dim) {
#if defined(ENABLE_AVX)
constexpr uint64_t kFloatsPerChunk = 8;
uint64_t i = 0;
for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
uint32_t bit_idx = static_cast<uint32_t>(i);
uint16_t flip_bits = 0;
Comment on lines +1286 to +1287
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Using a 32-bit bit_idx for indexing can overflow for large dimensions.

Casting i to uint32_t here means bit selection breaks once dim exceeds 2^32 - 1, as in the SSE path. Keeping bit_idx (and the related indexing math) as uint64_t/size_t avoids this overflow while preserving performance on typical platforms.

for (int j = 0; j < 8; j++) {
flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j;
}

alignas(32) uint32_t mask[8];
for (int j = 0; j < 8; j++) {
mask[j] = (flip_bits & (1 << j)) ? 0x80000000 : 0;
}

__m256i sign_mask_int = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
__m256 vec = _mm256_loadu_ps(&data[i]);
vec = _mm256_xor_ps(vec, _mm256_castsi256_ps(sign_mask_int));
_mm256_storeu_ps(&data[i], vec);
}
for (; i < dim; i++) {
bool mask = (flip[i / 8] & (1 << (i % 8))) != 0;
if (mask) {
data[i] = -data[i];
}
}
#else
return sse::FlipSign(flip, data, dim);
#endif
}
} // namespace vsag::avx
33 changes: 33 additions & 0 deletions src/simd/avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,4 +1426,37 @@ KacsWalk(float* data, uint64_t len) {
return avx::KacsWalk(data, len);
#endif
}

void
FlipSign(const uint8_t* flip, float* data, uint64_t dim) {
#if defined(ENABLE_AVX2)
constexpr uint64_t kFloatsPerChunk = 8;
uint64_t i = 0;
for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
uint32_t bit_idx = static_cast<uint32_t>(i);
uint16_t flip_bits = 0;
Comment on lines +1436 to +1437
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Same potential overflow of bit_idx in AVX2 path as in SSE/AVX implementations.

Keep bit_idx as a 64-bit value instead of truncating to uint32_t so very large dim values don’t silently wrap and behavior stays consistent with the other implementations.

for (int j = 0; j < 8; j++) {
flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j;
}

alignas(32) uint32_t mask[8];
for (int j = 0; j < 8; j++) {
mask[j] = (flip_bits & (1 << j)) ? 0x80000000 : 0;
}

__m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
__m256 vec = _mm256_loadu_ps(&data[i]);
vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(vec), sign_mask));
_mm256_storeu_ps(&data[i], vec);
}
for (; i < dim; i++) {
bool mask = (flip[i / 8] & (1 << (i % 8))) != 0;
if (mask) {
data[i] = -data[i];
}
}
#else
return avx::FlipSign(flip, data, dim);
#endif
}
} // namespace vsag::avx2
7 changes: 7 additions & 0 deletions src/simd/rabitq_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ VecRescale(float* data, uint64_t dim, float val);

void
KacsWalk(float* data, uint64_t len);

void
FlipSign(const uint8_t* flip, float* data, uint64_t dim);

} // namespace avx2

namespace avx {
Expand Down Expand Up @@ -99,6 +103,9 @@ RotateOp(float* data, int idx, int dim_, int step);
void
VecRescale(float* data, uint64_t dim, float val);

void
FlipSign(const uint8_t* flip, float* data, uint64_t dim);

void
KacsWalk(float* data, uint64_t len);
} // namespace sse
Expand Down
48 changes: 48 additions & 0 deletions src/simd/rabitq_simd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,27 @@ TEST_CASE("SIMD test for flip_sign", "[ut][simd]") {

generic::FlipSign(flip, gt_data, dim);

if (SimdStatus::SupportSSE()) {
auto* sse_data = sse_datas.data() + i * dim;
sse::FlipSign(flip, sse_data, dim);
for (int j = 0; j < dim; j++) {
REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX()) {
auto* avx_data = avx_datas.data() + i * dim;
avx::FlipSign(flip, avx_data, dim);
Comment on lines +466 to +475
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add explicit tests that exercise byte-boundary and tail cases for SSE/AVX/AVX2 FlipSign

The current checks validate SSE/AVX/AVX2 against the generic path, but they don’t guarantee we hit the edge cases most prone to SIMD bugs: (a) flips whose bit indices cross byte boundaries (e.g., dims 7, 8, 9, 15, 16, …) and (b) dims that trigger the scalar cleanup (non‑multiples of the SIMD width). Please add or extend a test that (1) uses hand-crafted flip masks around byte boundaries (e.g., flipping only indices 7–9, 15–17, etc.) and (2) explicitly exercises non-multiple dims for each ISA (e.g., 1, 3, 5, 7, 9 for SSE; 5, 9, 13 for AVX/AVX2) so we cover those tail and bit-indexing regions deterministically.

Suggested implementation:

TEST_CASE("FlipSign SIMD byte-boundary and tail cases", "[simd][flip_sign][boundary]") {
    constexpr float delta = 1e-5f;

    // Indices chosen to cross byte boundaries: bits 7–9, 15–17, ...
    const std::array<int, 6> boundary_indices = {7, 8, 9, 15, 16, 17};

    auto make_flip_mask = [&](int dim) {
        const std::size_t flip_size = (dim + 7) / 8;
        std::vector<uint8_t> flip(flip_size, uint8_t{0});
        for (int idx : boundary_indices) {
            if (idx >= dim) {
                continue;
            }
            const int byte = idx / 8;
            const int bit = idx % 8;
            flip[byte] |= static_cast<uint8_t>(1u << bit);
        }
        return flip;
    };

    auto run_dim = [&](int dim) {
        auto flip = make_flip_mask(dim);

        std::vector<float> base(dim);
        for (int i = 0; i < dim; ++i) {
            base[i] = static_cast<float>(i + 1);
        }

        std::vector<float> gt_data = base;
        generic::FlipSign(flip.data(), gt_data.data(), dim);

        if (SimdStatus::SupportSSE()) {
            std::vector<float> sse_data = base;
            sse::FlipSign(flip.data(), sse_data.data(), dim);
            for (int j = 0; j < dim; ++j) {
                REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
            }
        }

        if (SimdStatus::SupportAVX()) {
            std::vector<float> avx_data = base;
            avx::FlipSign(flip.data(), avx_data.data(), dim);
            for (int j = 0; j < dim; ++j) {
                REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
            }
        }

        if (SimdStatus::SupportAVX2()) {
            std::vector<float> avx2_data = base;
            avx2::FlipSign(flip.data(), avx2_data.data(), dim);
            for (int j = 0; j < dim; ++j) {
                REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta);
            }
        }
    };

    // Dims that are not multiples of typical SIMD widths, to force scalar tails.
    const std::array<int, 5> sse_dims = {1, 3, 5, 7, 9};
    const std::array<int, 3> avx_dims = {5, 9, 13};

    for (int dim : sse_dims) {
        run_dim(dim);
    }
    for (int dim : avx_dims) {
        run_dim(dim);
    }
}

    std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size);

    BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign);
  1. Ensure #include <array> is present near the top of src/simd/rabitq_simd_test.cpp, since the new test uses std::array.
  2. If the file wraps tests in a namespace (e.g., namespace milvus::simd { ... }), verify that the new TEST_CASE is placed inside the same namespace scope (the replacement above assumes it is already within the correct scope).
  3. If the existing tests use a specific float type alias (e.g., using value_t = float;), you may want to replace float in the new test with that alias for consistency.

for (int j = 0; j < dim; j++) {
REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX2()) {
auto* avx2_data = avx2_datas.data() + i * dim;
avx2::FlipSign(flip, avx2_data, dim);
for (int j = 0; j < dim; j++) {
REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX512()) {
auto* avx512_data = avx512_datas.data() + i * dim;
avx512::FlipSign(flip, avx512_data, dim);
Expand Down Expand Up @@ -505,6 +526,15 @@ TEST_CASE("SIMD FlipSign Benchmark", "[ut][simd][!benchmark]") {
std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size);

BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign);
if (SimdStatus::SupportSSE()) {
BENCHMARK_SIMD_FLIP_SIGN(sse, FlipSign);
}
if (SimdStatus::SupportAVX()) {
BENCHMARK_SIMD_FLIP_SIGN(avx, FlipSign);
}
if (SimdStatus::SupportAVX2()) {
BENCHMARK_SIMD_FLIP_SIGN(avx2, FlipSign);
}
if (SimdStatus::SupportAVX512()) {
BENCHMARK_SIMD_FLIP_SIGN(avx512, FlipSign);
}
Expand Down Expand Up @@ -549,6 +579,24 @@ TEST_CASE("SIMD FlipSign Correctness with Patterns", "[ut][simd]") {

generic::FlipSign(test.flip_pattern.data(), gt_data.data(), dim);

if (SimdStatus::SupportSSE()) {
sse::FlipSign(test.flip_pattern.data(), sse_data.data(), dim);
for (int i = 0; i < dim; i++) {
REQUIRE(std::abs(gt_data[i] - sse_data[i]) < delta);
}
}
if (SimdStatus::SupportAVX()) {
avx::FlipSign(test.flip_pattern.data(), avx_data.data(), dim);
for (int i = 0; i < dim; i++) {
REQUIRE(std::abs(gt_data[i] - avx_data[i]) < delta);
}
}
if (SimdStatus::SupportAVX2()) {
avx2::FlipSign(test.flip_pattern.data(), avx2_data.data(), dim);
for (int i = 0; i < dim; i++) {
REQUIRE(std::abs(gt_data[i] - avx2_data[i]) < delta);
}
}
if (SimdStatus::SupportAVX512()) {
avx512::FlipSign(test.flip_pattern.data(), avx512_data.data(), dim);
for (int i = 0; i < dim; i++) {
Expand Down
35 changes: 35 additions & 0 deletions src/simd/sse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,4 +1353,39 @@ KacsWalk(float* data, uint64_t len) {
return generic::KacsWalk(data, len);
#endif
}

void
FlipSign(const uint8_t* flip, float* data, uint64_t dim) {
#if defined(ENABLE_SSE)
constexpr uint64_t kFloatsPerChunk = 4;
uint64_t i = 0;
for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) {
uint32_t bit_idx = static_cast<uint32_t>(i);
uint8_t flip_bits = 0;
Comment on lines +1363 to +1364
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Avoid narrowing i to 32 bits for bit index computations to prevent incorrect behavior on very large dim.

Casting i (uint64_t) to uint32_t will wrap when dim >= 2^32, causing incorrect bit selection in flip. Unless dim < 2^32 is guaranteed, keep bit_idx as uint64_t/size_t and remove the cast, and apply the same treatment to all bit_idx uses in the SIMD loops.

flip_bits |= ((flip[bit_idx / 8] >> (bit_idx % 8)) & 1) << 0;
flip_bits |= ((flip[(bit_idx + 1) / 8] >> ((bit_idx + 1) % 8)) & 1) << 1;
flip_bits |= ((flip[(bit_idx + 2) / 8] >> ((bit_idx + 2) % 8)) & 1) << 2;
flip_bits |= ((flip[(bit_idx + 3) / 8] >> ((bit_idx + 3) % 8)) & 1) << 3;

alignas(16) uint32_t mask[4];
mask[0] = (flip_bits & 1) ? 0x80000000 : 0;
mask[1] = (flip_bits & 2) ? 0x80000000 : 0;
mask[2] = (flip_bits & 4) ? 0x80000000 : 0;
mask[3] = (flip_bits & 8) ? 0x80000000 : 0;

__m128i sign_mask_int = _mm_load_si128(reinterpret_cast<const __m128i*>(mask));
__m128 vec = _mm_loadu_ps(&data[i]);
vec = _mm_xor_ps(vec, _mm_castsi128_ps(sign_mask_int));
_mm_storeu_ps(&data[i], vec);
}
for (; i < dim; i++) {
bool mask = (flip[i / 8] & (1 << (i % 8))) != 0;
if (mask) {
data[i] = -data[i];
}
}
#else
return generic::FlipSign(flip, data, dim);
#endif
}
} // namespace vsag::sse
Loading