Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/openzl/shared/bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include "openzl/common/debug.h"
#include "openzl/shared/portability.h"

#if ZL_HAS_BMI2
# include <immintrin.h>
#endif

ZL_BEGIN_C_DECLS

ZL_INLINE bool ZL_32bits(void)
Expand Down Expand Up @@ -482,6 +486,92 @@ ZL_INLINE bool ZL_convertIntToDouble(ZL_IEEEDouble* dbl, int64_t value)

#endif

// ---------------------------------------------------------------------------
// bitDeposit: scatter contiguous source bits into positions given by mask
// (PDEP)
// ---------------------------------------------------------------------------

ZL_INLINE uint64_t ZL_bitDeposit64_fallback(uint64_t src, uint64_t mask)
{
uint64_t result = 0;
uint64_t srcBit = 1;
while (mask != 0) {
uint64_t const lowestBit = mask & (~mask + 1);
if (src & srcBit) {
result |= lowestBit;
}
mask &= ~lowestBit;
srcBit <<= 1;
}
return result;
}

ZL_INLINE uint64_t ZL_bitDeposit64(uint64_t src, uint64_t mask)
{
#if ZL_HAS_BMI2
return _pdep_u64(src, mask);
#else
return ZL_bitDeposit64_fallback(src, mask);
#endif
}

ZL_INLINE uint32_t ZL_bitDeposit32_fallback(uint32_t src, uint32_t mask)
{
return (uint32_t)ZL_bitDeposit64_fallback((uint64_t)src, (uint64_t)mask);
}

ZL_INLINE uint32_t ZL_bitDeposit32(uint32_t src, uint32_t mask)
{
#if ZL_HAS_BMI2
return _pdep_u32(src, mask);
#else
return ZL_bitDeposit32_fallback(src, mask);
#endif
}

// ---------------------------------------------------------------------------
// bitExtract: collect bits from positions given by mask into contiguous result
// (PEXT)
// ---------------------------------------------------------------------------

ZL_INLINE uint64_t ZL_bitExtract64_fallback(uint64_t src, uint64_t mask)
{
uint64_t result = 0;
uint64_t dstBit = 1;
while (mask != 0) {
uint64_t const lowestBit = mask & (~mask + 1);
if (src & lowestBit) {
result |= dstBit;
}
mask &= ~lowestBit;
dstBit <<= 1;
}
return result;
}

ZL_INLINE uint64_t ZL_bitExtract64(uint64_t src, uint64_t mask)
{
#if ZL_HAS_BMI2
return _pext_u64(src, mask);
#else
return ZL_bitExtract64_fallback(src, mask);
#endif
}

ZL_INLINE uint32_t ZL_bitExtract32_fallback(uint32_t src, uint32_t mask)
{
return (uint32_t)ZL_bitExtract64_fallback((uint64_t)src, (uint64_t)mask);
}

ZL_INLINE uint32_t ZL_bitExtract32(uint32_t src, uint32_t mask)
{
#if ZL_HAS_BMI2
return _pext_u32(src, mask);
#else
return ZL_bitExtract32_fallback(src, mask);
#endif
}

ZL_END_C_DECLS

#endif // ZSTRONG_COMMON_BITS_H
121 changes: 121 additions & 0 deletions tests/unittest/common/test_bits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <gtest/gtest.h>

#include <random>

#include "openzl/shared/bits.h"
#include "openzl/shared/mem.h"

Expand Down Expand Up @@ -234,4 +236,123 @@ TEST(BitsTest, convertDoubleToInt)
ASSERT_EQ(testDouble(-0.5).first, false);
ASSERT_EQ(testDouble(-0.99999).first, false);
}
TEST(BitsTest, bitDeposit32)
{
// Zero mask -> zero result
ASSERT_EQ(ZL_bitDeposit32(0xFFFFFFFF, 0), 0u);
// All-ones mask -> identity
ASSERT_EQ(ZL_bitDeposit32(0x12345678, 0xFFFFFFFF), 0x12345678u);
// Scatter low bits into mask positions
// src=0b111 deposits into mask=0b10101010: bits {1,3,5} set -> 0b00101010
ASSERT_EQ(ZL_bitDeposit32(0b111, 0b10101010), 0b00101010u);
ASSERT_EQ(ZL_bitDeposit32(0b1010, 0b11110000), 0b10100000u);
// Single bit
ASSERT_EQ(ZL_bitDeposit32(1, 0x80000000u), 0x80000000u);
ASSERT_EQ(ZL_bitDeposit32(0, 0x80000000u), 0u);
// Verify fallback matches optimized
std::mt19937 rng32(42);
std::uniform_int_distribution<uint32_t> dist32;
for (int i = 0; i < 10000; ++i) {
uint32_t src = dist32(rng32);
uint32_t mask = dist32(rng32);
ASSERT_EQ(
ZL_bitDeposit32(src, mask),
ZL_bitDeposit32_fallback(src, mask));
}
}

TEST(BitsTest, bitDeposit64)
{
ASSERT_EQ(ZL_bitDeposit64(0xFFFFFFFFFFFFFFFFull, 0), 0ull);
ASSERT_EQ(
ZL_bitDeposit64(0x123456789ABCDEFull, 0xFFFFFFFFFFFFFFFFull),
0x123456789ABCDEFull);
ASSERT_EQ(ZL_bitDeposit64(0b111, 0b10101010), 0b00101010ull);
ASSERT_EQ(ZL_bitDeposit64(1, 0x8000000000000000ull), 0x8000000000000000ull);
std::mt19937 rng(42);
std::uniform_int_distribution<uint64_t> dist64;
for (int i = 0; i < 10000; ++i) {
uint64_t src = dist64(rng);
uint64_t mask = dist64(rng);
ASSERT_EQ(
ZL_bitDeposit64(src, mask),
ZL_bitDeposit64_fallback(src, mask));
}
}

TEST(BitsTest, bitExtract32)
{
// Zero mask -> zero result
ASSERT_EQ(ZL_bitExtract32(0xFFFFFFFF, 0), 0u);
// All-ones mask -> identity
ASSERT_EQ(ZL_bitExtract32(0x12345678, 0xFFFFFFFF), 0x12345678u);
// Gather bits from mask positions into contiguous low bits
// mask=0b10101010 selects bits {1,3,5,7}; src=0b10101000 has {3,5,7} set
ASSERT_EQ(ZL_bitExtract32(0b10101000, 0b10101010), 0b1110u);
ASSERT_EQ(ZL_bitExtract32(0b10100000, 0b11110000), 0b1010u);
// Single bit
ASSERT_EQ(ZL_bitExtract32(0x80000000u, 0x80000000u), 1u);
ASSERT_EQ(ZL_bitExtract32(0u, 0x80000000u), 0u);
// Verify fallback matches optimized
std::mt19937 rng32(123);
std::uniform_int_distribution<uint32_t> dist32;
for (int i = 0; i < 10000; ++i) {
uint32_t src = dist32(rng32);
uint32_t mask = dist32(rng32);
ASSERT_EQ(
ZL_bitExtract32(src, mask),
ZL_bitExtract32_fallback(src, mask));
}
}

TEST(BitsTest, bitExtract64)
{
ASSERT_EQ(ZL_bitExtract64(0xFFFFFFFFFFFFFFFFull, 0), 0ull);
ASSERT_EQ(
ZL_bitExtract64(0x123456789ABCDEFull, 0xFFFFFFFFFFFFFFFFull),
0x123456789ABCDEFull);
ASSERT_EQ(ZL_bitExtract64(0b10101000, 0b10101010), 0b1110ull);
ASSERT_EQ(
ZL_bitExtract64(0x8000000000000000ull, 0x8000000000000000ull),
1ull);
std::mt19937 rng(123);
std::uniform_int_distribution<uint64_t> dist64;
for (int i = 0; i < 10000; ++i) {
uint64_t src = dist64(rng);
uint64_t mask = dist64(rng);
ASSERT_EQ(
ZL_bitExtract64(src, mask),
ZL_bitExtract64_fallback(src, mask));
}
}

TEST(BitsTest, bitDepositExtractRoundTrip32)
{
// PEXT(PDEP(src, mask), mask) == src & ((1 << popcount(mask)) - 1)
std::mt19937 rng(456);
std::uniform_int_distribution<uint32_t> dist32;
for (int i = 0; i < 10000; ++i) {
uint32_t src = dist32(rng);
uint32_t mask = dist32(rng);
int nbits = ZL_popcount64(mask);
uint32_t srcMasked = src & ((nbits >= 32) ? ~0u : ((1u << nbits) - 1));
ASSERT_EQ(ZL_bitExtract32(ZL_bitDeposit32(src, mask), mask), srcMasked);
}
}

TEST(BitsTest, bitDepositExtractRoundTrip64)
{
// PEXT(PDEP(src, mask), mask) == src & ((1 << popcount(mask)) - 1)
std::mt19937 rng(456);
std::uniform_int_distribution<uint64_t> dist64;
for (int i = 0; i < 10000; ++i) {
uint64_t src = dist64(rng);
uint64_t mask = dist64(rng);
int nbits = ZL_popcount64(mask);
uint64_t srcMasked =
src & ((nbits >= 64) ? ~0ull : ((1ull << nbits) - 1));
ASSERT_EQ(ZL_bitExtract64(ZL_bitDeposit64(src, mask), mask), srcMasked);
}
}

} // namespace
Loading