From 4c3295984bc8a978d427ba31d11dda4628a9a764 Mon Sep 17 00:00:00 2001 From: luoxiaojian Date: Tue, 16 Jun 2026 14:50:07 +0800 Subject: [PATCH 1/2] perf(hnsw,vamana): VisitImpl dispatch on fast pool search path Dispatch fast BlockHeap/LinearPool search through friend dispatch_visit_filter without exposing VisitFilter::context() or is_allocated(). Default streamer visit filter mode is BitMap. Co-authored-by: Cursor --- src/core/algorithm/hnsw/hnsw_algorithm.cc | 118 +++++++++++------- src/core/algorithm/hnsw/hnsw_context.h | 2 +- src/core/algorithm/vamana/vamana_algorithm.cc | 114 ++++++++++------- src/core/algorithm/vamana/vamana_context.h | 2 +- src/core/utility/visit_filter.h | 34 ++++- .../record_quantized_int8/common.h | 2 +- .../squared_euclidean.cc | 4 +- .../uniform_int8/squared_euclidean.cc | 4 +- 8 files changed, 184 insertions(+), 96 deletions(-) diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.cc b/src/core/algorithm/hnsw/hnsw_algorithm.cc index 8c6fcfe17..604ce7877 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.cc +++ b/src/core/algorithm/hnsw/hnsw_algorithm.cc @@ -185,20 +185,20 @@ void HnswAlgorithm::add_neighbors(node_id_t id, level_t level, // ============================================================================ // mmap/contiguous variant: resolve vectors via get_vector_ptr and use -// LinearPool or BlockHeap for visited tracking + top-k maintenance. -// HeapType must expose reset/set_visited/check_visited/push_block/has_next/pop. -template +// LinearPool or BlockHeap for top-k maintenance. +// VisitImpl supplies inlined visited tracking (VisitBitMap/VisitByteMap/...). +template void fast_search_neighbors(const EntityType &entity, HeapType &pool, - VisitFilter &visit, HnswDistCalculator &dc, - uint32_t topk, uint32_t ef, node_id_t entry_point, - dist_t entry_dist, uint32_t prefetch_lines, - uint32_t prefetch_offset) { + HnswDistCalculator &dc, uint32_t topk, uint32_t ef, + node_id_t entry_point, dist_t entry_dist, + uint32_t prefetch_lines, uint32_t prefetch_offset, + typename VisitImpl::Context *visit_ctx) { const uint32_t max_deg = entity.max_degree(0); // level 0 only const uint32_t cap = std::max(topk, ef); pool.reset(static_cast(cap), static_cast(max_deg)); - visit.clear(); - visit.set_visited(entry_point); + VisitImpl::clear(visit_ctx); + VisitImpl::set_visited(visit_ctx, entry_point); pool.push_block(&entry_dist, &entry_point, 1); uint32_t buf_capacity = max_deg; @@ -227,8 +227,8 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool, // Phase 1: scan first `po` neighbors with prefetch. for (; i < po; ++i) { node_id_t node = neighbors[i]; - if (visit.visited(node)) continue; - visit.set_visited(node); + if (VisitImpl::visited(visit_ctx, node)) continue; + VisitImpl::set_visited(visit_ctx, node); const void *vec_ptr = entity.get_vector_ptr(node); const char *p = reinterpret_cast(vec_ptr); for (uint32_t cl = 0; cl < prefetch_lines; ++cl) { @@ -242,8 +242,8 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool, // Phase 2: scan remaining neighbors. for (; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; - if (visit.visited(node)) continue; - visit.set_visited(node); + if (VisitImpl::visited(visit_ctx, node)) continue; + VisitImpl::set_visited(visit_ctx, node); neighbor_ids[unvisited_count] = node; neighbor_vecs[unvisited_count] = entity.get_vector_ptr(node); unvisited_count++; @@ -257,6 +257,51 @@ void fast_search_neighbors(const EntityType &entity, HeapType &pool, } } +template +struct HnswFastSearchRunner { + const EntityType &entity; + HnswDistCalculator &dc; + HnswContext *ctx; + node_id_t entry_point; + dist_t entry_dist; + uint32_t topk_v; + uint32_t ef_v; + uint32_t prefetch_lines; + bool avx2_ok; + TopkHeap &topk; + + void run(const VisitFilter &visit_filter) const { + dispatch_visit_filter(visit_filter, *this); + } + + template + void operator()(typename VisitImpl::Context *visit_ctx) const { + if (avx2_ok) { + run_with_heap(visit_ctx); + } else { + run_with_heap, VisitImpl>(visit_ctx); + } + } + + private: + template + void run_with_heap(typename VisitImpl::Context *visit_ctx) const { + if constexpr (std::is_same_v) { + auto &pool = ctx->block_pool(); + fast_search_neighbors( + entity, pool, dc, topk_v, ef_v, entry_point, entry_dist, + prefetch_lines, ctx->po(), visit_ctx); + copy_pool_to_topk(pool, topk); + } else { + auto &pool = ctx->pool(); + fast_search_neighbors, VisitImpl>( + entity, pool, dc, topk_v, ef_v, entry_point, entry_dist, + prefetch_lines, ctx->po(), visit_ctx); + copy_pool_to_topk(pool, topk); + } + } +}; + // ============================================================================ // dual_heap_search_neighbors: shared core for the fallback dual-heap path. // @@ -393,7 +438,6 @@ void HnswAlgorithm::search_neighbors(level_t level, HnswDistCalculator &dc = ctx->dist_calculator(); if (!use_pool || ctx->filter().is_valid() || level != 0) { - // Dual-heap path: add_node, filtered search, or upper-level scan. auto run_with_filter = [&](auto &&filter) { dual_heap_search_neighbors( entity, level, entry_point, dist, topk, ctx, dc, @@ -409,38 +453,22 @@ void HnswAlgorithm::search_neighbors(level_t level, auto filter = [](node_id_t) { return false; }; run_with_filter(filter); } + } else if constexpr (std::is_same_v) { + const uint32_t prefetch_lines = + ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64; + const uint32_t topk_v = static_cast(ctx->topk()); + const uint32_t ef_v = ctx->ef(); + const bool avx2_ok = + zvec::ailego::internal::CpuFeatures::static_flags_.AVX2; + + HnswFastSearchRunner{entity, dc, ctx, *entry_point, + *dist, topk_v, ef_v, prefetch_lines, + avx2_ok, topk} + .run(ctx->visit_filter()); } else { - // Pool-based path for level-0 unfiltered search. - if constexpr (std::is_same_v) { - const uint32_t prefetch_lines = - ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64; - - // Fast path: direct pointer access via get_vector_ptr. - // BlockHeap (AVX2) or LinearPool (scalar) for top-k tracking. - const uint32_t topk_v = static_cast(ctx->topk()); - const uint32_t ef_v = ctx->ef(); - const bool avx2_ok = - zvec::ailego::internal::CpuFeatures::static_flags_.AVX2; - - auto &visit = ctx->visit_filter(); - - if (avx2_ok) { - auto &bpool = ctx->block_pool(); - fast_search_neighbors(entity, bpool, visit, dc, topk_v, ef_v, - *entry_point, *dist, prefetch_lines, ctx->po()); - copy_pool_to_topk(bpool, topk); - } else { - auto &lpool = ctx->pool(); - fast_search_neighbors(entity, lpool, visit, dc, topk_v, ef_v, - *entry_point, *dist, prefetch_lines, ctx->po()); - copy_pool_to_topk(lpool, topk); - } - } else { - // BufferPool entities: fallback to dual-heap path. - auto filter = [](node_id_t) { return false; }; - dual_heap_search_neighbors( - entity, level, entry_point, dist, topk, ctx, dc, filter); - } + auto filter = [](node_id_t) { return false; }; + dual_heap_search_neighbors( + entity, level, entry_point, dist, topk, ctx, dc, filter); } } diff --git a/src/core/algorithm/hnsw/hnsw_context.h b/src/core/algorithm/hnsw/hnsw_context.h index bc97bbf6a..6d04cd41f 100644 --- a/src/core/algorithm/hnsw/hnsw_context.h +++ b/src/core/algorithm/hnsw/hnsw_context.h @@ -537,7 +537,7 @@ class HnswContext : public IndexContext { uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt}; uint32_t topk_{0}; uint32_t group_topk_{0}; - uint32_t filter_mode_{VisitFilter::ByteMap}; + uint32_t filter_mode_{VisitFilter::BitMap}; float negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; uint32_t ef_{HnswEntity::kDefaultEf}; uint32_t po_{8}; diff --git a/src/core/algorithm/vamana/vamana_algorithm.cc b/src/core/algorithm/vamana/vamana_algorithm.cc index bf0fe5e6e..db9d3857c 100644 --- a/src/core/algorithm/vamana/vamana_algorithm.cc +++ b/src/core/algorithm/vamana/vamana_algorithm.cc @@ -116,24 +116,26 @@ int VamanaAlgorithm::search(VamanaContext *ctx) const { // wrappers via get_vector_typed to pin pages. // // Both accept either BlockHeap or LinearPool as `HeapType` because the -// two expose the same reset(n, ef, block_size) / push_block(dists, ids, n) +// two expose the same reset(capacity, block_size) / push_block(dists, ids, n) // surface (LinearPool adapts via push_block and ignores the block_size hint). // ============================================================================ // mmap/contiguous variant: resolve vectors via get_vector_ptr // and dispatch to the classic pointer-array batch_dist. -template +template void fast_greedy_search(const EntityType &entity, HeapType &pool, - VisitFilter &visit, VamanaDistCalculator &dc, - uint32_t topk, uint32_t ef, node_id_t entry_point, - uint32_t prefetch_lines, uint32_t prefetch_offset) { + VamanaDistCalculator &dc, uint32_t topk, uint32_t ef, + node_id_t entry_point, uint32_t prefetch_lines, + uint32_t prefetch_offset, + typename VisitImpl::Context *visit_ctx) { const uint32_t max_deg = entity.max_degree(); const uint32_t cap = std::max(topk, ef); pool.reset(static_cast(cap), static_cast(max_deg)); - visit.clear(); + + VisitImpl::clear(visit_ctx); dist_t ep_dist = dc.batch_dist(entry_point); - visit.set_visited(entry_point); + VisitImpl::set_visited(visit_ctx, entry_point); pool.push_block(&ep_dist, &entry_point, 1); uint32_t buf_capacity = max_deg; @@ -161,8 +163,8 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool, for (; i < po; ++i) { node_id_t node = neighbors[i]; - if (visit.visited(node)) continue; - visit.set_visited(node); + if (VisitImpl::visited(visit_ctx, node)) continue; + VisitImpl::set_visited(visit_ctx, node); const void *vec_ptr = entity.get_vector_ptr(node); const char *p = reinterpret_cast(vec_ptr); for (uint32_t cl = 0; cl < prefetch_lines; ++cl) { @@ -174,8 +176,8 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool, } for (; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; - if (visit.visited(node)) continue; - visit.set_visited(node); + if (VisitImpl::visited(visit_ctx, node)) continue; + VisitImpl::set_visited(visit_ctx, node); neighbor_ids[unvisited_count] = node; neighbor_vecs[unvisited_count] = entity.get_vector_ptr(node); unvisited_count++; @@ -188,6 +190,50 @@ void fast_greedy_search(const EntityType &entity, HeapType &pool, } } +template +struct VamanaFastGreedyRunner { + const EntityType &entity; + VamanaDistCalculator &dc; + VamanaContext *ctx; + uint32_t topk_v; + uint32_t ef_v; + node_id_t entry_point; + uint32_t prefetch_lines; + bool avx2_ok; + TopkHeap &topk_heap; + + void run(const VisitFilter &visit_filter) const { + dispatch_visit_filter(visit_filter, *this); + } + + template + void operator()(typename VisitImpl::Context *visit_ctx) const { + if (avx2_ok) { + run_with_heap(visit_ctx); + } else { + run_with_heap, VisitImpl>(visit_ctx); + } + } + + private: + template + void run_with_heap(typename VisitImpl::Context *visit_ctx) const { + if constexpr (std::is_same_v) { + auto &pool = ctx->block_pool(); + fast_greedy_search( + entity, pool, dc, topk_v, ef_v, entry_point, prefetch_lines, + ctx->po(), visit_ctx); + copy_pool_to_topk(pool, topk_heap); + } else { + auto &pool = ctx->pool(); + fast_greedy_search, VisitImpl>( + entity, pool, dc, topk_v, ef_v, entry_point, prefetch_lines, + ctx->po(), visit_ctx); + copy_pool_to_topk(pool, topk_heap); + } + } +}; + // ============================================================================ // dual_heap_greedy_search: shared core for the fallback dual-heap path. // @@ -322,8 +368,6 @@ void VamanaAlgorithm::greedy_search(node_id_t entry_point, ctx->pl() > 0 ? ctx->pl() : (entity.vector_size() + 63) / 64; if (!use_pool || index_filter.is_valid()) { - // Fallback path used by add_node (use_pool=false) and filtered search. - // Dispatched to dual_heap_greedy_search (plain batch_dist). auto run_with_filter = [&](auto &&filter) { dual_heap_greedy_search( entity, ctx, dc, entry_point, std::forward(filter)); @@ -338,37 +382,21 @@ void VamanaAlgorithm::greedy_search(node_id_t entry_point, auto filter = [](node_id_t) { return false; }; run_with_filter(filter); } + } else if constexpr (std::is_same_v) { + const uint32_t topk_v = static_cast(ctx->topk()); + const uint32_t ef_v = ctx->ef(); + const bool avx2_ok = + zvec::ailego::internal::CpuFeatures::static_flags_.AVX2; + auto &topk_heap = ctx->topk_heap(); + + VamanaFastGreedyRunner{entity, dc, ctx, + topk_v, ef_v, entry_point, + prefetch_lines, avx2_ok, topk_heap} + .run(ctx->visit_filter()); } else { - // Fast pool-based path for mmap/contiguous entities that support - // direct pointer access. BlockHeap (AVX2) or LinearPool (scalar) - // are used for top-k tracking. BufferPool entities fall back to - // dual_heap_greedy_search since they lack direct pointer access. - if constexpr (std::is_same_v) { - const uint32_t topk_v = static_cast(ctx->topk()); - const uint32_t ef_v = ctx->ef(); - const bool avx2_ok = - zvec::ailego::internal::CpuFeatures::static_flags_.AVX2; - auto &topk_heap = ctx->topk_heap(); - - auto &visit = ctx->visit_filter(); - - if (avx2_ok) { - auto &bpool = ctx->block_pool(); - fast_greedy_search(entity, bpool, visit, dc, topk_v, ef_v, entry_point, - prefetch_lines, ctx->po()); - copy_pool_to_topk(bpool, topk_heap); - } else { - auto &lpool = ctx->pool(); - fast_greedy_search(entity, lpool, visit, dc, topk_v, ef_v, entry_point, - prefetch_lines, ctx->po()); - copy_pool_to_topk(lpool, topk_heap); - } - } else { - // BufferPool entities: fallback to dual-heap path. - auto filter = [](node_id_t) { return false; }; - dual_heap_greedy_search(entity, ctx, dc, - entry_point, filter); - } + auto filter = [](node_id_t) { return false; }; + dual_heap_greedy_search(entity, ctx, dc, + entry_point, filter); } } diff --git a/src/core/algorithm/vamana/vamana_context.h b/src/core/algorithm/vamana/vamana_context.h index a57635357..42357cc59 100644 --- a/src/core/algorithm/vamana/vamana_context.h +++ b/src/core/algorithm/vamana/vamana_context.h @@ -339,7 +339,7 @@ class VamanaContext : public IndexContext { //! Cached build-time distance offset (see build_distance_offset()). float build_distance_offset_{0.0f}; - VisitFilter::Mode filter_mode_{VisitFilter::ByteMap}; + VisitFilter::Mode filter_mode_{VisitFilter::BitMap}; float filter_negative_prob_{VamanaEntity::kDefaultBFNegativeProbability}; LinearPool pool_; diff --git a/src/core/utility/visit_filter.h b/src/core/utility/visit_filter.h index 959a0b450..68366db27 100644 --- a/src/core/utility/visit_filter.h +++ b/src/core/utility/visit_filter.h @@ -358,6 +358,13 @@ class VisitByteMap { // visit list will be called with high frequency, // so using switch instead of std::function or virtual class // funtion point, lambda, virtual class all cannot be inlined +class VisitFilter; + +// Dispatch once at the call site; `fn` must provide +// template void operator()(VisitImpl::Context*) const +template +void dispatch_visit_filter(const VisitFilter &visit_filter, Fn &&fn); + class VisitFilter { public: enum Mode { @@ -415,8 +422,10 @@ class VisitFilter { return mode_; } - private: + template + friend void dispatch_visit_filter(const VisitFilter &visit_filter, Fn &&fn); + VisitFilter(const VisitFilter &) = delete; VisitFilter &operator=(const VisitFilter &) = delete; @@ -424,5 +433,28 @@ class VisitFilter { void *ctx_{nullptr}; }; +template +void dispatch_visit_filter(const VisitFilter &visit_filter, Fn &&fn) { + if (visit_filter.ctx_ == nullptr) { + return; + } + switch (visit_filter.mode_) { + case VisitBloomFilter::mode: + std::forward(fn).template operator()( + static_cast(visit_filter.ctx_)); + break; + case VisitBitMap::mode: + std::forward(fn).template operator()( + static_cast(visit_filter.ctx_)); + break; + case VisitByteMap::mode: + std::forward(fn).template operator()( + static_cast(visit_filter.ctx_)); + break; + default: + break; + } +} + } // namespace core } // namespace zvec diff --git a/src/turbo/avx512_vnni/record_quantized_int8/common.h b/src/turbo/avx512_vnni/record_quantized_int8/common.h index e72c8f014..1fdde05c3 100644 --- a/src/turbo/avx512_vnni/record_quantized_int8/common.h +++ b/src/turbo/avx512_vnni/record_quantized_int8/common.h @@ -288,7 +288,7 @@ static ailego_force_inline void ip_int8_batch_avx512_vnni( const void *const *vectors, const void *query, size_t n, size_t dim, float *distances) { static constexpr size_t batch_size = 2; - static constexpr size_t prefetch_step = 2; + const size_t prefetch_step = dim > 256 ? 2 : 4; size_t i = 0; for (; i + batch_size <= n; i += batch_size) { std::array prefetch_ptrs; diff --git a/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc index 16590af10..8678263aa 100644 --- a/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc +++ b/src/turbo/avx512_vnni/record_quantized_int8/squared_euclidean.cc @@ -82,8 +82,8 @@ void squared_euclidean_int8_batch_distance(const void *const *vectors, if (original_dim <= 0) { return; } - static constexpr size_t batch_size = 12; - static constexpr size_t prefetch_step = 2; + static constexpr size_t batch_size = 2; + const size_t prefetch_step = original_dim > 256 ? 2 : 4; size_t i = 0; float *dist_ptr = distances; const int8_t *const *data_ptrs_ptr = diff --git a/src/turbo/avx512_vnni/uniform_int8/squared_euclidean.cc b/src/turbo/avx512_vnni/uniform_int8/squared_euclidean.cc index 1d6c0a0f4..f1d786b18 100644 --- a/src/turbo/avx512_vnni/uniform_int8/squared_euclidean.cc +++ b/src/turbo/avx512_vnni/uniform_int8/squared_euclidean.cc @@ -180,8 +180,8 @@ void uniform_squared_euclidean_int8_batch_distance(const void *const *vectors, const void *query, size_t n, size_t dim, float *distances) { - static constexpr size_t batch_size = 4; - static constexpr size_t prefetch_step = 2; + static constexpr size_t batch_size = 2; + const size_t prefetch_step = dim > 256 ? 2 : 4; size_t i = 0; for (; i + batch_size <= n; i += batch_size) { From 70a0f9dc65ebc29239932745847bb33e6d47c03e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E7=AE=80?= Date: Wed, 17 Jun 2026 14:51:20 +0800 Subject: [PATCH 2/2] lint --- src/core/utility/visit_filter.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/utility/visit_filter.h b/src/core/utility/visit_filter.h index 68366db27..a17b24740 100644 --- a/src/core/utility/visit_filter.h +++ b/src/core/utility/visit_filter.h @@ -44,7 +44,7 @@ class VisitBloomFilter { static constexpr int N = 5; struct Context { Context() - : mt(std::chrono::system_clock::now().time_since_epoch().count()) {}; + : mt(std::chrono::system_clock::now().time_since_epoch().count()) {} VisitFilterHeader h; std::mt19937 mt; ailego::BloomFilter *filter{nullptr}; @@ -374,7 +374,7 @@ class VisitFilter { ByteMap = VisitByteMap::mode }; - VisitFilter() : mode_(0), ctx_(nullptr) {}; + VisitFilter() : mode_(0), ctx_(nullptr) {} inline bool visited(id_t idx) { PROXIMA_HNSW_VISITFILTER_CALL_IMPL(visited, idx);