From d877a928230b380fa5aa163f6d477ce21575b5b6 Mon Sep 17 00:00:00 2001 From: aamijar Date: Tue, 26 May 2026 00:47:49 +0000 Subject: [PATCH 1/2] pca-row-major --- c/include/cuvs/preprocessing/pca.h | 38 +++-- c/src/preprocessing/pca.cpp | 125 ++++++++++------ cpp/include/cuvs/preprocessing/pca.hpp | 105 +++++++++++++- cpp/src/preprocessing/pca/detail/pca.cuh | 30 ++-- cpp/src/preprocessing/pca/pca.cu | 120 ++++++++-------- cpp/tests/preprocessing/pca.cu | 158 +++++++++++++++++++-- python/cuvs/cuvs/preprocessing/pca/pca.pyx | 115 ++++++++++----- python/cuvs/cuvs/tests/test_pca.py | 114 +++++++++++++-- 8 files changed, 612 insertions(+), 193 deletions(-) diff --git a/c/include/cuvs/preprocessing/pca.h b/c/include/cuvs/preprocessing/pca.h index 3ed144d166..44818afece 100644 --- a/c/include/cuvs/preprocessing/pca.h +++ b/c/include/cuvs/preprocessing/pca.h @@ -85,6 +85,9 @@ CUVS_EXPORT cuvsError_t cuvsPcaParamsDestroy(cuvsPcaParams_t params); * Computes the principal components, explained variances, singular values, and column means * from the input data. * + * The layout of `input` (C-contiguous / row-major or F-contiguous / col-major) is detected + * from its DLPack strides; `components` must use the same layout as `input`. + * * @code {.c} * #include * #include @@ -98,9 +101,9 @@ CUVS_EXPORT cuvsError_t cuvsPcaParamsDestroy(cuvsPcaParams_t params); * cuvsPcaParamsCreate(¶ms); * params->n_components = 2; * - * // Assume populated DLManagedTensor objects (col-major, float32, device memory) - * DLManagedTensor input; // [n_rows x n_cols] - * DLManagedTensor components; // [n_components x n_cols] + * // Assume populated DLManagedTensor objects (float32, device memory) + * DLManagedTensor input; // [n_rows x n_cols] (C- or F-contiguous) + * DLManagedTensor components; // [n_components x n_cols] (same layout as input) * DLManagedTensor explained_var; // [n_components] * DLManagedTensor explained_var_ratio; // [n_components] * DLManagedTensor singular_vals; // [n_components] @@ -117,8 +120,8 @@ CUVS_EXPORT cuvsError_t cuvsPcaParamsDestroy(cuvsPcaParams_t params); * * @param[in] res cuvsResources_t opaque C handle * @param[in] params PCA parameters - * @param[inout] input input data [n_rows x n_cols] (col-major, float32, device) - * @param[out] components principal components [n_components x n_cols] (col-major, float32, device) + * @param[inout] input input data [n_rows x n_cols] (C- or F-contiguous, float32, device) + * @param[out] components principal components [n_components x n_cols] (same layout as input) * @param[out] explained_var explained variances [n_components] (float32, device) * @param[out] explained_var_ratio explained variance ratios [n_components] (float32, device) * @param[out] singular_vals singular values [n_components] (float32, device) @@ -142,12 +145,14 @@ CUVS_EXPORT cuvsError_t cuvsPcaFit(cuvsResources_t res, * @brief Perform PCA fit and transform in a single operation. * * Computes the principal components and transforms the input data into the eigenspace. + * The layout of `input` (C- or F-contiguous) is detected from its DLPack strides; all + * other matrix tensors must use the same layout. * * @param[in] res cuvsResources_t opaque C handle * @param[in] params PCA parameters - * @param[inout] input input data [n_rows x n_cols] (col-major, float32, device) - * @param[out] trans_input transformed data [n_rows x n_components] (col-major, float32, device) - * @param[out] components principal components [n_components x n_cols] (col-major, float32, device) + * @param[inout] input input data [n_rows x n_cols] (C- or F-contiguous, float32, device) + * @param[out] trans_input transformed data [n_rows x n_components] (same layout as input) + * @param[out] components principal components [n_components x n_cols] (same layout as input) * @param[out] explained_var explained variances [n_components] (float32, device) * @param[out] explained_var_ratio explained variance ratios [n_components] (float32, device) * @param[out] singular_vals singular values [n_components] (float32, device) @@ -172,14 +177,16 @@ CUVS_EXPORT cuvsError_t cuvsPcaFitTransform(cuvsResources_t res, * @brief Perform PCA transform operation. * * Transforms the input data into the eigenspace using previously computed principal components. + * The layout of `input` (C- or F-contiguous) is detected from its DLPack strides; all other + * matrix tensors must use the same layout. * * @param[in] res cuvsResources_t opaque C handle * @param[in] params PCA parameters - * @param[inout] input data to transform [n_rows x n_cols] (col-major, float32, device) - * @param[in] components principal components [n_components x n_cols] (col-major, float32, device) + * @param[inout] input data to transform [n_rows x n_cols] (C- or F-contiguous, float32, device) + * @param[in] components principal components [n_components x n_cols] (same layout as input) * @param[in] singular_vals singular values [n_components] (float32, device) * @param[in] mu column means [n_cols] (float32, device) - * @param[out] trans_input transformed data [n_rows x n_components] (col-major, float32, device) + * @param[out] trans_input transformed data [n_rows x n_components] (same layout as input) * @return cuvsError_t */ CUVS_EXPORT cuvsError_t cuvsPcaTransform(cuvsResources_t res, @@ -194,14 +201,17 @@ CUVS_EXPORT cuvsError_t cuvsPcaTransform(cuvsResources_t res, * @brief Perform PCA inverse transform operation. * * Transforms data from the eigenspace back to the original space. + * The layout of `trans_input` (C- or F-contiguous) is detected from its DLPack strides; + * all other matrix tensors must use the same layout. * * @param[in] res cuvsResources_t opaque C handle * @param[in] params PCA parameters - * @param[in] trans_input transformed data [n_rows x n_components] (col-major, float32, device) - * @param[in] components principal components [n_components x n_cols] (col-major, float32, device) + * @param[in] trans_input transformed data [n_rows x n_components] (C- or F-contiguous, + * float32, device) + * @param[in] components principal components [n_components x n_cols] (same layout as trans_input) * @param[in] singular_vals singular values [n_components] (float32, device) * @param[in] mu column means [n_cols] (float32, device) - * @param[out] output reconstructed data [n_rows x n_cols] (col-major, float32, device) + * @param[out] output reconstructed data [n_rows x n_cols] (same layout as trans_input) * @return cuvsError_t */ CUVS_EXPORT cuvsError_t cuvsPcaInverseTransform(cuvsResources_t res, diff --git a/c/src/preprocessing/pca.cpp b/c/src/preprocessing/pca.cpp index 64acfeb30b..30dd98fb64 100644 --- a/c/src/preprocessing/pca.cpp +++ b/c/src/preprocessing/pca.cpp @@ -40,6 +40,7 @@ cuvs::preprocessing::pca::params to_cpp_params(const cuvsPcaParams& c_params) return cpp_params; } +template void _fit(cuvsResources_t res, const cuvsPcaParams& params, DLManagedTensor* input_tensor, @@ -54,12 +55,12 @@ void _fit(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); auto cpp_params = to_cpp_params(params); - using matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; using vector_type = raft::device_vector_view; using scalar_type = raft::device_scalar_view; - auto input = cuvs::core::from_dlpack(input_tensor); - auto components = cuvs::core::from_dlpack(components_tensor); + auto input = cuvs::core::from_dlpack(input_tensor); + auto components = cuvs::core::from_dlpack(components_tensor); auto explained_var = cuvs::core::from_dlpack(explained_var_tensor); auto explained_var_ratio = cuvs::core::from_dlpack(explained_var_ratio_tensor); auto singular_vals = cuvs::core::from_dlpack(singular_vals_tensor); @@ -78,6 +79,7 @@ void _fit(cuvsResources_t res, flip_signs_based_on_U); } +template void _fit_transform(cuvsResources_t res, const cuvsPcaParams& params, DLManagedTensor* input_tensor, @@ -93,13 +95,13 @@ void _fit_transform(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); auto cpp_params = to_cpp_params(params); - using matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; using vector_type = raft::device_vector_view; using scalar_type = raft::device_scalar_view; - auto input = cuvs::core::from_dlpack(input_tensor); - auto trans_input = cuvs::core::from_dlpack(trans_input_tensor); - auto components = cuvs::core::from_dlpack(components_tensor); + auto input = cuvs::core::from_dlpack(input_tensor); + auto trans_input = cuvs::core::from_dlpack(trans_input_tensor); + auto components = cuvs::core::from_dlpack(components_tensor); auto explained_var = cuvs::core::from_dlpack(explained_var_tensor); auto explained_var_ratio = cuvs::core::from_dlpack(explained_var_ratio_tensor); auto singular_vals = cuvs::core::from_dlpack(singular_vals_tensor); @@ -119,6 +121,7 @@ void _fit_transform(cuvsResources_t res, flip_signs_based_on_U); } +template void _transform(cuvsResources_t res, const cuvsPcaParams& params, DLManagedTensor* input_tensor, @@ -130,7 +133,7 @@ void _transform(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); auto cpp_params = to_cpp_params(params); - using matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; using vector_type = raft::device_vector_view; auto input = cuvs::core::from_dlpack(input_tensor); @@ -143,6 +146,7 @@ void _transform(cuvsResources_t res, *res_ptr, cpp_params, input, components, singular_vals, mu, trans_input); } +template void _inverse_transform(cuvsResources_t res, const cuvsPcaParams& params, DLManagedTensor* trans_input_tensor, @@ -154,7 +158,7 @@ void _inverse_transform(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); auto cpp_params = to_cpp_params(params); - using matrix_type = raft::device_matrix_view; + using matrix_type = raft::device_matrix_view; using vector_type = raft::device_vector_view; auto trans_input = cuvs::core::from_dlpack(trans_input_tensor); @@ -205,19 +209,32 @@ extern "C" cuvsError_t cuvsPcaFit(cuvsResources_t res, "PCA input must be float32 (kDLFloat, 32 bits)"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(input->dl_tensor), "PCA input must be device-accessible memory"); - RAFT_EXPECTS(cuvs::core::is_f_contiguous(input), - "PCA input must be col-major (Fortran-contiguous)"); - - _fit(res, - *params, - input, - components, - explained_var, - explained_var_ratio, - singular_vals, - mu, - noise_vars, - flip_signs_based_on_U); + + if (cuvs::core::is_f_contiguous(input)) { + _fit(res, + *params, + input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); + } else if (cuvs::core::is_c_contiguous(input)) { + _fit(res, + *params, + input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); + } else { + RAFT_FAIL("PCA input must be contiguous (C- or F-order)"); + } }); } @@ -239,20 +256,34 @@ extern "C" cuvsError_t cuvsPcaFitTransform(cuvsResources_t res, "PCA input must be float32 (kDLFloat, 32 bits)"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(input->dl_tensor), "PCA input must be device-accessible memory"); - RAFT_EXPECTS(cuvs::core::is_f_contiguous(input), - "PCA input must be col-major (Fortran-contiguous)"); - - _fit_transform(res, - *params, - input, - trans_input, - components, - explained_var, - explained_var_ratio, - singular_vals, - mu, - noise_vars, - flip_signs_based_on_U); + + if (cuvs::core::is_f_contiguous(input)) { + _fit_transform(res, + *params, + input, + trans_input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); + } else if (cuvs::core::is_c_contiguous(input)) { + _fit_transform(res, + *params, + input, + trans_input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); + } else { + RAFT_FAIL("PCA input must be contiguous (C- or F-order)"); + } }); } @@ -270,10 +301,14 @@ extern "C" cuvsError_t cuvsPcaTransform(cuvsResources_t res, "PCA input must be float32 (kDLFloat, 32 bits)"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(input->dl_tensor), "PCA input must be device-accessible memory"); - RAFT_EXPECTS(cuvs::core::is_f_contiguous(input), - "PCA input must be col-major (Fortran-contiguous)"); - _transform(res, *params, input, components, singular_vals, mu, trans_input); + if (cuvs::core::is_f_contiguous(input)) { + _transform(res, *params, input, components, singular_vals, mu, trans_input); + } else if (cuvs::core::is_c_contiguous(input)) { + _transform(res, *params, input, components, singular_vals, mu, trans_input); + } else { + RAFT_FAIL("PCA input must be contiguous (C- or F-order)"); + } }); } @@ -291,9 +326,15 @@ extern "C" cuvsError_t cuvsPcaInverseTransform(cuvsResources_t res, "PCA trans_input must be float32 (kDLFloat, 32 bits)"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(trans_input->dl_tensor), "PCA trans_input must be device-accessible memory"); - RAFT_EXPECTS(cuvs::core::is_f_contiguous(trans_input), - "PCA trans_input must be col-major (Fortran-contiguous)"); - _inverse_transform(res, *params, trans_input, components, singular_vals, mu, output); + if (cuvs::core::is_f_contiguous(trans_input)) { + _inverse_transform( + res, *params, trans_input, components, singular_vals, mu, output); + } else if (cuvs::core::is_c_contiguous(trans_input)) { + _inverse_transform( + res, *params, trans_input, components, singular_vals, mu, output); + } else { + RAFT_FAIL("PCA trans_input must be contiguous (C- or F-order)"); + } }); } diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp index 361b3f0075..29259994e4 100644 --- a/cpp/include/cuvs/preprocessing/pca.hpp +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -58,7 +58,7 @@ struct params { */ /** - * @brief Perform PCA fit operation. + * @brief Perform PCA fit operation (col-major input). * * Computes the principal components, explained variances, singular values, and column means * from the input data. @@ -111,7 +111,36 @@ void fit(raft::resources const& handle, bool flip_signs_based_on_U = false); /** - * @brief Perform PCA fit and transform operations. + * @brief Perform PCA fit operation (row-major input). + * + * Same as the col-major overload, but operates natively on row-major (C-contiguous) data + * with no internal copy/transpose of the input. The output `components` matrix is also + * row-major. + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input input data [n_rows x n_cols] (row-major). Modified temporarily. + * @param[out] components principal components [n_components x n_cols] (row-major) + * @param[out] explained_var explained variances [n_components] + * @param[out] explained_var_ratio explained variance ratios [n_components] + * @param[out] singular_vals singular values [n_components] + * @param[out] mu column means [n_cols] + * @param[out] noise_vars noise variance (scalar) + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +void fit(raft::resources const& handle, + const params& config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +/** + * @brief Perform PCA fit and transform operations (col-major). * * Computes the principal components and transforms the input data into the eigenspace * in a single operation. @@ -141,7 +170,36 @@ void fit_transform(raft::resources const& handle, bool flip_signs_based_on_U = false); /** - * @brief Perform PCA transform operation. + * @brief Perform PCA fit and transform operations (row-major). + * + * Same as the col-major overload but operates natively on row-major data. + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input input data [n_rows x n_cols] (row-major). Modified temporarily. + * @param[out] trans_input transformed data [n_rows x n_components] (row-major) + * @param[out] components principal components [n_components x n_cols] (row-major) + * @param[out] explained_var explained variances [n_components] + * @param[out] explained_var_ratio explained variance ratios [n_components] + * @param[out] singular_vals singular values [n_components] + * @param[out] mu column means [n_cols] + * @param[out] noise_vars noise variance (scalar) + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +void fit_transform(raft::resources const& handle, + const params& config, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +/** + * @brief Perform PCA transform operation (col-major). * * Transforms the input data into the eigenspace using previously computed principal components. * @@ -163,7 +221,27 @@ void transform(raft::resources const& handle, raft::device_matrix_view trans_input); /** - * @brief Perform PCA inverse transform operation. + * @brief Perform PCA transform operation (row-major). + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input data to transform [n_rows x n_cols] (row-major). Modified temporarily + * (mean-centered then restored). + * @param[in] components principal components [n_components x n_cols] (row-major) + * @param[in] singular_vals singular values [n_components] + * @param[in] mu column means [n_cols] + * @param[out] trans_input transformed data [n_rows x n_components] (row-major) + */ +void transform(raft::resources const& handle, + const params& config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input); + +/** + * @brief Perform PCA inverse transform operation (col-major). * * Transforms data from the eigenspace back to the original space. * @@ -183,6 +261,25 @@ void inverse_transform(raft::resources const& handle, raft::device_vector_view mu, raft::device_matrix_view output); +/** + * @brief Perform PCA inverse transform operation (row-major). + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[in] trans_input transformed data [n_rows x n_components] (row-major) + * @param[in] components principal components [n_components x n_cols] (row-major) + * @param[in] singular_vals singular values [n_components] + * @param[in] mu column means [n_cols] + * @param[out] output reconstructed data [n_rows x n_cols] (row-major) + */ +void inverse_transform(raft::resources const& handle, + const params& config, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output); + /** @} */ // end group pca } // namespace pca diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh index 8cd9e38a72..df73326581 100644 --- a/cpp/src/preprocessing/pca/detail/pca.cuh +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -27,11 +27,11 @@ inline auto to_raft_params(const params& config) -> raft::linalg::paramsPCA return prms; } -template +template void fit(raft::resources const& handle, const params& config, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -52,12 +52,12 @@ void fit(raft::resources const& handle, flip_signs_based_on_U); } -template +template void fit_transform(raft::resources const& handle, const params& config, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -79,27 +79,27 @@ void fit_transform(raft::resources const& handle, flip_signs_based_on_U); } -template +template void transform(raft::resources const& handle, const params& config, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view trans_input) { auto raft_prms = to_raft_params(config); raft::linalg::pca_transform(handle, raft_prms, input, components, singular_vals, mu, trans_input); } -template +template void inverse_transform(raft::resources const& handle, const params& config, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view output) { auto raft_prms = to_raft_params(config); raft::linalg::pca_inverse_transform( diff --git a/cpp/src/preprocessing/pca/pca.cu b/cpp/src/preprocessing/pca/pca.cu index ac944fddd7..2b18ee77b4 100644 --- a/cpp/src/preprocessing/pca/pca.cu +++ b/cpp/src/preprocessing/pca/pca.cu @@ -9,90 +9,94 @@ namespace cuvs::preprocessing::pca { -#define CUVS_INST_PCA_FIT(DataT, IndexT) \ - void fit(raft::resources const& handle, \ - const params& config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view components, \ - raft::device_vector_view explained_var, \ - raft::device_vector_view explained_var_ratio, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_scalar_view noise_vars, \ - bool flip_signs_based_on_U) \ - { \ - detail::fit(handle, \ - config, \ - input, \ - components, \ - explained_var, \ - explained_var_ratio, \ - singular_vals, \ - mu, \ - noise_vars, \ - flip_signs_based_on_U); \ +#define CUVS_INST_PCA_FIT(DataT, IndexT, LayoutT) \ + void fit(raft::resources const& handle, \ + const params& config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit(handle, \ + config, \ + input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ } -CUVS_INST_PCA_FIT(float, int64_t); +CUVS_INST_PCA_FIT(float, int64_t, raft::col_major); +CUVS_INST_PCA_FIT(float, int64_t, raft::row_major); #undef CUVS_INST_PCA_FIT -#define CUVS_INST_PCA_FIT_TRANSFORM(DataT, IndexT) \ - void fit_transform(raft::resources const& handle, \ - const params& config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view trans_input, \ - raft::device_matrix_view components, \ - raft::device_vector_view explained_var, \ - raft::device_vector_view explained_var_ratio, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_scalar_view noise_vars, \ - bool flip_signs_based_on_U) \ - { \ - detail::fit_transform(handle, \ - config, \ - input, \ - trans_input, \ - components, \ - explained_var, \ - explained_var_ratio, \ - singular_vals, \ - mu, \ - noise_vars, \ - flip_signs_based_on_U); \ +#define CUVS_INST_PCA_FIT_TRANSFORM(DataT, IndexT, LayoutT) \ + void fit_transform(raft::resources const& handle, \ + const params& config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit_transform(handle, \ + config, \ + input, \ + trans_input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ } -CUVS_INST_PCA_FIT_TRANSFORM(float, int64_t); +CUVS_INST_PCA_FIT_TRANSFORM(float, int64_t, raft::col_major); +CUVS_INST_PCA_FIT_TRANSFORM(float, int64_t, raft::row_major); #undef CUVS_INST_PCA_FIT_TRANSFORM -#define CUVS_INST_PCA_TRANSFORM(DataT, IndexT) \ +#define CUVS_INST_PCA_TRANSFORM(DataT, IndexT, LayoutT) \ void transform(raft::resources const& handle, \ const params& config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view components, \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ raft::device_vector_view singular_vals, \ raft::device_vector_view mu, \ - raft::device_matrix_view trans_input) \ + raft::device_matrix_view trans_input) \ { \ detail::transform(handle, config, input, components, singular_vals, mu, trans_input); \ } -CUVS_INST_PCA_TRANSFORM(float, int64_t); +CUVS_INST_PCA_TRANSFORM(float, int64_t, raft::col_major); +CUVS_INST_PCA_TRANSFORM(float, int64_t, raft::row_major); #undef CUVS_INST_PCA_TRANSFORM -#define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT, IndexT) \ +#define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT, IndexT, LayoutT) \ void inverse_transform(raft::resources const& handle, \ const params& config, \ - raft::device_matrix_view trans_input, \ - raft::device_matrix_view components, \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ raft::device_vector_view singular_vals, \ raft::device_vector_view mu, \ - raft::device_matrix_view output) \ + raft::device_matrix_view output) \ { \ detail::inverse_transform(handle, config, trans_input, components, singular_vals, mu, output); \ } -CUVS_INST_PCA_INVERSE_TRANSFORM(float, int64_t); +CUVS_INST_PCA_INVERSE_TRANSFORM(float, int64_t, raft::col_major); +CUVS_INST_PCA_INVERSE_TRANSFORM(float, int64_t, raft::row_major); #undef CUVS_INST_PCA_INVERSE_TRANSFORM } // namespace cuvs::preprocessing::pca diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index 4430972911..6b99d80b67 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -42,11 +42,12 @@ template /** * @brief Run fit_transform followed by inverse_transform. * + * Templated on layout to exercise both col-major and row-major paths. * Intermediate buffers are managed internally unless the caller provides * pre-allocated pointers via the optional parameters, in which case the * results are written there directly. */ -template +template void pca_roundtrip(raft::resources const& handle, T* input, int n_rows, @@ -79,19 +80,17 @@ void pca_roundtrip(raft::resources const& handle, rmm::device_uvector mu(n_cols, stream); rmm::device_uvector nv(1, stream); - auto input_view = - raft::make_device_matrix_view(input, n_rows, n_cols); + auto input_view = raft::make_device_matrix_view(input, n_rows, n_cols); auto trans_view = - raft::make_device_matrix_view(trans_ptr, n_rows, n_components); + raft::make_device_matrix_view(trans_ptr, n_rows, n_components); auto comp_view = - raft::make_device_matrix_view(comp_ptr, n_components, n_cols); - auto ev_view = raft::make_device_vector_view(ev_ptr, n_components); - auto evr_view = raft::make_device_vector_view(evr.data(), n_components); - auto sv_view = raft::make_device_vector_view(sv.data(), n_components); - auto mu_view = raft::make_device_vector_view(mu.data(), n_cols); - auto nv_view = raft::make_device_scalar_view(nv.data()); - auto output_view = - raft::make_device_matrix_view(output, n_rows, n_cols); + raft::make_device_matrix_view(comp_ptr, n_components, n_cols); + auto ev_view = raft::make_device_vector_view(ev_ptr, n_components); + auto evr_view = raft::make_device_vector_view(evr.data(), n_components); + auto sv_view = raft::make_device_vector_view(sv.data(), n_components); + auto mu_view = raft::make_device_vector_view(mu.data(), n_cols); + auto nv_view = raft::make_device_scalar_view(nv.data()); + auto output_view = raft::make_device_matrix_view(output, n_rows, n_cols); fit_transform( handle, prms, input_view, trans_view, comp_view, ev_view, evr_view, sv_view, mu_view, nv_view); @@ -263,4 +262,139 @@ TEST_P(PcaTestF, Result) { this->testPca(); } INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestF, ::testing::ValuesIn(inputsf2)); +/** + * Row-major end-to-end test: runs fit_transform + inverse_transform on row-major + * inputs and verifies the reconstruction matches the original. Also checks that + * row-major and col-major inputs (representing the same logical data) produce + * the same explained variances, singular values, and column means. + */ +template +class PcaRowMajorTest : public ::testing::TestWithParam> { + public: + PcaRowMajorTest() + : params_(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)) + { + } + + protected: + // Convert col-major data of shape (n_rows, n_cols) into row-major in dst. + // Both buffers live on device. + void to_row_major(const T* col_major_src, T* row_major_dst, int n_rows, int n_cols) + { + std::vector host_col(n_rows * n_cols); + std::vector host_row(n_rows * n_cols); + raft::update_host(host_col.data(), col_major_src, n_rows * n_cols, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (int i = 0; i < n_rows; ++i) { + for (int j = 0; j < n_cols; ++j) { + host_row[i * n_cols + j] = host_col[j * n_rows + i]; + } + } + raft::update_device(row_major_dst, host_row.data(), n_rows * n_cols, stream); + } + + void testRowMajorRoundtrip() + { + int n_rows = params_.n_row2; + int n_cols = params_.n_col2; + int len = n_rows * n_cols; + + rmm::device_uvector data_col(len, stream); + rmm::device_uvector data_row(len, stream); + rmm::device_uvector data_back_row(len, stream); + + raft::random::Rng r(params_.seed, raft::random::GenPC); + r.uniform(data_col.data(), len, T(-1.0), T(1.0), stream); + to_row_major(data_col.data(), data_row.data(), n_rows, n_cols); + + pca_roundtrip( + handle, data_row.data(), n_rows, n_cols, data_back_row.data(), n_cols, params_.algo, stream); + + ASSERT_TRUE(devArrMatch(data_row.data(), + data_back_row.data(), + len, + cuvs::CompareApprox(params_.tolerance), + stream)); + } + + void testLayoutsAgreeNumerically() + { + int n_rows = params_.n_row2; + int n_cols = params_.n_col2; + int len = n_rows * n_cols; + + rmm::device_uvector data_col(len, stream); + rmm::device_uvector data_row(len, stream); + + raft::random::Rng r(params_.seed, raft::random::GenPC); + r.uniform(data_col.data(), len, T(-1.0), T(1.0), stream); + to_row_major(data_col.data(), data_row.data(), n_rows, n_cols); + + // Run col-major path + rmm::device_uvector col_back(len, stream); + rmm::device_uvector col_ev(n_cols, stream); + rmm::device_uvector col_components(n_cols * n_cols, stream); + rmm::device_uvector col_trans(len, stream); + pca_roundtrip(handle, + data_col.data(), + n_rows, + n_cols, + col_back.data(), + n_cols, + params_.algo, + stream, + col_components.data(), + col_ev.data(), + col_trans.data()); + + // Run row-major path on the same logical data + rmm::device_uvector row_back(len, stream); + rmm::device_uvector row_ev(n_cols, stream); + rmm::device_uvector row_components(n_cols * n_cols, stream); + rmm::device_uvector row_trans(len, stream); + pca_roundtrip(handle, + data_row.data(), + n_rows, + n_cols, + row_back.data(), + n_cols, + params_.algo, + stream, + row_components.data(), + row_ev.data(), + row_trans.data()); + + // Explained variances and reconstructions should agree across layouts. + ASSERT_TRUE(devArrMatch( + col_ev.data(), row_ev.data(), n_cols, cuvs::CompareApprox(params_.tolerance), stream)); + + // Reconstructions are stored in their native layouts; compare element-by- + // element after a host-side reshape of the row-major result back to col. + std::vector col_back_h(len); + std::vector row_back_h(len); + raft::update_host(col_back_h.data(), col_back.data(), len, stream); + raft::update_host(row_back_h.data(), row_back.data(), len, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (int i = 0; i < n_rows; ++i) { + for (int j = 0; j < n_cols; ++j) { + T col_val = col_back_h[j * n_rows + i]; + T row_val = row_back_h[i * n_cols + j]; + ASSERT_NEAR(col_val, row_val, params_.tolerance) << "Mismatch at (" << i << "," << j << ")"; + } + } + } + + private: + raft::device_resources handle; + cudaStream_t stream; + PcaInputs params_; +}; + +typedef PcaRowMajorTest PcaRowMajorTestF; +TEST_P(PcaRowMajorTestF, Roundtrip) { this->testRowMajorRoundtrip(); } +TEST_P(PcaRowMajorTestF, AgreesWithColMajor) { this->testLayoutsAgreeNumerically(); } + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaRowMajorTestF, ::testing::ValuesIn(inputsf2)); + } // end namespace cuvs::preprocessing::pca diff --git a/python/cuvs/cuvs/preprocessing/pca/pca.pyx b/python/cuvs/cuvs/preprocessing/pca/pca.pyx index b751e13c6d..8067ef8287 100644 --- a/python/cuvs/cuvs/preprocessing/pca/pca.pyx +++ b/python/cuvs/cuvs/preprocessing/pca/pca.pyx @@ -18,7 +18,6 @@ from pylibraft.common.cai_wrapper import wrap_array from cuvs.common.exceptions import check_cuvs from cuvs.common.resources import auto_sync_resources -from cuvs.neighbors.common import _check_input_array SOLVER_TYPES = { "cov_eig_dq": cuvsPcaSolver.CUVS_PCA_COV_EIG_DQ, @@ -116,11 +115,30 @@ FitTransformOutput = namedtuple( ) -def _to_f_order(ary): - """Ensure a device array is Fortran-contiguous (col-major).""" +def _ensure_device_contiguous(ary, dtype=np.dtype("float32")): + """Ensure a device array is contiguous in either C- or F-order. + + Returns a tuple (arr, order) where ``order`` is ``"C"`` or ``"F"``. + No copy is performed if the input is already contiguous in some order. + """ if hasattr(ary, "__cuda_array_interface__"): - return cp.asfortranarray(cp.asarray(ary)) - return np.asfortranarray(np.asarray(ary)) + ary = cp.asarray(ary, dtype=dtype) + else: + ary = cp.asarray(np.asarray(ary, dtype=dtype)) + + if ary.flags.c_contiguous: + return ary, "C" + if ary.flags.f_contiguous: + return ary, "F" + return cp.ascontiguousarray(ary), "C" + + +def _validate_pca_input(x_ai, expected_dtypes): + """Verify ``x_ai`` has a supported dtype and is contiguous (C or F).""" + if x_ai.dtype not in expected_dtypes: + raise TypeError("dtype %s not supported" % x_ai.dtype) + if not (x_ai.c_contiguous or x_ai.f_contiguous): + raise ValueError("Input must be contiguous in C- or F-order") @auto_sync_resources @@ -131,13 +149,17 @@ def fit(Params params, X, resources=None): Computes the principal components, explained variances, singular values, and column means from the input data. + The input layout (C-contiguous / row-major or F-contiguous / col-major) + is preserved natively; no internal copy/transpose is performed. Output + arrays use the same layout as the input. + Parameters ---------- params : Params PCA parameters. ``params.copy`` should be True if you intend to reuse *X* after this call. X : device array-like, shape (n_samples, n_features), float32 - Input data (will be converted to col-major device memory). + Input data. Must be contiguous in either C- or F-order. {resources_docstring} Returns @@ -145,7 +167,7 @@ def fit(Params params, X, resources=None): FitOutput Named tuple with fields: ``components``, ``explained_var``, ``explained_var_ratio``, ``singular_vals``, ``mu``, - ``noise_vars``. + ``noise_vars``. ``components`` matches the layout of *X*. Examples -------- @@ -159,12 +181,12 @@ def fit(Params params, X, resources=None): """ n_components = params.n_components - X_f = _to_f_order(X) - x_ai = wrap_array(X_f) - _check_input_array(x_ai, [np.dtype("float32")], exp_row_major=False) + X_arr, order = _ensure_device_contiguous(X) + x_ai = wrap_array(X_arr) + _validate_pca_input(x_ai, [np.dtype("float32")]) n_rows, n_cols = x_ai.shape - components = cp.empty((n_components, n_cols), dtype="float32", order="F") + components = cp.empty((n_components, n_cols), dtype="float32", order=order) explained_var = cp.empty((n_components,), dtype="float32") explained_var_ratio = cp.empty((n_components,), dtype="float32") singular_vals = cp.empty((n_components,), dtype="float32") @@ -201,12 +223,15 @@ def fit_transform(Params params, X, resources=None): """ Compute PCA and transform the input data in a single operation. + The input layout (C- or F-contiguous) is preserved natively; output + arrays use the same layout. + Parameters ---------- params : Params PCA parameters. X : device array-like, shape (n_samples, n_features), float32 - Input data (will be converted to col-major device memory). + Input data. Must be contiguous in either C- or F-order. {resources_docstring} Returns @@ -214,7 +239,8 @@ def fit_transform(Params params, X, resources=None): FitTransformOutput Named tuple with fields: ``trans_input``, ``components``, ``explained_var``, ``explained_var_ratio``, ``singular_vals``, - ``mu``, ``noise_vars``. + ``mu``, ``noise_vars``. ``trans_input`` and ``components`` match + the layout of *X*. Examples -------- @@ -228,13 +254,14 @@ def fit_transform(Params params, X, resources=None): """ n_components = params.n_components - X_f = _to_f_order(X) - x_ai = wrap_array(X_f) - _check_input_array(x_ai, [np.dtype("float32")], exp_row_major=False) + X_arr, order = _ensure_device_contiguous(X) + x_ai = wrap_array(X_arr) + _validate_pca_input(x_ai, [np.dtype("float32")]) n_rows, n_cols = x_ai.shape - trans_input = cp.empty((n_rows, n_components), dtype="float32", order="F") - components = cp.empty((n_components, n_cols), dtype="float32", order="F") + trans_input = cp.empty((n_rows, n_components), + dtype="float32", order=order) + components = cp.empty((n_components, n_cols), dtype="float32", order=order) explained_var = cp.empty((n_components,), dtype="float32") explained_var_ratio = cp.empty((n_components,), dtype="float32") singular_vals = cp.empty((n_components,), dtype="float32") @@ -270,6 +297,18 @@ def fit_transform(Params params, X, resources=None): noise_vars) +def _match_layout(arr, order): + """Return ``arr`` in the requested layout (``"C"`` or ``"F"``). + + Avoids a copy when the array is already in the target layout.""" + arr = cp.asarray(arr) + if order == "C" and arr.flags.c_contiguous: + return arr + if order == "F" and arr.flags.f_contiguous: + return arr + return cp.asarray(arr, order=order) + + @auto_sync_resources @auto_convert_output def transform(Params params, X, components, singular_vals, mu, @@ -278,7 +317,9 @@ def transform(Params params, X, components, singular_vals, mu, Transform data into the PCA eigenspace. Uses previously computed principal components from :func:`fit` or - :func:`fit_transform`. + :func:`fit_transform`. The input layout (C- or F-contiguous) of *X* + determines the layout used internally; ``components`` and + ``trans_input`` are aligned to that layout. Parameters ---------- @@ -293,7 +334,7 @@ def transform(Params params, X, components, singular_vals, mu, mu : device array-like, shape (n_features,) Column means from a prior fit. trans_input : optional device array, shape (n_samples, n_components) - Pre-allocated output buffer (col-major, float32). + Pre-allocated output buffer (float32). Layout is matched to *X*. {resources_docstring} Returns @@ -312,26 +353,26 @@ def transform(Params params, X, components, singular_vals, mu, """ n_components = params.n_components - X_f = _to_f_order(X) - x_ai = wrap_array(X_f) - _check_input_array(x_ai, [np.dtype("float32")], exp_row_major=False) + X_arr, order = _ensure_device_contiguous(X) + x_ai = wrap_array(X_arr) + _validate_pca_input(x_ai, [np.dtype("float32")]) n_rows = x_ai.shape[0] - components_f = _to_f_order(components) + components_arr = _match_layout(components, order) singular_vals_arr = cp.asarray(singular_vals) mu_arr = cp.asarray(mu) if trans_input is None: trans_input = cp.empty( - (n_rows, n_components), dtype="float32", order="F" + (n_rows, n_components), dtype="float32", order=order ) else: - trans_input = _to_f_order(trans_input) + trans_input = _match_layout(trans_input, order) cdef cydlpack.DLManagedTensor* x_dlpack = \ cydlpack.dlpack_c(x_ai) cdef cydlpack.DLManagedTensor* comp_dlpack = \ - cydlpack.dlpack_c(wrap_array(components_f)) + cydlpack.dlpack_c(wrap_array(components_arr)) cdef cydlpack.DLManagedTensor* sv_dlpack = \ cydlpack.dlpack_c(wrap_array(singular_vals_arr)) cdef cydlpack.DLManagedTensor* mu_dlpack = \ @@ -355,6 +396,9 @@ def inverse_transform(Params params, trans_input, components, """ Transform data from the PCA eigenspace back to the original space. + The layout (C- or F-contiguous) of ``trans_input`` is preserved; + ``components`` and ``output`` are aligned to that layout. + Parameters ---------- params : Params @@ -368,7 +412,8 @@ def inverse_transform(Params params, trans_input, components, mu : device array-like, shape (n_features,) Column means from a prior fit. output : optional device array, shape (n_samples, n_features) - Pre-allocated output buffer (col-major, float32). + Pre-allocated output buffer (float32). Layout is matched to + ``trans_input``. {resources_docstring} Returns @@ -387,22 +432,22 @@ def inverse_transform(Params params, trans_input, components, ... params, result.trans_input, result.components, ... result.singular_vals, result.mu) """ - trans_f = _to_f_order(trans_input) - trans_ai = wrap_array(trans_f) - _check_input_array(trans_ai, [np.dtype("float32")], exp_row_major=False) + trans_arr, order = _ensure_device_contiguous(trans_input) + trans_ai = wrap_array(trans_arr) + _validate_pca_input(trans_ai, [np.dtype("float32")]) n_rows = trans_ai.shape[0] - components_f = _to_f_order(components) - comp_ai = wrap_array(components_f) + components_arr = _match_layout(components, order) + comp_ai = wrap_array(components_arr) n_cols = comp_ai.shape[1] singular_vals_arr = cp.asarray(singular_vals) mu_arr = cp.asarray(mu) if output is None: - output = cp.empty((n_rows, n_cols), dtype="float32", order="F") + output = cp.empty((n_rows, n_cols), dtype="float32", order=order) else: - output = _to_f_order(output) + output = _match_layout(output, order) cdef cydlpack.DLManagedTensor* trans_dlpack = \ cydlpack.dlpack_c(trans_ai) diff --git a/python/cuvs/cuvs/tests/test_pca.py b/python/cuvs/cuvs/tests/test_pca.py index b045a62186..e491c57876 100644 --- a/python/cuvs/cuvs/tests/test_pca.py +++ b/python/cuvs/cuvs/tests/test_pca.py @@ -7,15 +7,25 @@ from cuvs.preprocessing import pca +def _as_order(arr, order): + if order == "C": + return cp.ascontiguousarray(arr) + return cp.asfortranarray(arr) + + @pytest.mark.parametrize("n_rows", [256, 512]) @pytest.mark.parametrize("n_cols", [32, 64]) -def test_fit_transform_inverse_transform(n_rows, n_cols): +@pytest.mark.parametrize("order", ["C", "F"]) +def test_fit_transform_inverse_transform(n_rows, n_cols, order): """ fit_transform with all components then inverse_transform - should reconstruct the original data near-losslessly. + should reconstruct the original data near-losslessly, regardless of + whether the input is C- or F-contiguous. """ n_components = n_cols - X = cp.random.random_sample((n_rows, n_cols), dtype=cp.float32) + X = _as_order( + cp.random.random_sample((n_rows, n_cols), dtype=cp.float32), order + ) params = pca.Params(n_components=n_components) result = pca.fit_transform(params, X) @@ -23,6 +33,12 @@ def test_fit_transform_inverse_transform(n_rows, n_cols): assert result.trans_input.shape == (n_rows, n_components) assert result.components.shape == (n_components, n_cols) + expected_c = order == "C" + assert result.trans_input.flags.c_contiguous == expected_c + assert result.trans_input.flags.f_contiguous == (not expected_c) + assert result.components.flags.c_contiguous == expected_c + assert result.components.flags.f_contiguous == (not expected_c) + reconstructed = pca.inverse_transform( params, result.trans_input, @@ -31,21 +47,25 @@ def test_fit_transform_inverse_transform(n_rows, n_cols): result.mu, ) - max_err = float(cp.max(cp.abs(cp.asfortranarray(X) - reconstructed))) + max_err = float(cp.max(cp.abs(X - reconstructed))) assert max_err < 1e-3, ( - f"Reconstruction error {max_err} too large for lossless case" + f"Reconstruction error {max_err} too large for lossless case " + f"(order={order})" ) @pytest.mark.parametrize("n_rows", [256, 512]) @pytest.mark.parametrize("n_cols", [32, 64]) -def test_fit_then_transform(n_rows, n_cols): +@pytest.mark.parametrize("order", ["C", "F"]) +def test_fit_then_transform(n_rows, n_cols, order): """ fit() then transform() separately should give the same result - as fit_transform() when copy=True. + as fit_transform() when copy=True, regardless of input layout. """ n_components = n_cols - X = cp.random.random_sample((n_rows, n_cols), dtype=cp.float32) + X = _as_order( + cp.random.random_sample((n_rows, n_cols), dtype=cp.float32), order + ) params = pca.Params(n_components=n_components, copy=True) fit_result = pca.fit(params, X) @@ -71,18 +91,22 @@ def test_fit_then_transform(n_rows, n_cols): fit_result.mu, ) - max_err = float(cp.max(cp.abs(cp.asfortranarray(X) - reconstructed))) + max_err = float(cp.max(cp.abs(X - reconstructed))) assert max_err < 1e-3, ( - f"Reconstruction error {max_err} too large for lossless case" + f"Reconstruction error {max_err} too large for lossless case " + f"(order={order})" ) @pytest.mark.parametrize("n_rows", [512]) @pytest.mark.parametrize("n_cols", [64]) @pytest.mark.parametrize("n_components", [8, 16]) -def test_dim_reduction(n_rows, n_cols, n_components): +@pytest.mark.parametrize("order", ["C", "F"]) +def test_dim_reduction(n_rows, n_cols, n_components, order): """With fewer components, reconstruction should have bounded error.""" - X = cp.random.random_sample((n_rows, n_cols), dtype=cp.float32) + X = _as_order( + cp.random.random_sample((n_rows, n_cols), dtype=cp.float32), order + ) params = pca.Params(n_components=n_components) result = pca.fit_transform(params, X) @@ -97,13 +121,77 @@ def test_dim_reduction(n_rows, n_cols, n_components): result.mu, ) - max_err = float(cp.max(cp.abs(cp.asfortranarray(X) - reconstructed))) + max_err = float(cp.max(cp.abs(X - reconstructed))) assert max_err > 1e-5, ( "Reconstruction error should be non-zero with fewer components" ) assert max_err < 2.0, f"Reconstruction error {max_err} should be bounded" +def test_row_major_no_copy(): + """ + When the user passes a row-major (C-contiguous) array, the data pointer + must be preserved -- i.e. the implementation must NOT silently reorder + the input into Fortran layout. + """ + n_rows, n_cols = 128, 16 + X = cp.random.random_sample((n_rows, n_cols), dtype=cp.float32) + assert X.flags.c_contiguous + + original_ptr = X.data.ptr + + params = pca.Params(n_components=8) + pca.fit(params, X) + + # The user's array was contiguous and float32 -- the C ABI should accept + # it as-is, not reallocate or reorder it. + assert X.data.ptr == original_ptr + assert X.flags.c_contiguous + + +def test_layouts_agree_numerically(): + """ + Running PCA on the same data in C-order and F-order should yield the + same explained variances (within tolerance), and the same reconstruction. + Components individually may differ in sign convention but reconstruction + is invariant to sign flip of components. + """ + n_rows, n_cols, n_components = 256, 32, 16 + rng = cp.random.default_rng(42) + X = rng.standard_normal((n_rows, n_cols), dtype=cp.float32) + + X_c = cp.ascontiguousarray(X) + X_f = cp.asfortranarray(X) + + params = pca.Params(n_components=n_components) + res_c = pca.fit_transform(params, X_c) + res_f = pca.fit_transform(params, X_f) + + cp.testing.assert_allclose( + res_c.explained_var, res_f.explained_var, rtol=1e-4, atol=1e-4 + ) + cp.testing.assert_allclose( + res_c.singular_vals, res_f.singular_vals, rtol=1e-4, atol=1e-4 + ) + cp.testing.assert_allclose(res_c.mu, res_f.mu, rtol=1e-5, atol=1e-5) + + recon_c = pca.inverse_transform( + params, + res_c.trans_input, + res_c.components, + res_c.singular_vals, + res_c.mu, + ) + recon_f = pca.inverse_transform( + params, + res_f.trans_input, + res_f.components, + res_f.singular_vals, + res_f.mu, + ) + cp.testing.assert_allclose(recon_c, recon_f, rtol=1e-3, atol=1e-3) + + def test_explained_variance(): """ When all components are kept, explained_var_ratio should sum From 6408c8edfa030e794ec06d950f948e3b044e0eba Mon Sep 17 00:00:00 2001 From: aamijar Date: Tue, 26 May 2026 01:11:05 +0000 Subject: [PATCH 2/2] pin raft --- cpp/cmake/thirdparty/get_raft.cmake | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 8ecf3686be..9f948398ab 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -1,6 +1,6 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on @@ -60,8 +60,8 @@ endfunction() # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} + FORK aamijar + PINNED_TAG pca-row-major ENABLE_MNMG_DEPENDENCIES OFF ENABLE_NVTX OFF BUILD_STATIC_DEPS ${CUVS_STATIC_RAPIDS_LIBRARIES}