-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add FlipSign SIMD implementation for SSE/AVX/AVX2 #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
The head ref may contain hidden characters: "opencode/\u8865\u9F50-flipsign-simd"
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1426,4 +1426,37 @@ KacsWalk(float* data, uint64_t len) { | |
| return avx::KacsWalk(data, len); | ||
| #endif | ||
| } | ||
|
|
||
| void | ||
| FlipSign(const uint8_t* flip, float* data, uint64_t dim) { | ||
| #if defined(ENABLE_AVX2) | ||
| constexpr uint64_t kFloatsPerChunk = 8; | ||
| uint64_t i = 0; | ||
| for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) { | ||
| uint32_t bit_idx = static_cast<uint32_t>(i); | ||
| uint16_t flip_bits = 0; | ||
|
Comment on lines
+1436
to
+1437
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Same potential overflow of Keep |
||
| for (int j = 0; j < 8; j++) { | ||
| flip_bits |= ((flip[(bit_idx + j) / 8] >> ((bit_idx + j) % 8)) & 1) << j; | ||
| } | ||
|
|
||
| alignas(32) uint32_t mask[8]; | ||
| for (int j = 0; j < 8; j++) { | ||
| mask[j] = (flip_bits & (1 << j)) ? 0x80000000 : 0; | ||
| } | ||
|
|
||
| __m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask)); | ||
| __m256 vec = _mm256_loadu_ps(&data[i]); | ||
| vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(vec), sign_mask)); | ||
| _mm256_storeu_ps(&data[i], vec); | ||
| } | ||
| for (; i < dim; i++) { | ||
| bool mask = (flip[i / 8] & (1 << (i % 8))) != 0; | ||
| if (mask) { | ||
| data[i] = -data[i]; | ||
| } | ||
| } | ||
| #else | ||
| return avx::FlipSign(flip, data, dim); | ||
| #endif | ||
| } | ||
| } // namespace vsag::avx2 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,6 +463,27 @@ TEST_CASE("SIMD test for flip_sign", "[ut][simd]") { | |
|
|
||
| generic::FlipSign(flip, gt_data, dim); | ||
|
|
||
| if (SimdStatus::SupportSSE()) { | ||
| auto* sse_data = sse_datas.data() + i * dim; | ||
| sse::FlipSign(flip, sse_data, dim); | ||
| for (int j = 0; j < dim; j++) { | ||
| REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX()) { | ||
| auto* avx_data = avx_datas.data() + i * dim; | ||
| avx::FlipSign(flip, avx_data, dim); | ||
|
Comment on lines
+466
to
+475
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Add explicit tests that exercise byte-boundary and tail cases for SSE/AVX/AVX2 FlipSign The current checks validate SSE/AVX/AVX2 against the generic path, but they don’t guarantee we hit the edge cases most prone to SIMD bugs: (a) flips whose bit indices cross byte boundaries (e.g., dims 7, 8, 9, 15, 16, …) and (b) dims that trigger the scalar cleanup (non‑multiples of the SIMD width). Please add or extend a test that (1) uses hand-crafted flip masks around byte boundaries (e.g., flipping only indices 7–9, 15–17, etc.) and (2) explicitly exercises non-multiple dims for each ISA (e.g., 1, 3, 5, 7, 9 for SSE; 5, 9, 13 for AVX/AVX2) so we cover those tail and bit-indexing regions deterministically. Suggested implementation: TEST_CASE("FlipSign SIMD byte-boundary and tail cases", "[simd][flip_sign][boundary]") {
constexpr float delta = 1e-5f;
// Indices chosen to cross byte boundaries: bits 7–9, 15–17, ...
const std::array<int, 6> boundary_indices = {7, 8, 9, 15, 16, 17};
auto make_flip_mask = [&](int dim) {
const std::size_t flip_size = (dim + 7) / 8;
std::vector<uint8_t> flip(flip_size, uint8_t{0});
for (int idx : boundary_indices) {
if (idx >= dim) {
continue;
}
const int byte = idx / 8;
const int bit = idx % 8;
flip[byte] |= static_cast<uint8_t>(1u << bit);
}
return flip;
};
auto run_dim = [&](int dim) {
auto flip = make_flip_mask(dim);
std::vector<float> base(dim);
for (int i = 0; i < dim; ++i) {
base[i] = static_cast<float>(i + 1);
}
std::vector<float> gt_data = base;
generic::FlipSign(flip.data(), gt_data.data(), dim);
if (SimdStatus::SupportSSE()) {
std::vector<float> sse_data = base;
sse::FlipSign(flip.data(), sse_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - sse_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX()) {
std::vector<float> avx_data = base;
avx::FlipSign(flip.data(), avx_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta);
}
}
if (SimdStatus::SupportAVX2()) {
std::vector<float> avx2_data = base;
avx2::FlipSign(flip.data(), avx2_data.data(), dim);
for (int j = 0; j < dim; ++j) {
REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta);
}
}
};
// Dims that are not multiples of typical SIMD widths, to force scalar tails.
const std::array<int, 5> sse_dims = {1, 3, 5, 7, 9};
const std::array<int, 3> avx_dims = {5, 9, 13};
for (int dim : sse_dims) {
run_dim(dim);
}
for (int dim : avx_dims) {
run_dim(dim);
}
}
std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size);
BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign);
|
||
| for (int j = 0; j < dim; j++) { | ||
| REQUIRE(std::abs(gt_data[j] - avx_data[j]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX2()) { | ||
| auto* avx2_data = avx2_datas.data() + i * dim; | ||
| avx2::FlipSign(flip, avx2_data, dim); | ||
| for (int j = 0; j < dim; j++) { | ||
| REQUIRE(std::abs(gt_data[j] - avx2_data[j]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX512()) { | ||
| auto* avx512_data = avx512_datas.data() + i * dim; | ||
| avx512::FlipSign(flip, avx512_data, dim); | ||
|
|
@@ -505,6 +526,15 @@ TEST_CASE("SIMD FlipSign Benchmark", "[ut][simd][!benchmark]") { | |
| std::vector<uint8_t> flips = fixtures::GenerateVectors<uint8_t>(count, flip_size); | ||
|
|
||
| BENCHMARK_SIMD_FLIP_SIGN(generic, FlipSign); | ||
| if (SimdStatus::SupportSSE()) { | ||
| BENCHMARK_SIMD_FLIP_SIGN(sse, FlipSign); | ||
| } | ||
| if (SimdStatus::SupportAVX()) { | ||
| BENCHMARK_SIMD_FLIP_SIGN(avx, FlipSign); | ||
| } | ||
| if (SimdStatus::SupportAVX2()) { | ||
| BENCHMARK_SIMD_FLIP_SIGN(avx2, FlipSign); | ||
| } | ||
| if (SimdStatus::SupportAVX512()) { | ||
| BENCHMARK_SIMD_FLIP_SIGN(avx512, FlipSign); | ||
| } | ||
|
|
@@ -549,6 +579,24 @@ TEST_CASE("SIMD FlipSign Correctness with Patterns", "[ut][simd]") { | |
|
|
||
| generic::FlipSign(test.flip_pattern.data(), gt_data.data(), dim); | ||
|
|
||
| if (SimdStatus::SupportSSE()) { | ||
| sse::FlipSign(test.flip_pattern.data(), sse_data.data(), dim); | ||
| for (int i = 0; i < dim; i++) { | ||
| REQUIRE(std::abs(gt_data[i] - sse_data[i]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX()) { | ||
| avx::FlipSign(test.flip_pattern.data(), avx_data.data(), dim); | ||
| for (int i = 0; i < dim; i++) { | ||
| REQUIRE(std::abs(gt_data[i] - avx_data[i]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX2()) { | ||
| avx2::FlipSign(test.flip_pattern.data(), avx2_data.data(), dim); | ||
| for (int i = 0; i < dim; i++) { | ||
| REQUIRE(std::abs(gt_data[i] - avx2_data[i]) < delta); | ||
| } | ||
| } | ||
| if (SimdStatus::SupportAVX512()) { | ||
| avx512::FlipSign(test.flip_pattern.data(), avx512_data.data(), dim); | ||
| for (int i = 0; i < dim; i++) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1353,4 +1353,39 @@ KacsWalk(float* data, uint64_t len) { | |
| return generic::KacsWalk(data, len); | ||
| #endif | ||
| } | ||
|
|
||
| void | ||
| FlipSign(const uint8_t* flip, float* data, uint64_t dim) { | ||
| #if defined(ENABLE_SSE) | ||
| constexpr uint64_t kFloatsPerChunk = 4; | ||
| uint64_t i = 0; | ||
| for (; i + kFloatsPerChunk <= dim; i += kFloatsPerChunk) { | ||
| uint32_t bit_idx = static_cast<uint32_t>(i); | ||
| uint8_t flip_bits = 0; | ||
|
Comment on lines
+1363
to
+1364
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Avoid narrowing Casting |
||
| flip_bits |= ((flip[bit_idx / 8] >> (bit_idx % 8)) & 1) << 0; | ||
| flip_bits |= ((flip[(bit_idx + 1) / 8] >> ((bit_idx + 1) % 8)) & 1) << 1; | ||
| flip_bits |= ((flip[(bit_idx + 2) / 8] >> ((bit_idx + 2) % 8)) & 1) << 2; | ||
| flip_bits |= ((flip[(bit_idx + 3) / 8] >> ((bit_idx + 3) % 8)) & 1) << 3; | ||
|
|
||
| alignas(16) uint32_t mask[4]; | ||
| mask[0] = (flip_bits & 1) ? 0x80000000 : 0; | ||
| mask[1] = (flip_bits & 2) ? 0x80000000 : 0; | ||
| mask[2] = (flip_bits & 4) ? 0x80000000 : 0; | ||
| mask[3] = (flip_bits & 8) ? 0x80000000 : 0; | ||
|
|
||
| __m128i sign_mask_int = _mm_load_si128(reinterpret_cast<const __m128i*>(mask)); | ||
| __m128 vec = _mm_loadu_ps(&data[i]); | ||
| vec = _mm_xor_ps(vec, _mm_castsi128_ps(sign_mask_int)); | ||
| _mm_storeu_ps(&data[i], vec); | ||
| } | ||
| for (; i < dim; i++) { | ||
| bool mask = (flip[i / 8] & (1 << (i % 8))) != 0; | ||
| if (mask) { | ||
| data[i] = -data[i]; | ||
| } | ||
| } | ||
| #else | ||
| return generic::FlipSign(flip, data, dim); | ||
| #endif | ||
| } | ||
| } // namespace vsag::sse | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Using a 32-bit
bit_idxfor indexing can overflow for large dimensions.Casting
itouint32_there means bit selection breaks oncedimexceeds2^32 - 1, as in the SSE path. Keepingbit_idx(and the related indexing math) asuint64_t/size_tavoids this overflow while preserving performance on typical platforms.