Skip to content

[REVIEW] Add KMeans.fit_predict to Python API#1956

Open
jrbourbeau wants to merge 2 commits intorapidsai:mainfrom
jrbourbeau:kmeans_fit_predict_python
Open

[REVIEW] Add KMeans.fit_predict to Python API#1956
jrbourbeau wants to merge 2 commits intorapidsai:mainfrom
jrbourbeau:kmeans_fit_predict_python

Conversation

@jrbourbeau
Copy link
Copy Markdown
Member

This PR handles the Python portion of #1944

Signed-off-by: James Bourbeau <jbourbeau@nvidia.com>
…it_predict_python

Signed-off-by: James Bourbeau <jbourbeau@nvidia.com>
Comment on lines +307 to +323
centroids_out, _, n_iter = fit(
params,
X,
centroids=centroids,
sample_weights=sample_weights,
resources=resources,
)
labels_out, inertia_pred = predict(
params,
X,
centroids_out,
sample_weights=sample_weights,
labels=labels,
normalize_weight=normalize_weight,
resources=resources,
)
return FitPredictOutput(labels_out, centroids_out, inertia_pred, n_iter)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need to call the missing C function "cuvsKMeansFitPredict" because #1939 is adding some improvements to the fit_predict function that make it run faster than calling independantly both functions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing me to #1939 @lowener. Do the changes in that PR break the code here? I'd like to keep the scope of this PR to just adding the Python API. Seems like maybe this could be updated to a more optimized version after the C API is added

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jrbourbeau, given that we're likely to merge @lowener's PR first in sequence, could you base your branch on his and build from his changes? Then we can merge yours in shortly after.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI the latest update on #1939 is that the C++ fit_predict() function will be removed and the labels will be returned as part of the fit() function as an optionnal output

@cjnolet cjnolet changed the base branch from main to release/26.04 March 25, 2026 18:39
@cjnolet cjnolet requested review from a team as code owners March 25, 2026 18:39
@cjnolet cjnolet requested a review from bdice March 25, 2026 18:39
@tfeher tfeher changed the base branch from release/26.04 to main April 1, 2026 18:33
@aamijar aamijar added non-breaking Introduces a non-breaking change improvement Improves an existing functionality labels Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants