From d0884ca845d99caaf2fbd7fd1b5ed2678263382a Mon Sep 17 00:00:00 2001 From: Aaron Preus Date: Sun, 15 Feb 2026 18:12:18 +0100 Subject: [PATCH 1/4] fix; add missing initializer --- ...x_short_function_test_van_mises_distance_sq_derivative.cpp | 4 ++++ ...rac_approx_short_test_van_mises_distance_sq_derivative.cpp | 2 ++ .../tests/unit_tests/gm_to_dirac_short_test_derivative.cpp | 3 +++ 3 files changed, 9 insertions(+) diff --git a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_function_test_van_mises_distance_sq_derivative.cpp b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_function_test_van_mises_distance_sq_derivative.cpp index 6b85386..213f7e0 100644 --- a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_function_test_van_mises_distance_sq_derivative.cpp +++ b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_function_test_van_mises_distance_sq_derivative.cpp @@ -83,6 +83,10 @@ class gradVanMisesDistanceSqDynamicWeight; }; +gradient_van_mises_distance_sq_dynamic_weight +dirac_to_dirac_approx_short_function_test_modified_van_mises_distance_sq_derivative:: + gradVanMisesDistanceSqDynamicWeight; + static double wXcallbackWrapper(const gsl_vector* x, void* params) { DiracToDiracVariableWeightOptimizationParams* p = static_cast(params); diff --git a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp index fe92700..6c05395 100644 --- a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp +++ b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp @@ -56,6 +56,8 @@ class dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative gradVanMisesDistanceSqConstWeight; }; +gradient_van_mises_distance_sq_const_weight dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative::gradVanMisesDistanceSqConstWeight; + TEST_P( dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, parameterized_test_modified_van_mises_distance_sq_derivative) { diff --git a/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp b/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp index c0fa815..a6e3068 100644 --- a/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp +++ b/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp @@ -53,6 +53,9 @@ class gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative static gradient_van_mises_distance gradVanMisesDistance; }; +gradient_van_mises_distance +gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative::gradVanMisesDistance; + TEST_P(gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, parameterized_test_modified_van_mises_distance_sq_derivative) { GmToDiracTestCaseParams p = GetParam(); From 1fadcc17193530ff803f21713f6d9696326eea6a Mon Sep 17 00:00:00 2001 From: Aaron Preus Date: Sun, 15 Feb 2026 18:12:58 +0100 Subject: [PATCH 2/4] update from deprecated Google Benchmark access --- .../tests/benchmark/benchmark_D_E_cache_performance.cpp | 2 +- .../tests/benchmark/benchmark_dirac_to_dirac_approx_short.cpp | 2 +- .../benchmark_dirac_to_dirac_approx_short_threaded.cpp | 2 +- .../tests/benchmark/benchmark_gm_to_dirac_short.cpp | 4 ++-- .../benchmark/benchmark_squared_euclidean_distance_utils.cpp | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/dirac_to_dirac/tests/benchmark/benchmark_D_E_cache_performance.cpp b/lib/dirac_to_dirac/tests/benchmark/benchmark_D_E_cache_performance.cpp index 7d4b521..5fc79ff 100644 --- a/lib/dirac_to_dirac/tests/benchmark/benchmark_D_E_cache_performance.cpp +++ b/lib/dirac_to_dirac/tests/benchmark/benchmark_D_E_cache_performance.cpp @@ -85,7 +85,7 @@ static const long long maxN = 25; static const long long minN = 1; static const long long stepN = 1; -static void D_E_CustomArguments(benchmark::internal::Benchmark* b) { +static void D_E_CustomArguments(benchmark::Benchmark* b) { for (int L = minL; L <= maxL; L += stepL) for (int N = minN; N <= maxN; N += stepN) b->Args({L, N}); } diff --git a/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short.cpp b/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short.cpp index 9e04433..1170ffa 100644 --- a/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short.cpp +++ b/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short.cpp @@ -133,7 +133,7 @@ static const float minRatioMN = 20.0f; static const float maxRatioMN = 100.0f; static const float stepRatioMN = 5.0f; -static void D2D_CustomArguments(benchmark::internal::Benchmark* b) { +static void D2D_CustomArguments(benchmark::Benchmark* b) { for (float MN = minRatioMN; MN <= maxRatioMN; MN += stepRatioMN) for (float M = minM; M <= maxM; M += stepM) for (float N = minN; N <= maxN; N += stepN) diff --git a/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short_threaded.cpp b/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short_threaded.cpp index f6f0317..5cd0c5e 100644 --- a/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short_threaded.cpp +++ b/lib/dirac_to_dirac/tests/benchmark/benchmark_dirac_to_dirac_approx_short_threaded.cpp @@ -137,7 +137,7 @@ static const float minRatioMN = 100.0f; static const float maxRatioMN = 100.0f; static const float stepRatioMN = 10.0f; -static void D2D_CustomArguments(benchmark::internal::Benchmark* b) { +static void D2D_CustomArguments(benchmark::Benchmark* b) { for (float MN = minRatioMN; MN <= maxRatioMN; MN += stepRatioMN) for (float M = minM; M <= maxM; M += stepM) for (float N = minN; N <= maxN; N += stepN) diff --git a/lib/gm_to_dirac/tests/benchmark/benchmark_gm_to_dirac_short.cpp b/lib/gm_to_dirac/tests/benchmark/benchmark_gm_to_dirac_short.cpp index 9751775..62ffde5 100644 --- a/lib/gm_to_dirac/tests/benchmark/benchmark_gm_to_dirac_short.cpp +++ b/lib/gm_to_dirac/tests/benchmark/benchmark_gm_to_dirac_short.cpp @@ -119,13 +119,13 @@ static const long long minAcc = 60; static const long long maxAcc = 130; static void gm_to_dirac_short_CustomArguments_LN( - benchmark::internal::Benchmark* b) { + benchmark::Benchmark* b) { for (int L = minL; L <= maxL; L += stepL) for (int N = minN; N <= maxN; N += stepN) b->Args({L, N}); } static void gm_to_dirac_short_CustomArguments_N( - benchmark::internal::Benchmark* b) { + benchmark::Benchmark* b) { for (int N = minN; N <= maxN; N += stepN) b->Args({N}); } diff --git a/lib/math_utils/tests/benchmark/benchmark_squared_euclidean_distance_utils.cpp b/lib/math_utils/tests/benchmark/benchmark_squared_euclidean_distance_utils.cpp index f4ab354..08ebd3d 100644 --- a/lib/math_utils/tests/benchmark/benchmark_squared_euclidean_distance_utils.cpp +++ b/lib/math_utils/tests/benchmark/benchmark_squared_euclidean_distance_utils.cpp @@ -72,13 +72,13 @@ static const long long minN = 1; static const long long stepN = 1; static void SquaredEuclideanDistanceUtilsBenchmark_CustomArgumentsLM( - benchmark::internal::Benchmark* b) { + benchmark::Benchmark* b) { for (int L = minL; L <= maxL; L += stepL) for (int M = minM; M <= maxM; M += stepM) for (int N = minN; N <= maxN; N += stepN) b->Args({L, M, N}); } static void SquaredEuclideanDistanceUtilsBenchmark_CustomArgumentsLL( - benchmark::internal::Benchmark* b) { + benchmark::Benchmark* b) { for (int L = minL; L <= maxL; L += stepL) for (int N = minN; N <= maxN; N += stepN) b->Args({L, 0, N}); } From 61f6a26b994d1686c1d36f0a651592199fa98395 Mon Sep 17 00:00:00 2001 From: Aaron Preus Date: Mon, 16 Feb 2026 10:18:21 +0100 Subject: [PATCH 3/4] expand wrapper by default options implementation --- lib/CMakeLists.txt | 2 ++ lib/options/approximate_options.h | 2 +- lib/options/wrappers/approximate_options_c.cpp | 1 + lib/options/wrappers/approximate_options_c.h | 16 ++++++++++++++++ 4 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 lib/options/wrappers/approximate_options_c.cpp create mode 100644 lib/options/wrappers/approximate_options_c.h diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 7c5c877..15cce1e 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -40,6 +40,7 @@ file(GLOB GM_TO_DIRAC_C_WRAPPER_SOURCES ${GM_TO_DIRAC_C_WRAPPER_DIR}/*.cpp) set(COMMON_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/debug/capture_time.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gsl_minimizer/gsl_minimizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/options/wrappers/approximate_options_c.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/gsl_multivariative_gradient ) # common includes @@ -53,6 +54,7 @@ set(COMMON_INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}/math_utils ${CMAKE_CURRENT_SOURCE_DIR}/cache_manager ${CMAKE_CURRENT_SOURCE_DIR}/options + ${CMAKE_CURRENT_SOURCE_DIR}/options/wrappers ) # static library target diff --git a/lib/options/approximate_options.h b/lib/options/approximate_options.h index 406f059..84c61b4 100644 --- a/lib/options/approximate_options.h +++ b/lib/options/approximate_options.h @@ -14,4 +14,4 @@ struct ApproximateOptions { bool verbose = false; // True if verbose output is needed }; -#endif \ No newline at end of file +#endif // APPROXIMATE_OPTIONS_H \ No newline at end of file diff --git a/lib/options/wrappers/approximate_options_c.cpp b/lib/options/wrappers/approximate_options_c.cpp new file mode 100644 index 0000000..d8ab09c --- /dev/null +++ b/lib/options/wrappers/approximate_options_c.cpp @@ -0,0 +1 @@ +#include "approximate_options_c.h" \ No newline at end of file diff --git a/lib/options/wrappers/approximate_options_c.h b/lib/options/wrappers/approximate_options_c.h new file mode 100644 index 0000000..dfe67fe --- /dev/null +++ b/lib/options/wrappers/approximate_options_c.h @@ -0,0 +1,16 @@ +#ifndef APPROXIMATE_OPTIONS_C_H +#define APPROXIMATE_OPTIONS_C_H + +#include "approximate_options.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ApproximateOptions default_approximate_options() { return ApproximateOptions{}; } + +#ifdef __cplusplus +} +#endif + +#endif // APPROXIMATE_OPTIONS_C_H \ No newline at end of file From daeea976cc0f3bb10c45e623e0ba2632553e23ce Mon Sep 17 00:00:00 2001 From: Aaron Preus Date: Thu, 19 Feb 2026 20:47:21 +0100 Subject: [PATCH 4/4] expand gsl_utils by weight-helper, view-helper; expand interface to support calling distance and gradient implementations expand tests --- .../dirac_to_dirac_approx_function_i.h | 108 ++++++- lib/dirac_to_dirac/dirac_to_dirac_approx_i.h | 109 ++++++- .../dirac_to_dirac_approx_short.cpp | 260 ++++++++++------ .../dirac_to_dirac_approx_short.h | 162 ++++++++-- .../dirac_to_dirac_approx_short_function.cpp | 112 ++++++- .../dirac_to_dirac_approx_short_function.h | 83 ++++++ .../dirac_to_dirac_approx_short_thread.cpp | 246 +++++++++++----- .../dirac_to_dirac_approx_short_thread.h | 147 ++++++++-- ..._test_van_mises_distance_sq_derivative.cpp | 49 +++- .../wrappers/dirac_to_dirac_approx_short_c.h | 32 ++ .../dirac_to_dirac_approx_short_function_c.h | 24 ++ .../dirac_to_dirac_approx_short_thread_c.h | 33 +++ lib/gm_to_dirac/gm_to_dirac_approx_i.h | 107 ++++++- ...ac_approx_standard_normal_distribution_i.h | 106 ++++++- lib/gm_to_dirac/gm_to_dirac_short.cpp | 143 +++++++-- lib/gm_to_dirac/gm_to_dirac_short.h | 128 ++++++-- ..._dirac_short_standard_normal_deviation.cpp | 106 +++++-- ...to_dirac_short_standard_normal_deviation.h | 97 ++++-- .../gm_to_dirac_short_test_derivative.cpp | 45 ++- .../unit_tests/gm_to_dirac_test_case_params.h | 0 .../wrappers/gm_to_dirac_short_c.h | 32 ++ ..._dirac_short_standard_normal_deviation_c.h | 36 +++ lib/gsl_types/gsl_vector_matrix_types.h | 16 + lib/gsl_utils/gsl_utils_view_helper.h | 277 ++++++++++++++++++ lib/gsl_utils/gsl_utils_weight_helper.h | 53 ++++ 25 files changed, 2180 insertions(+), 331 deletions(-) rename lib/{dirac_to_dirac => gm_to_dirac}/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp (69%) rename lib/{dirac_to_dirac => gm_to_dirac}/tests/unit_tests/gm_to_dirac_test_case_params.h (100%) create mode 100644 lib/gsl_utils/gsl_utils_view_helper.h create mode 100644 lib/gsl_utils/gsl_utils_weight_helper.h diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_function_i.h b/lib/dirac_to_dirac/dirac_to_dirac_approx_function_i.h index 19d9266..9df3dd2 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_function_i.h +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_function_i.h @@ -15,7 +15,7 @@ * @brief interface for the gausian mixture to dirac approximation with a custom * weight function * - * @tparam T type of the vector (float, double, long double) + * @tparam T type of the vector (float, double) */ template class dirac_to_dirac_approx_function_i { @@ -48,6 +48,43 @@ class dirac_to_dirac_approx_function_i { GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param M number of elements in y + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, const T* y, size_t M, + size_t L, size_t N, size_t bMax, + T* x, wXf wXcallback, + wXd wXDcallback) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param M number of elements in y + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + T* gradient, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + wXf wXcallback, wXd wXDcallback) = 0; + /** * @brief reduce the data points using gsl vectors * @@ -67,6 +104,42 @@ class dirac_to_dirac_approx_function_i { wXd wXDcallback, GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, + const GSLVectorType* y, size_t L, + size_t N, size_t bMax, + GSLVectorType* x, wXf wXcallback, + wXd wXDcallback) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + GSLMatrixType* gradient, const GSLVectorType* y, size_t L, size_t N, + size_t bMax, GSLVectorType* x, wXf wXcallback, wXd wXDcallback) = 0; + /** * @brief reduce the data points using gsl matricies where possible * @@ -84,6 +157,39 @@ class dirac_to_dirac_approx_function_i { GSLMatrixType* x, wXf wXcallback, wXd wXDcallback, GslminimizerResult* result, const ApproximateOptions& options) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param L number of elements in x + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, GSLMatrixType* y, + size_t L, size_t bMax, + GSLMatrixType* x, wXf wXcallback, + wXd wXDcallback) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param L number of elements in x + * @param bMax bMax + * @param x input data points + * @param wXcallback callback for the weight function + * @param wXDcallback callback for the gradient of the weight function + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + GSLMatrixType* gradient, GSLMatrixType* y, size_t L, size_t bMax, + GSLMatrixType* x, wXf wXcallback, wXd wXDcallback) = 0; }; #endif // DIRAC_TO_DIRAC_APPROX_FUNCTION_I_H diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_i.h b/lib/dirac_to_dirac/dirac_to_dirac_approx_i.h index 9bb70b3..86b7490 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_i.h +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_i.h @@ -14,7 +14,7 @@ /** * @brief interface for the gausian mixture to dirac approximation * - * @tparam T type of the vector (float, double, long double) + * @tparam T type of the vector (float, double) */ template class dirac_to_dirac_approx_i { @@ -22,6 +22,7 @@ class dirac_to_dirac_approx_i { using GSLVectorType = typename GSLTemplateTypeAlias::VectorType; using GSLVectorViewType = typename GSLTemplateTypeAlias::VectorViewType; using GSLMatrixType = typename GSLTemplateTypeAlias::MatrixType; + using GSLMatrixViewType = typename GSLTemplateTypeAlias::MatrixViewType; virtual ~dirac_to_dirac_approx_i() = default; @@ -45,6 +46,43 @@ class dirac_to_dirac_approx_i { GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param M number of elements in y + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, const T* y, size_t M, + size_t L, size_t N, size_t bMax, + T* x, const T* wX, + const T* wY) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param M number of elements in y + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + T* gradient, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + const T* wX, const T* wY) = 0; + /** * @brief reduce the data points using gsl vectors * @@ -65,6 +103,41 @@ class dirac_to_dirac_approx_i { GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq( + T* distance, const GSLVectorType* y, size_t L, size_t N, size_t bMax, + GSLVectorType* x, const GSLVectorType* wX, const GSLVectorType* wY) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param L number of elements in x + * @param N dimension of the data + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + GSLMatrixType* gradient, const GSLVectorType* y, size_t L, size_t N, + size_t bMax, GSLVectorType* x, const GSLVectorType* wX, + const GSLVectorType* wY) = 0; + /** * @brief reduce the data points using gsl matricies where possible * @@ -82,6 +155,40 @@ class dirac_to_dirac_approx_i { GSLMatrixType* x, const GSLVectorType* wX, const GSLVectorType* wY, GslminimizerResult* result, const ApproximateOptions& options) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param distance pointer to distance value to be calculated + * @param y input data points + * @param L number of elements in x + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, GSLMatrixType* y, + size_t L, size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX, + const GSLVectorType* wY) = 0; + + /** + * @brief calculate modified van mises distance based on x and y + * + * @param gradient pointer to gradient to be calculated + * @param y input data points + * @param L number of elements in x + * @param bMax bMax + * @param x input data points + * @param wX weights for the x data points + * @param wY weights for the y data points + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + GSLMatrixType* gradient, GSLMatrixType* y, size_t L, size_t bMax, + GSLMatrixType* x, const GSLVectorType* wX, const GSLVectorType* wY) = 0; }; #endif // DIRAC_TO_DIRAC_APPROX_I_H \ No newline at end of file diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short.cpp b/lib/dirac_to_dirac/dirac_to_dirac_approx_short.cpp index 4fb7eb3..1b8063a 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short.cpp +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short.cpp @@ -13,10 +13,11 @@ #include "capture_time.h" #include "dirac_to_dirac_optimization_params.h" #include "gsl_minimizer.h" +#include "gsl_utils_view_helper.h" +#include "gsl_utils_weight_helper.h" #include "math_util_defs.h" #include "squared_euclidean_distance_utils.h" -#define eps 0.000001 #define diagamma_1 -0.5772156649015328606065120900824 template @@ -26,25 +27,45 @@ bool dirac_to_dirac_approx_short::approximate( const ApproximateOptions& options) { assert(x != nullptr); assert(y != nullptr); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::vector_view_from_array(x, L * N); - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::vector_view_from_array(y, M * N); - - GSLVectorType* wXVector = nullptr; - GSLVectorViewType wXVectorView; - if (wX) { - wXVectorView = GSLTemplateTypeAlias::vector_view_from_array(wX, L); - wXVector = &(wXVectorView.vector); - } - GSLVectorType* wYVector = nullptr; - GSLVectorViewType wYVectorView; - if (wY) { - wYVectorView = GSLTemplateTypeAlias::vector_view_from_array(wY, M); - wYVector = &(wYVectorView.vector); - } - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wXVector, - wYVector, result, options); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + return approximate(vectorViewY, L, N, bMax, vectorViewX, vectorViewWX, + vectorViewWY, result, options); +} + +template +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + T* distance, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + const T* wX, const T* wY) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, vectorViewX, + vectorViewWX, vectorViewWY); +} + +template +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq_derivative( + T* gradient, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + const T* wX, const T* wY) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + GSLMatrixView matrixViewGradient(gradient, L, N); + modified_van_mises_distance_sq_derivative(matrixViewGradient, vectorViewY, L, + N, bMax, vectorViewX, vectorViewWX, + vectorViewWY); } template @@ -56,12 +77,38 @@ bool dirac_to_dirac_approx_short::approximate( assert(x->size1 == L); size_t N = y->size2; - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(y); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(x); - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wX, wY, - result, options); + GSLVectorView vectorViewX(x); + GSLVectorView vectorViewY(y); + return approximate(vectorViewY, L, N, bMax, vectorViewX, wX, wY, result, + options); +} + +template +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + T* distance, GSLMatrixType* y, size_t L, size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX, const GSLVectorType* wY) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewX(x); + GSLVectorView vectorViewY(y); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, vectorViewX, + wX, wY); +} + +template +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq_derivative( + GSLMatrixType* gradient, GSLMatrixType* y, size_t L, size_t bMax, + GSLMatrixType* x, const GSLVectorType* wX, const GSLVectorType* wY) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewX(x); + GSLVectorView vectorViewY(y); + modified_van_mises_distance_sq_derivative(gradient, vectorViewY, L, N, bMax, + vectorViewX, wX, wY); } template @@ -184,11 +231,10 @@ inline void dirac_to_dirac_approx_short::combined_distance_metric( } template -inline void dirac_to_dirac_approx_short::correctMean(const gsl_vector* meanY, - gsl_vector* x, - const gsl_vector* wX, - size_t L, size_t N) { - std::vector mean(N, 0.0); +inline void dirac_to_dirac_approx_short::correctMean( + const GSLVectorType* meanY, GSLVectorType* x, const GSLVectorType* wX, + size_t L, size_t N) { + std::vector mean(N, T(0)); for (size_t i = 0; i < L; i++) { for (size_t k = 0; k < N; k++) { mean[k] += wX->data[i] * x->data[i * N + k]; @@ -209,47 +255,22 @@ bool dirac_to_dirac_approx_short::approximate( const gsl_vector_float* y, size_t L, size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX, const gsl_vector_float* wY, GslminimizerResult* result, const ApproximateOptions& options) { - gsl_vector* yDouble = gsl_vector_alloc(y->size); - gsl_vector* xDouble = gsl_vector_alloc(x->size); - gsl_vector* wXDouble = nullptr; - gsl_vector* wYDouble = nullptr; - - for (size_t i = 0; i < y->size; ++i) { - yDouble->data[i] = static_cast(y->data[i]); - } - if (wX) { - wXDouble = gsl_vector_alloc(L); - for (size_t i = 0; i < L; ++i) { - wXDouble->data[i] = static_cast(wX->data[i]); - } - } - if (wY) { - const size_t M = y->size / N; - wYDouble = gsl_vector_alloc(M); - for (size_t i = 0; i < M; ++i) { - wYDouble->data[i] = static_cast(wY->data[i]); - } - } - - if (options.initialX) { - for (size_t i = 0; i < x->size; ++i) { - xDouble->data[i] = static_cast(x->data[i]); - } - } + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); dirac_to_dirac_approx_short doubleApprox; - bool success = doubleApprox.approximate(yDouble, L, N, bMax, xDouble, - wXDouble, wYDouble, result, options); + bool success = + doubleApprox.approximate(vectorViewY, L, N, bMax, vectorViewX, + vectorViewWX, vectorViewWY, result, options); - for (size_t i = 0; i < x->size; ++i) { - x->data[i] = static_cast(xDouble->data[i]); + const size_t xSize = x->size; + for (size_t i = 0; i < xSize; ++i) { + x->data[i] = static_cast(vectorViewX.get()->data[i]); } - gsl_vector_free(yDouble); - gsl_vector_free(xDouble); - if (wXDouble) gsl_vector_free(wXDouble); - if (wYDouble) gsl_vector_free(wYDouble); - return success; } @@ -271,27 +292,12 @@ bool dirac_to_dirac_approx_short::approximate( } } - const gsl_vector* localWX; - const bool freeWx = wX == nullptr; - if (freeWx) { - gsl_vector* tmpWx = gsl_vector_alloc(L); - gsl_vector_set_all(tmpWx, 1.00 / static_cast(L)); - localWX = tmpWx; - } else { - localWX = wX; - } - const gsl_vector* localWY; - const bool freeWy = wY == nullptr; - if (freeWy) { - gsl_vector* tmpWy = gsl_vector_alloc(M); - gsl_vector_set_all(tmpWy, 1.00 / static_cast(M)); - localWY = tmpWy; - } else { - localWY = wY; - } + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); DiracToDiracConstWeightOptimizationParams params = - DiracToDiracConstWeightOptimizationParams(y, N, M, L, bMax, c_b(bMax)); + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, + bMax, c_b(bMax)); gsl_minimizer gslMinimizer( options.maxIterations, options.xtolAbs, options.xtolRel, options.ftolAbs, @@ -300,11 +306,89 @@ bool dirac_to_dirac_approx_short::approximate( const int status = gslMinimizer.minimize(x, result, options.verbose); correctMean(params.meanY, x, params.wX, L, N); - if (freeWx) gsl_vector_free(const_cast(localWX)); - if (freeWy) gsl_vector_free(const_cast(localWY)); return status == GSL_SUCCESS; } +template <> +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax, + gsl_vector_float* x, const gsl_vector_float* wX, + const gsl_vector_float* wY) { + double distanceDouble = 0.00; + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + dirac_to_dirac_approx_short doubleApprox; + doubleApprox.modified_van_mises_distance_sq(&distanceDouble, vectorViewY, L, + N, bMax, vectorViewX, + vectorViewWX, vectorViewWY); + *distance = static_cast(distanceDouble); +} + +template <> +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax, + gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY) { + const size_t M = y->size / N; + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); + DiracToDiracConstWeightOptimizationParams optiParams = + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, bMax, + c_b(bMax)); + *distance = modified_van_mises_distance_sq(x, &optiParams); +} + +template <> +void dirac_to_dirac_approx_short:: + modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient, + const gsl_vector_float* y, + size_t L, size_t N, size_t bMax, + gsl_vector_float* x, + const gsl_vector_float* wX, + const gsl_vector_float* wY) { + gsl_matrix* gradientDouble = + gsl_matrix_alloc(gradient->size1, gradient->size2); + + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + dirac_to_dirac_approx_short doubleApprox; + doubleApprox.modified_van_mises_distance_sq_derivative( + gradientDouble, vectorViewY, L, N, bMax, vectorViewX, vectorViewWX, + vectorViewWY); + + const size_t gradientSize = gradient->size1 * gradient->size2; + for (size_t i = 0; i < gradientSize; i++) + gradient->data[i] = static_cast(gradientDouble->data[i]); + + gsl_matrix_free(gradientDouble); +} + +template <> +void dirac_to_dirac_approx_short< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + const gsl_vector* wX, + const gsl_vector* wY) { + const size_t M = y->size / N; + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); + DiracToDiracConstWeightOptimizationParams optiParams = + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, bMax, + c_b(bMax)); + gsl_vector_view flatGradient = + gsl_vector_view_array(gradient->data, gradient->size1 * gradient->size2); + modified_van_mises_distance_sq_derivative(x, &optiParams, + &(flatGradient.vector)); +} + template class dirac_to_dirac_approx_short; template class dirac_to_dirac_approx_short; diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short.h b/lib/dirac_to_dirac/dirac_to_dirac_approx_short.h index 7a89e37..22c6230 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short.h +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short.h @@ -12,41 +12,109 @@ class dirac_to_dirac_approx_short : public dirac_to_dirac_approx_i { using GSLVectorViewType = typename dirac_to_dirac_approx_i::GSLVectorViewType; using GSLMatrixType = typename dirac_to_dirac_approx_i::GSLMatrixType; + using GSLMatrixViewType = + typename dirac_to_dirac_approx_i::GSLMatrixViewType; // clang-format off bool approximate(const T* y, - size_t M, - size_t L, - size_t N, - size_t bMax, - T* x, - const T* wX = nullptr, - const T* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t M, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX = nullptr, + const T* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + const T *wX = nullptr, + const T *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(T* gradient, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + const T *wX = nullptr, + const T *wY = nullptr) override; // clang-format on // clang-format off bool approximate(const GSLVectorType* y, - size_t L, - size_t N, - size_t bMax, - GSLVectorType* x, - const GSLVectorType* wX = nullptr, - const GSLVectorType* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX = nullptr, + const GSLVectorType* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; // clang-format on // clang-format off bool approximate(GSLMatrixType* y, - size_t L, - size_t bMax, - GSLMatrixType* x, - const GSLVectorType* wX = nullptr, - const GSLVectorType* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX = nullptr, + const GSLVectorType* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; // clang-format on private: @@ -59,12 +127,18 @@ class dirac_to_dirac_approx_short : public dirac_to_dirac_approx_i { static void combined_distance_metric(const gsl_vector* x, void* params, double* f, gsl_vector* grad); - static inline void correctMean(const gsl_vector* meanY, gsl_vector* x, - const gsl_vector* wX, size_t L, size_t N); + static inline void correctMean(const GSLVectorType* meanY, GSLVectorType* x, + const GSLVectorType* wX, size_t L, size_t N); FRIEND_TEST( dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, parameterized_test_modified_van_mises_distance_sq_derivative); + FRIEND_TEST( + dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance); + FRIEND_TEST( + dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient); FRIEND_TEST(dirac_to_dirac_approx_short_test_combined, parameterized_test_combined); friend class benchmark_dirac_to_dirac_approx_short; @@ -73,14 +147,44 @@ class dirac_to_dirac_approx_short : public dirac_to_dirac_approx_i { template <> bool dirac_to_dirac_approx_short::approximate( const gsl_vector_float* y, size_t L, size_t N, size_t bMax, - gsl_vector_float* x, const GSLVectorType* wX, const GSLVectorType* wY, + gsl_vector_float* x, const gsl_vector_float* wX, const gsl_vector_float* wY, GslminimizerResult* result, const ApproximateOptions& options); template <> bool dirac_to_dirac_approx_short::approximate( const gsl_vector* y, size_t L, size_t N, size_t bMax, gsl_vector* x, - const GSLVectorType* wX, const GSLVectorType* wY, - GslminimizerResult* result, const ApproximateOptions& options); + const gsl_vector* wX, const gsl_vector* wY, GslminimizerResult* result, + const ApproximateOptions& options); + +template <> +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax, + gsl_vector_float* x, const gsl_vector_float* wX, + const gsl_vector_float* wY); + +template <> +void dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax, + gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY); + +template <> +void dirac_to_dirac_approx_short:: + modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient, + const gsl_vector_float* y, + size_t L, size_t N, size_t bMax, + gsl_vector_float* x, + const gsl_vector_float* wX, + const gsl_vector_float* wY); + +template <> +void dirac_to_dirac_approx_short< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + const gsl_vector* wX, + const gsl_vector* wY); extern template class dirac_to_dirac_approx_short; extern template class dirac_to_dirac_approx_short; diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.cpp b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.cpp index 1419531..6fac4a4 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.cpp +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.cpp @@ -13,6 +13,7 @@ #include "capture_time.h" #include "dirac_to_dirac_optimization_params.h" #include "gsl_minimizer.h" +#include "gsl_utils_view_helper.h" #include "math_util_defs.h" template @@ -22,15 +23,44 @@ bool dirac_to_dirac_approx_short_function::approximate( const ApproximateOptions& options) { assert(x != nullptr); assert(y != nullptr); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::vector_view_from_array(x, L * N); - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::vector_view_from_array(y, M * N); - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wXcallback, + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + return approximate(vectorViewY, L, N, bMax, vectorViewX, wXcallback, wXDcallback, result, options); } +template +void dirac_to_dirac_approx_short_function::modified_van_mises_distance_sq( + T* distance, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + wXf wXcallback, wXd wXDcallback) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, vectorViewX, + wXcallback, wXDcallback); +} + +template +void dirac_to_dirac_approx_short_function< + T>::modified_van_mises_distance_sq_derivative(T* gradient, const T* y, + size_t M, size_t L, size_t N, + size_t bMax, T* x, + wXf wXcallback, + wXd wXDcallback) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLMatrixView matrixViewGradient(gradient, L, N); + modified_van_mises_distance_sq_derivative(matrixViewGradient, vectorViewY, L, + N, bMax, vectorViewX, wXcallback, + wXDcallback); +} + template bool dirac_to_dirac_approx_short_function::approximate( GSLMatrixType* y, size_t L, size_t bMax, GSLMatrixType* x, wXf wXcallback, @@ -40,14 +70,43 @@ bool dirac_to_dirac_approx_short_function::approximate( assert(x->size1 == L); size_t N = y->size2; - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(y); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(x); - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wXcallback, + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + return approximate(vectorViewY, L, N, bMax, vectorViewX, wXcallback, wXDcallback, result, options); } +template +void dirac_to_dirac_approx_short_function::modified_van_mises_distance_sq( + T* distance, GSLMatrixType* y, size_t L, size_t bMax, GSLMatrixType* x, + wXf wXcallback, wXd wXDcallback) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, vectorViewX, + wXcallback, wXDcallback); +} + +template +void dirac_to_dirac_approx_short_function< + T>::modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + GSLMatrixType* y, size_t L, + size_t bMax, GSLMatrixType* x, + wXf wXcallback, + wXd wXDcallback) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + modified_van_mises_distance_sq_derivative( + gradient, vectorViewY, L, N, bMax, vectorViewX, wXcallback, wXDcallback); +} + template inline double dirac_to_dirac_approx_short_function::c_b(size_t bMax) { return 100.00; @@ -258,4 +317,37 @@ bool dirac_to_dirac_approx_short_function::approximate( return status == GSL_SUCCESS; } +template <> +void dirac_to_dirac_approx_short_function< + double>::modified_van_mises_distance_sq(double* distance, + const gsl_vector* y, size_t L, + size_t N, size_t bMax, + gsl_vector* x, wXf wXcallback, + wXd wXDcallback) { + const size_t M = y->size / N; + DiracToDiracVariableWeightOptimizationParams optiParams = + DiracToDiracVariableWeightOptimizationParams(wXcallback, wXDcallback, y, + N, M, L, bMax, c_b(bMax)); + *distance = modified_van_mises_distance_sq(x, &optiParams); +} + +template <> +void dirac_to_dirac_approx_short_function< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + wXf wXcallback, + wXd wXDcallback) { + const size_t M = y->size / N; + DiracToDiracVariableWeightOptimizationParams optiParams = + DiracToDiracVariableWeightOptimizationParams(wXcallback, wXDcallback, y, + N, M, L, bMax, c_b(bMax)); + gsl_vector_view flatGradient = + gsl_vector_view_array(gradient->data, gradient->size1 * gradient->size2); + modified_van_mises_distance_sq_derivative(x, &optiParams, + &(flatGradient.vector)); +} + template class dirac_to_dirac_approx_short_function; \ No newline at end of file diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.h b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.h index 0f6cb32..9689917 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.h +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_function.h @@ -31,6 +31,29 @@ class dirac_to_dirac_approx_short_function const ApproximateOptions& options = ApproximateOptions{}) override; // clang-format on + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(T* gradient, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format off bool approximate(const GSLVectorType* y, size_t L, @@ -43,6 +66,28 @@ class dirac_to_dirac_approx_short_function const ApproximateOptions& options = ApproximateOptions{}) override; // clang-format on + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format on + // clang-format off bool approximate(GSLMatrixType* y, size_t L, @@ -54,6 +99,26 @@ class dirac_to_dirac_approx_short_function const ApproximateOptions& options = ApproximateOptions{}) override; // clang-format on + // clang-format off + void modified_van_mises_distance_sq(T* distance, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + wXf wXcallback, + wXd wXDcallback) override; + // clang-format on + private: static double c_b(size_t bMax); static double modified_van_mises_distance_sq(const gsl_vector* x, @@ -75,6 +140,24 @@ bool dirac_to_dirac_approx_short_function::approximate( wXf wXcallback, wXd wXDcallback, GslminimizerResult* result, const ApproximateOptions& options); +template <> +void dirac_to_dirac_approx_short_function< + double>::modified_van_mises_distance_sq(double* distance, + const gsl_vector* y, size_t L, + size_t N, size_t bMax, + gsl_vector* x, wXf wXcallback, + wXd wXDcallback); + +template <> +void dirac_to_dirac_approx_short_function< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + wXf wXcallback, + wXd wXDcallback); + extern template class dirac_to_dirac_approx_short_function; #endif // DIRAC_TO_DIRAC_SHORT_FUNCTION_H \ No newline at end of file diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.cpp b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.cpp index 848d103..81dc2e9 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.cpp +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.cpp @@ -14,6 +14,8 @@ #include "capture_time.h" #include "dirac_to_dirac_optimization_params.h" #include "gsl_minimizer.h" +#include "gsl_utils_weight_helper.h" +#include "gsl_utils_view_helper.h" #include "math_util_defs.h" #define eps 0.000001 @@ -26,25 +28,47 @@ bool dirac_to_dirac_approx_short_thread::approximate( const ApproximateOptions& options) { assert(x != nullptr); assert(y != nullptr); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::vector_view_from_array(x, L * N); - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::vector_view_from_array(y, M * N); - - GSLVectorType* wXVector = nullptr; - GSLVectorViewType wXVectorView; - if (wX) { - wXVectorView = GSLTemplateTypeAlias::vector_view_from_array(wX, L); - wXVector = &(wXVectorView.vector); - } - GSLVectorType* wYVector = nullptr; - GSLVectorViewType wYVectorView; - if (wY) { - wYVectorView = GSLTemplateTypeAlias::vector_view_from_array(wY, M); - wYVector = &(wYVectorView.vector); - } - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wXVector, - wYVector, result, options); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + return approximate(vectorViewY, L, N, bMax, vectorViewX, vectorViewWX, + vectorViewWY, result, options); +} + +template +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + T* distance, const T* y, size_t M, size_t L, size_t N, size_t bMax, T* x, + const T* wX, const T* wY) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, vectorViewX, + vectorViewWX, vectorViewWY); +} + +template +void dirac_to_dirac_approx_short_thread< + T>::modified_van_mises_distance_sq_derivative(T* gradient, const T* y, + size_t M, size_t L, size_t N, + size_t bMax, T* x, + const T* wX, const T* wY) { + assert(x != nullptr); + assert(y != nullptr); + + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + GSLMatrixView matrixViewGradient(gradient, L, N); + modified_van_mises_distance_sq_derivative(matrixViewGradient, vectorViewY, L, + N, bMax, vectorViewX, vectorViewWX, + vectorViewWY); } template @@ -56,14 +80,43 @@ bool dirac_to_dirac_approx_short_thread::approximate( assert(x->size1 == L); size_t N = y->size2; - GSLVectorViewType yFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(y); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(x); - return approximate(&(yFlat.vector), L, N, bMax, &(xFlat.vector), wX, wY, + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + return approximate(vectorViewY, L, N, bMax, vectorViewX, wX, wY, result, options); } +template +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + T* distance, GSLMatrixType* y, size_t L, size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX, const GSLVectorType* wY) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + modified_van_mises_distance_sq(distance, vectorViewY, L, N, bMax, + vectorViewX, wX, wY); +} + +template +void dirac_to_dirac_approx_short_thread< + T>::modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + GSLMatrixType* y, size_t L, + size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX, + const GSLVectorType* wY) { + assert(x->size2 == y->size2); + assert(x->size1 == L); + + size_t N = y->size2; + GSLVectorView vectorViewY(y); + GSLVectorView vectorViewX(x); + modified_van_mises_distance_sq_derivative(gradient, vectorViewY, L, N, + bMax, vectorViewX, wX, wY); +} + template inline double dirac_to_dirac_approx_short_thread::c_b(size_t bMax) { return 100.00; @@ -210,41 +263,21 @@ bool dirac_to_dirac_approx_short_thread::approximate( const gsl_vector_float* y, size_t L, size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX, const gsl_vector_float* wY, GslminimizerResult* result, const ApproximateOptions& options) { - gsl_vector* yDouble = gsl_vector_alloc(y->size); - gsl_vector* xDouble = gsl_vector_alloc(x->size); - gsl_vector* wXDouble = nullptr; - gsl_vector* wYDouble = nullptr; - - for (size_t i = 0; i < y->size; ++i) { - yDouble->data[i] = static_cast(y->data[i]); - } - if (wX) { - wXDouble = gsl_vector_alloc(L); - for (size_t i = 0; i < L; i++) { - wXDouble->data[i] = static_cast(wX->data[i]); - } - } - if (wY) { - const size_t M = y->size / N; - wYDouble = gsl_vector_alloc(M); - for (size_t i = 0; i < M; i++) { - wYDouble->data[i] = static_cast(wY->data[i]); - } - } - + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); dirac_to_dirac_approx_short_thread doubleApprox; - bool success = doubleApprox.approximate(yDouble, L, N, bMax, xDouble, - wXDouble, wYDouble, result, options); + bool success = doubleApprox.approximate(vectorViewY, L, N, bMax, + vectorViewX, vectorViewWX, + vectorViewWY, result, options); - for (size_t i = 0; i < x->size; ++i) { - x->data[i] = static_cast(xDouble->data[i]); + const size_t xSize = x->size; + for (size_t i = 0; i < xSize; ++i) { + x->data[i] = static_cast(vectorViewX.get()->data[i]); } - gsl_vector_free(yDouble); - gsl_vector_free(xDouble); - if (wXDouble) gsl_vector_free(wXDouble); - if (wYDouble) gsl_vector_free(wYDouble); - return success; } @@ -266,28 +299,12 @@ bool dirac_to_dirac_approx_short_thread::approximate( } } - const gsl_vector* localWX; - const bool freeWx = wX == nullptr; - if (freeWx) { - gsl_vector* tmpWx = gsl_vector_alloc(L); - gsl_vector_set_all(tmpWx, 1.00 / static_cast(L)); - localWX = tmpWx; - } else { - localWX = wX; - } - const gsl_vector* localWY; - const bool freeWy = wY == nullptr; - if (freeWy) { - gsl_vector* tmpWy = gsl_vector_alloc(M); - gsl_vector_set_all(tmpWy, 1.00 / static_cast(M)); - localWY = tmpWy; - } else { - localWY = wY; - } + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); // Set up optimization parameters DiracToDiracConstWeightOptimizationParams params = - DiracToDiracConstWeightOptimizationParams(localWX, localWY, y, N, M, L, + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, bMax, c_b(bMax)); gsl_minimizer gslMinimizer = gsl_minimizer( @@ -301,5 +318,86 @@ bool dirac_to_dirac_approx_short_thread::approximate( return status == GSL_SUCCESS; } +template <> +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax, + gsl_vector_float* x, const gsl_vector_float* wX, + const gsl_vector_float* wY) { + double distanceDouble = 0.00; + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + dirac_to_dirac_approx_short_thread doubleApprox; + doubleApprox.modified_van_mises_distance_sq(&distanceDouble, vectorViewY, + L, N, bMax, vectorViewX, + vectorViewWX, vectorViewWY); + *distance = static_cast(distanceDouble); +} + +template <> +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax, + gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY) { + const size_t M = y->size / N; + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); + DiracToDiracConstWeightOptimizationParams optiParams = + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, bMax, + c_b(bMax)); + *distance = modified_van_mises_distance_sq(x, &optiParams); +} + +template <> +void dirac_to_dirac_approx_short_thread:: + modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient, + const gsl_vector_float* y, + size_t L, size_t N, size_t bMax, + gsl_vector_float* x, + const gsl_vector_float* wX, + const gsl_vector_float* wY) { + gsl_matrix* gradientDouble = + gsl_matrix_alloc(gradient->size1, gradient->size2); + + const size_t M = y->size / N; + GSLVectorView vectorViewY(y, M * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWY(wY, M); + GSLVectorView vectorViewWX(wX, L); + dirac_to_dirac_approx_short_thread doubleApprox; + + doubleApprox.modified_van_mises_distance_sq_derivative( + gradientDouble, vectorViewY, L, N, bMax, vectorViewX, + vectorViewWX, vectorViewWY); + + const size_t gradientSize = gradient->size1 * gradient->size2; + for (size_t i = 0; i < gradientSize; i++) + gradient->data[i] = static_cast(gradientDouble->data[i]); + + gsl_matrix_free(gradientDouble); +} + +template <> +void dirac_to_dirac_approx_short_thread< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + const gsl_vector* wX, + const gsl_vector* wY) { + const size_t M = y->size / N; + GSLWeightHelper wXHelper(wX, L); + GSLWeightHelper wYHelper(wY, M); + DiracToDiracConstWeightOptimizationParams optiParams = + DiracToDiracConstWeightOptimizationParams(wXHelper, wYHelper, y, N, M, L, bMax, + c_b(bMax)); + gsl_vector_view flatGradient = + gsl_vector_view_array(gradient->data, gradient->size1 * gradient->size2); + modified_van_mises_distance_sq_derivative(x, &optiParams, + &(flatGradient.vector)); +} + template class dirac_to_dirac_approx_short_thread; template class dirac_to_dirac_approx_short_thread; \ No newline at end of file diff --git a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.h b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.h index 1ae7756..df69d20 100644 --- a/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.h +++ b/lib/dirac_to_dirac/dirac_to_dirac_approx_short_thread.h @@ -12,41 +12,108 @@ class dirac_to_dirac_approx_short_thread : public dirac_to_dirac_approx_i { using GSLVectorViewType = typename dirac_to_dirac_approx_i::GSLVectorViewType; using GSLMatrixType = typename dirac_to_dirac_approx_i::GSLMatrixType; + using GSLMatrixViewType = typename dirac_to_dirac_approx_i::GSLMatrixViewType; // clang-format off bool approximate(const T* y, - size_t M, - size_t L, - size_t N, - size_t bMax, - T* x, - const T* wX = nullptr, - const T* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t M, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX = nullptr, + const T* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + const T *wX = nullptr, + const T *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(T* gradient, + const T *y, + size_t M, + size_t L, + size_t N, + size_t bMax, + T *x, + const T *wX = nullptr, + const T *wY = nullptr) override; // clang-format on // clang-format off bool approximate(const GSLVectorType* y, - size_t L, - size_t N, - size_t bMax, - GSLVectorType* x, - const GSLVectorType* wX = nullptr, - const GSLVectorType* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX = nullptr, + const GSLVectorType* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + const GSLVectorType *y, + size_t L, + size_t N, + size_t bMax, + GSLVectorType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; // clang-format on // clang-format off bool approximate(GSLMatrixType* y, - size_t L, - size_t bMax, - GSLMatrixType* x, - const GSLVectorType* wX = nullptr, - const GSLVectorType* wY = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX = nullptr, + const GSLVectorType* wY = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + GSLMatrixType *y, + size_t L, + size_t bMax, + GSLMatrixType *x, + const GSLVectorType *wX = nullptr, + const GSLVectorType *wY = nullptr) override; // clang-format on private: @@ -83,7 +150,37 @@ bool dirac_to_dirac_approx_short_thread::approximate( const GSLVectorType* wX, const GSLVectorType* wY, GslminimizerResult* result, const ApproximateOptions& options); +template <> +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax, + gsl_vector_float* x, const gsl_vector_float* wX, + const gsl_vector_float* wY); + +template <> +void dirac_to_dirac_approx_short_thread::modified_van_mises_distance_sq( + double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax, + gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY); + +template <> +void dirac_to_dirac_approx_short_thread:: + modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient, + const gsl_vector_float* y, + size_t L, size_t N, size_t bMax, + gsl_vector_float* x, + const gsl_vector_float* wX, + const gsl_vector_float* wY); + +template <> +void dirac_to_dirac_approx_short_thread< + double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient, + const gsl_vector* y, + size_t L, size_t N, + size_t bMax, + gsl_vector* x, + const gsl_vector* wX, + const gsl_vector* wY); + extern template class dirac_to_dirac_approx_short_thread; extern template class dirac_to_dirac_approx_short_thread; -#endif // DIRAC_TO_DIRAC_SHORT_H +#endif // DIRAC_TO_DIRAC_SHORT_THREAD_H diff --git a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp index 6c05395..aaa0d41 100644 --- a/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp +++ b/lib/dirac_to_dirac/tests/unit_tests/dirac_to_dirac_approx_short_test_van_mises_distance_sq_derivative.cpp @@ -6,6 +6,7 @@ #include "dirac_to_dirac_test_case_params.h" #include "gradient_van_mises_distance_sq_const_weight.h" #include "gsl_utils_allocation.h" +#include "gsl_utils_view_helper.h" #include "gtest_compare_vec.h" class dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative @@ -56,7 +57,9 @@ class dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative gradVanMisesDistanceSqConstWeight; }; -gradient_van_mises_distance_sq_const_weight dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative::gradVanMisesDistanceSqConstWeight; +gradient_van_mises_distance_sq_const_weight + dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative:: + gradVanMisesDistanceSqConstWeight; TEST_P( dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, @@ -78,6 +81,50 @@ TEST_P( } } +TEST_P( + dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance) { + DiracToDiracTestCaseParams p = GetParam(); + + const double c_b = dirac_to_dirac_approx_short::c_b(p.bMax); + DiracToDiracConstWeightOptimizationParams params(y, p.N, p.M, p.L, p.bMax, + c_b); + + // wrapper + double distance_wrapper = 0; + auto d2d = dirac_to_dirac_approx_short(); + d2d.modified_van_mises_distance_sq(&distance_wrapper, y, p.L, p.N, p.bMax, x); + // internal impl + double distance_internal = 1; + distance_internal = + dirac_to_dirac_approx_short::modified_van_mises_distance_sq( + x, ¶ms); + + ASSERT_TRUE(distance_wrapper == distance_internal); +} + +TEST_P( + dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient) { + DiracToDiracTestCaseParams p = GetParam(); + + const double c_b = dirac_to_dirac_approx_short::c_b(p.bMax); + DiracToDiracConstWeightOptimizationParams params(y, p.N, p.M, p.L, p.bMax, + c_b); + + // wrapper + auto d2d = dirac_to_dirac_approx_short(); + GSLMatrixView numericalGradView(numericalGrad, p.L, p.N); + d2d.modified_van_mises_distance_sq_derivative(numericalGradView, y, p.L, p.N, + p.bMax, x); + + // internal impl + dirac_to_dirac_approx_short< + double>::modified_van_mises_distance_sq_derivative(x, ¶ms, + analyticalGrad); + ASSERT_TRUE(assert_gsl_vectors_close(analyticalGrad, numericalGrad, eps)); +} + INSTANTIATE_TEST_SUITE_P( ModifiedVanMisesDistanceDerivativeParameterizedTest, dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, diff --git a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_c.h b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_c.h index 88d9b25..533626b 100644 --- a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_c.h +++ b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_c.h @@ -28,6 +28,22 @@ DLL_EXPORT bool dirac_to_dirac_approx_short_double_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +dirac_to_dirac_approx_short_double_modified_van_mises_distance_sq( + void* instance, double* distance, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, const double* wX, const double* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq(distance, y, M, L, N, bMax, x, wX, wY); +} + +DLL_EXPORT void +dirac_to_dirac_approx_short_double_modified_van_mises_distance_sq_derivative( + void* instance, double* gradient, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, const double* wX, const double* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(gradient, y, M, L, N, bMax, x, wX, wY); +} + DLL_EXPORT void* create_dirac_to_dirac_approx_short_float() { return new dirac_to_dirac_approx_short(); } @@ -45,6 +61,22 @@ DLL_EXPORT bool dirac_to_dirac_approx_short_float_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +dirac_to_dirac_approx_short_float_modified_van_mises_distance_sq( + void* instance, float* distance, const float* y, size_t M, size_t L, + size_t N, size_t bMax, float* x, const float* wX, const float* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq(distance, y, M, L, N, bMax, x, wX, wY); +} + +DLL_EXPORT void +dirac_to_dirac_approx_short_float_modified_van_mises_distance_sq_derivative( + void* instance, float* gradient, const float* y, size_t M, size_t L, + size_t N, size_t bMax, float* x, const float* wX, const float* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(gradient, y, M, L, N, bMax, x, wX, wY); +} + } // extern "C" #endif // DIRAC_TO_DIRAC_APPROX_SHORT_C_H diff --git a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_function_c.h b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_function_c.h index 6d4e556..ad9e4ac 100644 --- a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_function_c.h +++ b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_function_c.h @@ -32,6 +32,30 @@ DLL_EXPORT bool dirac_to_dirac_approx_short_function_double_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +dirac_to_dirac_approx_short_function_double_modified_van_mises_distance_sq( + void* instance, double* distance, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, + typename dirac_to_dirac_approx_function_i::wXf wXcallback, + typename dirac_to_dirac_approx_function_i::wXd wXDcallback) { + auto* obj = + static_cast*>(instance); + obj->modified_van_mises_distance_sq(distance, y, M, L, N, bMax, x, + wXcallback, wXDcallback); +} + +DLL_EXPORT void +dirac_to_dirac_approx_short_function_double_modified_van_mises_distance_sq_derivative( + void* instance, double* gradient, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, + typename dirac_to_dirac_approx_function_i::wXf wXcallback, + typename dirac_to_dirac_approx_function_i::wXd wXDcallback) { + auto* obj = + static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(gradient, y, M, L, N, bMax, x, + wXcallback, wXDcallback); +} + } // extern "C" #endif // DIRAC_TO_DIRAC_APPROX_SHORT_FUNCTION_C_H diff --git a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_thread_c.h b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_thread_c.h index 30348b2..0478a66 100644 --- a/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_thread_c.h +++ b/lib/dirac_to_dirac/wrappers/dirac_to_dirac_approx_short_thread_c.h @@ -30,6 +30,23 @@ DLL_EXPORT bool dirac_to_dirac_approx_short_thread_double_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +dirac_to_dirac_approx_short_thread_double_modified_van_mises_distance_sq( + void* instance, double* distance, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, const double* wX, const double* wY) { + auto* obj = + static_cast*>(instance); + obj->modified_van_mises_distance_sq(distance, y, M, L, N, bMax, x, wX, wY); +} + +DLL_EXPORT void +dirac_to_dirac_approx_short_thread_double_modified_van_mises_distance_sq_derivative( + void* instance, double* gradient, const double* y, size_t M, size_t L, + size_t N, size_t bMax, double* x, const double* wX, const double* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(gradient, y, M, L, N, bMax, x, wX, wY); +} + DLL_EXPORT void* create_dirac_to_dirac_approx_short_thread_float() { return new dirac_to_dirac_approx_short_thread(); } @@ -48,6 +65,22 @@ DLL_EXPORT bool dirac_to_dirac_approx_short_thread_float_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +dirac_to_dirac_approx_short_thread_float_modified_van_mises_distance_sq( + void* instance, float* distance, const float* y, size_t M, size_t L, + size_t N, size_t bMax, float* x, const float* wX, const float* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq(distance, y, M, L, N, bMax, x, wX, wY); +} + +DLL_EXPORT void +dirac_to_dirac_approx_short_thread_float_modified_van_mises_distance_sq_derivative( + void* instance, float* gradient, const float* y, size_t M, size_t L, + size_t N, size_t bMax, float* x, const float* wX, const float* wY) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(gradient, y, M, L, N, bMax, x, wX, wY); +} + } // extern "C" #endif // DIRAC_TO_DIRAC_APPROX_SHORT_THREAD_C_H diff --git a/lib/gm_to_dirac/gm_to_dirac_approx_i.h b/lib/gm_to_dirac/gm_to_dirac_approx_i.h index 0f5b7aa..aef53fe 100644 --- a/lib/gm_to_dirac/gm_to_dirac_approx_i.h +++ b/lib/gm_to_dirac/gm_to_dirac_approx_i.h @@ -14,7 +14,7 @@ /** * @brief interface for the gausian mixture to dirac approximation * - * @tparam T type of the vector (float, double, long double) + * @tparam T type of the vector (float, double) */ template class gm_to_dirac_approx_i { @@ -41,6 +41,41 @@ class gm_to_dirac_approx_i { T* x, const T* wX, GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(const T* covDiag, T* distance, + size_t L, size_t N, size_t bMax, + T* x, const T* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative(const T* covDiag, + T* gradient, size_t L, + size_t N, size_t bMax, + T* x, const T* wX) = 0; + /** * @brief approximate using gsl vectors * @@ -59,6 +94,41 @@ class gm_to_dirac_approx_i { const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(const GSLVectorType* covDiag, + T* distance, size_t L, size_t N, + size_t bMax, GSLVectorType* x, + const GSLVectorType* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + const GSLVectorType* covDiag, GSLVectorType* gradient, size_t L, size_t N, + size_t bMax, GSLVectorType* x, const GSLVectorType* wX) = 0; + /** * @brief approximate using gsl matricies where possible * @@ -76,6 +146,41 @@ class gm_to_dirac_approx_i { size_t bMax, GSLMatrixType* x, const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(const GSLVectorType* covDiag, + T* distance, size_t L, size_t N, + size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative( + const GSLVectorType* covDiag, GSLMatrixType* gradient, size_t L, size_t N, + size_t bMax, GSLMatrixType* x, const GSLVectorType* wX) = 0; }; #endif // GM_TO_DIRAC_I_H diff --git a/lib/gm_to_dirac/gm_to_dirac_approx_standard_normal_distribution_i.h b/lib/gm_to_dirac/gm_to_dirac_approx_standard_normal_distribution_i.h index 6db963d..04613dc 100644 --- a/lib/gm_to_dirac/gm_to_dirac_approx_standard_normal_distribution_i.h +++ b/lib/gm_to_dirac/gm_to_dirac_approx_standard_normal_distribution_i.h @@ -14,7 +14,7 @@ /** * @brief interface for the gausian mixture to dirac approximation * - * @tparam T type of the vector (float, double, long double) + * @tparam T type of the vector (float, double) */ template class gm_to_dirac_approx_standard_normal_distribution_i { @@ -28,7 +28,7 @@ class gm_to_dirac_approx_standard_normal_distribution_i { /** * @brief approximate using raw pointers * - * @param L number of data points for apprioximation + * @param L number of data points for approximation * @param N dimension of the data * @param bMax bMax * @param x first guess for the approximation and return value @@ -40,6 +40,40 @@ class gm_to_dirac_approx_standard_normal_distribution_i { GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, T* x, + const T* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative(T* gradient, size_t L, + size_t N, size_t bMax, + T* x, const T* wX) = 0; + /** * @brief approximate using gsl vectors * @@ -56,6 +90,40 @@ class gm_to_dirac_approx_standard_normal_distribution_i { const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) = 0; + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, GSLVectorType* x, + const GSLVectorType* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative(GSLVectorType* gradient, size_t L, + size_t N, size_t bMax, + GSLVectorType* x, const GSLVectorType* wX) = 0; + /** * @brief approximate using gsl matricies where possible * @@ -71,6 +139,40 @@ class gm_to_dirac_approx_standard_normal_distribution_i { virtual bool approximate(size_t L, size_t N, size_t bMax, GSLMatrixType* x, const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param distance pointer to distance value to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX) = 0; + + /** + * @brief calculate modified van mises distance based on standard normal + * deviation and x + * + * @param gradient pointer to gradient to be calculated + * @param L number of data points for approximation + * @param N dimension of the data + * @param bMax bMax + * @param x first guess for the approximation and return value + * @param result minimizer result + * @param options options for minimizer + * @return true, on success, false otherwise + */ + virtual void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, size_t L, + size_t N, size_t bMax, + GSLMatrixType* x, const GSLVectorType* wX) = 0; }; #endif // GM_TO_DIRAC_APPROX_STANDARD_NORMAL_DISTRIBUTION_I_H diff --git a/lib/gm_to_dirac/gm_to_dirac_short.cpp b/lib/gm_to_dirac/gm_to_dirac_short.cpp index 1fc7acd..e70b4f9 100644 --- a/lib/gm_to_dirac/gm_to_dirac_short.cpp +++ b/lib/gm_to_dirac/gm_to_dirac_short.cpp @@ -18,6 +18,8 @@ #include "capture_time.h" #include "gsl_minimizer.h" +#include "gsl_utils_view_helper.h" +#include "gsl_utils_weight_helper.h" #include "math_util_defs.h" template @@ -28,20 +30,36 @@ bool gm_to_dirac_short::approximate(const T* covDiag, size_t L, size_t N, assert(x != nullptr); assert(covDiag != nullptr); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::vector_view_from_array(x, L * N); - GSLVectorViewType covDiagView = - GSLTemplateTypeAlias::vector_view_from_array(covDiag, N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWX(wX, L); + GSLVectorView vectorViewCovDiag(covDiag, N); + return approximate(vectorViewCovDiag, L, N, bMax, vectorViewX, vectorViewWX, + result, options); +} - GSLVectorType* wXVector = nullptr; - GSLVectorViewType wXVectorView; - if (wX) { - wXVectorView = GSLTemplateTypeAlias::vector_view_from_array(wX, L); - wXVector = &(wXVectorView.vector); - } - approximate(&(covDiagView.vector), L, N, bMax, &(xFlat.vector), wXVector, - result, options); - return true; +template +void gm_to_dirac_short::modified_van_mises_distance_sq(const T* covDiag, + T* distance, size_t L, + size_t N, size_t bMax, + T* x, const T* wX) { + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWX(wX, L); + GSLVectorView vectorViewCovDiag(covDiag, N); + return modified_van_mises_distance_sq(vectorViewCovDiag, distance, L, N, bMax, + vectorViewX, vectorViewWX); +} + +template +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const T* covDiag, T* gradient, size_t L, size_t N, size_t bMax, T* x, + const T* wX) { + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWX(wX, L); + GSLVectorView vectorViewCovDiag(covDiag, N); + GSLVectorView vectorViewGradient(gradient, L * N); + return modified_van_mises_distance_sq_derivative( + vectorViewCovDiag, vectorViewGradient, L, N, bMax, vectorViewX, + vectorViewWX); } template @@ -52,9 +70,31 @@ bool gm_to_dirac_short::approximate(const GSLVectorType* covDiag, size_t L, const ApproximateOptions& options) { assert(x->size1 == L); assert(x->size2 == N); - GSLVectorViewType xFlat = - GSLTemplateTypeAlias::flatten_matrix_to_vector(x); - return approximate(covDiag, L, N, bMax, &(xFlat.vector), wX, result, options); + GSLVectorView vectorViewX(x); + return approximate(covDiag, L, N, bMax, vectorViewX, wX, result, options); +} + +template +void gm_to_dirac_short::modified_van_mises_distance_sq( + const GSLVectorType* covDiag, T* distance, size_t L, size_t N, size_t bMax, + GSLMatrixType* x, const GSLVectorType* wX) { + assert(x->size1 == L); + assert(x->size2 == N); + GSLVectorView vectorViewX(x); + return modified_van_mises_distance_sq(covDiag, distance, L, N, bMax, + vectorViewX, wX); +} + +template +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const GSLVectorType* covDiag, GSLMatrixType* gradient, size_t L, size_t N, + size_t bMax, GSLMatrixType* x, const GSLVectorType* wX) { + assert(x->size1 == L); + assert(x->size2 == N); + GSLVectorView vectorViewX(x); + GSLVectorView vectorViewGradient(gradient); + return modified_van_mises_distance_sq_derivative(covDiag, vectorViewGradient, + L, N, bMax, vectorViewX, wX); } template @@ -155,17 +195,8 @@ bool gm_to_dirac_short::approximate(const gsl_vector* covDiag, size_t L, gsl_rng_free(r); } - const gsl_vector* localWX; - const bool freeWx = wX == nullptr; - if (freeWx) { - gsl_vector* tmpWx = gsl_vector_alloc(L); - gsl_vector_set_all(tmpWx, 1.00 / static_cast(L)); - localWX = tmpWx; - } else { - localWX = wX; - } - - GMToDiracConstWeightOptimizationParams params(covDiag, localWX, N, L, bMax, + GSLWeightHelper wXHelper(wX, L); + GMToDiracConstWeightOptimizationParams params(covDiag, wXHelper, N, L, bMax, c_b(bMax)); gsl_minimizer gslMinimizer( @@ -177,10 +208,64 @@ bool gm_to_dirac_short::approximate(const gsl_vector* covDiag, size_t L, correctMean(x, params.wX, L, N); - if (freeWx) gsl_vector_free(const_cast(localWX)); - return status == GSL_SUCCESS; } +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq( + const gsl_vector_float* covDiag, float* distance, size_t L, size_t N, + size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX) { + double distanceDouble = 0.00; + GSLVectorView vectorViewCovDiag(covDiag, L * N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWX(wX, L); + gm_to_dirac_short doubleApprox; + doubleApprox.modified_van_mises_distance_sq(vectorViewCovDiag, + &distanceDouble, L, N, bMax, + vectorViewX, vectorViewWX); + *distance = static_cast(distanceDouble); +} + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq( + const gsl_vector* covDiag, double* distance, size_t L, size_t N, + size_t bMax, gsl_vector* x, const gsl_vector* wX) { + GSLWeightHelper wXHelper(wX, L); + GMToDiracConstWeightOptimizationParams optiParams = + GMToDiracConstWeightOptimizationParams(covDiag, wXHelper, N, L, bMax, + c_b(bMax)); + *distance = modified_van_mises_distance_sq(x, &optiParams); +} + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const gsl_vector_float* covDiag, gsl_vector_float* gradient, size_t L, + size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX) { + gsl_vector* gradientDouble = gsl_vector_alloc(gradient->size); + + GSLVectorView vectorViewCovDiag(covDiag, N); + GSLVectorView vectorViewX(x, L * N); + GSLVectorView vectorViewWX(wX, L); + gm_to_dirac_short doubleApprox; + doubleApprox.modified_van_mises_distance_sq_derivative( + vectorViewCovDiag, gradientDouble, L, N, bMax, vectorViewX, vectorViewWX); + + for (size_t i = 0; i < gradient->size; i++) + gradient->data[i] = static_cast(gradientDouble->data[i]); + + gsl_vector_free(gradientDouble); +} + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const gsl_vector* covDiag, gsl_vector* gradient, size_t L, size_t N, + size_t bMax, gsl_vector* x, const gsl_vector* wX) { + GSLWeightHelper wXHelper(wX, L); + GMToDiracConstWeightOptimizationParams optiParams = + GMToDiracConstWeightOptimizationParams(covDiag, wXHelper, N, L, bMax, + c_b(bMax)); + modified_van_mises_distance_sq_derivative(x, &optiParams, gradient); +} + template class gm_to_dirac_short; template class gm_to_dirac_short; \ No newline at end of file diff --git a/lib/gm_to_dirac/gm_to_dirac_short.h b/lib/gm_to_dirac/gm_to_dirac_short.h index d743fa8..db6e003 100644 --- a/lib/gm_to_dirac/gm_to_dirac_short.h +++ b/lib/gm_to_dirac/gm_to_dirac_short.h @@ -15,35 +15,95 @@ class gm_to_dirac_short : public gm_to_dirac_approx_i { // clang-format off bool approximate(const T* covDiag, - size_t L, - size_t N, - size_t bMax, - T* x, - const T* wX = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(const T* covDiag, + T* distance, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(const T* covDiag, + T* gradient, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX) override; // clang-format on // clang-format off bool approximate(const GSLVectorType* covDiag, - size_t L, - size_t N, - size_t bMax, - GSLVectorType* x, - const GSLVectorType* wX = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(const GSLVectorType* covDiag, + T* distance, + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag, + GSLVectorType* gradient, + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX) override; // clang-format on // clang-format off bool approximate(const GSLVectorType* covDiag, - size_t L, - size_t N, - size_t bMax, - GSLMatrixType* x, - const GSLVectorType* wX = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t L, + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(const GSLVectorType* covDiag, + T* distance, + size_t L, + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag, + GSLMatrixType* gradient, + size_t L, + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX) override; // clang-format on private: @@ -74,6 +134,12 @@ class gm_to_dirac_short : public gm_to_dirac_approx_i { FRIEND_TEST( gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, parameterized_test_modified_van_mises_distance_sq_derivative); + FRIEND_TEST( + gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance); + FRIEND_TEST( + gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient); }; #include "gm_to_dirac_short.tpp" @@ -93,6 +159,26 @@ bool gm_to_dirac_short::approximate(const gsl_vector* covDiag, size_t L, GslminimizerResult* result, const ApproximateOptions& options); +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq( + const gsl_vector_float* covDiag, float* distance, size_t L, size_t N, + size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX); + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq( + const gsl_vector* covDiag, double* distance, size_t L, size_t N, + size_t bMax, gsl_vector* x, const gsl_vector* wX); + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const gsl_vector_float* covDiag, gsl_vector_float* gradient, size_t L, + size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX); + +template <> +void gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + const gsl_vector* covDiag, gsl_vector* gradient, size_t L, size_t N, + size_t bMax, gsl_vector* x, const gsl_vector* wX); + extern template class gm_to_dirac_short; extern template class gm_to_dirac_short; diff --git a/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.cpp b/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.cpp index 5d942e9..1d61a63 100644 --- a/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.cpp +++ b/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.cpp @@ -1,60 +1,118 @@ #include "gm_to_dirac_short_standard_normal_deviation.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include #include #include #include +#include #include "capture_time.h" #include "gm_to_dirac_short.h" #include "gsl_minimizer.h" +#include "gsl_utils_view_helper.h" #include "math_util_defs.h" -// #define USE_CACHE_MANAGER - template bool gm_to_dirac_short_standard_normal_deviation::approximate( size_t L, size_t N, size_t bMax, T* x, const T* wX, GslminimizerResult* result, const ApproximateOptions& options) { - std::vector covDiag(N, 1.0); + std::vector covDiag(N, T(1)); gm_to_dirac_short gmToDiracInstance; return gmToDiracInstance.approximate(covDiag.data(), L, N, bMax, x, wX, result, options); } +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, T* x, const T* wX) { + std::vector covDiag(N, T(1)); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq(covDiag.data(), distance, L, + N, bMax, x, wX); +} + +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq_derivative(T* gradient, size_t L, + size_t N, size_t bMax, T* x, + const T* wX) { + std::vector covDiag(N, T(1)); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq_derivative( + covDiag.data(), gradient, L, N, bMax, x, wX); +} + template bool gm_to_dirac_short_standard_normal_deviation::approximate( size_t L, size_t N, size_t bMax, GSLVectorType* x, const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) { - std::vector covDiag(N, 1.0); - GSLVectorViewType covDiagView = - GSLTemplateTypeAlias::vector_view_from_array(covDiag.data(), N); + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); gm_to_dirac_short gmToDiracInstance; - return gmToDiracInstance.approximate(&(covDiagView.vector), L, N, bMax, x, wX, - result, options); + return gmToDiracInstance.approximate(covDiagView, L, N, bMax, x, wX, result, + options); +} + +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, GSLVectorType* x, + const GSLVectorType* wX) { + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq(covDiagView, distance, L, N, + bMax, x, wX); +} + +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq_derivative(GSLVectorType* gradient, + size_t L, size_t N, + size_t bMax, GSLVectorType* x, + const GSLVectorType* wX) { + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq_derivative( + covDiagView, gradient, L, N, bMax, x, wX); } template bool gm_to_dirac_short_standard_normal_deviation::approximate( size_t L, size_t N, size_t bMax, GSLMatrixType* x, const GSLVectorType* wX, GslminimizerResult* result, const ApproximateOptions& options) { - std::vector covDiag(N, 1.0); - GSLVectorViewType covDiagView = - GSLTemplateTypeAlias::vector_view_from_array(covDiag.data(), N); + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); gm_to_dirac_short gmToDiracInstance; - return gmToDiracInstance.approximate(&(covDiagView.vector), L, N, bMax, x, wX, - result, options); + return gmToDiracInstance.approximate(covDiagView, L, N, bMax, x, wX, result, + options); +} + +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq(T* distance, size_t L, size_t N, + size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX) { + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq(covDiagView, distance, L, N, + bMax, x, wX); +} + +template +void gm_to_dirac_short_standard_normal_deviation< + T>::modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + size_t L, size_t N, + size_t bMax, GSLMatrixType* x, + const GSLVectorType* wX) { + std::vector covDiag(N, T(1)); + GSLVectorView covDiagView(covDiag.data(), N); + gm_to_dirac_short gmToDiracInstance; + gmToDiracInstance.modified_van_mises_distance_sq_derivative( + covDiagView, gradient, L, N, bMax, x, wX); } template class gm_to_dirac_short_standard_normal_deviation; diff --git a/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.h b/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.h index e195d62..e5c0722 100644 --- a/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.h +++ b/lib/gm_to_dirac/gm_to_dirac_short_standard_normal_deviation.h @@ -1,8 +1,6 @@ #ifndef GM_TO_DIRAC_SHORT_STANDARD_NORMAL_DEVIATION_H #define GM_TO_DIRAC_SHORT_STANDARD_NORMAL_DEVIATION_H -#include - #include "gm_to_dirac_approx_standard_normal_distribution_i.h" #include "gm_to_dirac_optimization_params.h" @@ -13,41 +11,92 @@ class gm_to_dirac_short_standard_normal_deviation using GSLVectorType = typename gm_to_dirac_approx_standard_normal_distribution_i< T>::GSLVectorType; - using GSLVectorViewType = - typename gm_to_dirac_approx_standard_normal_distribution_i< - T>::GSLVectorViewType; using GSLMatrixType = typename gm_to_dirac_approx_standard_normal_distribution_i< T>::GSLMatrixType; // clang-format off bool approximate(size_t L, - size_t N, - size_t bMax, - T* x, - const T* wX, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t N, + size_t bMax, + T* x, + const T* wX, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; // clang-format on - + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(T* gradient, + size_t L, + size_t N, + size_t bMax, + T* x, + const T* wX) override; + // clang-format on + // clang-format off bool approximate(size_t L, - size_t N, - size_t bMax, - GSLVectorType* x, - const GSLVectorType* wX = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLVectorType* gradient, + size_t L, + size_t N, + size_t bMax, + GSLVectorType* x, + const GSLVectorType* wX) override; // clang-format on // clang-format off bool approximate(size_t L, - size_t N, - size_t bMax, - GSLMatrixType* x, - const GSLVectorType* wX = nullptr, - GslminimizerResult* result = nullptr, - const ApproximateOptions& options = ApproximateOptions{}) override; + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX = nullptr, + GslminimizerResult* result = nullptr, + const ApproximateOptions& options = ApproximateOptions{}) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq(T* distance, + size_t L, + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX) override; + // clang-format on + + // clang-format off + void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient, + size_t L, + size_t N, + size_t bMax, + GSLMatrixType* x, + const GSLVectorType* wX) override; // clang-format on }; diff --git a/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp b/lib/gm_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp similarity index 69% rename from lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp rename to lib/gm_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp index a6e3068..5befb85 100644 --- a/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp +++ b/lib/gm_to_dirac/tests/unit_tests/gm_to_dirac_short_test_derivative.cpp @@ -54,7 +54,8 @@ class gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative }; gradient_van_mises_distance -gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative::gradVanMisesDistance; + gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative:: + gradVanMisesDistance; TEST_P(gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, parameterized_test_modified_van_mises_distance_sq_derivative) { @@ -74,6 +75,48 @@ TEST_P(gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, } } +TEST_P( + gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance) { + GmToDiracTestCaseParams p = GetParam(); + + const double c_b = gm_to_dirac_short::c_b(p.bMax); + GMToDiracConstWeightOptimizationParams params(covDiag, wX, p.N, p.L, p.bMax, + c_b); + + // wrapper + double distance_wrapper = 0; + auto gm2d = gm_to_dirac_short(); + gm2d.modified_van_mises_distance_sq(covDiag, &distance_wrapper, p.L, p.N, + p.bMax, x, wX); + // internal impl + double distance_internal = 1; + distance_internal = + gm_to_dirac_short::modified_van_mises_distance_sq(x, ¶ms); + + ASSERT_TRUE(distance_wrapper == distance_internal); +} + +TEST_P( + gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, + parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient) { + GmToDiracTestCaseParams p = GetParam(); + + const double c_b = gm_to_dirac_short::c_b(p.bMax); + GMToDiracConstWeightOptimizationParams params(covDiag, wX, p.N, p.L, p.bMax, + c_b); + + // wrapper + auto gm2d = gm_to_dirac_short(); + gm2d.modified_van_mises_distance_sq_derivative(covDiag, numericalGrad, p.L, p.N, + p.bMax, x, wX); + + // internal impl + gm_to_dirac_short::modified_van_mises_distance_sq_derivative( + x, ¶ms, analyticalGrad); + ASSERT_TRUE(assert_gsl_vectors_close(analyticalGrad, numericalGrad, eps)); +} + INSTANTIATE_TEST_SUITE_P( ModifiedVanMisesDistanceDerivativeParameterizedTest, gm_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative, diff --git a/lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_test_case_params.h b/lib/gm_to_dirac/tests/unit_tests/gm_to_dirac_test_case_params.h similarity index 100% rename from lib/dirac_to_dirac/tests/unit_tests/gm_to_dirac_test_case_params.h rename to lib/gm_to_dirac/tests/unit_tests/gm_to_dirac_test_case_params.h diff --git a/lib/gm_to_dirac/wrappers/gm_to_dirac_short_c.h b/lib/gm_to_dirac/wrappers/gm_to_dirac_short_c.h index a5934bc..7783162 100644 --- a/lib/gm_to_dirac/wrappers/gm_to_dirac_short_c.h +++ b/lib/gm_to_dirac/wrappers/gm_to_dirac_short_c.h @@ -28,6 +28,22 @@ DLL_EXPORT bool gm_to_dirac_short_double_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void gm_to_dirac_short_double_modified_van_mises_distance_sq( + void* instance, const double* covDiag, double* distance, size_t L, size_t N, + size_t bMax, double* x, const double* wX) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq(covDiag, distance, L, N, bMax, x, wX); +} + +DLL_EXPORT void +gm_to_dirac_short_double_modified_van_mises_distance_sq_derivative( + void* instance, const double* covDiag, double* gradient, size_t L, size_t N, + size_t bMax, double* x, const double* wX) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(covDiag, gradient, L, N, bMax, + x, wX); +} + DLL_EXPORT void* create_gm_to_dirac_short_float() { return new gm_to_dirac_short(); } @@ -45,6 +61,22 @@ DLL_EXPORT bool gm_to_dirac_short_float_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void gm_to_dirac_short_float_modified_van_mises_distance_sq( + void* instance, const float* covDiag, float* distance, size_t L, size_t N, + size_t bMax, float* x, const float* wX) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq(covDiag, distance, L, N, bMax, x, wX); +} + +DLL_EXPORT void +gm_to_dirac_short_float_modified_van_mises_distance_sq_derivative( + void* instance, const float* covDiag, float* gradient, size_t L, size_t N, + size_t bMax, float* x, const float* wX) { + auto* obj = static_cast*>(instance); + obj->modified_van_mises_distance_sq_derivative(covDiag, gradient, L, N, bMax, + x, wX); +} + } // extern "C" #endif // GM_TO_DIRAC_SHORT_C_H diff --git a/lib/gm_to_dirac/wrappers/gm_to_dirac_short_standard_normal_deviation_c.h b/lib/gm_to_dirac/wrappers/gm_to_dirac_short_standard_normal_deviation_c.h index f5636ae..b69ee9c 100644 --- a/lib/gm_to_dirac/wrappers/gm_to_dirac_short_standard_normal_deviation_c.h +++ b/lib/gm_to_dirac/wrappers/gm_to_dirac_short_standard_normal_deviation_c.h @@ -31,6 +31,24 @@ DLL_EXPORT bool gm_to_dirac_short_standard_normal_deviation_double_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +gm_to_dirac_short_standard_normal_deviation_double_modified_van_mises_distance_sq( + void* instance, double* distance, size_t L, size_t N, size_t bMax, + double* x, const double* wX) { + auto* obj = static_cast*>( + instance); + obj->modified_van_mises_distance_sq(distance, L, N, bMax, x, wX); +} + +DLL_EXPORT void +gm_to_dirac_short_standard_normal_deviation_double_modified_van_mises_distance_sq_derivative( + void* instance, double* gradient, size_t L, size_t N, size_t bMax, + double* x, const double* wX) { + auto* obj = static_cast*>( + instance); + obj->modified_van_mises_distance_sq_derivative(gradient, L, N, bMax, x, wX); +} + DLL_EXPORT void* create_gm_to_dirac_short_standard_normal_deviation_float() { return new gm_to_dirac_short_standard_normal_deviation(); } @@ -50,6 +68,24 @@ DLL_EXPORT bool gm_to_dirac_short_standard_normal_deviation_float_approximate( options ? *options : ApproximateOptions{}); } +DLL_EXPORT void +gm_to_dirac_short_standard_normal_deviation_float_modified_van_mises_distance_sq( + void* instance, float* distance, size_t L, size_t N, size_t bMax, float* x, + const float* wX) { + auto* obj = static_cast*>( + instance); + obj->modified_van_mises_distance_sq(distance, L, N, bMax, x, wX); +} + +DLL_EXPORT void +gm_to_dirac_short_standard_normal_deviation_float_modified_van_mises_distance_sq_derivative( + void* instance, float* gradient, size_t L, size_t N, size_t bMax, float* x, + const float* wX) { + auto* obj = static_cast*>( + instance); + obj->modified_van_mises_distance_sq_derivative(gradient, L, N, bMax, x, wX); +} + } // extern "C" #endif // GM_TO_DIRAC_SHORT_STANDARD_NORMAL_DEVIATION_C_H diff --git a/lib/gsl_types/gsl_vector_matrix_types.h b/lib/gsl_types/gsl_vector_matrix_types.h index 67824c8..c5822ab 100644 --- a/lib/gsl_types/gsl_vector_matrix_types.h +++ b/lib/gsl_types/gsl_vector_matrix_types.h @@ -19,6 +19,10 @@ class GSLTemplateTypeAlias { typename std::conditional::value, gsl_matrix, gsl_matrix_float>::type; + using MatrixViewType = + typename std::conditional::value, gsl_matrix_view, + gsl_matrix_float_view>::type; + using VectorType = typename std::conditional::value, gsl_vector, gsl_vector_float>::type; @@ -63,6 +67,18 @@ class GSLTemplateTypeAlias { } } + // Static method to create a matrix view from a raw pointer + static MatrixViewType matrix_view_from_array(const T* data, size_t size1, size_t size2) { + return matrix_view_from_array(const_cast(data), size1, size2); + } + static MatrixViewType matrix_view_from_array(T* data, size_t size1, size_t size2) { + if constexpr (std::is_same::value) { + return gsl_matrix_view_array(data, size1, size2); + } else { + return gsl_matrix_float_view_array(data, size1, size2); + } + } + // Static method to allocate / free a matrix static MatrixType* allocate_matrix(size_t rows, size_t cols) { if constexpr (std::is_same::value) { diff --git a/lib/gsl_utils/gsl_utils_view_helper.h b/lib/gsl_utils/gsl_utils_view_helper.h new file mode 100644 index 0000000..4121c58 --- /dev/null +++ b/lib/gsl_utils/gsl_utils_view_helper.h @@ -0,0 +1,277 @@ +#ifndef GSL_UTILS_VIEW_HELPER_H +#define GSL_UTILS_VIEW_HELPER_H + +#include + +#include "gsl_vector_matrix_types.h" + +template +class GSLViewHelper { + static_assert(std::is_same::value || std::is_same::value, + "Only float and double supported"); + + public: + using GSLVectorType = typename GSLTemplateTypeAlias::VectorType; + using GSLVectorViewType = typename GSLTemplateTypeAlias::VectorViewType; + using GSLMatrixType = typename GSLTemplateTypeAlias::MatrixType; + using GSLMatrixViewType = typename GSLTemplateTypeAlias::MatrixViewType; + + using GSLType = + typename std::conditional::type; + + using ViewType = typename std::conditional::type; + + /**************************************************************************/ + /********************************* pointer ********************************/ + /**************************************************************************/ + template + GSLViewHelper(const U* ptr, size_t size) { + static_assert(!IsMatrix, "Vector constructor used for matrix"); + static_assert(is_float_or_double(), "Only float/double allowed"); + + if (!ptr) { + _ptr = nullptr; + return; + } + + construct_vector_from_ptr(ptr, size); + } + + template + GSLViewHelper(const U* ptr, size_t rows, size_t cols) { + static_assert(IsMatrix, "Matrix constructor used for vector"); + static_assert(is_float_or_double(), "Only float/double allowed"); + + if (!ptr) { + _ptr = nullptr; + return; + } + + construct_matrix_from_ptr(ptr, rows, cols); + } + + /**************************************************************************/ + /********************************* vector *********************************/ + /**************************************************************************/ + GSLViewHelper(const gsl_vector* v, size_t rows = 0, size_t cols = 0) { + if (!v) { + _ptr = nullptr; + return; + } + + if constexpr (!IsMatrix) { + // internal storage = vector + construct_vector_from_vector(v); + } else { + // internal storage = matrix + construct_matrix_from_vector(v, rows, cols); + } + } + + GSLViewHelper(const gsl_vector_float* v, size_t rows = 0, size_t cols = 0) { + if (!v) { + _ptr = nullptr; + return; + } + + if constexpr (!IsMatrix) { + // internal storage = vector + construct_vector_from_vector(v); + } else { + // internal storage = matrix + construct_matrix_from_vector(v, rows, cols); + } + } + + /**************************************************************************/ + /********************************* matrix *********************************/ + /**************************************************************************/ + GSLViewHelper(const gsl_matrix* m) { + if (!m) { + _ptr = nullptr; + return; + } + + if constexpr (IsMatrix) { + // internal storage = matrix + construct_matrix_from_matrix(m); + } else { + // internal storage = vector + construct_vector_from_matrix(m); + } + } + + GSLViewHelper(const gsl_matrix_float* m) { + if (!m) { + _ptr = nullptr; + return; + } + + if constexpr (IsMatrix) { + // internal storage = matrix + construct_matrix_from_matrix(m); + } else { + // internal storage = vector + construct_vector_from_matrix(m); + } + } + + /**************************************************************************/ + /******************************* destructor *******************************/ + /**************************************************************************/ + ~GSLViewHelper() { + if (!_freeMemory || !_ptr) return; + + if constexpr (IsMatrix) + GSLTemplateTypeAlias::free_matrix(_ptr); + else + GSLTemplateTypeAlias::free_vector(_ptr); + } + + /**************************************************************************/ + /********************************* access *********************************/ + /**************************************************************************/ + GSLType* get() { return _ptr; } + const GSLType* get() const { return _ptr; } + + operator GSLType*() { return _ptr; } + operator const GSLType*() const { return _ptr; } + + private: + template + void construct_vector_from_ptr(const U* ptr, size_t size) { + if constexpr (std::is_same::value) { + _view = GSLTemplateTypeAlias::vector_view_from_array(ptr, size); + _ptr = &_view.vector; + } else { + _ptr = GSLTemplateTypeAlias::allocate_vector(size); + _freeMemory = true; + + for (size_t i = 0; i < size; ++i) _ptr->data[i] = static_cast(ptr[i]); + } + } + + template + void construct_vector_from_vector( + const typename GSLTemplateTypeAlias::VectorType* v) { + if (!v) { + _ptr = nullptr; + return; + } + + if constexpr (std::is_same::value) { + _ptr = const_cast(v); + } else { + _ptr = GSLTemplateTypeAlias::allocate_vector(v->size); + _freeMemory = true; + + for (size_t i = 0; i < v->size; ++i) + _ptr->data[i] = static_cast(v->data[i]); + } + } + + template + void construct_matrix_from_vector( + const typename GSLTemplateTypeAlias::VectorType* v, size_t rows, + size_t cols) { + if (!v) { + _ptr = nullptr; + return; + } + + if (rows == 0 || cols == 0) + throw std::runtime_error( + "Matrix construction from vector requires rows and cols"); + + if (v->size != rows * cols) + throw std::runtime_error("Size mismatch in reshape"); + + if constexpr (std::is_same::value) { + _view = + GSLTemplateTypeAlias::matrix_view_from_array(v->data, rows, cols); + + _ptr = &_view.matrix; + } else { + _ptr = GSLTemplateTypeAlias::allocate_matrix(rows, cols); + _freeMemory = true; + + for (size_t i = 0; i < v->size; ++i) + _ptr->data[i] = static_cast(v->data[i]); + } + } + + template + void construct_matrix_from_ptr(const U* ptr, size_t rows, size_t cols) { + if constexpr (std::is_same::value) { + _view = GSLTemplateTypeAlias::matrix_view_from_array(ptr, rows, cols); + _ptr = &_view.matrix; + } else { + _ptr = GSLTemplateTypeAlias::allocate_matrix(rows, cols); + _freeMemory = true; + + size_t total = rows * cols; + for (size_t i = 0; i < total; ++i) _ptr->data[i] = static_cast(ptr[i]); + } + } + + template + void construct_matrix_from_matrix( + const typename GSLTemplateTypeAlias::MatrixType* m) { + if (!m) { + _ptr = nullptr; + return; + } + + if constexpr (std::is_same::value) { + _ptr = const_cast(m); + } else { + _ptr = GSLTemplateTypeAlias::allocate_matrix(m->size1, m->size2); + _freeMemory = true; + + size_t total = m->size1 * m->size2; + for (size_t i = 0; i < total; ++i) + _ptr->data[i] = static_cast(m->data[i]); + } + } + + template + void construct_vector_from_matrix( + const typename GSLTemplateTypeAlias::MatrixType* m) { + if (!m) { + _ptr = nullptr; + return; + } + + if constexpr (std::is_same::value) { + _view = GSLTemplateTypeAlias::flatten_matrix_to_vector(m); + _ptr = &_view.vector; + } else { + const size_t total = m->size1 * m->size2; + _ptr = GSLTemplateTypeAlias::allocate_vector(total); + _freeMemory = true; + + for (size_t i = 0; i < total; ++i) + _ptr->data[i] = static_cast(m->data[i]); + } + } + + template + static constexpr bool is_float_or_double() { + return std::is_same::value || std::is_same::value; + } + + private: + GSLType* _ptr = nullptr; + bool _freeMemory = false; + + ViewType _view{}; +}; + +template +using GSLVectorView = GSLViewHelper; + +template +using GSLMatrixView = GSLViewHelper; + +#endif // GSL_UTILS_VIEW_HELPER_H \ No newline at end of file diff --git a/lib/gsl_utils/gsl_utils_weight_helper.h b/lib/gsl_utils/gsl_utils_weight_helper.h new file mode 100644 index 0000000..43342ae --- /dev/null +++ b/lib/gsl_utils/gsl_utils_weight_helper.h @@ -0,0 +1,53 @@ +#ifndef GSL_UTILS_WEIGHT_HELPER_H +#define GSL_UTILS_WEIGHT_HELPER_H + +#include +#include + +#include "gsl_vector_matrix_types.h" + +template +class GSLWeightHelper { + static_assert(std::is_same::value || + std::is_same::value, + "Only float and double supported"); + + public: + using GSLVectorType = + typename GSLTemplateTypeAlias::VectorType; + + GSLWeightHelper(const GSLVectorType* v, size_t size) { + if (size == 0) + throw std::runtime_error("Weight vector size must be > 0"); + + if (!v) { + _ownedPtr = GSLTemplateTypeAlias::allocate_vector(size); + _freeMemory = true; + + const T weight = static_cast(1.0) / static_cast(size); + for (size_t i = 0; i < size; ++i) + _ownedPtr->data[i] = weight; + + _ptr = _ownedPtr; + } else { + _ptr = v; + } + } + + ~GSLWeightHelper() { + if (_freeMemory && _ownedPtr) { + GSLTemplateTypeAlias::free_vector(_ownedPtr); + } + } + + const GSLVectorType* get() const { return _ptr; } + + operator const GSLVectorType*() const { return _ptr; } + + private: + const GSLVectorType* _ptr = nullptr; + GSLVectorType* _ownedPtr = nullptr; + bool _freeMemory = false; +}; + +#endif // GSL_UTILS_WEIGHT_HELPER_H