diff --git a/faiss/IndexIVFRaBitQFastScan.cpp b/faiss/IndexIVFRaBitQFastScan.cpp index 752df158cb..251941b39b 100644 --- a/faiss/IndexIVFRaBitQFastScan.cpp +++ b/faiss/IndexIVFRaBitQFastScan.cpp @@ -354,6 +354,86 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT( } } +void IndexIVFRaBitQFastScan::compute_residual_LUT( + const float* residual, + QueryFactorsData& query_factors, + float* lut_out, + const float* original_query, + std::vector& rotated_q, + std::vector& rotated_qq) const { + FAISS_THROW_IF_NOT(qb > 0 && qb <= 8); + + query_factors = rabitq_utils::compute_query_factors( + residual, + d, + nullptr, + qb, + centered, + metric_type, + rotated_q, + rotated_qq); + + if (metric_type == MetricType::METRIC_INNER_PRODUCT && + original_query != nullptr) { + query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d); + query_factors.q_dot_c = query_factors.qr_norm_L2sqr - + fvec_inner_product(original_query, residual, d); + } + + const size_t ex_bits = rabitq.nb_bits - 1; + if (ex_bits > 0) { + query_factors.rotated_q.assign(rotated_q.begin(), rotated_q.end()); + } + + if (centered) { + const float max_code_value = (1 << qb) - 1; + + for (size_t m = 0; m < M; m++) { + const size_t dim_start = m * 4; + + for (int code_val = 0; code_val < 16; code_val++) { + float xor_contribution = 0.0f; + + for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) { + const size_t dim_idx = dim_start + dim_offset; + + if (dim_idx < static_cast(d)) { + const bool db_bit = (code_val >> dim_offset) & 1; + const float query_value = rotated_qq[dim_idx]; + + xor_contribution += db_bit + ? (max_code_value - query_value) + : query_value; + } + } + + lut_out[m * 16 + code_val] = xor_contribution; + } + } + } else { + for (size_t m = 0; m < M; m++) { + const size_t dim_start = m * 4; + + for (int code_val = 0; code_val < 16; code_val++) { + float inner_product = 0.0f; + int popcount = 0; + + for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) { + const size_t dim_idx = dim_start + dim_offset; + + if (dim_idx < static_cast(d) && + ((code_val >> dim_offset) & 1)) { + inner_product += rotated_qq[dim_idx]; + popcount++; + } + } + lut_out[m * 16 + code_val] = query_factors.c1 * inner_product + + query_factors.c2 * popcount; + } + } + } +} + void IndexIVFRaBitQFastScan::search_preassigned( idx_t n, const float* x, @@ -406,34 +486,42 @@ void IndexIVFRaBitQFastScan::compute_LUT( if (n * cq_nprobe > 0) { memset(biases.get(), 0, sizeof(float) * n * cq_nprobe); } - std::unique_ptr xrel(new float[n * cq_nprobe * d]); + // Use per-thread buffers instead of one O(n * nprobe * d) allocation. + // rotated_q / rotated_qq keep their capacity across iterations so the + // allocator is only hit once per thread. +#pragma omp parallel if (n * cq_nprobe > 1000) + { + std::vector xij_buf(d); + std::vector rotated_q(d); + std::vector rotated_qq(d); -#pragma omp parallel for if (n * cq_nprobe > 1000) - for (idx_t ij = 0; ij < static_cast(n * cq_nprobe); ij++) { - idx_t i = ij / cq_nprobe; - float* xij = &xrel[ij * d]; - idx_t cij = cq.ids[ij]; +#pragma omp for + for (idx_t ij = 0; ij < static_cast(n * cq_nprobe); ij++) { + idx_t i = ij / cq_nprobe; + idx_t cij = cq.ids[ij]; - if (cij >= 0) { - quantizer->compute_residual(x + i * d, xij, cij); + if (cij >= 0) { + quantizer->compute_residual(x + i * d, xij_buf.data(), cij); - // Create QueryFactorsData for this query-list combination - QueryFactorsData query_factors_data; + QueryFactorsData query_factors_data; - compute_residual_LUT( - xij, - query_factors_data, - dis_tables.get() + ij * dim12, - x + i * d); + compute_residual_LUT( + xij_buf.data(), + query_factors_data, + dis_tables.get() + ij * dim12, + x + i * d, + rotated_q, + rotated_qq); - // Store query factors using compact indexing (ij directly) - if (context.query_factors != nullptr) { - context.query_factors[ij] = query_factors_data; - } + if (context.query_factors != nullptr) { + context.query_factors[ij] = std::move(query_factors_data); + } - } else { - memset(xij, -1, sizeof(float) * d); - memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12); + } else { + memset(dis_tables.get() + ij * dim12, + -1, + sizeof(float) * dim12); + } } } } diff --git a/faiss/IndexIVFRaBitQFastScan.h b/faiss/IndexIVFRaBitQFastScan.h index e9d3684c06..5afe57db27 100644 --- a/faiss/IndexIVFRaBitQFastScan.h +++ b/faiss/IndexIVFRaBitQFastScan.h @@ -121,6 +121,17 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan { float* lut_out, const float* original_query = nullptr) const; + /// Overload that accepts caller-owned scratch buffers to avoid + /// per-call heap allocation. rotated_q and rotated_qq must each + /// have at least d elements. Their contents are overwritten. + void compute_residual_LUT( + const float* residual, + QueryFactorsData& query_factors, + float* lut_out, + const float* original_query, + std::vector& rotated_q, + std::vector& rotated_qq) const; + /// Decode FastScan code to RaBitQ residual vector with explicit /// dp_multiplier void decode_fastscan_to_residual(