Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -640,18 +640,14 @@ if(NOT BUILD_CPU_ONLY)
src/cluster/kmeans_fit_double.cu
src/cluster/kmeans_fit_float.cu
src/cluster/kmeans_auto_find_k_float.cu
src/cluster/kmeans_fit_predict_double.cu
src/cluster/kmeans_fit_predict_float.cu
src/cluster/kmeans_predict_double.cu
src/cluster/kmeans_predict_float.cu
src/cluster/kmeans_balanced_fit_float.cu
src/cluster/kmeans_balanced_fit_half.cu
src/cluster/kmeans_balanced_fit_predict_float.cu
src/cluster/kmeans_balanced_predict_float.cu
src/cluster/kmeans_balanced_predict_half.cu
src/cluster/kmeans_balanced_fit_int8.cu
src/cluster/kmeans_balanced_fit_uint8.cu
src/cluster/kmeans_balanced_fit_predict_int8.cu
src/cluster/kmeans_balanced_predict_float.cu
src/cluster/kmeans_balanced_predict_half.cu
src/cluster/kmeans_balanced_predict_int8.cu
src/cluster/kmeans_balanced_predict_uint8.cu
src/cluster/kmeans_transform_double.cu
Expand Down
353 changes: 36 additions & 317 deletions cpp/include/cuvs/cluster/kmeans.hpp

Large diffs are not rendered by default.

28 changes: 25 additions & 3 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ void kmeans_fit_main(raft::resources const& handle,
raft::device_matrix_view<DataT, IndexT> centroidsRawData,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter,
rmm::device_uvector<char>& workspace)
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<IndexT, IndexT>> labels = std::nullopt)
{
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("kmeans_fit_main");
raft::default_logger().set_level(params.verbosity);
Expand Down Expand Up @@ -443,6 +444,13 @@ void kmeans_fit_main(raft::resources const& handle,
inertia,
std::make_optional(weight));

if (labels.has_value()) {
raft::linalg::map(handle,
labels.value(),
raft::key_op{},
raft::make_const_mdspan(minClusterAndDistance.view()));
}

RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ",
n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0],
inertia[0]);
Expand Down Expand Up @@ -727,7 +735,8 @@ void kmeans_fit(raft::resources const& handle,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
raft::host_scalar_view<IndexT> n_iter,
std::optional<raft::device_vector_view<IndexT, IndexT>> labels = std::nullopt)
{
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("kmeans_fit");
auto n_samples = X.extent(0);
Expand Down Expand Up @@ -794,6 +803,15 @@ void kmeans_fit(raft::resources const& handle,
std::mt19937 gen(pams.rng_state.seed);
inertia[0] = std::numeric_limits<DataT>::max();

std::optional<raft::device_vector<IndexT, IndexT>> labels_iter;
std::optional<raft::device_vector_view<IndexT, IndexT>> labels_iter_view;
if (labels.has_value() && n_init > 1) {
labels_iter = raft::make_device_vector<IndexT, IndexT>(handle, n_samples);
labels_iter_view = std::make_optional(labels_iter->view());
} else if (labels.has_value()) {
labels_iter_view = std::make_optional(labels.value());
}

for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) {
cuvs::cluster::kmeans::params iter_params = pams;
iter_params.rng_state.seed = gen();
Expand Down Expand Up @@ -845,14 +863,18 @@ void kmeans_fit(raft::resources const& handle,
centroidsRawData.view(),
raft::make_host_scalar_view<DataT>(&iter_inertia),
raft::make_host_scalar_view<IndexT>(&n_current_iter),
workspace);
workspace,
labels_iter_view);
if (iter_inertia < inertia[0]) {
inertia[0] = iter_inertia;
n_iter[0] = n_current_iter;
raft::copy(
handle,
raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features),
raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features));
if (labels.has_value() && n_init > 1) {
raft::copy(handle, labels.value(), labels_iter_view.value());
}
}
RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d",
seed_iter + 1,
Expand Down
22 changes: 11 additions & 11 deletions cpp/src/cluster/detail/kmeans_auto_find_k.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -45,8 +45,8 @@ void compute_dispersion(raft::resources const& handle,

params.n_clusters = val;

cuvs::cluster::kmeans::fit_predict(
handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter);
cuvs::cluster::kmeans::fit(
handle, params, X, std::nullopt, centroids_view, residual, n_iter, std::make_optional(labels));

detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace);

Expand Down Expand Up @@ -212,14 +212,14 @@ void find_k(raft::resources const& handle,
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), best_k[0], d);

params.n_clusters = best_k[0];
cuvs::cluster::kmeans::fit_predict(handle,
params,
X,
std::nullopt,
std::make_optional(centroids_view),
labels.view(),
residual,
n_iter);
cuvs::cluster::kmeans::fit(handle,
params,
X,
std::nullopt,
centroids_view,
residual,
n_iter,
std::make_optional(labels.view()));
}
}
} // namespace cuvs::cluster::kmeans::detail
16 changes: 13 additions & 3 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,8 @@ auto build_fine_clusters(const raft::resources& handle,
* @param[out] inertia (optional) If non-null, the sum of squared distances of samples to
* their closest cluster center is written here.
* Only supported when T == MathT (float/double).
* @param[out] labels (optional) If non-null, the labels of the clusters are returned here.
* [dim = n_rows]
*/
template <typename T, typename MathT, typename IdxT, typename MappingOpT>
void build_hierarchical(const raft::resources& handle,
Expand All @@ -1025,7 +1027,8 @@ void build_hierarchical(const raft::resources& handle,
MathT* cluster_centers,
IdxT n_clusters,
MappingOpT mapping_op,
MathT* inertia = nullptr)
MathT* inertia = nullptr,
uint32_t* labels_ret = nullptr)
{
auto stream = raft::resource::get_cuda_stream(handle);
using LabelT = uint32_t;
Expand Down Expand Up @@ -1141,7 +1144,14 @@ void build_hierarchical(const raft::resources& handle,
RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters.");

rmm::device_uvector<CounterT> cluster_sizes(n_clusters, stream, device_memory);
rmm::device_uvector<LabelT> labels(n_rows, stream, device_memory);
std::optional<rmm::device_uvector<LabelT>> labels_buf = std::nullopt;
LabelT* labels_ptr = nullptr;
if (labels_ret == nullptr) {
labels_buf = rmm::device_uvector<LabelT>(n_rows, stream, device_memory);
labels_ptr = labels_buf.value().data();
} else {
labels_ptr = labels_ret;
}

// Fine-tuning k-means for all clusters
//
Expand All @@ -1159,7 +1169,7 @@ void build_hierarchical(const raft::resources& handle,
n_rows,
n_clusters,
cluster_centers,
labels.data(),
labels_ptr,
cluster_sizes.data(),
5,
MathT{0.2},
Expand Down
10 changes: 9 additions & 1 deletion cpp/src/cluster/detail/kmeans_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,8 @@ void fit(const raft::resources& handle,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter,
rmm::device_uvector<char>& workspace)
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<IndexT, IndexT>> labels = std::nullopt)
{
const auto& comm = raft::resource::get_comms(handle);
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -745,6 +746,13 @@ void fit(const raft::resources& handle,
priorClusteringCost = curClusteringCost;
}

if (labels.has_value()) {
raft::linalg::map(handle,
labels.value(),
raft::key_op{},
raft::make_const_mdspan(minClusterAndDistance.view()));
}

raft::resource::sync_stream(handle, stream);
if (sqrdNormError < params.tol) done = true;

Expand Down
20 changes: 11 additions & 9 deletions cpp/src/cluster/detail/spectral.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -51,14 +51,16 @@ void fit_predict(raft::resources const& handle,
config.n_components,
raft::resource::get_cuda_stream(handle));

cuvs::cluster::kmeans::fit_predict(handle,
kmeans_config,
embedding_row_major.view(),
std::nullopt,
std::nullopt,
labels,
raft::make_host_scalar_view(&inertia),
raft::make_host_scalar_view(&n_iter));
auto centroids =
raft::make_device_matrix<DataT, int>(handle, kmeans_config.n_clusters, config.n_components);
cuvs::cluster::kmeans::fit(handle,
kmeans_config,
embedding_row_major.view(),
std::nullopt,
centroids.view(),
raft::make_host_scalar_view(&inertia),
raft::make_host_scalar_view(&n_iter),
labels);
}

void fit_predict(raft::resources const& handle,
Expand Down
149 changes: 0 additions & 149 deletions cpp/src/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,155 +84,6 @@ EXTERN_TEMPLATE_FIT_MAIN(float, int64_t)
EXTERN_TEMPLATE_FIT_MAIN(float, int)

#undef EXTERN_TEMPLATE_FIT_MAIN
/**
* @brief Find clusters with k-means algorithm.
* Initial centroids are chosen with k-means++ algorithm. Empty
* clusters are reinitialized by choosing new centroids with
* k-means++ algorithm.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.cuh>
* #include <cuvs/cluster/kmeans_types.hpp>
* using namespace cuvs::cluster;
* ...
* raft::resources handle;
* cuvs::cluster::kmeans::params params;
* int n_features = 15, inertia, n_iter;
* auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);
*
* kmeans::fit(handle,
* params,
* X,
* std::nullopt,
* centroids,
* raft::make_scalar_view(&inertia),
* raft::make_scalar_view(&n_iter));
* @endcode
*
* @tparam DataT the type of data used for weights, distances.
* @tparam IndexT the type of data used for indexing.
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X.
* [len = n_samples]
* @param[inout] centroids [in] When init is InitMethod::Array, use
* centroids as the initial cluster centers.
* [out] The generated centroids from the
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
* @param[out] n_iter Number of iterations run.
*/
template <typename DataT, typename IndexT>
void fit(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter);

#define EXTERN_TEMPLATE_FIT(DataT, IndexT) \
extern template void fit<DataT, IndexT>( \
raft::resources const& handle, \
const kmeans::params& params, \
raft::device_matrix_view<const DataT, IndexT> X, \
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight, \
raft::device_matrix_view<DataT, IndexT> centroids, \
raft::host_scalar_view<DataT> inertia, \
raft::host_scalar_view<IndexT> n_iter);

EXTERN_TEMPLATE_FIT(double, int)
EXTERN_TEMPLATE_FIT(double, int64_t)
EXTERN_TEMPLATE_FIT(float, int)
EXTERN_TEMPLATE_FIT(float, int64_t)

#undef EXTERN_TEMPLATE_FIT
/**
* @brief Predict the closest cluster each sample in X belongs to.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.cuh>
* #include <cuvs/cluster/kmeans_types.hpp>
* using namespace cuvs::cluster;
* ...
* raft::resources handle;
* cuvs::cluster::kmeans::params params;
* int n_features = 15, inertia, n_iter;
* auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);
*
* kmeans::fit(handle,
* params,
* X,
* std::nullopt,
* centroids.view(),
* raft::make_scalar_view(&inertia),
* raft::make_scalar_view(&n_iter));
* ...
* auto labels = raft::make_device_vector<int, int>(handle, X.extent(0));
*
* kmeans::predict(handle,
* params,
* X,
* std::nullopt,
* centroids.view(),
* false,
* labels.view(),
* raft::make_scalar_view(&ineratia));
* @endcode
*
* @tparam DataT the type of data used for weights, distances.
* @tparam IndexT the type of data used for indexing.
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X New data to predict.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X.
* [len = n_samples]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[in] normalize_weight True if the weights should be normalized
* @param[out] labels Index of the cluster each sample in X
* belongs to.
* [len = n_samples]
* @param[out] inertia Sum of squared distances of samples to
* their closest cluster center.
*/
template <typename DataT, typename IndexT>
void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_vector_view<IndexT, IndexT> labels,
bool normalize_weight,
raft::host_scalar_view<DataT> inertia);

#define EXTERN_TEMPLATE_PREDICT(DataT, IndexT) \
extern template void predict<DataT, IndexT>( \
raft::resources const& handle, \
const kmeans::params& params, \
raft::device_matrix_view<const DataT, IndexT> X, \
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight, \
raft::device_matrix_view<const DataT, IndexT> centroids, \
raft::device_vector_view<IndexT, IndexT> labels, \
bool normalize_weight, \
raft::host_scalar_view<DataT> inertia);

EXTERN_TEMPLATE_PREDICT(double, int)
EXTERN_TEMPLATE_PREDICT(double, int64_t)
EXTERN_TEMPLATE_PREDICT(float, int)
EXTERN_TEMPLATE_PREDICT(float, int64_t)

#undef EXTERN_TEMPLATE_PREDICT

/**
* @brief Transform X to a cluster-distance space.
Expand Down
Loading
Loading