Skip to content

Commit e0d4133

Browse files
authored
[REVIEW] Add a public API for CAGRA graph optimize (rapidsai#1260)
This change exposes CAGRA's graph optimization functionality through a new public API in the `cuvs::neighbors::cagra::helpers` namespace. The optimize function allows users to optimize a custom KNN graphs to create CAGRA search graphs. Key Changes: - Added `cpp/include/cuvs/neighbors/cagra_optimize.hpp` with a public API declaration. As cagra_optimize assumes both input and output are host matrices, this PR exposes the API with host inputs only. Given this, users need to explicitly manage device-to-host transfers if KNN graph is not a RAFT host matrix - Modified previously unused `cpp/src/neighbors/cagra_optimize.cu` with implementation that forwards to internal CAGRA optimize. Removed the version using device matrix for `knn_graph`, which generates runtime segfaults - Added unit tests for the new API This PR partially resolves rapidsai#456 and completes one of the steps for rapidsai#1146 Authors: - Rui Lan (https://github.com/abc99lr) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1260
1 parent 2922045 commit e0d4133

5 files changed

Lines changed: 127 additions & 17 deletions

File tree

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,3 +2686,4 @@ auto distribute(const raft::resources& clique, const std::string& filename)
26862686
} // namespace cuvs::neighbors::cagra
26872687

26882688
#include <cuvs/neighbors/cagra_index_wrapper.hpp>
2689+
#include <cuvs/neighbors/cagra_optimize.hpp>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <raft/core/device_mdspan.hpp>
20+
#include <raft/core/host_mdspan.hpp>
21+
#include <raft/core/mdspan_types.hpp>
22+
#include <raft/core/resources.hpp>
23+
24+
namespace cuvs::neighbors::cagra::helpers {
25+
26+
/**
27+
* @brief Optimize a KNN graph into a CAGRA graph.
28+
*
29+
* This function optimizes a k-NN graph to create a CAGRA graph.
30+
* The input/output graphs must be on host memory.
31+
*
32+
* Usage example:
33+
* @code{.cpp}
34+
* raft::resources res;
35+
* auto h_knn = raft::make_host_matrix<uint32_t, int64_t>(N, K_in);
36+
* // Fill h_knn with KNN graph
37+
* auto h_out = raft::make_host_matrix<uint32_t, int64_t>(N, K_out);
38+
* cuvs::neighbors::cagra::helpers::optimize(res, h_knn.view(), h_out.view());
39+
* @endcode
40+
*
41+
* @param[in] handle RAFT resources
42+
* @param[in] knn_graph Input KNN graph on host [n_rows, k_in]
43+
* @param[out] new_graph Output CAGRA graph on host [n_rows, k_out]
44+
*/
45+
void optimize(raft::resources const& handle,
46+
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> knn_graph,
47+
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> new_graph);
48+
49+
} // namespace cuvs::neighbors::cagra::helpers
Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,27 +15,15 @@
1515
*/
1616

1717
#include "cagra.cuh"
18-
#include <cuvs/neighbors/cagra.hpp>
18+
#include <cuvs/neighbors/cagra_optimize.hpp>
1919

20-
namespace cuvs::neighbors::cagra {
20+
namespace cuvs::neighbors::cagra::helpers {
2121

22-
void optimize(raft::resources const& handle,
23-
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> knn_graph,
24-
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> new_graph)
25-
{
26-
cuvs::neighbors::cagra::optimize<
27-
uint32_t,
28-
raft::host_device_accessor<std::experimental::default_accessor<uint32_t>,
29-
raft::memory_type::device>>(handle, knn_graph, new_graph);
30-
}
3122
void optimize(raft::resources const& handle,
3223
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> knn_graph,
3324
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> new_graph)
3425
{
35-
cuvs::neighbors::cagra::optimize<
36-
uint32_t,
37-
raft::host_device_accessor<std::experimental::default_accessor<uint32_t>,
38-
raft::memory_type::host>>(handle, knn_graph, new_graph);
26+
cuvs::neighbors::cagra::optimize(handle, knn_graph, new_graph);
3927
}
4028

41-
} // namespace cuvs::neighbors::cagra
29+
} // namespace cuvs::neighbors::cagra::helpers

cpp/tests/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ if(BUILD_TESTS)
166166
1 PERCENT 100
167167
)
168168

169+
ConfigureTest(
170+
NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu GPUS 1
171+
PERCENT 100
172+
)
173+
169174
ConfigureTest(
170175
NAME NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST PATH neighbors/ann_cagra/test_half_uint32_t.cu GPUS 1
171176
PERCENT 100
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <cuvs/neighbors/cagra.hpp>
18+
#include <gtest/gtest.h>
19+
#include <raft/core/host_mdspan.hpp>
20+
#include <raft/core/resources.hpp>
21+
22+
// This test targets public API exposure and basic invariants only (shapes, in-range indices).
23+
// Detailed optimization correctness is exercised by CAGRA build tests.
24+
25+
namespace {
26+
27+
using IdxT = uint32_t;
28+
29+
// Helper to create a simple synthetic KNN graph (ring-like neighbors)
30+
auto make_ring_knn_host(int64_t num_rows, int64_t kin)
31+
{
32+
auto knn_graph = raft::make_host_matrix<IdxT, int64_t>(num_rows, kin);
33+
for (int64_t i = 0; i < num_rows; ++i) {
34+
for (int64_t j = 0; j < kin; ++j) {
35+
knn_graph(i, j) = static_cast<IdxT>((i + j + 1) % num_rows);
36+
}
37+
}
38+
return knn_graph;
39+
}
40+
41+
TEST(CagraOptimize, HostToHostOptimizesGraph)
42+
{
43+
raft::resources res;
44+
45+
constexpr int64_t num_rows = 8;
46+
constexpr int64_t kin = 8;
47+
constexpr int64_t kout = 4;
48+
49+
auto knn_graph = make_ring_knn_host(num_rows, kin);
50+
auto optimized_graph = raft::make_host_matrix<IdxT, int64_t>(num_rows, kout);
51+
52+
// Test the optimize API
53+
cuvs::neighbors::cagra::helpers::optimize(res, knn_graph.view(), optimized_graph.view());
54+
55+
// Check basic invariants
56+
ASSERT_EQ(optimized_graph.extent(0), num_rows);
57+
ASSERT_EQ(optimized_graph.extent(1), kout);
58+
59+
// Check that all neighbors are valid indices
60+
for (int64_t i = 0; i < num_rows; ++i) {
61+
for (int64_t j = 0; j < kout; ++j) {
62+
EXPECT_LT(optimized_graph(i, j), static_cast<IdxT>(num_rows));
63+
}
64+
}
65+
}
66+
67+
} // namespace

0 commit comments

Comments
 (0)