11/*
2- * Copyright (c) 2021-2024 , NVIDIA CORPORATION.
2+ * Copyright (c) 2021-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1616
1717#pragma once
1818
19+ #include " ../../neighbors/detail/reachability.cuh"
1920#include " agglomerative.cuh"
2021#include " connectivities.cuh"
2122#include " mst.cuh"
2223#include < cuvs/cluster/agglomerative.hpp>
2324#include < raft/core/resource/cuda_stream.hpp>
25+ #include < raft/sparse/coo.hpp>
2426#include < raft/util/cudart_utils.hpp>
2527
2628#include < rmm/device_uvector.hpp>
2729
2830namespace cuvs ::cluster::agglomerative::detail {
2931
32+ /* *
33+ * Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
34+ * Reachability space. Returns mst edges sorted by weight and the dendrogram.
35+ * @tparam value_t
36+ * @tparam value_idx
37+ * @tparam nnz_t
38+ * @param[in] handle raft handle for resource reuse
39+ * @param[in] X data points (size m * n)
40+ * @param[in] metric distance metric to use
41+ * @param[in] min_samples this neighborhood will be selected for core distances
42+ * @param[in] alpha weight applied when internal distance is chosen for mutual reachability (value
43+ * of 1.0 disables the weighting)
44+ * @param[out] core_dists core distances (size m)
45+ * @param[out] out_mst output MST sorted by edge weights (size m - 1)
46+ * @param[out] out_dendrogram output dendrogram
47+ * @param[out] out_distances distances of output
48+ * @param[out] out_sizes cluster sizes of output
49+ */
50+ template <typename value_t = float , typename value_idx = int , typename nnz_t = size_t >
51+ void build_mr_linkage (raft::resources const & handle,
52+ raft::device_matrix_view<const value_t , value_idx, raft::row_major> X,
53+ value_idx min_samples,
54+ float alpha,
55+ cuvs::distance::DistanceType metric,
56+ raft::device_vector_view<value_t , value_idx> core_dists,
57+ raft::device_coo_matrix_view<value_t , value_idx, value_idx, nnz_t > out_mst,
58+ raft::device_matrix_view<value_idx, value_idx> out_dendrogram,
59+ raft::device_vector_view<value_t , value_idx> out_distances,
60+ raft::device_vector_view<value_idx, value_idx> out_sizes)
61+ {
62+ size_t m = X.extent (0 );
63+ size_t n = X.extent (1 );
64+ auto mutual_reachability_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1 );
65+ raft::sparse::COO<value_t , value_idx, nnz_t > mutual_reachability_coo (
66+ raft::resource::get_cuda_stream (handle), min_samples * m * 2 );
67+
68+ // Replace this with mutual reachability graph cronstruction from within all_neighbors wrapper.
69+ // Reference: https://github.com/rapidsai/cuvs/issues/982
70+ cuvs::neighbors::detail::reachability::mutual_reachability_graph<value_idx, value_t , nnz_t >(
71+ handle,
72+ X.data_handle (),
73+ m,
74+ n,
75+ metric,
76+ min_samples,
77+ alpha,
78+ mutual_reachability_indptr.data_handle (),
79+ core_dists.data_handle (),
80+ mutual_reachability_coo);
81+
82+ // auto color = raft::make_device_vector<value_idx, value_idx>(handle, static_cast<value_idx>(m));
83+ rmm::device_uvector<value_idx> color (m, raft::resource::get_cuda_stream (handle));
84+ cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t >
85+ reduction_op (core_dists.data_handle (), m);
86+
87+ size_t nnz = m * min_samples;
88+
89+ detail::build_sorted_mst<value_idx, value_t >(handle,
90+ X.data_handle (),
91+ mutual_reachability_indptr.data_handle (),
92+ mutual_reachability_coo.cols (),
93+ mutual_reachability_coo.vals (),
94+ m,
95+ n,
96+ out_mst.structure_view ().get_rows ().data (),
97+ out_mst.structure_view ().get_cols ().data (),
98+ out_mst.get_elements ().data (),
99+ color.data (),
100+ mutual_reachability_coo.nnz ,
101+ reduction_op,
102+ metric,
103+ 10 );
104+
105+ /* *
106+ * Perform hierarchical labeling
107+ */
108+ size_t n_edges = m - 1 ;
109+
110+ detail::build_dendrogram_host<value_idx, value_t >(handle,
111+ out_mst.structure_view ().get_rows ().data (),
112+ out_mst.structure_view ().get_cols ().data (),
113+ out_mst.get_elements ().data (),
114+ n_edges,
115+ out_dendrogram.data_handle (),
116+ out_distances.data_handle (),
117+ out_sizes.data_handle ());
118+ }
119+
30120static const size_t EMPTY = 0 ;
31121
32122/* *
33- * Single-linkage clustering, capable of constructing a KNN graph to
34- * scale the algorithm beyond the n^2 memory consumption of implementations
35- * that use the fully-connected graph of pairwise distances by connecting
36- * a knn graph when k is not large enough to connect it.
37-
38- * @tparam value_idx
123+ * Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
124+ * Reachability space. Returns mst edges sorted by weight and the dendrogram.
39125 * @tparam value_t
126+ * @tparam value_idx
127+ * @tparam nnz_t
40128 * @tparam dist_type method to use for constructing connectivities graph
41- * @param[in] handle raft handle
42- * @param[in] X dense input matrix in row-major layout
43- * @param[in] m number of rows in X
44- * @param[in] n number of columns in X
45- * @param[in] metric distance metrix to use when constructing connectivities graph
46- * @param[out] out struct containing output dendrogram and cluster assignments
47- * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
48- control
49- * of k. The algorithm will set `k = log(n) + c`
50- * @param[in] n_clusters number of clusters to assign data samples
129+ * @param[in] handle raft handle for resource reuse
130+ * @param[in] X data points (size m * n)
131+ * @param[in] c a constant used when constructing linkage from knn graph. Allows the indirect
132+ * control of k. The algorithm will set `k = log(n) + c`
133+ * @param[in] metric distance metric to use
134+ * @param[out] out_mst output MST sorted by edge weights (size m - 1)
135+ * @param[out] out_dendrogram output dendrogram
136+ * @param[out] out_distances distances of output
137+ * @param[out] out_sizes cluster sizes of output
51138 */
52- template <typename value_idx, typename value_t , Linkage dist_type>
53- void single_linkage (raft::resources const & handle,
54- const value_t * X,
55- size_t m ,
56- size_t n ,
57- cuvs::distance::DistanceType metric ,
58- single_linkage_output <value_idx>* out ,
59- int c ,
60- size_t n_clusters )
139+ template <typename value_t , typename value_idx, typename nnz_t , Linkage dist_type>
140+ void build_dist_linkage (raft::resources const & handle,
141+ raft::device_matrix_view< const value_t , value_idx, raft::row_major> X,
142+ int c ,
143+ cuvs::distance::DistanceType metric ,
144+ raft::device_coo_matrix_view< value_t , value_idx, value_idx, nnz_t > out_mst ,
145+ raft::device_matrix_view <value_idx, value_idx> out_dendrogram ,
146+ raft::device_vector_view< value_t , value_idx> out_distances ,
147+ raft::device_vector_view<value_idx, value_idx> out_sizes )
61148{
62- ASSERT (n_clusters <= m, " n_clusters must be less than or equal to the number of data points " );
63-
149+ size_t m = X. extent ( 0 );
150+ size_t n = X. extent ( 1 );
64151 auto stream = raft::resource::get_cuda_stream (handle);
65152
66153 rmm::device_uvector<value_idx> indptr (EMPTY, stream);
@@ -70,51 +157,109 @@ void single_linkage(raft::resources const& handle,
70157 /* *
71158 * 1. Construct distance graph
72159 */
73- detail::get_distance_graph<value_idx, value_t , dist_type>(
74- handle, X, m, n, metric, indptr, indices, pw_dists, c);
75-
76- rmm::device_uvector<value_idx> mst_rows (m - 1 , stream);
77- rmm::device_uvector<value_idx> mst_cols (m - 1 , stream);
78- rmm::device_uvector<value_t > mst_data (m - 1 , stream);
160+ detail::get_distance_graph<value_idx, value_t , dist_type>(handle,
161+ X.data_handle (),
162+ static_cast <value_idx>(m),
163+ static_cast <value_idx>(n),
164+ metric,
165+ indptr,
166+ indices,
167+ pw_dists,
168+ c);
79169
80170 /* *
81171 * 2. Construct MST, sorted by weights
82172 */
83173 rmm::device_uvector<value_idx> color (m, stream);
84174 cuvs::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t > op (m);
175+
176+ size_t n_edges = m - 1 ;
177+
85178 detail::build_sorted_mst<value_idx, value_t >(handle,
86- X,
179+ X. data_handle () ,
87180 indptr.data (),
88181 indices.data (),
89182 pw_dists.data (),
90183 m,
91184 n,
92- mst_rows .data (),
93- mst_cols .data (),
94- mst_data .data (),
185+ out_mst. structure_view (). get_rows () .data (),
186+ out_mst. structure_view (). get_cols () .data (),
187+ out_mst. get_elements () .data (),
95188 color.data (),
96189 indices.size (),
97190 op,
98191 metric);
99-
100192 pw_dists.release ();
101193
102194 /* *
103195 * Perform hierarchical labeling
104196 */
105- size_t n_edges = mst_rows.size ();
106-
107- rmm::device_uvector<value_t > out_delta (n_edges, stream);
108- rmm::device_uvector<value_idx> out_size (n_edges, stream);
109- // Create dendrogram
110197 detail::build_dendrogram_host<value_idx, value_t >(handle,
111- mst_rows .data (),
112- mst_cols .data (),
113- mst_data .data (),
198+ out_mst. structure_view (). get_rows () .data (),
199+ out_mst. structure_view (). get_cols () .data (),
200+ out_mst. get_elements () .data (),
114201 n_edges,
115- out->children ,
116- out_delta.data (),
117- out_size.data ());
202+ out_dendrogram.data_handle (),
203+ out_distances.data_handle (),
204+ out_sizes.data_handle ());
205+ }
206+
207+ /* *
208+ * Single-linkage clustering, capable of constructing a KNN graph to
209+ * scale the algorithm beyond the n^2 memory consumption of implementations
210+ * that use the fully-connected graph of pairwise distances by connecting
211+ * a knn graph when k is not large enough to connect it.
212+
213+ * @tparam value_idx
214+ * @tparam value_t
215+ * @tparam dist_type method to use for constructing connectivities graph
216+ * @param[in] handle raft handle
217+ * @param[in] X dense input matrix in row-major layout
218+ * @param[in] m number of rows in X
219+ * @param[in] n number of columns in X
220+ * @param[in] metric distance metrix to use when constructing connectivities graph
221+ * @param[out] out struct containing output dendrogram and cluster assignments
222+ * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
223+ control
224+ * of k. The algorithm will set `k = log(n) + c`
225+ * @param[in] n_clusters number of clusters to assign data samples
226+ */
227+ template <typename value_idx, typename value_t , Linkage dist_type>
228+ void single_linkage (raft::resources const & handle,
229+ const value_t * X,
230+ size_t m,
231+ size_t n,
232+ cuvs::distance::DistanceType metric,
233+ single_linkage_output<value_idx>* out,
234+ int c,
235+ size_t n_clusters)
236+ {
237+ ASSERT (n_clusters <= m, " n_clusters must be less than or equal to the number of data points" );
238+
239+ value_idx n_edges = m - 1 ;
240+ auto mst_rows = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
241+ auto mst_cols = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
242+ auto mst_weights = raft::make_device_vector<value_t , value_idx>(handle, n_edges);
243+ auto structure_view =
244+ raft::make_device_coordinate_structure_view<value_idx, value_idx, value_idx>(
245+ mst_rows.data_handle (), mst_cols.data_handle (), m, m, n_edges);
246+ auto mst_view = raft::make_device_coo_matrix_view<value_t , value_idx, value_idx, value_idx>(
247+ mst_weights.data_handle (), structure_view);
248+
249+ auto out_delta = raft::make_device_vector<value_t , value_idx>(handle, n_edges);
250+ auto out_sizes = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
251+
252+ build_dist_linkage<value_t , value_idx, value_idx, dist_type>(
253+ handle,
254+ raft::make_device_matrix_view<const value_t , value_idx, raft::row_major>(
255+ X, static_cast <value_idx>(m), static_cast <value_idx>(n)),
256+ c,
257+ metric,
258+ mst_view,
259+ raft::make_device_matrix_view<value_idx, value_idx, raft::row_major>(out->children , n_edges, 2 ),
260+ out_delta.view (),
261+ out_sizes.view ());
262+
118263 detail::extract_flattened_clusters (handle, out->labels , out->children , n_clusters, m);
119264
120265 out->m = m;
0 commit comments