Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 110 additions & 22 deletions faiss/IndexIVFRaBitQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,86 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
}
}

void IndexIVFRaBitQFastScan::compute_residual_LUT(
const float* residual,
QueryFactorsData& query_factors,
float* lut_out,
const float* original_query,
std::vector<float>& rotated_q,
std::vector<uint8_t>& rotated_qq) const {
FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);

query_factors = rabitq_utils::compute_query_factors(
residual,
d,
nullptr,
qb,
centered,
metric_type,
rotated_q,
rotated_qq);

if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
original_query != nullptr) {
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
query_factors.q_dot_c = query_factors.qr_norm_L2sqr -
fvec_inner_product(original_query, residual, d);
}

const size_t ex_bits = rabitq.nb_bits - 1;
if (ex_bits > 0) {
query_factors.rotated_q.assign(rotated_q.begin(), rotated_q.end());
}

if (centered) {
const float max_code_value = (1 << qb) - 1;

for (size_t m = 0; m < M; m++) {
const size_t dim_start = m * 4;

for (int code_val = 0; code_val < 16; code_val++) {
float xor_contribution = 0.0f;

for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
const size_t dim_idx = dim_start + dim_offset;

if (dim_idx < static_cast<size_t>(d)) {
const bool db_bit = (code_val >> dim_offset) & 1;
const float query_value = rotated_qq[dim_idx];

xor_contribution += db_bit
? (max_code_value - query_value)
: query_value;
}
}

lut_out[m * 16 + code_val] = xor_contribution;
}
}
} else {
for (size_t m = 0; m < M; m++) {
const size_t dim_start = m * 4;

for (int code_val = 0; code_val < 16; code_val++) {
float inner_product = 0.0f;
int popcount = 0;

for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
const size_t dim_idx = dim_start + dim_offset;

if (dim_idx < static_cast<size_t>(d) &&
((code_val >> dim_offset) & 1)) {
inner_product += rotated_qq[dim_idx];
popcount++;
}
}
lut_out[m * 16 + code_val] = query_factors.c1 * inner_product +
query_factors.c2 * popcount;
}
}
}
}

void IndexIVFRaBitQFastScan::search_preassigned(
idx_t n,
const float* x,
Expand Down Expand Up @@ -406,34 +486,42 @@ void IndexIVFRaBitQFastScan::compute_LUT(
if (n * cq_nprobe > 0) {
memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
}
std::unique_ptr<float[]> xrel(new float[n * cq_nprobe * d]);
// Use per-thread buffers instead of one O(n * nprobe * d) allocation.
// rotated_q / rotated_qq keep their capacity across iterations so the
// allocator is only hit once per thread.
#pragma omp parallel if (n * cq_nprobe > 1000)
{
std::vector<float> xij_buf(d);
std::vector<float> rotated_q(d);
std::vector<uint8_t> rotated_qq(d);

#pragma omp parallel for if (n * cq_nprobe > 1000)
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
idx_t i = ij / cq_nprobe;
float* xij = &xrel[ij * d];
idx_t cij = cq.ids[ij];
#pragma omp for
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
idx_t i = ij / cq_nprobe;
idx_t cij = cq.ids[ij];

if (cij >= 0) {
quantizer->compute_residual(x + i * d, xij, cij);
if (cij >= 0) {
quantizer->compute_residual(x + i * d, xij_buf.data(), cij);

// Create QueryFactorsData for this query-list combination
QueryFactorsData query_factors_data;
QueryFactorsData query_factors_data;

compute_residual_LUT(
xij,
query_factors_data,
dis_tables.get() + ij * dim12,
x + i * d);
compute_residual_LUT(
xij_buf.data(),
query_factors_data,
dis_tables.get() + ij * dim12,
x + i * d,
rotated_q,
rotated_qq);

// Store query factors using compact indexing (ij directly)
if (context.query_factors != nullptr) {
context.query_factors[ij] = query_factors_data;
}
if (context.query_factors != nullptr) {
context.query_factors[ij] = std::move(query_factors_data);
}

} else {
memset(xij, -1, sizeof(float) * d);
memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12);
} else {
memset(dis_tables.get() + ij * dim12,
-1,
sizeof(float) * dim12);
}
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions faiss/IndexIVFRaBitQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
float* lut_out,
const float* original_query = nullptr) const;

/// Overload that accepts caller-owned scratch buffers to avoid
/// per-call heap allocation. rotated_q and rotated_qq must each
/// have at least d elements. Their contents are overwritten.
void compute_residual_LUT(
const float* residual,
QueryFactorsData& query_factors,
float* lut_out,
const float* original_query,
std::vector<float>& rotated_q,
std::vector<uint8_t>& rotated_qq) const;

/// Decode FastScan code to RaBitQ residual vector with explicit
/// dp_multiplier
void decode_fastscan_to_residual(
Expand Down
Loading