Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ void IndexIVFFastScan::init_fastscan(

void IndexIVFFastScan::init_code_packer() {
auto bil = dynamic_cast<BlockInvertedLists*>(invlists);
FAISS_THROW_IF_NOT(bil);
if (!bil) {
// invlists is not block-packed (e.g., when own_invlists=false).
// Nothing to do — the caller manages inverted lists externally.
return;
}
delete bil->packer; // in case there was one before
bil->packer = get_CodePacker();
}
Expand Down
8 changes: 8 additions & 0 deletions faiss/IndexIVFRaBitQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,10 @@ struct RaBitInvertedListScanner : InvertedListScanner {
// Stats tracking for multi-bit two-stage search
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
// n_multibit_evaluations: candidates requiring full multi-bit distance
#ifndef NDEBUG
size_t local_1bit_evaluations = 0;
size_t local_multibit_evaluations = 0;
#endif

for (size_t j = 0; j < list_size; j++) {
if (sel != nullptr) {
Expand All @@ -246,7 +248,9 @@ struct RaBitInvertedListScanner : InvertedListScanner {
}
}

#ifndef NDEBUG
local_1bit_evaluations++;
#endif

// Stage 1: Compute distance bound using 1-bit codes
// For L2 (min-heap): use lower_bound to safely skip if it's
Expand All @@ -269,7 +273,9 @@ struct RaBitInvertedListScanner : InvertedListScanner {
handler.threshold,
keep_max);
if (should_refine) {
#ifndef NDEBUG
local_multibit_evaluations++;
#endif
// Lower bound is promising, compute full distance
float dis = distance_to_code(codes);
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
Expand All @@ -281,11 +287,13 @@ struct RaBitInvertedListScanner : InvertedListScanner {
codes += code_size;
}

#ifndef NDEBUG
// Update global stats atomically
#pragma omp atomic
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
#pragma omp atomic
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
#endif

return nup;
}
Expand Down
8 changes: 6 additions & 2 deletions faiss/IndexIVFRaBitQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,14 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
handler->ntotal = ntotal;
handler->id_map = ids;

// RaBitQ needs list context for factor lookup
// RaBitQ needs list context for factor lookup.
// If invlists is unavailable (e.g., own_invlists=false), fall back
// to the codes pointer which already contains the block data.
std::vector<int> probe_map = {0};
handler->set_list_context(list_no, probe_map);
if (!handler->list_codes_ptr) {
handler->list_codes_ptr = codes;
}

scanner->accumulate_loop(
1,
Expand Down Expand Up @@ -701,7 +706,6 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
curr_labels.data(),
k);
}

return handler->num_updates();
}
};
Expand Down
92 changes: 47 additions & 45 deletions faiss/IndexIVFRaBitQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,35 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
"Query factors not available: FastScanDistancePostProcessing with query_factors required");
}

size_t probe_rank = probe_indices[local_q];
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
size_t storage_idx = q * nprobe + probe_rank;
const size_t probe_rank = probe_indices[local_q];
const size_t storage_idx = q * cached_nprobe + probe_rank;
const auto& query_factors = context->query_factors[storage_idx];

const float one_a =
this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;

uint64_t idx_base = this->j0 + b * 32;
const uint64_t idx_base = this->j0 + b * 32;
if (idx_base >= this->ntotal) {
return;
}
size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
const size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);

// Hoist aux pointer base out of loop: all 32 elements in this block share
// the same block base. Only the per-element offset (j * storage_size)
// varies.
const uint8_t* aux_base = this->list_codes_ptr +
(idx_base / index->bbs) * full_block_size + packed_block_size;

// Cache index fields used in the inner loop.
const bool centered = index->centered;
const size_t qb = index->qb;
const size_t d = index->d;

#ifndef NDEBUG
size_t local_1bit_evaluations = 0;
size_t local_multibit_evaluations = 0;
#endif

for (size_t j = 0; j < max_positions; j++) {
const int64_t result_id = this->adjust_id(b, j);
Expand All @@ -274,40 +286,36 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
this->scan_cnt++;

const float normalized_distance = d32tab[j] * one_a + bias;
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
list_codes_ptr,
idx_base + j,
index->bbs,
packed_block_size,
full_block_size,
storage_size);
const uint8_t* base_ptr = aux_base + j * storage_size;

if (is_multibit) {
#ifndef NDEBUG
local_1bit_evaluations++;
#endif
const SignBitFactorsWithError& full_factors =
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);

float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
normalized_distance,
full_factors,
query_factors,
index->centered,
index->qb,
index->d);
centered,
qb,
d);

const bool is_similarity =
index->metric_type == MetricType::METRIC_INNER_PRODUCT;
bool should_refine = rabitq_utils::should_refine_candidate(
dist_1bit,
full_factors.f_error,
query_factors.g_error,
heap_dis[0],
is_similarity);
if (should_refine) {
#ifndef NDEBUG
local_multibit_evaluations++;
size_t local_offset = this->j0 + b * 32 + j;
#endif
size_t local_offset = idx_base + j;
float dist_full = compute_full_multibit_distance(
result_id, local_q, q, local_offset);
local_q, q, local_offset, base_ptr);
if (Cfloat::cmp(heap_dis[0], dist_full)) {
heap_replace_top<Cfloat>(
k, heap_dis, heap_ids, dist_full, result_id);
Expand All @@ -322,9 +330,9 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
normalized_distance,
db_factors,
query_factors,
index->centered,
index->qb,
index->d);
centered,
qb,
d);
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
heap_replace_top<Cfloat>(
k, heap_dis, heap_ids, adjusted_distance, result_id);
Expand All @@ -333,10 +341,12 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
}
}

#ifndef NDEBUG
#pragma omp atomic
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
#pragma omp atomic
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
#endif
}

template <class C, SIMDLevel SL>
Expand All @@ -345,7 +355,12 @@ void IVFRaBitQHeapHandler<C, SL>::set_list_context(
const std::vector<int>& probe_map) {
current_list_no = list_no;
probe_indices = probe_map;
list_codes_ptr = index->invlists->get_codes(list_no);
cached_nprobe =
context && context->nprobe > 0 ? context->nprobe : index->nprobe;
is_similarity = index->metric_type == MetricType::METRIC_INNER_PRODUCT;
if (index->invlists) {
this->list_codes_ptr = index->invlists->get_codes(list_no);
}
}

template <class C, SIMDLevel SL>
Expand All @@ -363,44 +378,31 @@ void IVFRaBitQHeapHandler<C, SL>::end() {

template <class C, SIMDLevel SL>
float IVFRaBitQHeapHandler<C, SL>::compute_full_multibit_distance(
size_t /*db_idx*/,
size_t local_q,
size_t global_q,
size_t local_offset) {
size_t local_offset,
const uint8_t* aux_ptr) {
const size_t ex_bits = index->rabitq.nb_bits - 1;
const size_t dim = index->d;

const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
list_codes_ptr,
local_offset,
index->bbs,
packed_block_size,
full_block_size,
storage_size);

const size_t ex_code_size = (dim * ex_bits + 7) / 8;
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
const uint8_t* ex_code = aux_ptr + sizeof(SignBitFactorsWithError);
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
aux_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);

size_t probe_rank = probe_indices[local_q];
size_t nprobe_val = context->nprobe > 0 ? context->nprobe : index->nprobe;
size_t storage_idx_val = global_q * nprobe_val + probe_rank;
const size_t probe_rank = probe_indices[local_q];
const size_t storage_idx_val = global_q * cached_nprobe + probe_rank;
const auto& query_factors = context->query_factors[storage_idx_val];

// Use list_codes_ptr (already set by set_list_context) and the
// pre-allocated unpack_buf to avoid per-refinement ScopedCodes
// re-acquisition and heap allocation.
packer->unpack_1(list_codes_ptr, local_offset, unpack_buf.data());
// Unpack PQ4-interleaved sign bits for this vector into a linear buffer.
packer->unpack_1(this->list_codes_ptr, local_offset, unpack_buf.data());

return rabitq_utils::compute_full_multibit_distance(
unpack_buf.data(),
ex_code,
ex_fac,
query_factors.rotated_q.data(),
(index->metric_type == MetricType::METRIC_INNER_PRODUCT)
? query_factors.q_dot_c
: query_factors.qr_to_c_L2sqr,
is_similarity ? query_factors.q_dot_c : query_factors.qr_to_c_L2sqr,
dim,
ex_bits,
index->metric_type);
Expand Down
8 changes: 8 additions & 0 deletions faiss/IndexRaBitQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ struct Run_search_with_dc_res {
// n_1bit_evaluations: candidates evaluated using 1-bit lower
// bound n_multibit_evaluations: candidates requiring full
// multi-bit distance
#ifndef NDEBUG
size_t local_1bit_evaluations = 0;
size_t local_multibit_evaluations = 0;
#endif

if (ex_bits == 0) {
// 1-bit: Standard single-stage search (no stats tracking)
Expand Down Expand Up @@ -142,7 +144,9 @@ struct Run_search_with_dc_res {
const uint8_t* code =
index->codes.data() + i * index->code_size;

#ifndef NDEBUG
local_1bit_evaluations++;
#endif

// Stage 1: Compute distance bound using 1-bit codes
// For L2 (min-heap): use lower_bound (est -
Expand All @@ -168,7 +172,9 @@ struct Run_search_with_dc_res {
resi.threshold,
is_similarity);
if (should_refine) {
#ifndef NDEBUG
local_multibit_evaluations++;
#endif
// Compute full multi-bit distance
float dist_full =
dc->distance_to_code_full(code);
Expand All @@ -178,12 +184,14 @@ struct Run_search_with_dc_res {
}
}

#ifndef NDEBUG
// Update global stats atomically
#pragma omp atomic
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
#pragma omp atomic
rabitq_stats.n_multibit_evaluations +=
local_multibit_evaluations;
#endif

resi.end();
}
Expand Down
8 changes: 8 additions & 0 deletions faiss/IndexRaBitQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,20 @@ struct RaBitQHeapHandler
const uint8_t* aux_base = rabitq_index->codes.get() +
block_idx * full_block_size + packed_block_size;

#ifndef NDEBUG
size_t local_1bit_evaluations = 0;
size_t local_multibit_evaluations = 0;
#endif

for (size_t i = 0; i < max_vectors; i++) {
const size_t db_idx = base_db_idx + i;
const float normalized_distance = d32tab[i] * one_a + bias;
const uint8_t* base_ptr = aux_base + i * storage_size;

if (is_multi_bit) {
#ifndef NDEBUG
local_1bit_evaluations++;
#endif

const SignBitFactorsWithError& full_factors =
*reinterpret_cast<const SignBitFactorsWithError*>(
Expand All @@ -252,7 +256,9 @@ struct RaBitQHeapHandler
is_similarity);

if (should_refine) {
#ifndef NDEBUG
local_multibit_evaluations++;
#endif
float dist_full = compute_full_multibit_distance(db_idx, q);

if (Cfloat::cmp(heap_dis[0], dist_full)) {
Expand Down Expand Up @@ -281,10 +287,12 @@ struct RaBitQHeapHandler
}
}

#ifndef NDEBUG
#pragma omp atomic
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
#pragma omp atomic
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
#endif
}

void begin(const float* norms) override {
Expand Down
10 changes: 7 additions & 3 deletions faiss/impl/fast_scan/rabitq_result_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ struct IVFRaBitQHeapHandler : ResultHandlerCompare<C, true, SL> {
int64_t* heap_labels; // [nq * k]
const size_t nq, k;
size_t current_list_no = 0;
const uint8_t* list_codes_ptr = nullptr; // raw block data for list
std::vector<int>
probe_indices; // probe index for each query in current batch
const FastScanDistancePostProcessing*
Expand All @@ -64,6 +63,11 @@ struct IVFRaBitQHeapHandler : ResultHandlerCompare<C, true, SL> {
// instance is confined to one search slice and not entered concurrently.
std::vector<uint8_t> unpack_buf; // reusable buffer for unpack_1

// Cached per-list values (set in set_list_context, avoid recomputing in
// handle)
size_t cached_nprobe = 0;
bool is_similarity = false; // metric == INNER_PRODUCT

// Use float-based comparator for heap operations
using Cfloat = typename std::conditional<
C::is_max,
Expand Down Expand Up @@ -98,10 +102,10 @@ struct IVFRaBitQHeapHandler : ResultHandlerCompare<C, true, SL> {

private:
float compute_full_multibit_distance(
size_t db_idx,
size_t local_q,
size_t global_q,
size_t local_offset);
size_t local_offset,
const uint8_t* aux_ptr);
};

} // namespace simd_result_handlers
Expand Down
Loading