From 9fd71ef43dead5378f6fbcbe60e8adf6a50633d4 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 18 Mar 2026 14:09:31 -0700 Subject: [PATCH 1/6] add CAGRA ANN to replace brute force kmeans for prediction + test cases to find inflection point --- cpp/bench/ann/CMakeLists.txt | 27 ++ .../src/cuvs/cuvs_cluster_assignment_bench.cu | 278 ++++++++++++++++++ .../cluster/detail/kmeans_predict_cagra.cuh | 113 +++++++ cpp/src/cluster/kmeans_balanced.cuh | 40 +++ cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 75 ++++- 5 files changed, 519 insertions(+), 14 deletions(-) create mode 100644 cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu create mode 100644 cpp/src/cluster/detail/kmeans_predict_cagra.cuh diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 4e4527267c..3823c6b0a7 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -255,6 +255,33 @@ if(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE) ) endif() +# Cluster assignment benchmark: brute force vs CAGRA for assigning vectors to clusters (IVF training) +if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) + add_executable(CUVS_CLUSTER_ASSIGNMENT_BENCH src/cuvs/cuvs_cluster_assignment_bench.cu) + target_link_libraries( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PRIVATE cuvs + benchmark::benchmark + $<$:CUDA::nvtx3> + $ + ) + target_include_directories( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PUBLIC "$" + "$" + PRIVATE "$" + ) + set_target_properties( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PROPERTIES CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + install(TARGETS CUVS_CLUSTER_ASSIGNMENT_BENCH COMPONENT ann_bench DESTINATION bin/ann) + add_dependencies(CUVS_ANN_BENCH_ALL CUVS_CLUSTER_ASSIGNMENT_BENCH) +endif() + if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) ConfigureAnnBench( NAME diff --git a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu new file mode 100644 index 0000000000..253edc4c24 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu @@ -0,0 +1,278 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * Benchmark: brute force vs CAGRA-based cluster assignment for IVF training. + * Compares time to assign N vectors to K clusters (nearest centroid) using + * (1) brute force 1-NN and (2) CAGRA build on centroids + k=1 search. + */ +#include + +// kmeans_balanced.cuh lives in src/cluster/, not in public include/ +#include "cluster/kmeans_balanced.cuh" +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace cuvs::cluster::kmeans_balanced; + +void init_random_data(raft::resources const& handle, + float* X, + int64_t n_rows, + int64_t dim, + float* centroids, + int64_t n_clusters) +{ + raft::random::RngState rng(12345ULL); + raft::random::uniform(handle, rng, X, n_rows * dim, float(-1), float(1)); + raft::random::uniform(handle, rng, centroids, n_clusters * dim, float(-1), float(1)); + raft::resource::sync_stream(handle); +} + +} // namespace + +static void BM_ClusterAssignment_BruteForce(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + int64_t n_clusters = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector X(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector centroids(static_cast(n_clusters) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector labels(static_cast(n_rows), + raft::resource::get_cuda_stream(handle)); + + init_random_data(handle, X.data(), n_rows, dim, centroids.data(), n_clusters); + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + + auto X_view = raft::make_device_matrix_view(X.data(), n_rows, dim); + auto centers_view = + raft::make_device_matrix_view(centroids.data(), n_clusters, dim); + auto labels_view = raft::make_device_vector_view(labels.data(), n_rows); + + for (auto _ : state) { + predict(handle, params, X_view, centers_view, labels_view); + raft::resource::sync_stream(handle); + } + state.SetItemsProcessed(state.iterations() * n_rows); +} + +static void BM_ClusterAssignment_CAGRA(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + int64_t n_clusters = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector X(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector centroids(static_cast(n_clusters) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector labels(static_cast(n_rows), + raft::resource::get_cuda_stream(handle)); + + init_random_data(handle, X.data(), n_rows, dim, centroids.data(), n_clusters); + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + + auto X_view = raft::make_device_matrix_view(X.data(), n_rows, dim); + auto centers_view = + raft::make_device_matrix_view(centroids.data(), n_clusters, dim); + auto labels_view = raft::make_device_vector_view(labels.data(), n_rows); + + for (auto _ : state) { + predict_cagra(handle, params, X_view, centers_view, labels_view); + raft::resource::sync_stream(handle); + } + state.SetItemsProcessed(state.iterations() * n_rows); +} + +// N = vectors to assign, K = number of clusters, D = dimension +// Small: 10K vectors, 1K clusters, 128 dim +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({10000, 1000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({10000, 1000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Medium: 100K vectors, 4K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({100000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({100000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Large K: 100K vectors, 16K clusters (brute force starts to hurt) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({100000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({100000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Very large K: 500K vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({500000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({500000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Larger N: amortize CAGRA build over more queries +// 1M vectors, 4K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 5M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({5000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({5000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 5M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({5000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({5000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Hundreds of thousands of centroids (K = 100K, 200K, 500K, 1M) +// 1M vectors, 100K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 100K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 200K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 200000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 200000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 500K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 500000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 500000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 500K vectors, 1M clusters (very large K) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({500000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({500000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 1M clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +BENCHMARK_MAIN(); diff --git a/cpp/src/cluster/detail/kmeans_predict_cagra.cuh b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh new file mode 100644 index 0000000000..8d2d5879e6 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh @@ -0,0 +1,113 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * Cluster assignment via CAGRA: assign each data point to nearest centroid using an + * approximate nearest neighbor search (CAGRA) over the centroids instead of brute force. + * Used for scaling IVF training when the number of clusters K is very large. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cuvs::cluster::kmeans::detail { + +/** + * @brief Assign each row in X to the nearest centroid using CAGRA (1-NN search over centroids). + * + * Builds a CAGRA index on the centroids and runs k=1 search with X as queries. The returned + * neighbor indices are the cluster labels. This is approximate and faster than brute force + * when the number of clusters K is large. + * + * Supports the same metrics as CAGRA: L2Expanded, L2SqrtExpanded, InnerProduct, CosineExpanded. + * Centroids and (after mapping) query data must be float. + * + * @param[in] handle RAFT resources + * @param[in] params Balanced params (metric used for assignment) + * @param[in] X Data to assign [n_rows, dim] + * @param[in] centroids Cluster centers [n_clusters, dim] (device, row-major) + * @param[out] labels Output cluster index per row [n_rows] + * @param[in] mapping_op Optional mapping from DataT to float (e.g. for quantized input) + */ +template +std::enable_if_t> predict_cagra( + raft::resources const& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + using namespace cuvs::neighbors::cagra; + + RAFT_EXPECTS(X.extent(0) == labels.extent(0), "X rows and labels size must match"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "X dim and centroids dim must match"); + RAFT_EXPECTS(centroids.extent(0) >= 1, "Need at least one centroid"); + + auto stream = raft::resource::get_cuda_stream(handle); + int64_t n_rows = static_cast(X.extent(0)); + int64_t n_clusters = static_cast(centroids.extent(0)); + int64_t dim = static_cast(centroids.extent(1)); + + // CAGRA graph degree cannot exceed n_clusters - 1 + size_t graph_degree = std::min(64, std::max(1, n_clusters - 1)); + size_t inter_degree = std::min(128, std::max(1, n_clusters - 1)); + + index_params build_params; + build_params.metric = params.metric; + build_params.graph_degree = graph_degree; + build_params.intermediate_graph_degree = inter_degree; + build_params.attach_dataset_on_build = true; + + // Build CAGRA index on centroids (centroids are [n_clusters, dim]) + auto centers_view = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, dim); + auto cagra_index = cuvs::neighbors::cagra::build(handle, build_params, centers_view); + + // Queries: convert X to float if needed + rmm::device_uvector queries_buf(0, stream); + raft::device_matrix_view queries_view(nullptr, 0, 0); + if constexpr (std::is_same_v) { + queries_view = raft::make_device_matrix_view( + reinterpret_cast(X.data_handle()), n_rows, dim); + } else { + queries_buf.resize(static_cast(n_rows) * static_cast(dim), stream); + auto queries_mat = raft::make_device_matrix_view( + queries_buf.data(), n_rows, dim); + raft::linalg::map(handle, raft::make_const_mdspan(X), queries_mat, mapping_op); + queries_view = raft::make_device_matrix_view( + queries_buf.data(), n_rows, dim); + } + + // Search k=1 + search_params search_params; + search_params.max_queries = 0; // auto + + auto neighbors = raft::make_device_matrix(handle, n_rows, 1); + auto distances = raft::make_device_matrix(handle, n_rows, 1); + + cuvs::neighbors::cagra::search( + handle, search_params, cagra_index, queries_view, neighbors.view(), distances.view()); + + // Copy neighbor indices (column 0) to labels with cast to LabelT + auto neighbors_col = raft::make_device_vector_view( + neighbors.data_handle(), n_rows); + auto labels_view = raft::make_device_vector_view(labels.data_handle(), n_rows); + raft::linalg::map( + handle, raft::make_const_mdspan(neighbors_col), labels_view, raft::cast_op()); +} + +} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 0c0df03397..d36890e8cc 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -7,6 +7,7 @@ #include "../neighbors/detail/ann_utils.cuh" #include "detail/kmeans_balanced.cuh" +#include "detail/kmeans_predict_cagra.cuh" #include #include #include @@ -157,6 +158,45 @@ void predict(const raft::resources& handle, raft::resource::get_workspace_resource(handle)); } +/** + * @brief Assign each sample to nearest centroid using CAGRA (ANN-based, approximate). + * + * Same contract as predict() but builds a CAGRA index on centroids and runs k=1 search. + * Faster than brute force when the number of clusters K is large; results are approximate. + * Only supported when centroids are float and metric is L2Expanded, L2SqrtExpanded, + * InnerProduct, or CosineExpanded. + * + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters (metric) + * @param[in] X Dataset for which to infer the closest clusters + * @param[in] centroids The input centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to float + */ +template +void predict_cagra(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) <= + static_cast(std::numeric_limits::max()), + "The chosen label type cannot represent all cluster labels"); + + cuvs::cluster::kmeans::detail::predict_cagra( + handle, params, X, centroids, labels, mapping_op); +} + namespace helpers { /** diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index b2da2bb821..521b8c112e 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -22,6 +22,10 @@ // TODO (cjnolet): This should be using an exposed API instead of circumventing the public APIs. #include "../../cluster/kmeans_balanced.cuh" #include +#include + +// Use CAGRA for cluster assignment in extend when n_lists >= this (faster for large K). +constexpr uint32_t kUseAnnForClusterAssignmentMinClusters = 4096; #include #include @@ -1111,21 +1115,64 @@ void extend(raft::resources const& handle, cudaMemcpyDefault, stream)); vec_batches.prefetch_next_batch(); - for (const auto& batch : vec_batches) { - auto batch_data_view = raft::make_device_matrix_view( - batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( + + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.metric = index->metric(); + + if (n_clusters >= kUseAnnForClusterAssignmentMinClusters) { + // Use CAGRA for cluster assignment when K is large (build once, search per batch). + auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - cuvs::cluster::kmeans::predict( - handle, kmeans_params, batch_data_view, centers_view, batch_labels_view); - vec_batches.prefetch_next_batch(); - // User needs to make sure kernel finishes its work before we overwrite batch in the next - // iteration if different streams are used for kernel and copy. - raft::resource::sync_stream(handle); + cuvs::neighbors::cagra::index_params cagra_params; + cagra_params.metric = index->metric(); + cagra_params.graph_degree = + std::min(64, std::max(1, static_cast(n_clusters) - 1)); + cagra_params.intermediate_graph_degree = + std::min(128, std::max(1, static_cast(n_clusters) - 1)); + cagra_params.attach_dataset_on_build = true; + auto cagra_idx = cuvs::neighbors::cagra::build(handle, cagra_params, centers_view); + cuvs::neighbors::cagra::search_params search_params; + + for (const auto& batch : vec_batches) { + auto batch_size = batch.size(); + rmm::device_uvector queries_float( + static_cast(batch_size) * static_cast(index->dim()), stream, + device_memory); + auto batch_view = raft::make_device_matrix_view( + batch.data(), batch_size, index->dim()); + raft::linalg::map(handle, + raft::make_const_mdspan(batch_view), + raft::make_device_matrix_view(queries_float.data(), + batch_size, + index->dim()), + utils::mapping{}); + auto queries_view = raft::make_device_matrix_view( + queries_float.data(), batch_size, index->dim()); + auto neighbors = raft::make_device_matrix(handle, batch_size, 1); + auto distances = raft::make_device_matrix(handle, batch_size, 1); + cuvs::neighbors::cagra::search( + handle, search_params, cagra_idx, queries_view, neighbors.view(), distances.view()); + raft::copy(handle, + raft::make_device_vector_view(new_data_labels.data() + batch.offset(), + batch_size), + raft::make_device_vector_view(neighbors.data_handle(), + batch_size)); + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(handle); + } + } else { + for (const auto& batch : vec_batches) { + auto batch_data_view = raft::make_device_matrix_view( + batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_data_labels.data() + batch.offset(), batch.size()); + auto centers_view = raft::make_device_matrix_view( + cluster_centers.data(), n_clusters, index->dim()); + cuvs::cluster::kmeans::predict( + handle, kmeans_params, batch_data_view, centers_view, batch_labels_view); + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(handle); + } } } From 80070c4f920f3781219712867713b4c3fc66e64a Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 18 Mar 2026 15:00:03 -0700 Subject: [PATCH 2/6] fix test cases + fix threshold --- .../ann/src/cuvs/cuvs_cluster_assignment_bench.cu | 15 ++++++++------- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu index 253edc4c24..6b5cc4312a 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu @@ -245,7 +245,7 @@ BENCHMARK(BM_ClusterAssignment_CAGRA) ->Unit(benchmark::kMillisecond) ->UseRealTime(); -// 1M vectors, 500K clusters +// 1M vectors, 500K clusters (~2 vectors per cluster) BENCHMARK(BM_ClusterAssignment_BruteForce) ->Args({1000000, 500000, 128}) ->Unit(benchmark::kMillisecond) @@ -255,23 +255,24 @@ BENCHMARK(BM_ClusterAssignment_CAGRA) ->Unit(benchmark::kMillisecond) ->UseRealTime(); -// 500K vectors, 1M clusters (very large K) +// 1M clusters with N > K (realistic: many vectors per cluster) +// 2M vectors, 1M clusters (~2 per cluster) BENCHMARK(BM_ClusterAssignment_BruteForce) - ->Args({500000, 1000000, 128}) + ->Args({2000000, 1000000, 128}) ->Unit(benchmark::kMillisecond) ->UseRealTime(); BENCHMARK(BM_ClusterAssignment_CAGRA) - ->Args({500000, 1000000, 128}) + ->Args({2000000, 1000000, 128}) ->Unit(benchmark::kMillisecond) ->UseRealTime(); -// 1M vectors, 1M clusters +// 5M vectors, 1M clusters (~5 per cluster) BENCHMARK(BM_ClusterAssignment_BruteForce) - ->Args({1000000, 1000000, 128}) + ->Args({5000000, 1000000, 128}) ->Unit(benchmark::kMillisecond) ->UseRealTime(); BENCHMARK(BM_ClusterAssignment_CAGRA) - ->Args({1000000, 1000000, 128}) + ->Args({5000000, 1000000, 128}) ->Unit(benchmark::kMillisecond) ->UseRealTime(); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 521b8c112e..b910970025 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -25,7 +25,8 @@ #include // Use CAGRA for cluster assignment in extend when n_lists >= this (faster for large K). -constexpr uint32_t kUseAnnForClusterAssignmentMinClusters = 4096; +// Set from cluster-assignment benchmark: CAGRA wins over brute force for K >= ~200K (e.g. N=1M). +constexpr uint32_t kUseAnnForClusterAssignmentMinClusters = 200000; #include #include From fa6fd704dc82437dbc4c09548592795f3ebc25db Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 18 Mar 2026 16:55:10 -0700 Subject: [PATCH 3/6] add test cases for entire IVF-PQ pipeline --- cpp/bench/ann/CMakeLists.txt | 21 +++ .../ann/src/cuvs/cuvs_ivf_pq_build_bench.cu | 149 ++++++++++++++++++ cpp/include/cuvs/neighbors/ivf_pq.hpp | 13 ++ cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 13 +- cpp/src/neighbors/ivf_pq_impl.hpp | 11 +- cpp/src/neighbors/ivf_pq_index.cu | 28 +++- 6 files changed, 223 insertions(+), 12 deletions(-) create mode 100644 cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 3823c6b0a7..b74679507e 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -280,6 +280,27 @@ if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) ) install(TARGETS CUVS_CLUSTER_ASSIGNMENT_BENCH COMPONENT ann_bench DESTINATION bin/ann) add_dependencies(CUVS_ANN_BENCH_ALL CUVS_CLUSTER_ASSIGNMENT_BENCH) + + # Full IVF-PQ build benchmark: validates extend() uses CAGRA when n_lists >= 200k + add_executable(CUVS_IVFPQ_BUILD_BENCH src/cuvs/cuvs_ivf_pq_build_bench.cu) + target_link_libraries( + CUVS_IVFPQ_BUILD_BENCH + PRIVATE cuvs benchmark::benchmark + ) + target_include_directories( + CUVS_IVFPQ_BUILD_BENCH + PUBLIC "$" + "$" + ) + set_target_properties( + CUVS_IVFPQ_BUILD_BENCH + PROPERTIES CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + install(TARGETS CUVS_IVFPQ_BUILD_BENCH COMPONENT ann_bench DESTINATION bin/ann) + add_dependencies(CUVS_ANN_BENCH_ALL CUVS_IVFPQ_BUILD_BENCH) endif() if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu new file mode 100644 index 0000000000..ed57a9dae2 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu @@ -0,0 +1,149 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * Benchmark: full IVF-PQ build path and compare brute-force vs CAGRA cluster assignment. + * For each scenario we run one benchmark that does both a brute and a CAGRA full build, + * reports brute_ms and cagra_ms, and speedup = brute_ms / cagra_ms (>1 means CAGRA faster). + * + * n_lists = number of cluster centroids. We use at least 5 vectors per cluster + * (n_vectors >= 5 * n_lists). "Time" = wall time for one iteration (brute + CAGRA build). + * + * Full build includes kmeans, PQ codebook training, and assignment. Threshold 200K in + * ivf_pq_build.cuh was chosen from assignment-only benchmarks. + */ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +void init_random_dataset(raft::resources const& handle, + float* data, + int64_t n_rows, + int64_t dim) +{ + raft::random::RngState rng(12345ULL); + raft::random::uniform(handle, rng, data, n_rows * dim, float(-1), float(1)); + raft::resource::sync_stream(handle); +} + +} // namespace + +static void BM_IVFPQ_Build_Speedup(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + uint32_t n_lists = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector dataset(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + init_random_dataset(handle, dataset.data(), n_rows, dim); + + cuvs::neighbors::ivf_pq::index_params params_brute; + params_brute.n_lists = n_lists; + params_brute.kmeans_n_iters = 3; + params_brute.kmeans_trainset_fraction = 0.2; + params_brute.add_data_on_build = true; + params_brute.metric = cuvs::distance::DistanceType::L2Expanded; + params_brute.use_ann_for_cluster_assignment = false; + + cuvs::neighbors::ivf_pq::index_params params_cagra = params_brute; + params_cagra.use_ann_for_cluster_assignment = true; + + raft::resource::set_cuda_stream_pool(handle, std::make_shared(1)); + + auto dataset_view = raft::make_device_matrix_view( + dataset.data(), n_rows, dim); + + double total_brute_ms = 0.0, total_cagra_ms = 0.0; + int64_t iterations = 0; + + for (auto _ : state) { + auto start = std::chrono::steady_clock::now(); + auto idx_brute = cuvs::neighbors::ivf_pq::build(handle, params_brute, dataset_view); + benchmark::DoNotOptimize(idx_brute.size()); + raft::resource::sync_stream(handle); + auto end = std::chrono::steady_clock::now(); + total_brute_ms += 1e-6 * std::chrono::duration(end - start).count(); + + start = std::chrono::steady_clock::now(); + auto idx_cagra = cuvs::neighbors::ivf_pq::build(handle, params_cagra, dataset_view); + benchmark::DoNotOptimize(idx_cagra.size()); + raft::resource::sync_stream(handle); + end = std::chrono::steady_clock::now(); + total_cagra_ms += 1e-6 * std::chrono::duration(end - start).count(); + + ++iterations; + } + + if (total_cagra_ms > 0) { + state.counters["speedup"] = total_brute_ms / total_cagra_ms; + } + state.counters["brute_ms"] = benchmark::Counter(total_brute_ms, + benchmark::Counter::kAvgIterations); + state.counters["cagra_ms"] = benchmark::Counter(total_cagra_ms, + benchmark::Counter::kAvgIterations); +} + +constexpr int64_t kDim = 128; + +// At least 5 vectors per cluster (n_vectors = 5 * n_lists). One row per config: brute_ms, cagra_ms, speedup. + +// 1. 64K centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({327680, 65536, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// 2. 200K centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({1000000, 200000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// 3. 400K centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({2000000, 400000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// 4. 600K centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({3000000, 600000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// 5. 800K centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({4000000, 800000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// 6. 1M centroids, 5 vecs/cluster +BENCHMARK(BM_IVFPQ_Build_Speedup) + ->Args({5000000, 1000000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +BENCHMARK_MAIN(); diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 710a08cd0c..e40f73ec0f 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -128,6 +128,14 @@ struct index_params : cuvs::neighbors::index_params { */ uint32_t max_train_points_per_pq_code = 256; + /** + * Override for cluster assignment during extend (and add_data_on_build). + * - std::nullopt (default): use heuristic (CAGRA when n_lists >= 200k, else brute force). + * - true: use CAGRA for assignment regardless of n_lists (for benchmarking). + * - false: use brute-force assignment regardless of n_lists (for benchmarking). + */ + std::optional use_ann_for_cluster_assignment; + /** * Creates index_params based on shape of the input dataset. * Usage example: @@ -418,6 +426,9 @@ class index_iface { const raft::resources& res) const = 0; virtual raft::device_matrix_view centers_half( const raft::resources& res) const = 0; + + /** When set, overrides heuristic for using CAGRA vs brute force in extend (cluster assignment). */ + virtual std::optional use_ann_for_cluster_assignment() const = 0; }; /** @@ -653,6 +664,8 @@ class index : public index_iface, cuvs::neighbors::index { */ uint32_t get_list_size_in_bytes(uint32_t label) const override; + std::optional use_ann_for_cluster_assignment() const noexcept override; + /** * @brief Construct index from implementation pointer. * diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index b910970025..d9d94399f2 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -963,7 +963,8 @@ auto clone(const raft::resources& res, const index& source) -> index source.pq_bits(), source.pq_dim(), source.conservative_memory_allocation(), - source.codes_layout()); + source.codes_layout(), + source.use_ann_for_cluster_assignment()); // Copy the independent parts using mutable accessors raft::copy(res, impl->list_sizes(), source.list_sizes()); @@ -1120,7 +1121,9 @@ void extend(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); - if (n_clusters >= kUseAnnForClusterAssignmentMinClusters) { + bool use_cagra = index->use_ann_for_cluster_assignment().value_or( + n_clusters >= kUseAnnForClusterAssignmentMinClusters); + if (use_cagra) { // Use CAGRA for cluster assignment when K is large (build once, search per batch). auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); @@ -1300,7 +1303,8 @@ auto build(raft::resources const& handle, params.pq_bits, params.pq_dim == 0 ? index::calculate_pq_dim(dim) : params.pq_dim, params.conservative_memory_allocation, - params.codes_layout); + params.codes_layout, + params.use_ann_for_cluster_assignment); auto stream = raft::resource::get_cuda_stream(handle); utils::memzero( @@ -1608,7 +1612,8 @@ auto build( index_params.pq_bits, pq_dim, index_params.conservative_memory_allocation, - index_params.codes_layout); + index_params.codes_layout, + index_params.use_ann_for_cluster_assignment); utils::memzero( impl->accum_sorted_sizes().data_handle(), impl->accum_sorted_sizes().size(), stream); diff --git a/cpp/src/neighbors/ivf_pq_impl.hpp b/cpp/src/neighbors/ivf_pq_impl.hpp index 5c96755808..18b86894e6 100644 --- a/cpp/src/neighbors/ivf_pq_impl.hpp +++ b/cpp/src/neighbors/ivf_pq_impl.hpp @@ -20,7 +20,8 @@ class index_impl : public index_iface { uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_cluster_assignment = std::nullopt); ~index_impl() = default; index_impl(index_impl&&) = default; @@ -71,9 +72,11 @@ class index_impl : public index_iface { const raft::resources& res) const override; uint32_t get_list_size_in_bytes(uint32_t label) const override; + std::optional use_ann_for_cluster_assignment() const noexcept override; protected: cuvs::distance::DistanceType metric_; + std::optional use_ann_for_cluster_assignment_; codebook_gen codebook_kind_; list_layout codes_layout_; uint32_t dim_; @@ -113,7 +116,8 @@ class owning_impl : public index_impl { uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_cluster_assignment = std::nullopt); ~owning_impl() = default; owning_impl(owning_impl&&) = default; @@ -159,7 +163,8 @@ class view_impl : public index_impl { raft::device_matrix_view centers_view, raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_cluster_assignment = std::nullopt); ~view_impl() = default; view_impl(view_impl&&) = default; diff --git a/cpp/src/neighbors/ivf_pq_index.cu b/cpp/src/neighbors/ivf_pq_index.cu index fe230cd8aa..e05f43526d 100644 --- a/cpp/src/neighbors/ivf_pq_index.cu +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -25,7 +25,8 @@ index_impl::index_impl(raft::resources const& handle, uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_cluster_assignment) : metric_(metric), codebook_kind_(codebook_kind), codes_layout_(codes_layout), @@ -33,6 +34,7 @@ index_impl::index_impl(raft::resources const& handle, pq_bits_(pq_bits), pq_dim_(pq_dim == 0 ? index::calculate_pq_dim(dim) : pq_dim), conservative_memory_allocation_(conservative_memory_allocation), + use_ann_for_cluster_assignment_(use_ann_for_cluster_assignment), lists_(n_lists), list_sizes_{raft::make_device_vector(handle, n_lists)}, data_ptrs_{raft::make_device_vector(handle, n_lists)}, @@ -121,6 +123,12 @@ bool index_impl::conservative_memory_allocation() const noexcept return conservative_memory_allocation_; } +template +std::optional index_impl::use_ann_for_cluster_assignment() const noexcept +{ + return use_ann_for_cluster_assignment_; +} + template std::vector>>& index_impl::lists() noexcept { @@ -197,7 +205,8 @@ owning_impl::owning_impl(raft::resources const& handle, uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_cluster_assignment) : index_impl(handle, metric, codebook_kind, @@ -206,7 +215,8 @@ owning_impl::owning_impl(raft::resources const& handle, pq_bits, pq_dim, conservative_memory_allocation, - codes_layout), + codes_layout, + use_ann_for_cluster_assignment), pq_centers_{raft::make_device_mdarray( handle, index::make_pq_centers_extents(dim, pq_dim, pq_bits, codebook_kind, n_lists))}, centers_{ @@ -247,7 +257,8 @@ view_impl::view_impl( raft::device_matrix_view centers_view, raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_cluster_assignment) : index_impl(handle, metric, codebook_kind, @@ -256,7 +267,8 @@ view_impl::view_impl( pq_bits, pq_dim, conservative_memory_allocation, - codes_layout), + codes_layout, + use_ann_for_cluster_assignment), pq_centers_view_(pq_centers_view), centers_view_(centers_view), centers_rot_view_(centers_rot_view), @@ -594,6 +606,12 @@ uint32_t index::get_list_size_in_bytes(uint32_t label) const return impl_->get_list_size_in_bytes(label); } +template +std::optional index::use_ann_for_cluster_assignment() const noexcept +{ + return impl_->use_ann_for_cluster_assignment(); +} + template void index_impl::check_consistency() { From d8543a6c76b77433a0df242f20f93204eccf6bc2 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 18 Mar 2026 17:27:03 -0700 Subject: [PATCH 4/6] rename variable to use_ann_for_extend since previous changes are only applied to accelerate extend() --- .../ann/src/cuvs/cuvs_ivf_pq_build_bench.cu | 4 ++-- cpp/include/cuvs/neighbors/ivf_pq.hpp | 8 ++++---- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 8 ++++---- cpp/src/neighbors/ivf_pq_impl.hpp | 10 +++++----- cpp/src/neighbors/ivf_pq_index.cu | 20 +++++++++---------- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu index ed57a9dae2..0eb23e32b5 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu @@ -60,10 +60,10 @@ static void BM_IVFPQ_Build_Speedup(benchmark::State& state) params_brute.kmeans_trainset_fraction = 0.2; params_brute.add_data_on_build = true; params_brute.metric = cuvs::distance::DistanceType::L2Expanded; - params_brute.use_ann_for_cluster_assignment = false; + params_brute.use_ann_for_extend = false; cuvs::neighbors::ivf_pq::index_params params_cagra = params_brute; - params_cagra.use_ann_for_cluster_assignment = true; + params_cagra.use_ann_for_extend = true; raft::resource::set_cuda_stream_pool(handle, std::make_shared(1)); diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index e40f73ec0f..16d5839d3b 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -129,12 +129,12 @@ struct index_params : cuvs::neighbors::index_params { uint32_t max_train_points_per_pq_code = 256; /** - * Override for cluster assignment during extend (and add_data_on_build). + * Override for cluster assignment during extend() (and add_data_on_build). * - std::nullopt (default): use heuristic (CAGRA when n_lists >= 200k, else brute force). * - true: use CAGRA for assignment regardless of n_lists (for benchmarking). * - false: use brute-force assignment regardless of n_lists (for benchmarking). */ - std::optional use_ann_for_cluster_assignment; + std::optional use_ann_for_extend; /** * Creates index_params based on shape of the input dataset. @@ -428,7 +428,7 @@ class index_iface { const raft::resources& res) const = 0; /** When set, overrides heuristic for using CAGRA vs brute force in extend (cluster assignment). */ - virtual std::optional use_ann_for_cluster_assignment() const = 0; + virtual std::optional use_ann_for_extend() const = 0; }; /** @@ -664,7 +664,7 @@ class index : public index_iface, cuvs::neighbors::index { */ uint32_t get_list_size_in_bytes(uint32_t label) const override; - std::optional use_ann_for_cluster_assignment() const noexcept override; + std::optional use_ann_for_extend() const noexcept override; /** * @brief Construct index from implementation pointer. diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index d9d94399f2..bd49868b81 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -964,7 +964,7 @@ auto clone(const raft::resources& res, const index& source) -> index source.pq_dim(), source.conservative_memory_allocation(), source.codes_layout(), - source.use_ann_for_cluster_assignment()); + source.use_ann_for_extend()); // Copy the independent parts using mutable accessors raft::copy(res, impl->list_sizes(), source.list_sizes()); @@ -1121,7 +1121,7 @@ void extend(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); - bool use_cagra = index->use_ann_for_cluster_assignment().value_or( + bool use_cagra = index->use_ann_for_extend().value_or( n_clusters >= kUseAnnForClusterAssignmentMinClusters); if (use_cagra) { // Use CAGRA for cluster assignment when K is large (build once, search per batch). @@ -1304,7 +1304,7 @@ auto build(raft::resources const& handle, params.pq_dim == 0 ? index::calculate_pq_dim(dim) : params.pq_dim, params.conservative_memory_allocation, params.codes_layout, - params.use_ann_for_cluster_assignment); + params.use_ann_for_extend); auto stream = raft::resource::get_cuda_stream(handle); utils::memzero( @@ -1613,7 +1613,7 @@ auto build( pq_dim, index_params.conservative_memory_allocation, index_params.codes_layout, - index_params.use_ann_for_cluster_assignment); + index_params.use_ann_for_extend); utils::memzero( impl->accum_sorted_sizes().data_handle(), impl->accum_sorted_sizes().size(), stream); diff --git a/cpp/src/neighbors/ivf_pq_impl.hpp b/cpp/src/neighbors/ivf_pq_impl.hpp index 18b86894e6..8c44442be0 100644 --- a/cpp/src/neighbors/ivf_pq_impl.hpp +++ b/cpp/src/neighbors/ivf_pq_impl.hpp @@ -21,7 +21,7 @@ class index_impl : public index_iface { uint32_t pq_dim, bool conservative_memory_allocation, list_layout codes_layout = list_layout::INTERLEAVED, - std::optional use_ann_for_cluster_assignment = std::nullopt); + std::optional use_ann_for_extend = std::nullopt); ~index_impl() = default; index_impl(index_impl&&) = default; @@ -72,11 +72,11 @@ class index_impl : public index_iface { const raft::resources& res) const override; uint32_t get_list_size_in_bytes(uint32_t label) const override; - std::optional use_ann_for_cluster_assignment() const noexcept override; + std::optional use_ann_for_extend() const noexcept override; protected: cuvs::distance::DistanceType metric_; - std::optional use_ann_for_cluster_assignment_; + std::optional use_ann_for_extend_; codebook_gen codebook_kind_; list_layout codes_layout_; uint32_t dim_; @@ -117,7 +117,7 @@ class owning_impl : public index_impl { uint32_t pq_dim, bool conservative_memory_allocation, list_layout codes_layout = list_layout::INTERLEAVED, - std::optional use_ann_for_cluster_assignment = std::nullopt); + std::optional use_ann_for_extend = std::nullopt); ~owning_impl() = default; owning_impl(owning_impl&&) = default; @@ -164,7 +164,7 @@ class view_impl : public index_impl { raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, list_layout codes_layout = list_layout::INTERLEAVED, - std::optional use_ann_for_cluster_assignment = std::nullopt); + std::optional use_ann_for_extend = std::nullopt); ~view_impl() = default; view_impl(view_impl&&) = default; diff --git a/cpp/src/neighbors/ivf_pq_index.cu b/cpp/src/neighbors/ivf_pq_index.cu index e05f43526d..6939f6e3b2 100644 --- a/cpp/src/neighbors/ivf_pq_index.cu +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -26,7 +26,7 @@ index_impl::index_impl(raft::resources const& handle, uint32_t pq_dim, bool conservative_memory_allocation, list_layout codes_layout, - std::optional use_ann_for_cluster_assignment) + std::optional use_ann_for_extend) : metric_(metric), codebook_kind_(codebook_kind), codes_layout_(codes_layout), @@ -34,7 +34,7 @@ index_impl::index_impl(raft::resources const& handle, pq_bits_(pq_bits), pq_dim_(pq_dim == 0 ? index::calculate_pq_dim(dim) : pq_dim), conservative_memory_allocation_(conservative_memory_allocation), - use_ann_for_cluster_assignment_(use_ann_for_cluster_assignment), + use_ann_for_extend_(use_ann_for_extend), lists_(n_lists), list_sizes_{raft::make_device_vector(handle, n_lists)}, data_ptrs_{raft::make_device_vector(handle, n_lists)}, @@ -124,9 +124,9 @@ bool index_impl::conservative_memory_allocation() const noexcept } template -std::optional index_impl::use_ann_for_cluster_assignment() const noexcept +std::optional index_impl::use_ann_for_extend() const noexcept { - return use_ann_for_cluster_assignment_; + return use_ann_for_extend_; } template @@ -206,7 +206,7 @@ owning_impl::owning_impl(raft::resources const& handle, uint32_t pq_dim, bool conservative_memory_allocation, list_layout codes_layout, - std::optional use_ann_for_cluster_assignment) + std::optional use_ann_for_extend) : index_impl(handle, metric, codebook_kind, @@ -216,7 +216,7 @@ owning_impl::owning_impl(raft::resources const& handle, pq_dim, conservative_memory_allocation, codes_layout, - use_ann_for_cluster_assignment), + use_ann_for_extend), pq_centers_{raft::make_device_mdarray( handle, index::make_pq_centers_extents(dim, pq_dim, pq_bits, codebook_kind, n_lists))}, centers_{ @@ -258,7 +258,7 @@ view_impl::view_impl( raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, list_layout codes_layout, - std::optional use_ann_for_cluster_assignment) + std::optional use_ann_for_extend) : index_impl(handle, metric, codebook_kind, @@ -268,7 +268,7 @@ view_impl::view_impl( pq_dim, conservative_memory_allocation, codes_layout, - use_ann_for_cluster_assignment), + use_ann_for_extend), pq_centers_view_(pq_centers_view), centers_view_(centers_view), centers_rot_view_(centers_rot_view), @@ -607,9 +607,9 @@ uint32_t index::get_list_size_in_bytes(uint32_t label) const } template -std::optional index::use_ann_for_cluster_assignment() const noexcept +std::optional index::use_ann_for_extend() const noexcept { - return impl_->use_ann_for_cluster_assignment(); + return impl_->use_ann_for_extend(); } template From d26e91658ca2136adb320c03bdc183bcb868a34d Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 19 Mar 2026 19:18:32 -0700 Subject: [PATCH 5/6] Replace brute force with CAGRA ANN in build() --- cpp/bench/ann/CMakeLists.txt | 2 +- .../src/cuvs/cuvs_cluster_assignment_bench.cu | 22 +- .../ann/src/cuvs/cuvs_ivf_pq_build_bench.cu | 208 +++++++++++++----- cpp/include/cuvs/cluster/kmeans.hpp | 15 ++ cpp/include/cuvs/neighbors/ivf_pq.hpp | 18 +- cpp/src/cluster/detail/kmeans_balanced.cuh | 44 +++- .../cluster/detail/kmeans_predict_cagra.cuh | 176 +++++++++------ cpp/src/cluster/kmeans_balanced.cuh | 40 ---- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 55 +++-- 9 files changed, 367 insertions(+), 213 deletions(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index b74679507e..4208bb3247 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -281,7 +281,7 @@ if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) install(TARGETS CUVS_CLUSTER_ASSIGNMENT_BENCH COMPONENT ann_bench DESTINATION bin/ann) add_dependencies(CUVS_ANN_BENCH_ALL CUVS_CLUSTER_ASSIGNMENT_BENCH) - # Full IVF-PQ build benchmark: validates extend() uses CAGRA when n_lists >= 200k + # IVF-PQ build benchmarks: k-means fit and extend cluster-assignment (brute vs CAGRA) add_executable(CUVS_IVFPQ_BUILD_BENCH src/cuvs/cuvs_ivf_pq_build_bench.cu) target_link_libraries( CUVS_IVFPQ_BUILD_BENCH diff --git a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu index 6b5cc4312a..e4e358a4dc 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu @@ -19,8 +19,11 @@ #include #include +#include #include +#include + namespace { using namespace cuvs::cluster::kmeans_balanced; @@ -90,13 +93,22 @@ static void BM_ClusterAssignment_CAGRA(benchmark::State& state) cuvs::cluster::kmeans::balanced_params params; params.metric = cuvs::distance::DistanceType::L2Expanded; - auto X_view = raft::make_device_matrix_view(X.data(), n_rows, dim); - auto centers_view = - raft::make_device_matrix_view(centroids.data(), n_clusters, dim); - auto labels_view = raft::make_device_vector_view(labels.data(), n_rows); + // Same timing as former predict_cagra: rebuild CAGRA on centroids + 1-NN each iteration + // (rebuild=true). float X/centroids only; predict_cagra_with_index_reuse matches that path. + std::optional> cagra_index_opt; for (auto _ : state) { - predict_cagra(handle, params, X_view, centers_view, labels_view); + cuvs::cluster::kmeans::detail::predict_cagra_with_index_reuse( + handle, + params, + centroids.data(), + n_clusters, + dim, + X.data(), + n_rows, + labels.data(), + &cagra_index_opt, + true); raft::resource::sync_stream(handle); } state.SetItemsProcessed(state.iterations() * n_rows); diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu index 0eb23e32b5..1c77d7d36d 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu @@ -2,15 +2,15 @@ * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 * - * Benchmark: full IVF-PQ build path and compare brute-force vs CAGRA cluster assignment. - * For each scenario we run one benchmark that does both a brute and a CAGRA full build, - * reports brute_ms and cagra_ms, and speedup = brute_ms / cagra_ms (>1 means CAGRA faster). + * IVF-PQ benchmarks (two scenarios): * - * n_lists = number of cluster centroids. We use at least 5 vectors per cluster - * (n_vectors >= 5 * n_lists). "Time" = wall time for one iteration (brute + CAGRA build). + * 1) Full build — k-means *fit* ANN vs brute (shifting centroids). Both runs use brute cluster + * assignment for extend/add_data_on_build (`use_ann_for_extend = false`) so the comparison + * isolates the balanced k-means training path, not extend-time assignment. * - * Full build includes kmeans, PQ codebook training, and assignment. Threshold 200K in - * ivf_pq_build.cuh was chosen from assignment-only benchmarks. + * 2) Extend only — brute vs CAGRA for assigning vectors to *fixed* trained centroids. Empty + * trained indices are restored each iteration via deserialize (setup not timed). This matches + * the assumption that centroids do not move during extend, unlike k-means fit. */ #include @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include #include @@ -41,9 +43,29 @@ void init_random_dataset(raft::resources const& handle, raft::resource::sync_stream(handle); } +/** Serialize an empty trained index to a string for benchmark reset. */ +std::string serialize_index_blob(raft::resources const& handle, + cuvs::neighbors::ivf_pq::index const& index) +{ + std::ostringstream os(std::ios::binary); + cuvs::neighbors::ivf_pq::serialize(handle, os, index); + raft::resource::sync_stream(handle); + return os.str(); +} + +void deserialize_index_from_blob(raft::resources const& handle, + std::string const& blob, + cuvs::neighbors::ivf_pq::index* index) +{ + std::istringstream is(blob, std::ios::binary); + cuvs::neighbors::ivf_pq::deserialize(handle, is, index); + raft::resource::sync_stream(handle); +} + } // namespace -static void BM_IVFPQ_Build_Speedup(benchmark::State& state) +/** Full IVF-PQ build: compare brute vs ANN for balanced k-means *fit* only. */ +static void BM_IVFPQ_Build_KMeansFit_Speedup(benchmark::State& state) { int64_t n_rows = static_cast(state.range(0)); uint32_t n_lists = static_cast(state.range(1)); @@ -54,93 +76,181 @@ static void BM_IVFPQ_Build_Speedup(benchmark::State& state) raft::resource::get_cuda_stream(handle)); init_random_dataset(handle, dataset.data(), n_rows, dim); - cuvs::neighbors::ivf_pq::index_params params_brute; - params_brute.n_lists = n_lists; - params_brute.kmeans_n_iters = 3; - params_brute.kmeans_trainset_fraction = 0.2; - params_brute.add_data_on_build = true; - params_brute.metric = cuvs::distance::DistanceType::L2Expanded; - params_brute.use_ann_for_extend = false; + cuvs::neighbors::ivf_pq::index_params params_bf_fit; + params_bf_fit.n_lists = n_lists; + params_bf_fit.kmeans_n_iters = 3; + params_bf_fit.kmeans_trainset_fraction = 0.2; + params_bf_fit.add_data_on_build = true; + params_bf_fit.metric = cuvs::distance::DistanceType::L2Expanded; + params_bf_fit.use_ann_for_extend = false; + params_bf_fit.use_ann_for_fit = false; - cuvs::neighbors::ivf_pq::index_params params_cagra = params_brute; - params_cagra.use_ann_for_extend = true; + cuvs::neighbors::ivf_pq::index_params params_ann_fit = params_bf_fit; + params_ann_fit.use_ann_for_fit = true; raft::resource::set_cuda_stream_pool(handle, std::make_shared(1)); auto dataset_view = raft::make_device_matrix_view( dataset.data(), n_rows, dim); - double total_brute_ms = 0.0, total_cagra_ms = 0.0; - int64_t iterations = 0; + double total_bf_ms = 0.0, total_ann_ms = 0.0; for (auto _ : state) { auto start = std::chrono::steady_clock::now(); - auto idx_brute = cuvs::neighbors::ivf_pq::build(handle, params_brute, dataset_view); - benchmark::DoNotOptimize(idx_brute.size()); + auto idx_bf = cuvs::neighbors::ivf_pq::build(handle, params_bf_fit, dataset_view); + benchmark::DoNotOptimize(idx_bf.size()); raft::resource::sync_stream(handle); auto end = std::chrono::steady_clock::now(); - total_brute_ms += 1e-6 * std::chrono::duration(end - start).count(); + total_bf_ms += 1e-6 * std::chrono::duration(end - start).count(); start = std::chrono::steady_clock::now(); - auto idx_cagra = cuvs::neighbors::ivf_pq::build(handle, params_cagra, dataset_view); - benchmark::DoNotOptimize(idx_cagra.size()); + auto idx_ann = cuvs::neighbors::ivf_pq::build(handle, params_ann_fit, dataset_view); + benchmark::DoNotOptimize(idx_ann.size()); raft::resource::sync_stream(handle); end = std::chrono::steady_clock::now(); - total_cagra_ms += 1e-6 * std::chrono::duration(end - start).count(); + total_ann_ms += 1e-6 * std::chrono::duration(end - start).count(); + } + + if (total_ann_ms > 0) { + state.counters["speedup_fit"] = total_bf_ms / total_ann_ms; + } + state.counters["bf_fit_ms"] = + benchmark::Counter(total_bf_ms, benchmark::Counter::kAvgIterations); + state.counters["ann_fit_ms"] = + benchmark::Counter(total_ann_ms, benchmark::Counter::kAvgIterations); +} + +/** + * extend() only: brute vs CAGRA cluster assignment with fixed centroids (trained empty index). + * Deserialize is not timed. + */ +static void BM_IVFPQ_Extend_ClusterAssign_Speedup(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + uint32_t n_lists = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector dataset(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + init_random_dataset(handle, dataset.data(), n_rows, dim); + + cuvs::neighbors::ivf_pq::index_params params_common; + params_common.n_lists = n_lists; + params_common.kmeans_n_iters = 3; + params_common.kmeans_trainset_fraction = 0.2; + params_common.add_data_on_build = false; + params_common.metric = cuvs::distance::DistanceType::L2Expanded; + params_common.use_ann_for_fit = false; + + cuvs::neighbors::ivf_pq::index_params params_bf_ext = params_common; + params_bf_ext.use_ann_for_extend = false; + cuvs::neighbors::ivf_pq::index_params params_cagra_ext = params_common; + params_cagra_ext.use_ann_for_extend = true; + + raft::resource::set_cuda_stream_pool(handle, std::make_shared(1)); + + auto dataset_view = raft::make_device_matrix_view( + dataset.data(), n_rows, dim); + + auto idx_bf = cuvs::neighbors::ivf_pq::build(handle, params_bf_ext, dataset_view); + auto idx_cag = cuvs::neighbors::ivf_pq::build(handle, params_cagra_ext, dataset_view); + std::string blob_bf = serialize_index_blob(handle, idx_bf); + std::string blob_cag = serialize_index_blob(handle, idx_cag); + + double total_bf_ms = 0.0, total_cagra_ms = 0.0; + + for (auto _ : state) { + state.PauseTiming(); + cuvs::neighbors::ivf_pq::index empty_bf(handle); + cuvs::neighbors::ivf_pq::index empty_cag(handle); + deserialize_index_from_blob(handle, blob_bf, &empty_bf); + deserialize_index_from_blob(handle, blob_cag, &empty_cag); + state.ResumeTiming(); + + auto start = std::chrono::steady_clock::now(); + cuvs::neighbors::ivf_pq::extend(handle, dataset_view, std::nullopt, &empty_bf); + raft::resource::sync_stream(handle); + auto end = std::chrono::steady_clock::now(); + total_bf_ms += 1e-6 * std::chrono::duration(end - start).count(); - ++iterations; + start = std::chrono::steady_clock::now(); + cuvs::neighbors::ivf_pq::extend(handle, dataset_view, std::nullopt, &empty_cag); + raft::resource::sync_stream(handle); + end = std::chrono::steady_clock::now(); + total_cagra_ms += 1e-6 * std::chrono::duration(end - start).count(); } if (total_cagra_ms > 0) { - state.counters["speedup"] = total_brute_ms / total_cagra_ms; + state.counters["speedup_extend"] = total_bf_ms / total_cagra_ms; } - state.counters["brute_ms"] = benchmark::Counter(total_brute_ms, - benchmark::Counter::kAvgIterations); - state.counters["cagra_ms"] = benchmark::Counter(total_cagra_ms, - benchmark::Counter::kAvgIterations); + state.counters["bf_extend_ms"] = + benchmark::Counter(total_bf_ms, benchmark::Counter::kAvgIterations); + state.counters["cagra_extend_ms"] = + benchmark::Counter(total_cagra_ms, benchmark::Counter::kAvgIterations); } constexpr int64_t kDim = 128; -// At least 5 vectors per cluster (n_vectors = 5 * n_lists). One row per config: brute_ms, cagra_ms, speedup. - -// 1. 64K centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +// Full build: k-means fit brute vs ANN (same problem sizes as before). +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({327680, 65536, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); - -// 2. 200K centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({1000000, 200000, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); - -// 3. 400K centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({2000000, 400000, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); - -// 4. 600K centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({3000000, 600000, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); - -// 5. 800K centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({4000000, 800000, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) + ->Args({5000000, 1000000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); -// 6. 1M centroids, 5 vecs/cluster -BENCHMARK(BM_IVFPQ_Build_Speedup) +// extend(): fixed-centroid assignment brute vs CAGRA. +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({327680, 65536, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({1000000, 200000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({2000000, 400000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({3000000, 600000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({4000000, 800000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) ->Args({5000000, 1000000, kDim}) ->Unit(benchmark::kMillisecond) ->UseRealTime() diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a839cecf56..9259bb3736 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -125,6 +125,21 @@ struct balanced_params : base_params { * Number of training iterations */ uint32_t n_iters = 20; + + /** + * Use approximate nearest neighbor (CAGRA) for cluster assignment during k-means fit. + * When true and n_clusters is large, assignment uses a CAGRA index over centroids to speed + * up the E-step. The index is rebuilt every `ann_rebuild_interval` iterations to limit + * rebuild cost (batched centroid updates). + */ + bool use_ann_for_fit = false; + + /** + * Rebuild the ANN index used for fit assignment every this many iterations (when + * use_ann_for_fit is true). Larger values reduce index build cost but use staler + * centroids for assignment in between rebuilds. + */ + uint32_t ann_rebuild_interval = 3; }; /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 16d5839d3b..4190eb8666 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -129,13 +129,21 @@ struct index_params : cuvs::neighbors::index_params { uint32_t max_train_points_per_pq_code = 256; /** - * Override for cluster assignment during extend() (and add_data_on_build). - * - std::nullopt (default): use heuristic (CAGRA when n_lists >= 200k, else brute force). - * - true: use CAGRA for assignment regardless of n_lists (for benchmarking). - * - false: use brute-force assignment regardless of n_lists (for benchmarking). + * Use CAGRA vs brute force for cluster assignment during extend() (and add_data_on_build). + * - std::nullopt (default): brute-force assignment. + * - true: use CAGRA for assignment. + * - false: brute-force assignment (explicit). */ std::optional use_ann_for_extend; + /** + * Use CAGRA vs brute force for balanced k-means *fit* (training centroids; they move each EM iter). + * - std::nullopt (default): brute-force assignment during fit. + * - true: ANN-based assignment during fit (CAGRA with periodic index rebuilds). + * - false: brute-force assignment during fit (explicit). + */ + std::optional use_ann_for_fit; + /** * Creates index_params based on shape of the input dataset. * Usage example: @@ -427,7 +435,7 @@ class index_iface { virtual raft::device_matrix_view centers_half( const raft::resources& res) const = 0; - /** When set, overrides heuristic for using CAGRA vs brute force in extend (cluster assignment). */ + /** Stored IVF-PQ build param: CAGRA vs brute for extend/add_data cluster assignment (nullopt => brute). */ virtual std::optional use_ann_for_extend() const = 0; }; diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index f5dc759725..78f7dd480d 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -7,7 +7,9 @@ #include "../../distance/fused_distance_nn.cuh" #include "kmeans_common.cuh" +#include "kmeans_predict_cagra.cuh" #include +#include #include "../../core/nvtx.hpp" #include "../../distance/distance.cuh" @@ -696,6 +698,7 @@ void balancing_em_iters(const raft::resources& handle, { auto stream = raft::resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; + std::optional> cagra_index_opt; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) @@ -731,18 +734,35 @@ void balancing_em_iters(const raft::resources& handle, } default: break; } - // E: Expectation step - predict labels - predict(handle, - params, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - mapping_op, - device_memory, - dataset_norm); + // E: Expectation step - predict labels (optionally via ANN with batched index rebuild) + const bool use_ann = params.use_ann_for_fit && + n_clusters >= cuvs::cluster::kmeans::detail::kMinClustersForAnnFit && + std::is_same_v && std::is_same_v; + if (use_ann) { + bool rebuild = (iter % params.ann_rebuild_interval == 0); + cuvs::cluster::kmeans::detail::predict_cagra_with_index_reuse(handle, + params, + cluster_centers, + n_clusters, + dim, + reinterpret_cast(dataset), + n_rows, + cluster_labels, + &cagra_index_opt, + rebuild); + } else { + predict(handle, + params, + cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + mapping_op, + device_memory, + dataset_norm); + } // M: Maximization step - calculate optimal cluster centers calc_centers_and_sizes(handle, cluster_centers, diff --git a/cpp/src/cluster/detail/kmeans_predict_cagra.cuh b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh index 8d2d5879e6..31a27da330 100644 --- a/cpp/src/cluster/detail/kmeans_predict_cagra.cuh +++ b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh @@ -5,11 +5,17 @@ * Cluster assignment via CAGRA: assign each data point to nearest centroid using an * approximate nearest neighbor search (CAGRA) over the centroids instead of brute force. * Used for scaling IVF training when the number of clusters K is very large. + * + * Shared helpers (build index on centroids, 1-NN search -> labels) are used by: + * - predict_cagra_with_index_reuse (k-means fit: optional index reuse; pass rebuild=true every call + * for one-shot assign on float data, same work as a former predict_cagra path) + * - ivf_pq extend (batched queries: build once + search_cagra_1nn per batch) */ #pragma once #include #include +#include #include #include #include @@ -17,97 +23,123 @@ #include #include -#include +#include namespace cuvs::cluster::kmeans::detail { +/** Default search params for 1-NN centroid assignment (max_queries = auto). */ +inline cuvs::neighbors::cagra::search_params default_cagra_centroid_search_params() +{ + cuvs::neighbors::cagra::search_params p; + p.max_queries = 0; // auto + return p; +} + /** - * @brief Assign each row in X to the nearest centroid using CAGRA (1-NN search over centroids). - * - * Builds a CAGRA index on the centroids and runs k=1 search with X as queries. The returned - * neighbor indices are the cluster labels. This is approximate and faster than brute force - * when the number of clusters K is large. - * - * Supports the same metrics as CAGRA: L2Expanded, L2SqrtExpanded, InnerProduct, CosineExpanded. - * Centroids and (after mapping) query data must be float. - * - * @param[in] handle RAFT resources - * @param[in] params Balanced params (metric used for assignment) - * @param[in] X Data to assign [n_rows, dim] - * @param[in] centroids Cluster centers [n_clusters, dim] (device, row-major) - * @param[out] labels Output cluster index per row [n_rows] - * @param[in] mapping_op Optional mapping from DataT to float (e.g. for quantized input) + * @brief Build a CAGRA index on centroid vectors (shared by extend, search_cagra_1nn callers, and + * predict_cagra_with_index_reuse). */ -template -std::enable_if_t> predict_cagra( +inline cuvs::neighbors::cagra::index build_cagra_index_for_centroids( raft::resources const& handle, cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) + raft::device_matrix_view centroids) { using namespace cuvs::neighbors::cagra; - - RAFT_EXPECTS(X.extent(0) == labels.extent(0), "X rows and labels size must match"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "X dim and centroids dim must match"); - RAFT_EXPECTS(centroids.extent(0) >= 1, "Need at least one centroid"); - - auto stream = raft::resource::get_cuda_stream(handle); - int64_t n_rows = static_cast(X.extent(0)); - int64_t n_clusters = static_cast(centroids.extent(0)); - int64_t dim = static_cast(centroids.extent(1)); - - // CAGRA graph degree cannot exceed n_clusters - 1 + int64_t n_clusters = centroids.extent(0); size_t graph_degree = std::min(64, std::max(1, n_clusters - 1)); size_t inter_degree = std::min(128, std::max(1, n_clusters - 1)); - index_params build_params; - build_params.metric = params.metric; - build_params.graph_degree = graph_degree; + build_params.metric = params.metric; + build_params.graph_degree = graph_degree; build_params.intermediate_graph_degree = inter_degree; - build_params.attach_dataset_on_build = true; - - // Build CAGRA index on centroids (centroids are [n_clusters, dim]) - auto centers_view = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, dim); - auto cagra_index = cuvs::neighbors::cagra::build(handle, build_params, centers_view); + build_params.attach_dataset_on_build = true; + return build(handle, build_params, centroids); +} - // Queries: convert X to float if needed - rmm::device_uvector queries_buf(0, stream); - raft::device_matrix_view queries_view(nullptr, 0, 0); - if constexpr (std::is_same_v) { - queries_view = raft::make_device_matrix_view( - reinterpret_cast(X.data_handle()), n_rows, dim); - } else { - queries_buf.resize(static_cast(n_rows) * static_cast(dim), stream); - auto queries_mat = raft::make_device_matrix_view( - queries_buf.data(), n_rows, dim); - raft::linalg::map(handle, raft::make_const_mdspan(X), queries_mat, mapping_op); - queries_view = raft::make_device_matrix_view( - queries_buf.data(), n_rows, dim); +/** + * @brief Run 1-NN search with an existing CAGRA index and write cluster labels (and optional + * per-query distances). Queries must be float row-major [n_queries, dim]. + * + * Uses explicit row_major matrix view and raw label pointer so this compiles when + * raft::make_device_*_view returns layout_c_contiguous mdspan (not assignable to the view types + * required by cagra::search / device_vector_view). + */ +template +void search_cagra_1nn(raft::resources const& handle, + cuvs::neighbors::cagra::search_params const& search_params, + cuvs::neighbors::cagra::index const& cagra_index, + raft::device_matrix_view queries, + LabelT* labels_out, + int64_t n_labels, + float* distances_out = nullptr) +{ + using namespace cuvs::neighbors::cagra; + int64_t n_rows = queries.extent(0); + RAFT_EXPECTS(n_labels == n_rows, "search_cagra_1nn: labels length must match n_queries"); + auto neighbors = raft::make_device_matrix(handle, n_rows, 1); + auto distances = raft::make_device_matrix(handle, n_rows, 1); + search(handle, search_params, cagra_index, queries, neighbors.view(), distances.view()); + auto neighbors_col = + raft::make_device_vector_view(neighbors.data_handle(), n_rows); + auto labels_view = raft::make_device_vector_view(labels_out, n_rows); + raft::linalg::map( + handle, raft::make_const_mdspan(neighbors_col), labels_view, raft::cast_op()); + if (distances_out != nullptr) { + raft::copy(handle, + raft::make_device_vector_view(distances_out, n_rows), + raft::make_device_vector_view(distances.data_handle(), n_rows)); } +} - // Search k=1 - search_params search_params; - search_params.max_queries = 0; // auto +/** Minimum number of clusters to use ANN for k-means fit (below this, brute is faster). */ +constexpr uint32_t kMinClustersForAnnFit = 5000; - auto neighbors = raft::make_device_matrix(handle, n_rows, 1); - auto distances = raft::make_device_matrix(handle, n_rows, 1); +/** + * @brief Assign each row to nearest centroid using CAGRA, reusing or rebuilding the index. + * + * When rebuild is true (or index is empty), builds the index on current centroids and stores it + * in *index_opt. Otherwise skips build and searches with the existing index: centroid vectors in + * memory may have shifted since that build (k-means M-step), so the graph still indexes a stale + * snapshot — assignments are intentionally approximate between rebuilds. + * + * For a one-shot assign on float data (same work as building a fresh index then searching once), + * pass rebuild=true each time (e.g. each benchmark iteration). For k-means fit, pass + * rebuild=(iter % ann_rebuild_interval == 0) to amortize builds. Centroids and dataset must be + * float. For k-means ANN path, call only when use_ann_for_fit and n_clusters >= kMinClustersForAnnFit. + */ +template +void predict_cagra_with_index_reuse( + raft::resources const& handle, + cuvs::cluster::kmeans::balanced_params const& params, + const float* centers, + IdxT n_clusters, + IdxT dim, + const float* dataset, + IdxT n_rows, + LabelT* labels, + std::optional>* index_opt, + bool rebuild) +{ + RAFT_EXPECTS(centers != nullptr && dataset != nullptr && labels != nullptr && index_opt != nullptr, + "predict_cagra_with_index_reuse: null argument"); + RAFT_EXPECTS(n_clusters >= 1 && dim >= 1 && n_rows >= 1, "predict_cagra_with_index_reuse: bad extents"); - cuvs::neighbors::cagra::search( - handle, search_params, cagra_index, queries_view, neighbors.view(), distances.view()); + raft::device_matrix_view centers_view( + centers, static_cast(n_clusters), static_cast(dim)); + raft::device_matrix_view queries_view( + dataset, static_cast(n_rows), static_cast(dim)); - // Copy neighbor indices (column 0) to labels with cast to LabelT - auto neighbors_col = raft::make_device_vector_view( - neighbors.data_handle(), n_rows); - auto labels_view = raft::make_device_vector_view(labels.data_handle(), n_rows); - raft::linalg::map( - handle, raft::make_const_mdspan(neighbors_col), labels_view, raft::cast_op()); + if (rebuild || !index_opt->has_value()) { + *index_opt = build_cagra_index_for_centroids(handle, params, centers_view); + } + + search_cagra_1nn(handle, + default_cagra_centroid_search_params(), + index_opt->value(), + queries_view, + labels, + static_cast(n_rows), + nullptr); } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index d36890e8cc..0c0df03397 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -7,7 +7,6 @@ #include "../neighbors/detail/ann_utils.cuh" #include "detail/kmeans_balanced.cuh" -#include "detail/kmeans_predict_cagra.cuh" #include #include #include @@ -158,45 +157,6 @@ void predict(const raft::resources& handle, raft::resource::get_workspace_resource(handle)); } -/** - * @brief Assign each sample to nearest centroid using CAGRA (ANN-based, approximate). - * - * Same contract as predict() but builds a CAGRA index on centroids and runs k=1 search. - * Faster than brute force when the number of clusters K is large; results are approximate. - * Only supported when centroids are float and metric is L2Expanded, L2SqrtExpanded, - * InnerProduct, or CosineExpanded. - * - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters (metric) - * @param[in] X Dataset for which to infer the closest clusters - * @param[in] centroids The input centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to float - */ -template -void predict_cagra(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) <= - static_cast(std::numeric_limits::max()), - "The chosen label type cannot represent all cluster labels"); - - cuvs::cluster::kmeans::detail::predict_cagra( - handle, params, X, centroids, labels, mapping_op); -} - namespace helpers { /** diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index bd49868b81..4b4ca10373 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -22,11 +22,6 @@ // TODO (cjnolet): This should be using an exposed API instead of circumventing the public APIs. #include "../../cluster/kmeans_balanced.cuh" #include -#include - -// Use CAGRA for cluster assignment in extend when n_lists >= this (faster for large K). -// Set from cluster-assignment benchmark: CAGRA wins over brute force for K >= ~200K (e.g. N=1M). -constexpr uint32_t kUseAnnForClusterAssignmentMinClusters = 200000; #include #include @@ -1121,21 +1116,18 @@ void extend(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); - bool use_cagra = index->use_ann_for_extend().value_or( - n_clusters >= kUseAnnForClusterAssignmentMinClusters); + // Default: brute-force assignment; set use_ann_for_extend to opt in to CAGRA. + bool use_cagra = index->use_ann_for_extend().value_or(false); if (use_cagra) { // Use CAGRA for cluster assignment when K is large (build once, search per batch). - auto centers_view = raft::make_device_matrix_view( - cluster_centers.data(), n_clusters, index->dim()); - cuvs::neighbors::cagra::index_params cagra_params; - cagra_params.metric = index->metric(); - cagra_params.graph_degree = - std::min(64, std::max(1, static_cast(n_clusters) - 1)); - cagra_params.intermediate_graph_degree = - std::min(128, std::max(1, static_cast(n_clusters) - 1)); - cagra_params.attach_dataset_on_build = true; - auto cagra_idx = cuvs::neighbors::cagra::build(handle, cagra_params, centers_view); - cuvs::neighbors::cagra::search_params search_params; + // Same centroid index + 1-NN search path as kmeans::detail (extend batches; fit uses + // predict_cagra_with_index_reuse). + raft::device_matrix_view centers_view( + cluster_centers.data(), static_cast(n_clusters), static_cast(index->dim())); + auto cagra_idx = + cuvs::cluster::kmeans::detail::build_cagra_index_for_centroids( + handle, kmeans_params, centers_view); + auto search_params = cuvs::cluster::kmeans::detail::default_cagra_centroid_search_params(); for (const auto& batch : vec_batches) { auto batch_size = batch.size(); @@ -1150,17 +1142,16 @@ void extend(raft::resources const& handle, batch_size, index->dim()), utils::mapping{}); - auto queries_view = raft::make_device_matrix_view( - queries_float.data(), batch_size, index->dim()); - auto neighbors = raft::make_device_matrix(handle, batch_size, 1); - auto distances = raft::make_device_matrix(handle, batch_size, 1); - cuvs::neighbors::cagra::search( - handle, search_params, cagra_idx, queries_view, neighbors.view(), distances.view()); - raft::copy(handle, - raft::make_device_vector_view(new_data_labels.data() + batch.offset(), - batch_size), - raft::make_device_vector_view(neighbors.data_handle(), - batch_size)); + raft::device_matrix_view queries_view( + queries_float.data(), static_cast(batch_size), static_cast(index->dim())); + cuvs::cluster::kmeans::detail::search_cagra_1nn( + handle, + search_params, + cagra_idx, + queries_view, + new_data_labels.data() + batch.offset(), + static_cast(batch_size), + nullptr); vec_batches.prefetch_next_batch(); raft::resource::sync_stream(handle); } @@ -1382,6 +1373,12 @@ auto build(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = static_cast((int)impl->metric()); + // ANN for k-means fit (centroids change each iteration): only when use_ann_for_fit == true. + // Default (nullopt / false) is brute-force assignment so benchmarks can sweep cluster counts. + if (params.use_ann_for_fit.value_or(false)) { + kmeans_params.use_ann_for_fit = true; + kmeans_params.ann_rebuild_interval = 3; + } if (impl->metric() == distance::DistanceType::CosineExpanded) { raft::linalg::row_normalize( From ad054af709e42d2598ab2444df4ff6a2d9b06241 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 24 Mar 2026 11:57:23 -0700 Subject: [PATCH 6/6] add test cases --- .../ann/src/cuvs/cuvs_ivf_pq_build_bench.cu | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu index 1c77d7d36d..76f162ad13 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu @@ -203,6 +203,16 @@ BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) + ->Args({1500000, 300000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) + ->Args({1750000, 350000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); BENCHMARK(BM_IVFPQ_Build_KMeansFit_Speedup) ->Args({2000000, 400000, kDim}) ->Unit(benchmark::kMillisecond) @@ -235,6 +245,16 @@ BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) ->Unit(benchmark::kMillisecond) ->UseRealTime() ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({1500000, 300000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); +BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) + ->Args({1750000, 350000, kDim}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime() + ->ArgNames({"n_vectors", "n_lists", "dim"}); BENCHMARK(BM_IVFPQ_Extend_ClusterAssign_Speedup) ->Args({2000000, 400000, kDim}) ->Unit(benchmark::kMillisecond)