Skip to content

Commit fb969cc

Browse files
authored
[FEA] Build Single Linkage API (rapidsai#820)
- Expose functions for building the dendrogram on the mutual reachability graph Authors: - Tarang Jain (https://github.com/tarang-jain) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Jinsol Park (https://github.com/jinsolp) - Divye Gala (https://github.com/divyegala) URL: rapidsai#820
1 parent 62e5d0d commit fb969cc

6 files changed

Lines changed: 402 additions & 144 deletions

File tree

cpp/include/cuvs/cluster/agglomerative.hpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,7 +18,9 @@
1818

1919
#include <cuvs/distance/distance.hpp>
2020
#include <optional>
21+
#include <variant>
2122

23+
#include <raft/core/device_coo_matrix.hpp>
2224
#include <raft/core/device_mdspan.hpp>
2325
#include <raft/core/resources.hpp>
2426

@@ -119,6 +121,58 @@ void single_linkage(
119121
cuvs::cluster::agglomerative::Linkage linkage = cuvs::cluster::agglomerative::Linkage::KNN_GRAPH,
120122
std::optional<int> c = std::make_optional<int>(DEFAULT_CONST_C));
121123

124+
namespace helpers {
125+
126+
namespace linkage_graph_params {
127+
/** Specialized parameters to build the KNN graph with regular distances */
128+
struct distance_params {
129+
/** a constant used when constructing linkage from knn graph. Allows the indirect control of k.
130+
* The algorithm will set `k = log(n) + c` */
131+
int c = DEFAULT_CONST_C;
132+
133+
/** strategy for constructing the linkage. PAIRWISE uses more memory but can be faster for smaller
134+
* datasets. KNN_GRAPH allows the memory usage to be controlled (using parameter c) */
135+
cuvs::cluster::agglomerative::Linkage dist_type =
136+
cuvs::cluster::agglomerative::Linkage::KNN_GRAPH;
137+
};
138+
139+
/** Specialized parameters to build the Mutual Reachability graph */
140+
struct mutual_reachability_params {
141+
/** this neighborhood will be selected for core distances. */
142+
int min_samples;
143+
144+
/** weight applied when internal distance is chosen for mutual reachability (value of 1.0 disables
145+
* the weighting) */
146+
float alpha = 1.0;
147+
};
148+
} // namespace linkage_graph_params
149+
/**
150+
* Given a dataset, builds the KNN graph, connects graph components and builds a linkage
151+
* (dendrogram). Returns the Minimum Spanning Tree edges sorted by weight and the dendrogram.
152+
* @param[in] handle raft handle for resource reuse
153+
* @param[in] X data points (size n_rows * d)
154+
* @param[in] linkage_graph_params linkage params or mutual reachability params for building the KNN
155+
* graph
156+
* @param[in] metric distance metric to use
157+
* @param[out] out_mst output MST sorted by edge weights (size n_rows - 1)
158+
* @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2)
159+
* @param[out] out_distances distances for output
160+
* @param[out] out_sizes cluster sizes of output
161+
* @param[out] core_dists (optional) core distances (size m). Must be supplied in the Mutual
162+
* Reachability space
163+
*/
164+
void build_linkage(
165+
raft::resources const& handle,
166+
raft::device_matrix_view<const float, int, raft::row_major> X,
167+
std::variant<linkage_graph_params::distance_params,
168+
linkage_graph_params::mutual_reachability_params> linkage_graph_params,
169+
cuvs::distance::DistanceType metric,
170+
raft::device_coo_matrix_view<float, int, int, size_t> out_mst,
171+
raft::device_matrix_view<int, int> dendrogram,
172+
raft::device_vector_view<float, int> out_distances,
173+
raft::device_vector_view<int, int> out_sizes,
174+
std::optional<raft::device_vector_view<float, int>> core_dists);
175+
} // namespace helpers
122176
/**
123177
* @}
124178
*/

cpp/src/cluster/detail/single_linkage.cuh

Lines changed: 195 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,51 +16,138 @@
1616

1717
#pragma once
1818

19+
#include "../../neighbors/detail/reachability.cuh"
1920
#include "agglomerative.cuh"
2021
#include "connectivities.cuh"
2122
#include "mst.cuh"
2223
#include <cuvs/cluster/agglomerative.hpp>
2324
#include <raft/core/resource/cuda_stream.hpp>
25+
#include <raft/sparse/coo.hpp>
2426
#include <raft/util/cudart_utils.hpp>
2527

2628
#include <rmm/device_uvector.hpp>
2729

2830
namespace cuvs::cluster::agglomerative::detail {
2931

32+
/**
33+
* Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
34+
* Reachability space. Returns mst edges sorted by weight and the dendrogram.
35+
* @tparam value_t
36+
* @tparam value_idx
37+
* @tparam nnz_t
38+
* @param[in] handle raft handle for resource reuse
39+
* @param[in] X data points (size m * n)
40+
* @param[in] metric distance metric to use
41+
* @param[in] min_samples this neighborhood will be selected for core distances
42+
* @param[in] alpha weight applied when internal distance is chosen for mutual reachability (value
43+
* of 1.0 disables the weighting)
44+
* @param[out] core_dists core distances (size m)
45+
* @param[out] out_mst output MST sorted by edge weights (size m - 1)
46+
* @param[out] out_dendrogram output dendrogram
47+
* @param[out] out_distances distances of output
48+
* @param[out] out_sizes cluster sizes of output
49+
*/
50+
template <typename value_t = float, typename value_idx = int, typename nnz_t = size_t>
51+
void build_mr_linkage(raft::resources const& handle,
52+
raft::device_matrix_view<const value_t, value_idx, raft::row_major> X,
53+
value_idx min_samples,
54+
float alpha,
55+
cuvs::distance::DistanceType metric,
56+
raft::device_vector_view<value_t, value_idx> core_dists,
57+
raft::device_coo_matrix_view<value_t, value_idx, value_idx, nnz_t> out_mst,
58+
raft::device_matrix_view<value_idx, value_idx> out_dendrogram,
59+
raft::device_vector_view<value_t, value_idx> out_distances,
60+
raft::device_vector_view<value_idx, value_idx> out_sizes)
61+
{
62+
size_t m = X.extent(0);
63+
size_t n = X.extent(1);
64+
auto mutual_reachability_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1);
65+
raft::sparse::COO<value_t, value_idx, nnz_t> mutual_reachability_coo(
66+
raft::resource::get_cuda_stream(handle), min_samples * m * 2);
67+
68+
// Replace this with mutual reachability graph cronstruction from within all_neighbors wrapper.
69+
// Reference: https://github.com/rapidsai/cuvs/issues/982
70+
cuvs::neighbors::detail::reachability::mutual_reachability_graph<value_idx, value_t, nnz_t>(
71+
handle,
72+
X.data_handle(),
73+
m,
74+
n,
75+
metric,
76+
min_samples,
77+
alpha,
78+
mutual_reachability_indptr.data_handle(),
79+
core_dists.data_handle(),
80+
mutual_reachability_coo);
81+
82+
// auto color = raft::make_device_vector<value_idx, value_idx>(handle, static_cast<value_idx>(m));
83+
rmm::device_uvector<value_idx> color(m, raft::resource::get_cuda_stream(handle));
84+
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>
85+
reduction_op(core_dists.data_handle(), m);
86+
87+
size_t nnz = m * min_samples;
88+
89+
detail::build_sorted_mst<value_idx, value_t>(handle,
90+
X.data_handle(),
91+
mutual_reachability_indptr.data_handle(),
92+
mutual_reachability_coo.cols(),
93+
mutual_reachability_coo.vals(),
94+
m,
95+
n,
96+
out_mst.structure_view().get_rows().data(),
97+
out_mst.structure_view().get_cols().data(),
98+
out_mst.get_elements().data(),
99+
color.data(),
100+
mutual_reachability_coo.nnz,
101+
reduction_op,
102+
metric,
103+
10);
104+
105+
/**
106+
* Perform hierarchical labeling
107+
*/
108+
size_t n_edges = m - 1;
109+
110+
detail::build_dendrogram_host<value_idx, value_t>(handle,
111+
out_mst.structure_view().get_rows().data(),
112+
out_mst.structure_view().get_cols().data(),
113+
out_mst.get_elements().data(),
114+
n_edges,
115+
out_dendrogram.data_handle(),
116+
out_distances.data_handle(),
117+
out_sizes.data_handle());
118+
}
119+
30120
static const size_t EMPTY = 0;
31121

32122
/**
33-
* Single-linkage clustering, capable of constructing a KNN graph to
34-
* scale the algorithm beyond the n^2 memory consumption of implementations
35-
* that use the fully-connected graph of pairwise distances by connecting
36-
* a knn graph when k is not large enough to connect it.
37-
38-
* @tparam value_idx
123+
* Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
124+
* Reachability space. Returns mst edges sorted by weight and the dendrogram.
39125
* @tparam value_t
126+
* @tparam value_idx
127+
* @tparam nnz_t
40128
* @tparam dist_type method to use for constructing connectivities graph
41-
* @param[in] handle raft handle
42-
* @param[in] X dense input matrix in row-major layout
43-
* @param[in] m number of rows in X
44-
* @param[in] n number of columns in X
45-
* @param[in] metric distance metrix to use when constructing connectivities graph
46-
* @param[out] out struct containing output dendrogram and cluster assignments
47-
* @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
48-
control
49-
* of k. The algorithm will set `k = log(n) + c`
50-
* @param[in] n_clusters number of clusters to assign data samples
129+
* @param[in] handle raft handle for resource reuse
130+
* @param[in] X data points (size m * n)
131+
* @param[in] c a constant used when constructing linkage from knn graph. Allows the indirect
132+
* control of k. The algorithm will set `k = log(n) + c`
133+
* @param[in] metric distance metric to use
134+
* @param[out] out_mst output MST sorted by edge weights (size m - 1)
135+
* @param[out] out_dendrogram output dendrogram
136+
* @param[out] out_distances distances of output
137+
* @param[out] out_sizes cluster sizes of output
51138
*/
52-
template <typename value_idx, typename value_t, Linkage dist_type>
53-
void single_linkage(raft::resources const& handle,
54-
const value_t* X,
55-
size_t m,
56-
size_t n,
57-
cuvs::distance::DistanceType metric,
58-
single_linkage_output<value_idx>* out,
59-
int c,
60-
size_t n_clusters)
139+
template <typename value_t, typename value_idx, typename nnz_t, Linkage dist_type>
140+
void build_dist_linkage(raft::resources const& handle,
141+
raft::device_matrix_view<const value_t, value_idx, raft::row_major> X,
142+
int c,
143+
cuvs::distance::DistanceType metric,
144+
raft::device_coo_matrix_view<value_t, value_idx, value_idx, nnz_t> out_mst,
145+
raft::device_matrix_view<value_idx, value_idx> out_dendrogram,
146+
raft::device_vector_view<value_t, value_idx> out_distances,
147+
raft::device_vector_view<value_idx, value_idx> out_sizes)
61148
{
62-
ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points");
63-
149+
size_t m = X.extent(0);
150+
size_t n = X.extent(1);
64151
auto stream = raft::resource::get_cuda_stream(handle);
65152

66153
rmm::device_uvector<value_idx> indptr(EMPTY, stream);
@@ -70,51 +157,109 @@ void single_linkage(raft::resources const& handle,
70157
/**
71158
* 1. Construct distance graph
72159
*/
73-
detail::get_distance_graph<value_idx, value_t, dist_type>(
74-
handle, X, m, n, metric, indptr, indices, pw_dists, c);
75-
76-
rmm::device_uvector<value_idx> mst_rows(m - 1, stream);
77-
rmm::device_uvector<value_idx> mst_cols(m - 1, stream);
78-
rmm::device_uvector<value_t> mst_data(m - 1, stream);
160+
detail::get_distance_graph<value_idx, value_t, dist_type>(handle,
161+
X.data_handle(),
162+
static_cast<value_idx>(m),
163+
static_cast<value_idx>(n),
164+
metric,
165+
indptr,
166+
indices,
167+
pw_dists,
168+
c);
79169

80170
/**
81171
* 2. Construct MST, sorted by weights
82172
*/
83173
rmm::device_uvector<value_idx> color(m, stream);
84174
cuvs::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);
175+
176+
size_t n_edges = m - 1;
177+
85178
detail::build_sorted_mst<value_idx, value_t>(handle,
86-
X,
179+
X.data_handle(),
87180
indptr.data(),
88181
indices.data(),
89182
pw_dists.data(),
90183
m,
91184
n,
92-
mst_rows.data(),
93-
mst_cols.data(),
94-
mst_data.data(),
185+
out_mst.structure_view().get_rows().data(),
186+
out_mst.structure_view().get_cols().data(),
187+
out_mst.get_elements().data(),
95188
color.data(),
96189
indices.size(),
97190
op,
98191
metric);
99-
100192
pw_dists.release();
101193

102194
/**
103195
* Perform hierarchical labeling
104196
*/
105-
size_t n_edges = mst_rows.size();
106-
107-
rmm::device_uvector<value_t> out_delta(n_edges, stream);
108-
rmm::device_uvector<value_idx> out_size(n_edges, stream);
109-
// Create dendrogram
110197
detail::build_dendrogram_host<value_idx, value_t>(handle,
111-
mst_rows.data(),
112-
mst_cols.data(),
113-
mst_data.data(),
198+
out_mst.structure_view().get_rows().data(),
199+
out_mst.structure_view().get_cols().data(),
200+
out_mst.get_elements().data(),
114201
n_edges,
115-
out->children,
116-
out_delta.data(),
117-
out_size.data());
202+
out_dendrogram.data_handle(),
203+
out_distances.data_handle(),
204+
out_sizes.data_handle());
205+
}
206+
207+
/**
208+
* Single-linkage clustering, capable of constructing a KNN graph to
209+
* scale the algorithm beyond the n^2 memory consumption of implementations
210+
* that use the fully-connected graph of pairwise distances by connecting
211+
* a knn graph when k is not large enough to connect it.
212+
213+
* @tparam value_idx
214+
* @tparam value_t
215+
* @tparam dist_type method to use for constructing connectivities graph
216+
* @param[in] handle raft handle
217+
* @param[in] X dense input matrix in row-major layout
218+
* @param[in] m number of rows in X
219+
* @param[in] n number of columns in X
220+
* @param[in] metric distance metrix to use when constructing connectivities graph
221+
* @param[out] out struct containing output dendrogram and cluster assignments
222+
* @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
223+
control
224+
* of k. The algorithm will set `k = log(n) + c`
225+
* @param[in] n_clusters number of clusters to assign data samples
226+
*/
227+
template <typename value_idx, typename value_t, Linkage dist_type>
228+
void single_linkage(raft::resources const& handle,
229+
const value_t* X,
230+
size_t m,
231+
size_t n,
232+
cuvs::distance::DistanceType metric,
233+
single_linkage_output<value_idx>* out,
234+
int c,
235+
size_t n_clusters)
236+
{
237+
ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points");
238+
239+
value_idx n_edges = m - 1;
240+
auto mst_rows = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
241+
auto mst_cols = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
242+
auto mst_weights = raft::make_device_vector<value_t, value_idx>(handle, n_edges);
243+
auto structure_view =
244+
raft::make_device_coordinate_structure_view<value_idx, value_idx, value_idx>(
245+
mst_rows.data_handle(), mst_cols.data_handle(), m, m, n_edges);
246+
auto mst_view = raft::make_device_coo_matrix_view<value_t, value_idx, value_idx, value_idx>(
247+
mst_weights.data_handle(), structure_view);
248+
249+
auto out_delta = raft::make_device_vector<value_t, value_idx>(handle, n_edges);
250+
auto out_sizes = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
251+
252+
build_dist_linkage<value_t, value_idx, value_idx, dist_type>(
253+
handle,
254+
raft::make_device_matrix_view<const value_t, value_idx, raft::row_major>(
255+
X, static_cast<value_idx>(m), static_cast<value_idx>(n)),
256+
c,
257+
metric,
258+
mst_view,
259+
raft::make_device_matrix_view<value_idx, value_idx, raft::row_major>(out->children, n_edges, 2),
260+
out_delta.view(),
261+
out_sizes.view());
262+
118263
detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m);
119264

120265
out->m = m;

0 commit comments

Comments
 (0)