Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 73 additions & 45 deletions src/core/algorithm/hnsw/hnsw_algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,20 @@ void HnswAlgorithm<EntityType>::add_neighbors(node_id_t id, level_t level,
// ============================================================================

// mmap/contiguous variant: resolve vectors via get_vector_ptr and use
// LinearPool or BlockHeap for visited tracking + top-k maintenance.
// HeapType must expose reset/set_visited/check_visited/push_block/has_next/pop.
template <typename EntityType, typename HeapType>
// LinearPool or BlockHeap for top-k maintenance.
// VisitImpl supplies inlined visited tracking (VisitBitMap/VisitByteMap/...).
template <typename EntityType, typename HeapType, typename VisitImpl>
void fast_search_neighbors(const EntityType &entity, HeapType &pool,
VisitFilter &visit, HnswDistCalculator &dc,
uint32_t topk, uint32_t ef, node_id_t entry_point,
dist_t entry_dist, uint32_t prefetch_lines,
uint32_t prefetch_offset) {
HnswDistCalculator &dc, uint32_t topk, uint32_t ef,
node_id_t entry_point, dist_t entry_dist,
uint32_t prefetch_lines, uint32_t prefetch_offset,
typename VisitImpl::Context *visit_ctx) {
const uint32_t max_deg = entity.max_degree(0); // level 0 only
const uint32_t cap = std::max(topk, ef);
pool.reset(static_cast<int32_t>(cap), static_cast<int32_t>(max_deg));
visit.clear();

visit.set_visited(entry_point);
VisitImpl::clear(visit_ctx);
VisitImpl::set_visited(visit_ctx, entry_point);
pool.push_block(&entry_dist, &entry_point, 1);

uint32_t buf_capacity = max_deg;
Expand Down Expand Up @@ -227,8 +227,8 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool,
// Phase 1: scan first `po` neighbors with prefetch.
for (; i < po; ++i) {
node_id_t node = neighbors[i];
if (visit.visited(node)) continue;
visit.set_visited(node);
if (VisitImpl::visited(visit_ctx, node)) continue;
VisitImpl::set_visited(visit_ctx, node);
const void *vec_ptr = entity.get_vector_ptr(node);
const char *p = reinterpret_cast<const char *>(vec_ptr);
for (uint32_t cl = 0; cl < prefetch_lines; ++cl) {
Expand All @@ -242,8 +242,8 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool,
// Phase 2: scan remaining neighbors.
for (; i < neighbors.size(); ++i) {
node_id_t node = neighbors[i];
if (visit.visited(node)) continue;
visit.set_visited(node);
if (VisitImpl::visited(visit_ctx, node)) continue;
VisitImpl::set_visited(visit_ctx, node);
neighbor_ids[unvisited_count] = node;
neighbor_vecs[unvisited_count] = entity.get_vector_ptr(node);
unvisited_count++;
Expand All @@ -257,6 +257,51 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool,
}
}

template <typename EntityType>
struct HnswFastSearchRunner {
const EntityType &entity;
HnswDistCalculator &dc;
HnswContext *ctx;
node_id_t entry_point;
dist_t entry_dist;
uint32_t topk_v;
uint32_t ef_v;
uint32_t prefetch_lines;
bool avx2_ok;
TopkHeap &topk;

void run(const VisitFilter &visit_filter) const {
dispatch_visit_filter(visit_filter, *this);
}

template <typename VisitImpl>
void operator()(typename VisitImpl::Context *visit_ctx) const {
if (avx2_ok) {
run_with_heap<BlockHeap, VisitImpl>(visit_ctx);
} else {
run_with_heap<LinearPool<float>, VisitImpl>(visit_ctx);
}
}

private:
template <typename HeapType, typename VisitImpl>
void run_with_heap(typename VisitImpl::Context *visit_ctx) const {
if constexpr (std::is_same_v<HeapType, BlockHeap>) {
auto &pool = ctx->block_pool();
fast_search_neighbors<EntityType, BlockHeap, VisitImpl>(
entity, pool, dc, topk_v, ef_v, entry_point, entry_dist,
prefetch_lines, ctx->po(), visit_ctx);
copy_pool_to_topk(pool, topk);
} else {
auto &pool = ctx->pool();
fast_search_neighbors<EntityType, LinearPool<float>, VisitImpl>(
entity, pool, dc, topk_v, ef_v, entry_point, entry_dist,
prefetch_lines, ctx->po(), visit_ctx);
copy_pool_to_topk(pool, topk);
}
}
};

// ============================================================================
// dual_heap_search_neighbors: shared core for the fallback dual-heap path.
//
Expand Down Expand Up @@ -393,7 +438,6 @@ void HnswAlgorithm<EntityType>::search_neighbors(level_t level,
HnswDistCalculator &dc = ctx->dist_calculator();

if (!use_pool || ctx->filter().is_valid() || level != 0) {
// Dual-heap path: add_node, filtered search, or upper-level scan.
auto run_with_filter = [&](auto &&filter) {
dual_heap_search_neighbors<EntityType, MemBlockType>(
entity, level, entry_point, dist, topk, ctx, dc,
Expand All @@ -409,38 +453,22 @@ void HnswAlgorithm<EntityType>::search_neighbors(level_t level,
auto filter = [](node_id_t) { return false; };
run_with_filter(filter);
}
} else if constexpr (std::is_same_v<MemBlockType, MmapMemoryBlock>) {
const uint32_t prefetch_lines =
ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64;
const uint32_t topk_v = static_cast<uint32_t>(ctx->topk());
const uint32_t ef_v = ctx->ef();
const bool avx2_ok =
zvec::ailego::internal::CpuFeatures::static_flags_.AVX2;

HnswFastSearchRunner<EntityType>{entity, dc, ctx, *entry_point,
*dist, topk_v, ef_v, prefetch_lines,
avx2_ok, topk}
.run(ctx->visit_filter());
} else {
// Pool-based path for level-0 unfiltered search.
if constexpr (std::is_same_v<MemBlockType, MmapMemoryBlock>) {
const uint32_t prefetch_lines =
ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64;

// Fast path: direct pointer access via get_vector_ptr.
// BlockHeap (AVX2) or LinearPool (scalar) for top-k tracking.
const uint32_t topk_v = static_cast<uint32_t>(ctx->topk());
const uint32_t ef_v = ctx->ef();
const bool avx2_ok =
zvec::ailego::internal::CpuFeatures::static_flags_.AVX2;

auto &visit = ctx->visit_filter();

if (avx2_ok) {
auto &bpool = ctx->block_pool();
fast_search_neighbors(entity, bpool, visit, dc, topk_v, ef_v,
*entry_point, *dist, prefetch_lines, ctx->po());
copy_pool_to_topk(bpool, topk);
} else {
auto &lpool = ctx->pool();
fast_search_neighbors(entity, lpool, visit, dc, topk_v, ef_v,
*entry_point, *dist, prefetch_lines, ctx->po());
copy_pool_to_topk(lpool, topk);
}
} else {
// BufferPool entities: fallback to dual-heap path.
auto filter = [](node_id_t) { return false; };
dual_heap_search_neighbors<EntityType, MemBlockType>(
entity, level, entry_point, dist, topk, ctx, dc, filter);
}
auto filter = [](node_id_t) { return false; };
dual_heap_search_neighbors<EntityType, MemBlockType>(
entity, level, entry_point, dist, topk, ctx, dc, filter);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithm/hnsw/hnsw_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ class HnswContext : public IndexContext {
uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt};
uint32_t topk_{0};
uint32_t group_topk_{0};
uint32_t filter_mode_{VisitFilter::ByteMap};
uint32_t filter_mode_{VisitFilter::BitMap};
float negative_probability_{HnswEntity::kDefaultBFNegativeProbability};
uint32_t ef_{HnswEntity::kDefaultEf};
uint32_t po_{8};
Expand Down
114 changes: 71 additions & 43 deletions src/core/algorithm/vamana/vamana_algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,26 @@ int VamanaAlgorithm<EntityType>::search(VamanaContext *ctx) const {
// wrappers via get_vector_typed to pin pages.
//
// Both accept either BlockHeap or LinearPool as `HeapType` because the
// two expose the same reset(n, ef, block_size) / push_block(dists, ids, n)
// two expose the same reset(capacity, block_size) / push_block(dists, ids, n)
// surface (LinearPool adapts via push_block and ignores the block_size hint).
// ============================================================================

// mmap/contiguous variant: resolve vectors via get_vector_ptr
// and dispatch to the classic pointer-array batch_dist.
template <typename EntityType, typename HeapType>
template <typename EntityType, typename HeapType, typename VisitImpl>
void fast_greedy_search(const EntityType &entity, HeapType &pool,
VisitFilter &visit, VamanaDistCalculator &dc,
uint32_t topk, uint32_t ef, node_id_t entry_point,
uint32_t prefetch_lines, uint32_t prefetch_offset) {
VamanaDistCalculator &dc, uint32_t topk, uint32_t ef,
node_id_t entry_point, uint32_t prefetch_lines,
uint32_t prefetch_offset,
typename VisitImpl::Context *visit_ctx) {
const uint32_t max_deg = entity.max_degree();
const uint32_t cap = std::max(topk, ef);
pool.reset(static_cast<int32_t>(cap), static_cast<int32_t>(max_deg));
visit.clear();

VisitImpl::clear(visit_ctx);

dist_t ep_dist = dc.batch_dist(entry_point);
visit.set_visited(entry_point);
VisitImpl::set_visited(visit_ctx, entry_point);
pool.push_block(&ep_dist, &entry_point, 1);

uint32_t buf_capacity = max_deg;
Expand Down Expand Up @@ -161,8 +163,8 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool,

for (; i < po; ++i) {
node_id_t node = neighbors[i];
if (visit.visited(node)) continue;
visit.set_visited(node);
if (VisitImpl::visited(visit_ctx, node)) continue;
VisitImpl::set_visited(visit_ctx, node);
const void *vec_ptr = entity.get_vector_ptr(node);
const char *p = reinterpret_cast<const char *>(vec_ptr);
for (uint32_t cl = 0; cl < prefetch_lines; ++cl) {
Expand All @@ -174,8 +176,8 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool,
}
for (; i < neighbors.size(); ++i) {
node_id_t node = neighbors[i];
if (visit.visited(node)) continue;
visit.set_visited(node);
if (VisitImpl::visited(visit_ctx, node)) continue;
VisitImpl::set_visited(visit_ctx, node);
neighbor_ids[unvisited_count] = node;
neighbor_vecs[unvisited_count] = entity.get_vector_ptr(node);
unvisited_count++;
Expand All @@ -188,6 +190,50 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool,
}
}

template <typename EntityType>
struct VamanaFastGreedyRunner {
const EntityType &entity;
VamanaDistCalculator &dc;
VamanaContext *ctx;
uint32_t topk_v;
uint32_t ef_v;
node_id_t entry_point;
uint32_t prefetch_lines;
bool avx2_ok;
TopkHeap &topk_heap;

void run(const VisitFilter &visit_filter) const {
dispatch_visit_filter(visit_filter, *this);
}

template <typename VisitImpl>
void operator()(typename VisitImpl::Context *visit_ctx) const {
if (avx2_ok) {
run_with_heap<BlockHeap, VisitImpl>(visit_ctx);
} else {
run_with_heap<LinearPool<float>, VisitImpl>(visit_ctx);
}
}

private:
template <typename HeapType, typename VisitImpl>
void run_with_heap(typename VisitImpl::Context *visit_ctx) const {
if constexpr (std::is_same_v<HeapType, BlockHeap>) {
auto &pool = ctx->block_pool();
fast_greedy_search<EntityType, BlockHeap, VisitImpl>(
entity, pool, dc, topk_v, ef_v, entry_point, prefetch_lines,
ctx->po(), visit_ctx);
copy_pool_to_topk(pool, topk_heap);
} else {
auto &pool = ctx->pool();
fast_greedy_search<EntityType, LinearPool<float>, VisitImpl>(
entity, pool, dc, topk_v, ef_v, entry_point, prefetch_lines,
ctx->po(), visit_ctx);
copy_pool_to_topk(pool, topk_heap);
}
}
};

// ============================================================================
// dual_heap_greedy_search: shared core for the fallback dual-heap path.
//
Expand Down Expand Up @@ -322,8 +368,6 @@ void VamanaAlgorithm<EntityType>::greedy_search(node_id_t entry_point,
ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64;

if (!use_pool || index_filter.is_valid()) {
// Fallback path used by add_node (use_pool=false) and filtered search.
// Dispatched to dual_heap_greedy_search (plain batch_dist).
auto run_with_filter = [&](auto &&filter) {
dual_heap_greedy_search<EntityType, MemBlockType>(
entity, ctx, dc, entry_point, std::forward<decltype(filter)>(filter));
Expand All @@ -338,37 +382,21 @@ void VamanaAlgorithm<EntityType>::greedy_search(node_id_t entry_point,
auto filter = [](node_id_t) { return false; };
run_with_filter(filter);
}
} else if constexpr (std::is_same_v<MemBlockType, MmapMemoryBlock>) {
const uint32_t topk_v = static_cast<uint32_t>(ctx->topk());
const uint32_t ef_v = ctx->ef();
const bool avx2_ok =
zvec::ailego::internal::CpuFeatures::static_flags_.AVX2;
auto &topk_heap = ctx->topk_heap();

VamanaFastGreedyRunner<EntityType>{entity, dc, ctx,
topk_v, ef_v, entry_point,
prefetch_lines, avx2_ok, topk_heap}
.run(ctx->visit_filter());
} else {
// Fast pool-based path for mmap/contiguous entities that support
// direct pointer access. BlockHeap (AVX2) or LinearPool (scalar)
// are used for top-k tracking. BufferPool entities fall back to
// dual_heap_greedy_search since they lack direct pointer access.
if constexpr (std::is_same_v<MemBlockType, MmapMemoryBlock>) {
const uint32_t topk_v = static_cast<uint32_t>(ctx->topk());
const uint32_t ef_v = ctx->ef();
const bool avx2_ok =
zvec::ailego::internal::CpuFeatures::static_flags_.AVX2;
auto &topk_heap = ctx->topk_heap();

auto &visit = ctx->visit_filter();

if (avx2_ok) {
auto &bpool = ctx->block_pool();
fast_greedy_search(entity, bpool, visit, dc, topk_v, ef_v, entry_point,
prefetch_lines, ctx->po());
copy_pool_to_topk(bpool, topk_heap);
} else {
auto &lpool = ctx->pool();
fast_greedy_search(entity, lpool, visit, dc, topk_v, ef_v, entry_point,
prefetch_lines, ctx->po());
copy_pool_to_topk(lpool, topk_heap);
}
} else {
// BufferPool entities: fallback to dual-heap path.
auto filter = [](node_id_t) { return false; };
dual_heap_greedy_search<EntityType, MemBlockType>(entity, ctx, dc,
entry_point, filter);
}
auto filter = [](node_id_t) { return false; };
dual_heap_greedy_search<EntityType, MemBlockType>(entity, ctx, dc,
entry_point, filter);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithm/vamana/vamana_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class VamanaContext : public IndexContext {
//! Cached build-time distance offset (see build_distance_offset()).
float build_distance_offset_{0.0f};

VisitFilter::Mode filter_mode_{VisitFilter::ByteMap};
VisitFilter::Mode filter_mode_{VisitFilter::BitMap};
float filter_negative_prob_{VamanaEntity::kDefaultBFNegativeProbability};

LinearPool<dist_t> pool_;
Expand Down
Loading
Loading