Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 179 additions & 6 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <vector>

#include <omp.h>

Expand Down Expand Up @@ -589,6 +590,157 @@ int distance_compute_blas_query_bs = 4096;
int distance_compute_blas_database_bs = 1024;
int distance_compute_min_k_reservoir = 100;

// Database-parallel KNN: parallelizes over database segments instead of
// queries, for the case where nx < nthreads and the database is large.
static constexpr size_t kDbParallelMinVectors = 10000;

template <class C>
static void knn_db_parallel_impl(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
size_t k,
float* vals,
int64_t* ids,
const float* y_norms) {
using T = typename C::T;
using TI = typename C::TI;

int nt = omp_get_max_threads();
size_t segment = (ny + nt - 1) / nt;
const size_t bs_y = distance_compute_blas_database_bs;

// Per-thread result heaps: nt threads x nx queries x k results
std::vector<T> all_dis(static_cast<size_t>(nt) * nx * k);
std::vector<TI> all_ids(static_cast<size_t>(nt) * nx * k);

// Precompute norms for L2
std::unique_ptr<float[]> x_norms_storage;
std::unique_ptr<float[]> y_norms_storage;
const float* x_norms = nullptr;
if constexpr (C::is_max) {
x_norms_storage.reset(new float[nx]);
fvec_norms_L2sqr(x_norms_storage.get(), x, d, nx);
x_norms = x_norms_storage.get();

if (!y_norms) {
y_norms_storage.reset(new float[ny]);
fvec_norms_L2sqr(y_norms_storage.get(), y, d, ny);
y_norms = y_norms_storage.get();
}
}

#pragma omp parallel
{
int tid = omp_get_thread_num();
size_t j_begin = tid * segment;
size_t j_end = std::min(j_begin + segment, ny);
size_t local_ny = (j_begin < ny) ? (j_end - j_begin) : 0;

T* my_dis = all_dis.data() + tid * nx * k;
TI* my_ids = all_ids.data() + tid * nx * k;

// Each thread initializes its own heaps
for (size_t i = 0; i < nx; i++) {
heap_heapify<C>(k, my_dis + i * k, my_ids + i * k);
}

if (local_ny > 0) {
size_t max_block = std::min(bs_y, local_ny);
std::unique_ptr<float[]> ip_block(new float[nx * max_block]);

for (size_t jj0 = 0; jj0 < local_ny; jj0 += bs_y) {
size_t jj1 = std::min(jj0 + bs_y, local_ny);
size_t block_ny = jj1 - jj0;

{
float one = 1, zero = 0;
FINTEGER nyi = block_ny, nxi = nx, di = d;
sgemm_("Transpose",
"Not transpose",
&nyi,
&nxi,
&di,
&one,
y + (j_begin + jj0) * d,
&di,
x,
&di,
&zero,
ip_block.get(),
&nyi);
}

for (size_t i = 0; i < nx; i++) {
T* heap_dis = my_dis + i * k;
TI* heap_ids = my_ids + i * k;
const float* ip_line = ip_block.get() + i * block_ny;
T thresh = heap_dis[0];

for (size_t jj = 0; jj < block_ny; jj++) {
size_t global_j = j_begin + jj0 + jj;
float ip = ip_line[jj];
T dis;

if constexpr (C::is_max) {
dis = x_norms[i] + y_norms[global_j] - 2 * ip;
if (dis < 0)
dis = 0;
} else {
dis = ip;
}

if (C::cmp(thresh, dis)) {
heap_replace_top<C>(
k, heap_dis, heap_ids, dis, global_j);
thresh = heap_dis[0];
}
}
}
}
}
}

// Merge per-thread heaps into output, parallelized over queries
#pragma omp parallel for
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
heap_heapify<C>(k, vals + i * k, ids + i * k);

for (int t = 0; t < nt; t++) {
T* t_dis = all_dis.data() + (t * nx + i) * k;
TI* t_ids = all_ids.data() + (t * nx + i) * k;
T* out_dis = vals + i * k;
TI* out_ids = ids + i * k;

for (size_t j = 0; j < k; j++) {
if (t_ids[j] >= 0 && C::cmp(out_dis[0], t_dis[j])) {
heap_replace_top<C>(
k, out_dis, out_ids, t_dis[j], t_ids[j]);
}
}
}

heap_reorder<C>(k, vals + i * k, ids + i * k);
}
}

static bool should_use_db_parallel(
size_t nx,
size_t ny,
const IDSelector* sel) {
if (sel) {
return false;
}
int nt = omp_get_max_threads();
size_t min_ny = std::max(
kDbParallelMinVectors,
static_cast<size_t>(nt) *
static_cast<size_t>(distance_compute_blas_database_bs));
return nt > 1 && nx < static_cast<size_t>(nt) && ny >= min_ny;
}

void knn_inner_product(
const float* x,
const float* y,
Expand All @@ -613,9 +765,25 @@ void knn_inner_product(
return;
}

Run_search_inner_product r;
dispatch_knn_ResultHandler(
nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
if (should_use_db_parallel(nx, ny, sel)) {
knn_db_parallel_impl<CMin<float, int64_t>>(
x, y, d, nx, ny, k, vals, ids, nullptr);
} else {
Run_search_inner_product r;
dispatch_knn_ResultHandler(
nx,
vals,
ids,
k,
METRIC_INNER_PRODUCT,
sel,
r,
x,
y,
d,
nx,
ny);
}

if (imin != 0) {
for (size_t i = 0; i < nx * k; i++) {
Expand Down Expand Up @@ -662,9 +830,14 @@ void knn_L2sqr(
return;
}

Run_search_L2sqr r;
dispatch_knn_ResultHandler(
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
if (should_use_db_parallel(nx, ny, sel)) {
knn_db_parallel_impl<CMax<float, int64_t>>(
x, y, d, nx, ny, k, vals, ids, y_norm2);
} else {
Run_search_L2sqr r;
dispatch_knn_ResultHandler(
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
}

if (imin != 0) {
for (size_t i = 0; i < nx * k; i++) {
Expand Down
83 changes: 83 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,89 @@ def test_with_blas_reservoir_ip(self):
self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150)


class TestDbParallelSearch(unittest.TestCase):
"""Test the database-parallel search path that activates when
the number of queries is smaller than the thread count and the
database is large enough. Validates correctness for both IP and L2
metrics, multiple k values, and edge cases."""

def _check(self, metric_type, nq, nb, k):
d = 64
np.random.seed(1234)
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')

index = faiss.IndexFlat(d, metric_type)
index.add(xb)
D, I = index.search(xq, k)

# compute ground truth with numpy
if metric_type == faiss.METRIC_L2:
all_dis = ((xq.reshape(nq, 1, d) -
xb.reshape(1, nb, d)) ** 2).sum(2)
Iref = all_dis.argsort(axis=1)[:, :k]
else:
all_dis = np.dot(xq, xb.T)
Iref = all_dis.argsort(axis=1)[:, ::-1][:, :k]

Dref = all_dis[np.arange(nq)[:, None], Iref]
np.testing.assert_almost_equal(Dref, D, decimal=5)

def test_ip_single_query(self):
"""nx=1, triggers db-parallel when nthreads > 1"""
self._check(faiss.METRIC_INNER_PRODUCT, nq=1, nb=20000, k=10)

def test_ip_few_queries(self):
"""nx=4, typical case for db-parallel"""
self._check(faiss.METRIC_INNER_PRODUCT, nq=4, nb=20000, k=10)

def test_l2_single_query(self):
self._check(faiss.METRIC_L2, nq=1, nb=20000, k=10)

def test_l2_few_queries(self):
self._check(faiss.METRIC_L2, nq=4, nb=20000, k=10)

def test_ip_k1(self):
"""k=1 uses Top1BlockResultHandler in original path"""
self._check(faiss.METRIC_INNER_PRODUCT, nq=2, nb=20000, k=1)

def test_l2_k1(self):
self._check(faiss.METRIC_L2, nq=2, nb=20000, k=1)

def test_ip_large_k(self):
"""k=200 uses ReservoirBlockResultHandler in original path"""
self._check(faiss.METRIC_INNER_PRODUCT, nq=2, nb=20000, k=200)

def test_l2_large_k(self):
self._check(faiss.METRIC_L2, nq=2, nb=20000, k=200)

def test_thread_scaling(self):
"""Verify results are identical across different thread counts."""
d = 64
nb = 30000
nq = 2
k = 10
np.random.seed(42)
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')

index = faiss.IndexFlatIP(d)
index.add(xb)

saved_nt = faiss.omp_get_max_threads()
try:
faiss.omp_set_num_threads(1)
D1, I1 = index.search(xq, k)

for nt in [2, 4, 8]:
faiss.omp_set_num_threads(nt)
Dn, In = index.search(xq, k)
np.testing.assert_array_equal(I1, In)
np.testing.assert_almost_equal(D1, Dn, decimal=5)
finally:
faiss.omp_set_num_threads(saved_nt)


class TestIndexFlatL2(unittest.TestCase):
def test_indexflat_l2_sync_norms_1(self):
d = 32
Expand Down