Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,11 +1432,19 @@ size_t InvertedListScanner::iterate_codes(
size_t nup = 0;
list_size = 0;

const bool has_cb = it->has_search_callbacks_;

if (!keep_max) {
for (; it->is_available(); it->next()) {
auto id_and_codes = it->get_id_and_codes();
float dis = distance_to_code(id_and_codes.second);
if (has_cb) {
it->on_distance_computed(id_and_codes.first, dis);
}
if (dis < simi[0]) {
if (has_cb) {
it->on_heap_changed(id_and_codes.first, idxi[0]);
}
maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
nup++;
}
Expand All @@ -1446,7 +1454,13 @@ size_t InvertedListScanner::iterate_codes(
for (; it->is_available(); it->next()) {
auto id_and_codes = it->get_id_and_codes();
float dis = distance_to_code(id_and_codes.second);
if (has_cb) {
it->on_distance_computed(id_and_codes.first, dis);
}
if (dis > simi[0]) {
if (has_cb) {
it->on_heap_changed(id_and_codes.first, idxi[0]);
}
minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
nup++;
}
Expand Down
18 changes: 18 additions & 0 deletions faiss/invlists/InvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ struct InvertedListsIterator {
virtual bool is_available() const = 0;
virtual void next() = 0;
virtual std::pair<idx_t, const uint8_t*> get_id_and_codes() = 0;

/// When true, iterate_codes will invoke on_distance_computed() and
/// on_heap_changed() via virtual dispatch. When false (the default),
/// iterate_codes skips the callbacks entirely — the guard branch is
/// perfectly predicted and costs ~0 cycles, so non-callback users
/// pay no overhead. Derived classes that override the callbacks
/// should set this to true in their constructor.
bool has_search_callbacks_ = false;

/// Called from iterate_codes after distance computation for the vector
/// returned by the most recent get_id_and_codes(). Default: no-op.
/// Only invoked when has_search_callbacks_ is true.
virtual void on_distance_computed(idx_t /* vid */, float /* distance */) {}

/// Called from iterate_codes when a vector replaces the current worst
/// in the top-K heap. evicted_id is the displaced entry. Default: no-op.
/// Only invoked when has_search_callbacks_ is true.
virtual void on_heap_changed(idx_t /* new_id */, idx_t /* evicted_id */) {}
};

/** Table of inverted lists
Expand Down
121 changes: 121 additions & 0 deletions tests/test_ivf_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,124 @@ TEST(IVF, search1_out_of_range_key) {

EXPECT_THROW(idx.search1(xq.data(), handler), faiss::FaissException);
}

// Iterator that enables search callbacks and tracks invocations.
class CallbackTrackingIterator : public TestInvertedListIterator {
public:
CallbackTrackingIterator(
size_t list_no,
TestContext* context,
size_t& distance_count,
size_t& heap_count)
: TestInvertedListIterator(list_no, context),
distance_count_{distance_count},
heap_count_{heap_count} {
has_search_callbacks_ = true;
}

void on_distance_computed(faiss::idx_t id, float distance) override {
EXPECT_GE(id, 0) << "vector ID should be non-negative";
EXPECT_GE(distance, 0.0f) << "L2 distance should be non-negative";
distance_count_++;
}

void on_heap_changed(faiss::idx_t new_id, faiss::idx_t evicted_id)
override {
EXPECT_GE(new_id, 0) << "new heap entry ID should be non-negative";
(void)evicted_id; // may be -1 when heap not yet full
heap_count_++;
}

private:
size_t& distance_count_;
size_t& heap_count_;
};

// InvertedLists that uses CallbackTrackingIterator.
class CallbackTrackingInvertedLists : public TestInvertedLists {
public:
CallbackTrackingInvertedLists(
size_t nlist_in,
size_t code_size_in,
size_t& distance_count,
size_t& heap_count)
: TestInvertedLists(nlist_in, code_size_in),
distance_count_{distance_count},
heap_count_{heap_count} {}

faiss::InvertedListsIterator* get_iterator(size_t list_no, void* context)
const override {
auto testContext = (TestContext*)context;
testContext->lists_probed.insert(list_no);
return new CallbackTrackingIterator(
list_no, testContext, distance_count_, heap_count_);
}

private:
size_t& distance_count_;
size_t& heap_count_;
};

// Test: on_distance_computed and on_heap_changed fire during search
// when has_search_callbacks_ is true.
TEST(IVF, search_callbacks) {
constexpr int d = 8;
constexpr int nb = 200;
constexpr int nlist = 4;

std::mt19937 rng(42);
std::uniform_real_distribution<> distrib;

omp_set_num_threads(1);

faiss::IndexFlatL2 quantizer(d);
faiss::IndexIVFFlat index(&quantizer, d, nlist);

size_t distance_count = 0;
size_t heap_count = 0;
CallbackTrackingInvertedLists invlists(
nlist, index.code_size, distance_count, heap_count);
index.replace_invlists(&invlists);

// Train
constexpr size_t nt = 100;
std::vector<float> trainvecs(nt * d);
for (size_t i = 0; i < nt * d; i++) {
trainvecs[i] = distrib(rng);
}
index.train(nt, trainvecs.data());

// Populate via context
TestContext context;
std::vector<float> database(nb * d);
for (size_t i = 0; i < nb * d; i++) {
database[i] = distrib(rng);
}
std::vector<faiss::idx_t> coarse_idx(nb);
index.quantizer->assign(nb, database.data(), coarse_idx.data());
std::vector<faiss::idx_t> xids(nb, 42);
index.add_core(
nb, database.data(), xids.data(), coarse_idx.data(), &context);

// Search
constexpr faiss::idx_t k = 5;
constexpr size_t nprobe = 2;
std::vector<float> query(d);
for (int i = 0; i < d; i++) {
query[i] = distrib(rng);
}
std::vector<float> distances(k);
std::vector<faiss::idx_t> labels(k);
faiss::SearchParametersIVF params;
params.inverted_list_context = &context;
params.nprobe = nprobe;

index.search(1, query.data(), k, distances.data(), labels.data(), &params);

EXPECT_GT(distance_count, 0)
<< "on_distance_computed should fire for scored vectors";
EXPECT_GT(heap_count, 0)
<< "on_heap_changed should fire when vectors enter the heap";
EXPECT_GE(distance_count, heap_count)
<< "not every distance computation leads to a heap change";
}
Loading