From 47bc414fabd6e9f8346d45dcfec14eb0be437310 Mon Sep 17 00:00:00 2001 From: David Fang Date: Thu, 22 Jan 2026 01:29:21 +0100 Subject: [PATCH 01/33] add Euler method --- src/shammath/include/shammath/integrator.hpp | 41 ++++++++++++++++++++ src/shampylib/src/math/pyshammath.cpp | 24 ++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/shammath/include/shammath/integrator.hpp b/src/shammath/include/shammath/integrator.hpp index e6fea94cc7..5fccc0bd49 100644 --- a/src/shammath/include/shammath/integrator.hpp +++ b/src/shammath/include/shammath/integrator.hpp @@ -11,6 +11,7 @@ /** * @file integrator.hpp + * @author David Fang (david.fang@ikmail.com) * @author Timothée David--Cléris (tim.shamrock@proton.me) * @brief * @@ -30,4 +31,44 @@ namespace shammath { return acc; } + /** + * @brief Euler solving of ODE + * The ode has the form + * \f{eqnarray*}{ + * u'(x) &=& f(u,x) \\ + * u(x_0) &=& u_0 + * \f} + * and will be solved between start and end with step $\mathrm{d}t$. + * + * @param start Lower bound of integration + * @param end Higher bound of integration + * @param step Step of integration $\mathrm{d}t$ + * @param ode Ode function $f$ + * @param x0 Initial coordinate $x_0$ + * @param u0 Initial value $u_0$ + */ + template + inline constexpr std::pair, std::vector> euler_ode( + T start, T end, T step, Lambda &&ode, T x0, T u0) { + std::vector U = {u0}; + std::vector X = {x0}; + + T u_prev = u0; + T u = u0; + for (T x = x0 + step; x < end; x += step) { + u = u_prev + ode(u_prev, x) * step; + X.push_back(x); + U.push_back(u); + u_prev = u; + }; + u_prev = u0; + for (T x = x0 - step; x > start; x -= step) { + u = u_prev - ode(u_prev, x) * step; + X.insert(X.begin(), x); + U.insert(U.begin(), u); + u_prev = u; + } + return {X, U}; + } + } // namespace shammath diff --git a/src/shampylib/src/math/pyshammath.cpp b/src/shampylib/src/math/pyshammath.cpp index 2b1c513f56..7c20728313 100644 --- a/src/shampylib/src/math/pyshammath.cpp +++ b/src/shampylib/src/math/pyshammath.cpp @@ -9,6 +9,7 @@ /** * @file pyshammath.cpp + * @author David Fang (david.fang@ikmail.com) * @author Timothée David--Cléris (tim.shamrock@proton.me) * @author Yann Bernard (yann.bernard@univ-grenoble-alpes.fr) * @brief @@ -19,6 +20,7 @@ #include "shambindings/pybindaliases.hpp" #include "shambindings/pytypealias.hpp" #include "shammath/derivatives.hpp" +#include "shammath/integrator.hpp" #include "shammath/matrix.hpp" #include "shammath/matrix_op.hpp" #include "shammath/paving_function.hpp" @@ -849,4 +851,26 @@ Register_pymod(pysham_mathinit) { "SymTensorCollection_f64_1_1(\n t1={}\n)", py::str(py::cast(c.t1)).cast()); }); + + math_module.def( + "euler_ode", + [](f64 start, f64 end, f64 step, std::function &&ode, f64 x0, f64 u0) { + return shammath::euler_ode(start, end, step, ode, x0, u0); + }, + py::kw_only(), + py::arg("start"), + py::arg("end"), + py::arg("step"), + py::arg("ode"), + py::arg("x0"), + py::arg("u0"), + R"pbdoc( + Solve ODE with Euler method + start : Lower bound of integration + end : Higher bound of integration + step : Step of integration + ode : Ode function + x0 : Initial coordinate + u0 : Initial value + )pbdoc"); } From db367cf902166465ef2642df40309bb7c279657c Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 16:07:04 +0100 Subject: [PATCH 02/33] add Cholesky decomposition --- src/shammath/include/shammath/matrix_op.hpp | 44 +++++++++++++++++++++ src/tests/shammath/matrixTests.cpp | 22 +++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index dde113f365..37843c3a42 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -11,6 +11,7 @@ /** * @file matrix_op.hpp + * @author David Fang (david.fang@ikmail.com) * @author Léodasce Sewanou (leodasce.sewanou@ens-lyon.fr) * @author Timothée David--Cléris (tim.shamrock@proton.me) * @author Yann Bernard (yann.bernard@univ-grenoble-alpes.fr) @@ -653,4 +654,47 @@ namespace shammath { } } + /** + * @brief This function performs Cholesky decomposition. From a (real) symmetric, + definite-positive square matrix $M$, return a lower triangular matrix $L$ such that + \f[ + M = L L^T + \f] + * @param M a square symmetric, definite-positive matrix + * @param L the output matrix to store the lower triangular matrix obtained by Cholesky + decomposition + */ + template< + class T, + class Extents1, + class Extents2, + class Layout1, + class Layout2, + class Accessor1, + class Accessor2> + inline void Cholesky_decomp( + const std::mdspan &M, + const std::mdspan &L) { + + SHAM_ASSERT(M.extent(1) == M.extent(0)); + SHAM_ASSERT(M.extent(0) == L.extent(0)); + SHAM_ASSERT(L.extent(1) == L.extent(0)); + + for (int i = 0; i <= M.extent(0); i++) { + T sum_ik = 0.0; + for (int k = 0; k < i; k++) { + sum_ik += L(i, k) * L(i, k); + } + L(i, i) = sycl::sqrt(M(i, i) - sum_ik); + for (int j = i + 1; j < M.extent(1); j++) { + T sum_ikjk = 0.0; + for (int k = 0; k < i; k++) { + sum_ikjk += L(i, k) * L(j, k); + } + L(j, i) = (M(i, j) - sum_ikjk) / L(i, i); + L(i, j) = 0.0; + } + } + } + } // namespace shammath diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 07cff95c83..73d1c698f5 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -490,3 +490,25 @@ TestStart(Unittest, "shammath/matrix::mat_gemv", test_mat_gemv, 1) { shammath::mat_gemv(a, B.get_mdspan(), x.get_mdspan(), b, y.get_mdspan()); REQUIRE_EQUAL(y.data, ex_res.data); } + +TestStart(Unittest, "shammath/matrix::Cholesky_decomp", Cholesky_decomp, 1) { + shammath::mat M{ + // clang-format off + 1, 1, 1, 1, + 1, 5, 5, 5, + 1, 5, 14, 14, + 1, 5, 14, 15, + // clang-format on + }; + shammath::mat L; + shammath::mat ex_res{ + // clang-format off + 1,0,0,0, + 1,2,0,0, + 1,2,3,0, + 1,2,3,1 + // clang-format on + }; + shammath::Cholesky_decomp(M.get_mdspan(), L.get_mdspan()); + REQUIRE_EQUAL(L.data, ex_res.data); +} From c9dd6fdca6a428c3d148de0f0791e93dbf80112e Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 18:33:59 +0100 Subject: [PATCH 03/33] add Cholesky solving --- src/shammath/include/shammath/matrix_op.hpp | 54 +++++++++++++++++++++ src/tests/shammath/matrixTests.cpp | 30 ++++++++++++ 2 files changed, 84 insertions(+) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 37843c3a42..9d6778f3de 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -697,4 +697,58 @@ namespace shammath { } } + /** +* @brief This function solves a system of linear equations with Cholesky decomposition. The +system must have the form +\f[ + Mx = y +\f] +where $M$ is a (real) symmetric, definite-positive square matrix. +* @param M a square symmetric, definite-positive matrix +* @param y a vector, right hand side of the system +* @param x the ouput vector to store the solution of the system +*/ + template< + class T, + class Extents1, + class Extents2, + class Extents3, + class Layout1, + class Layout2, + class Layout3, + class Accessor1, + class Accessor2, + class Accessor3> + inline void Cholesky_solve( + const std::mdspan &M, + const std::mdspan &y, + const std::mdspan &x) { + + SHAM_ASSERT(M.extent(1) == M.extent(0)); + SHAM_ASSERT(M.extent(1) == x.extent(0)); + SHAM_ASSERT(M.extent(0) == y.extent(0)); + + std::vector a_storage(M.extent(0)); + std::vector L_storage(M.extent(0) * M.extent(1)); + + std::mdspan L{L_storage.data()}; + std::mdspan a{a_storage.data()}; + Cholesky_decomp(M, L); + + for (int i = 0; i < M.extent(0); i++) { + T sum = 0.0; + for (int k = 0; k < i; k++) { + sum += L(i, k) * a(k); + } + a(i) = (y(i) - sum) / L(i, i); + } + for (int i = M.extent(0) - 1; i >= 0; i--) { + T sum = 0.0; + for (int k = i + 1; k < M.extent(0); k++) { + sum += L(k, i) * x(k); + } + x(i) = (a(i) - sum) / L(i, i); + } + } + } // namespace shammath diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 73d1c698f5..3ea58240f5 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -8,6 +8,7 @@ // -------------------------------------------------------// #include "shambase/aliases_float.hpp" +#include "shamcomm/logs.hpp" #include "shammath/matrix.hpp" #include "shammath/matrix_op.hpp" #include "shamtest/details/TestResult.hpp" @@ -512,3 +513,32 @@ TestStart(Unittest, "shammath/matrix::Cholesky_decomp", Cholesky_decomp, 1) { shammath::Cholesky_decomp(M.get_mdspan(), L.get_mdspan()); REQUIRE_EQUAL(L.data, ex_res.data); } + +TestStart(Unittest, "shammath/matrix::Cholesky_solve", Cholesky_solve, 1) { + shammath::mat M{ + // clang-format off + 6,15,55, + 15,55,225, + 55,225,979 + // clang-format on + }; + + shammath::vec y{ + // clang-format off + 76,295,1259 + // clang-format on + }; + + shammath::vec x; + shammath::vec ex_res{ + // clang-format off + 1,1,1 + // clang-format on + }; + shammath::Cholesky_solve(M.get_mdspan(), y.get_mdspan(), x.get_mdspan()); + REQUIRE_EQUAL_CUSTOM_COMP_NAMED("", x.data, ex_res.data, [](const auto &p1, const auto &p2) { + return sycl::pow(p1[0] - p2[0], 2) + sycl::pow(p1[0] - p2[0], 2) + + sycl::pow(p1[0] - p2[0], 2) + < 1e-9; + }); +} From 04d385c4e45f6fb381d0e8f27dc8f4eda938d5d2 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 20:59:34 +0100 Subject: [PATCH 04/33] oops --- src/shammath/include/shammath/matrix_op.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 9d6778f3de..6a21e68058 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -680,7 +680,7 @@ namespace shammath { SHAM_ASSERT(M.extent(0) == L.extent(0)); SHAM_ASSERT(L.extent(1) == L.extent(0)); - for (int i = 0; i <= M.extent(0); i++) { + for (int i = 0; i < M.extent(0); i++) { T sum_ik = 0.0; for (int k = 0; k < i; k++) { sum_ik += L(i, k) * L(i, k); From c13398a308a8d6571d9ac1a3ab74ac9d612f46d1 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 21:01:20 +0100 Subject: [PATCH 05/33] typo --- src/tests/shammath/matrixTests.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 3ea58240f5..549d6190e1 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -492,7 +492,7 @@ TestStart(Unittest, "shammath/matrix::mat_gemv", test_mat_gemv, 1) { REQUIRE_EQUAL(y.data, ex_res.data); } -TestStart(Unittest, "shammath/matrix::Cholesky_decomp", Cholesky_decomp, 1) { +TestStart(Unittest, "shammath/matrix::Cholesky_decomp", test_Cholesky_decomp, 1) { shammath::mat M{ // clang-format off 1, 1, 1, 1, @@ -514,7 +514,7 @@ TestStart(Unittest, "shammath/matrix::Cholesky_decomp", Cholesky_decomp, 1) { REQUIRE_EQUAL(L.data, ex_res.data); } -TestStart(Unittest, "shammath/matrix::Cholesky_solve", Cholesky_solve, 1) { +TestStart(Unittest, "shammath/matrix::Cholesky_solve", test_Cholesky_solve, 1) { shammath::mat M{ // clang-format off 6,15,55, From 979b3e54482be2c7ce33f9766b71fd9542e22ab7 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 21:28:45 +0100 Subject: [PATCH 06/33] add matrix tranpose --- src/shammath/include/shammath/matrix_op.hpp | 28 +++++++++++++++++++++ src/tests/shammath/matrixTests.cpp | 20 +++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 6a21e68058..59f1bcf07c 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -654,6 +654,34 @@ namespace shammath { } } + /** + * @brief This function transposes a (square) matrix + * @param intput matrix to tranpose + * @param output matrix to store the transposed matrix + */ + template< + class T, + class Extents1, + class Extents2, + class Layout1, + class Layout2, + class Accessor1, + class Accessor2> + inline void mat_transpose( + const std::mdspan &input, + const std::mdspan &output) { + + SHAM_ASSERT(input.extent(1) == input.extent(0)); + SHAM_ASSERT(output.extent(1) == output.extent(0)); + SHAM_ASSERT(input.extent(0) == output.extent(0)); + + for (int i = 0; i < input.extent(0); i++) { + for (int j = 0; j < input.extent(1); j++) { + output(i, j) = input(j, i); + } + } + } + /** * @brief This function performs Cholesky decomposition. From a (real) symmetric, definite-positive square matrix $M$, return a lower triangular matrix $L$ such that diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 549d6190e1..7f5a7427ad 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -492,6 +492,26 @@ TestStart(Unittest, "shammath/matrix::mat_gemv", test_mat_gemv, 1) { REQUIRE_EQUAL(y.data, ex_res.data); } +TestStart(Unittest, "shammath/matrix::mat_transpose", test_transpose, 1) { + shammath::mat A{ + // clang-format off + 1,2,3, + 4,5,6, + 7,8,9 + // clang-format on + }; + shammath::mat B; + shammath::mat ex_res{ + // clang-format off + 1,4,7, + 2,5,8, + 3,6,9 + // clang-format on + }; + shammath::mat_transpose(A.get_mdspan(), B.get_mdspan()); + REQUIRE_EQUAL(B.data, ex_res.data); +} + TestStart(Unittest, "shammath/matrix::Cholesky_decomp", test_Cholesky_decomp, 1) { shammath::mat M{ // clang-format off From e7468f3b9622a36b700f4480c231ffd5b9240a6b Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 24 Jan 2026 22:13:13 +0100 Subject: [PATCH 07/33] add vec_update_vals --- src/shammath/include/shammath/matrix_op.hpp | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 59f1bcf07c..3d7c1a9bd4 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -79,6 +79,29 @@ namespace shammath { } } + /** + * @brief Update the elements of a vector according to a user-provided function + * + * @param input The vector to update the elements of + * @param func The function to use to update the elements of the matrix. The + * function must take two arguments, the first being the value of the + * element to update and the second being the index. + * + * @details The function `func` is called for each element of the vector, and + * the value returned by the function is used to update the corresponding + * element of the matrix. + */ + template + inline void vec_update_vals( + const std::mdspan &input, Func &&func) { + + shambase::check_functor_signature(func); + + for (int i = 0; i < input.extent(0); i++) { + func(input(i), i); + } + } + /** * @brief Set the content of a matrix to the identity matrix * From 594076fa32abd0a373d4313e545b24359d3da61f Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 14:18:06 +0100 Subject: [PATCH 08/33] whoops --- src/shammath/include/shammath/matrix_op.hpp | 9 ++++----- src/tests/shammath/matrixTests.cpp | 13 ++++++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 3d7c1a9bd4..fa11331074 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -678,7 +678,7 @@ namespace shammath { } /** - * @brief This function transposes a (square) matrix + * @brief This function transposes a matrix * @param intput matrix to tranpose * @param output matrix to store the transposed matrix */ @@ -694,12 +694,11 @@ namespace shammath { const std::mdspan &input, const std::mdspan &output) { + SHAM_ASSERT(input.extent(0) == output.extent(1)); SHAM_ASSERT(input.extent(1) == input.extent(0)); - SHAM_ASSERT(output.extent(1) == output.extent(0)); - SHAM_ASSERT(input.extent(0) == output.extent(0)); - for (int i = 0; i < input.extent(0); i++) { - for (int j = 0; j < input.extent(1); j++) { + for (int i = 0; i < output.extent(0); i++) { + for (int j = 0; j < output.extent(1); j++) { output(i, j) = input(j, i); } } diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 7f5a7427ad..6e64137485 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -493,19 +493,18 @@ TestStart(Unittest, "shammath/matrix::mat_gemv", test_mat_gemv, 1) { } TestStart(Unittest, "shammath/matrix::mat_transpose", test_transpose, 1) { - shammath::mat A{ + shammath::mat A{ // clang-format off 1,2,3, 4,5,6, - 7,8,9 // clang-format on }; - shammath::mat B; - shammath::mat ex_res{ + shammath::mat B; + shammath::mat ex_res{ // clang-format off - 1,4,7, - 2,5,8, - 3,6,9 + 1, 4, + 2, 5, + 3, 6, // clang-format on }; shammath::mat_transpose(A.get_mdspan(), B.get_mdspan()); From 9487349dc101c301b73e504165b164b016302674 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 14:19:05 +0100 Subject: [PATCH 09/33] reshape --- src/shammath/include/shammath/matrix_op.hpp | 53 ++++++++++----------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index fa11331074..49d57395fb 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -54,51 +54,49 @@ namespace shammath { } /** - * @brief Update the elements of a matrix according to a user-provided function + * @brief Set the elements of a vector according to a user-provided function * - * @param input The matrix to update the elements of + * @param input The vector to update the elements of * @param func The function to use to update the elements of the matrix. The - * function must take three arguments, the first being the value of the - * element to update, the second being the row index and the third being the - * column index. + * function must take one argument being the index. * - * @details The function `func` is called for each element of the matrix, and - * the value returned by the function is used to update the corresponding - * element of the matrix. + * @details The function `func` is called for each element of the vector, and + * the value returned by the function is used to set the corresponding + * element of the vector. */ template - inline void mat_update_vals( - const std::mdspan &input, Func &&func) { + inline void vec_set_vals(const std::mdspan &input, Func &&func) { - shambase::check_functor_signature(func); + shambase::check_functor_signature(func); for (int i = 0; i < input.extent(0); i++) { - for (int j = 0; j < input.extent(1); j++) { - func(input(i, j), i, j); - } + input(i) = func(i); } } /** - * @brief Update the elements of a vector according to a user-provided function + * @brief Update the elements of a matrix according to a user-provided function * - * @param input The vector to update the elements of + * @param input The matrix to update the elements of * @param func The function to use to update the elements of the matrix. The - * function must take two arguments, the first being the value of the - * element to update and the second being the index. + * function must take three arguments, the first being the value of the + * element to update, the second being the row index and the third being the + * column index. * - * @details The function `func` is called for each element of the vector, and + * @details The function `func` is called for each element of the matrix, and * the value returned by the function is used to update the corresponding * element of the matrix. */ template - inline void vec_update_vals( + inline void mat_update_vals( const std::mdspan &input, Func &&func) { - shambase::check_functor_signature(func); + shambase::check_functor_signature(func); for (int i = 0; i < input.extent(0); i++) { - func(input(i), i); + for (int j = 0; j < input.extent(1); j++) { + func(input(i, j), i, j); + } } } @@ -778,26 +776,25 @@ where $M$ is a (real) symmetric, definite-positive square matrix. SHAM_ASSERT(M.extent(1) == x.extent(0)); SHAM_ASSERT(M.extent(0) == y.extent(0)); - std::vector a_storage(M.extent(0)); + std::vector a(M.extent(0)); std::vector L_storage(M.extent(0) * M.extent(1)); - std::mdspan L{L_storage.data()}; - std::mdspan a{a_storage.data()}; + std::mdspan L{L_storage.data(), M.extent(0), M.extent(1)}; Cholesky_decomp(M, L); for (int i = 0; i < M.extent(0); i++) { T sum = 0.0; for (int k = 0; k < i; k++) { - sum += L(i, k) * a(k); + sum += L(i, k) * a[k]; } - a(i) = (y(i) - sum) / L(i, i); + a[i] = (y(i) - sum) / L(i, i); } for (int i = M.extent(0) - 1; i >= 0; i--) { T sum = 0.0; for (int k = i + 1; k < M.extent(0); k++) { sum += L(k, i) * x(k); } - x(i) = (a(i) - sum) / L(i, i); + x(i) = (a[i] - sum) / L(i, i); } } From 39d00941801cb5cbb931050fd16c5a6895217dfa Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 14:19:42 +0100 Subject: [PATCH 10/33] add dynamic matrices when dimensions are known at runtime --- src/shammath/include/shammath/matrix.hpp | 96 ++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 33c1f7f39e..22d93fd54a 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -120,6 +120,102 @@ namespace shammath { bool operator==(const vec &other) { return data == other.data; } }; + /** + * @brief Matrix class based on std::vector storage and mdspan + * @tparam T the type of the matrix entries + */ + template + class mat_d { + public: + /// The matrix data + std::vector data; + int rows; + int columns; + + mat_d(int rows, int columns) : rows(rows), columns(columns), data(rows * columns) {} + + /// Get the matrix data as a mdspan + inline constexpr auto get_mdspan() { + return std::mdspan>(data.data(), rows, columns); + } + + /// const overload + inline constexpr auto get_mdspan() const { + return std::mdspan>(data.data(), rows, columns); + } + + inline constexpr auto get_rows_nb() { return get_mdspan().extent(0); } + + inline constexpr auto get_columns_nb() { return get_mdspan().extent(1); } + + /// Access the matrix entry at position (i, j) + inline constexpr T &operator()(int i, int j) { return get_mdspan()(i, j); } + + /// const overload + inline constexpr const T &operator()(int i, int j) const { return get_mdspan()(i, j); } + + /// Check if this matrix is equal to another one + bool operator==(const mat_d &other) const { return data == other.data; } + + inline mat_d &operator+=(const mat_d &other) { +#pragma unroll + for (size_t i = 0; i < get_mdspan().extent(0) * get_mdspan().extent(1); i++) { + data[i] += other.data[i]; + } + return *this; + } + + /// check if this matrix is equal to another one at a given precison + bool equal_at_precision(const mat_d &other, const T precision) const { + bool res = true; + for (auto i = 0; i < get_rows_nb(); i++) { + for (auto j = 0; j < get_columns_nb(); j++) { + if (sham::abs( + data[i * get_rows_nb() + j] - other.data[i * get_columns_nb() + j]) + >= precision) { + res = false; + } + } + } + return res; + } + }; + + /** + * @brief Vector class based on std::array storage and mdspan + * @tparam T the type of the vector entries + * @tparam n the number of entries + */ + template + class vec_d { + public: + /// The vector data + std::vector data; + int size; + + vec_d(int size) : size(size), data(size) {} + + /// Get the vector data as a mdspan + inline constexpr auto get_mdspan() { + return std::mdspan>(data.data(), size); + } + + /// Get the vector data as a mdspan of a matrix with one column + inline constexpr auto get_mdspan_mat_col() { + return std::mdspan>(data.data(), size, 1); + } + + /// Get the vector data as a mdspan of a matrix with one row + inline constexpr auto get_mdspan_mat_row() { + return std::mdspan>(data.data(), 1, size); + } + + /// Access the vector entry at position i + inline constexpr T &operator[](int i) { return get_mdspan()(i); } + + /// Check if this vector is equal to another one + bool operator==(const vec_d &other) { return data == other.data; } + }; } // namespace shammath template From a10b2d7f4f4ab85fceeff3c3ba60b60b9d10a640 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 14:20:22 +0100 Subject: [PATCH 11/33] add least squares --- src/shammath/include/shammath/solve.hpp | 85 +++++++++++++++++++++++++ src/shampylib/src/math/pyshammath.cpp | 22 +++++++ 2 files changed, 107 insertions(+) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 6dc47b8025..e9facae4ef 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -16,6 +16,8 @@ * */ +#include "shammath/matrix.hpp" +#include "shammath/matrix_op.hpp" #include #include @@ -42,4 +44,87 @@ namespace shammath { return xk; } + /** + * @brief This function determines the best fit parameters $\vec p$ for a given function(f(\vec + * (beta), \mathbf(X)) with least squares. + * + * @param f Function (1d values) + * @param X $x$ Data to fit + * @param Y $y$ Data to fit + * @param p0 Initial parameters guessed + * + * @details The Levenberg-Marquardt method is used. Therefore, the number of observations needs + * to be greater than the number of paremeters. + */ + template + std::vector least_squares( + const Lambda &f, + const std::vector &X, + const std::vector &Y, + const std::vector &p0) { + SHAM_ASSERT(X.size() == Y.size()); + + const int params_nb = p0.size(); + const int data_size = X.size(); + + std::vector p = p0; + T mu = 1e-2; // damping parameter + T beta = 0.1; // decay rate + int maxits = 1000; + int it = 0; + T sse = 0.0; + for (int k = 0; k < X.size(); k++) { + T r = Y[k] - f(p, X[k]); + sse += r * r; + }; + while (it < maxits) { + + // Construct the Jaobian (finite differences) + shammath::mat_d J(data_size, params_nb); + mat_set_vals(J.get_mdspan(), [&](auto i, auto j) -> T { + auto p_plus_dpj = p; + T step_scale = (std::abs(p_plus_dpj[j]) < 1e-6) ? 1e-6 : p_plus_dpj[j]; + T dpj = step_scale * 0.001; + p_plus_dpj[j] += dpj; + return (f(p_plus_dpj, X[i]) - f(p, X[i])) / dpj; + }); + + shammath::vec_d R(data_size); + shammath::vec_set_vals(R.get_mdspan(), [&](auto i) -> T { + return Y[i] - f(p, X[i]); + }); + + shammath::mat_d J_T(params_nb, data_size); // Jacobian tranpose + shammath::mat_transpose(J.get_mdspan(), J_T.get_mdspan()); + + shammath::mat_d G(params_nb, params_nb); // left hand side + shammath::mat_prod(J_T.get_mdspan(), J.get_mdspan(), G.get_mdspan()); + shammath::mat_plus_equal_scalar_id(G.get_mdspan(), mu); + + shammath::vec_d d(params_nb); // right hand side + shammath::mat_gemv(1.0, J_T.get_mdspan(), R.get_mdspan(), 0.0, d.get_mdspan()); + + shammath::vec_d delta(params_nb); // increment for p + shammath::Cholesky_solve(G.get_mdspan(), d.get_mdspan(), delta.get_mdspan()); + + std::vector p_trial = p; + for (int i = 0; i < params_nb; i++) { + p_trial[i] += delta.data[i]; + }; + T sse_trial = 0.0; + for (int k = 0; k < X.size(); k++) { + sse_trial += (Y[k] - f(p_trial, X[k])) * (Y[k] - f(p_trial, X[k])); + }; + if (sse_trial > sse) { // Fail -> gradient descent + mu /= beta; + } else { // Not bad -> Gauss-Newton + // it++; + mu *= beta; + p = p_trial; + } + it++; + }; + + return p; + } } // namespace shammath diff --git a/src/shampylib/src/math/pyshammath.cpp b/src/shampylib/src/math/pyshammath.cpp index 7c20728313..901eeca769 100644 --- a/src/shampylib/src/math/pyshammath.cpp +++ b/src/shampylib/src/math/pyshammath.cpp @@ -24,6 +24,7 @@ #include "shammath/matrix.hpp" #include "shammath/matrix_op.hpp" #include "shammath/paving_function.hpp" +#include "shammath/solve.hpp" #include "shammath/symtensor_collections.hpp" #include "shammath/symtensors.hpp" #include "shampylib/math/pyAABB.hpp" @@ -873,4 +874,25 @@ Register_pymod(pysham_mathinit) { x0 : Initial coordinate u0 : Initial value )pbdoc"); + + math_module.def( + "least_squares", + [](const std::function, f64)> &func, + const std::vector &x_data, + const std::vector &y_data, + const std::vector &p0) { + return shammath::least_squares(func, x_data, y_data, p0); + }, + py::kw_only(), + py::arg("func"), + py::arg("x_data"), + py::arg("y_data"), + py::arg("p0"), + R"pbdoc( + Fit data with a given function by least squares method + f: Function (1d values) + X: $x$ Data to fit + Y: $y$ Data to fit + p0: Initial parameters estimated + )pbdoc"); } From b6a165a1004f7e7a2b76f7b142a9cce7b585e80e Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 14:32:47 +0100 Subject: [PATCH 12/33] authorship --- src/shammath/include/shammath/matrix.hpp | 1 + src/shammath/include/shammath/solve.hpp | 1 + 2 files changed, 2 insertions(+) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 22d93fd54a..3f9bd5b728 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -11,6 +11,7 @@ /** * @file matrix.hpp + * @author David Fang (david.fang@ikmail.com) * @author Léodasce Sewanou (leodasce.sewanou@ens-lyon.fr) * @author Timothée David--Cléris (tim.shamrock@proton.me) * @author Yann Bernard (yann.bernard@univ-grenoble-alpes.fr) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index e9facae4ef..03035a65c5 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -11,6 +11,7 @@ /** * @file solve.hpp + * @author David Fang (david.fang@ikmail.com) * @author Timothée David--Cléris (tim.shamrock@proton.me) * @brief * From 1c42e5458c8c7b0326538a065b71c92110564d56 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 15:46:02 +0100 Subject: [PATCH 13/33] fixes --- src/shammath/include/shammath/integrator.hpp | 7 +++++-- src/shammath/include/shammath/matrix.hpp | 17 ++++++----------- src/shammath/include/shammath/matrix_op.hpp | 2 +- src/shammath/include/shammath/solve.hpp | 20 ++++++++++++++------ src/shampylib/src/math/pyshammath.cpp | 2 +- src/tests/shammath/matrixTests.cpp | 4 ++-- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/shammath/include/shammath/integrator.hpp b/src/shammath/include/shammath/integrator.hpp index 5fccc0bd49..4872a4696f 100644 --- a/src/shammath/include/shammath/integrator.hpp +++ b/src/shammath/include/shammath/integrator.hpp @@ -62,12 +62,15 @@ namespace shammath { u_prev = u; }; u_prev = u0; + std::vector X_backward, U_backward; for (T x = x0 - step; x > start; x -= step) { u = u_prev - ode(u_prev, x) * step; - X.insert(X.begin(), x); - U.insert(U.begin(), u); + X_backward.push_back(x); + U_backward.push_back(u); u_prev = u; } + X.insert(X.begin(), X_backward.rbegin(), X_backward.rend()); + U.insert(U.begin(), U_backward.rbegin(), U_backward.rend()); return {X, U}; } diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 3f9bd5b728..2b92f3d28e 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -69,8 +69,8 @@ namespace shammath { /// check if this matrix is equal to another one at a given precison bool equal_at_precision(const mat &other, const T precision) const { bool res = true; - for (auto i = 0; i < m; i++) { - for (auto j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { if (sham::abs(data[i * n + j] - other.data[i * n + j]) >= precision) { res = false; } @@ -145,10 +145,6 @@ namespace shammath { return std::mdspan>(data.data(), rows, columns); } - inline constexpr auto get_rows_nb() { return get_mdspan().extent(0); } - - inline constexpr auto get_columns_nb() { return get_mdspan().extent(1); } - /// Access the matrix entry at position (i, j) inline constexpr T &operator()(int i, int j) { return get_mdspan()(i, j); } @@ -169,10 +165,9 @@ namespace shammath { /// check if this matrix is equal to another one at a given precison bool equal_at_precision(const mat_d &other, const T precision) const { bool res = true; - for (auto i = 0; i < get_rows_nb(); i++) { - for (auto j = 0; j < get_columns_nb(); j++) { - if (sham::abs( - data[i * get_rows_nb() + j] - other.data[i * get_columns_nb() + j]) + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + if (sham::abs(data[i * columns + j] - other.data[i * columns + j]) >= precision) { res = false; } @@ -183,7 +178,7 @@ namespace shammath { }; /** - * @brief Vector class based on std::array storage and mdspan + * @brief Vector class based on std::vector storage and mdspan * @tparam T the type of the vector entries * @tparam n the number of entries */ diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 49d57395fb..761fb66a2a 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -693,7 +693,7 @@ namespace shammath { const std::mdspan &output) { SHAM_ASSERT(input.extent(0) == output.extent(1)); - SHAM_ASSERT(input.extent(1) == input.extent(0)); + SHAM_ASSERT(input.extent(1) == output.extent(0)); for (int i = 0; i < output.extent(0); i++) { for (int j = 0; j < output.extent(1); j++) { diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 03035a65c5..e5fefdffea 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -80,14 +80,22 @@ namespace shammath { }; while (it < maxits) { - // Construct the Jaobian (finite differences) + // Construct the Jacobian (finite differences) shammath::mat_d J(data_size, params_nb); + std::vector f_at_p(data_size); + for (int i = 0; i < data_size; i++) { + f_at_p[i] = f(p, X[i]); + } mat_set_vals(J.get_mdspan(), [&](auto i, auto j) -> T { - auto p_plus_dpj = p; - T step_scale = (std::abs(p_plus_dpj[j]) < 1e-6) ? 1e-6 : p_plus_dpj[j]; - T dpj = step_scale * 0.001; - p_plus_dpj[j] += dpj; - return (f(p_plus_dpj, X[i]) - f(p, X[i])) / dpj; + T original_p_j = p[j]; + T step_scale = (std::abs(original_p_j) < 1e-6) ? 1e-6 : original_p_j; + T dpj = step_scale * 0.001; + + p[j] += dpj; + T f_perturbed = f(p, X[i]); + p[j] = original_p_j; // Restore + + return (f_perturbed - f_at_p[i]) / dpj; }); shammath::vec_d R(data_size); diff --git a/src/shampylib/src/math/pyshammath.cpp b/src/shampylib/src/math/pyshammath.cpp index 901eeca769..cdf39fadd1 100644 --- a/src/shampylib/src/math/pyshammath.cpp +++ b/src/shampylib/src/math/pyshammath.cpp @@ -877,7 +877,7 @@ Register_pymod(pysham_mathinit) { math_module.def( "least_squares", - [](const std::function, f64)> &func, + [](const std::function, f64)> &func, const std::vector &x_data, const std::vector &y_data, const std::vector &p0) { diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index 6e64137485..d14a3722fc 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -556,8 +556,8 @@ TestStart(Unittest, "shammath/matrix::Cholesky_solve", test_Cholesky_solve, 1) { }; shammath::Cholesky_solve(M.get_mdspan(), y.get_mdspan(), x.get_mdspan()); REQUIRE_EQUAL_CUSTOM_COMP_NAMED("", x.data, ex_res.data, [](const auto &p1, const auto &p2) { - return sycl::pow(p1[0] - p2[0], 2) + sycl::pow(p1[0] - p2[0], 2) - + sycl::pow(p1[0] - p2[0], 2) + return sycl::pow(p1[0] - p2[0], 2) + sycl::pow(p1[1] - p2[1], 2) + + sycl::pow(p1[2] - p2[2], 2) < 1e-9; }); } From 9ce02cd6b1b1ddf967288dd33e814443fdc0f595 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 15:55:04 +0100 Subject: [PATCH 14/33] fix --- src/shampylib/src/math/pyshammath.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shampylib/src/math/pyshammath.cpp b/src/shampylib/src/math/pyshammath.cpp index cdf39fadd1..c63bf011f0 100644 --- a/src/shampylib/src/math/pyshammath.cpp +++ b/src/shampylib/src/math/pyshammath.cpp @@ -877,7 +877,7 @@ Register_pymod(pysham_mathinit) { math_module.def( "least_squares", - [](const std::function, f64)> &func, + [](const std::function &, f64)> &func, const std::vector &x_data, const std::vector &y_data, const std::vector &p0) { From 3e6e72fbd4d5520ca9f5893cfa4ae92894526bb9 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 15:55:27 +0100 Subject: [PATCH 15/33] add tolerance on sse for convergence criteria in least_squares --- src/shammath/include/shammath/solve.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index e5fefdffea..ac4c46042c 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -62,7 +62,9 @@ namespace shammath { const Lambda &f, const std::vector &X, const std::vector &Y, - const std::vector &p0) { + const std::vector &p0, + int maxits = 1000, + T tolerence = 1e-6) { SHAM_ASSERT(X.size() == Y.size()); const int params_nb = p0.size(); @@ -71,14 +73,14 @@ namespace shammath { std::vector p = p0; T mu = 1e-2; // damping parameter T beta = 0.1; // decay rate - int maxits = 1000; int it = 0; T sse = 0.0; for (int k = 0; k < X.size(); k++) { T r = Y[k] - f(p, X[k]); sse += r * r; }; - while (it < maxits) { + T sse_trial = 999.0; + while (it < maxits and sham::abs(sse_trial - sse) > tolerence) { // Construct the Jacobian (finite differences) shammath::mat_d J(data_size, params_nb); @@ -120,7 +122,8 @@ namespace shammath { for (int i = 0; i < params_nb; i++) { p_trial[i] += delta.data[i]; }; - T sse_trial = 0.0; + + sse_trial = 0.0; for (int k = 0; k < X.size(); k++) { sse_trial += (Y[k] - f(p_trial, X[k])) * (Y[k] - f(p_trial, X[k])); }; From 247e17de15f12c0b0da4e192e482cc6d54caa138 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 16:08:14 +0100 Subject: [PATCH 16/33] fix --- src/shammath/include/shammath/solve.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index ac4c46042c..4172f60eb4 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -130,7 +130,7 @@ namespace shammath { if (sse_trial > sse) { // Fail -> gradient descent mu /= beta; } else { // Not bad -> Gauss-Newton - // it++; + it++; mu *= beta; p = p_trial; } From 33629bd3280d2627040968e0eadc51ea21309487 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 16:08:38 +0100 Subject: [PATCH 17/33] least_squares now also returns R^2 --- src/shammath/include/shammath/solve.hpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 4172f60eb4..5b1aeb53db 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -58,7 +58,7 @@ namespace shammath { * to be greater than the number of paremeters. */ template - std::vector least_squares( + std::tuple, T> least_squares( const Lambda &f, const std::vector &X, const std::vector &Y, @@ -81,6 +81,11 @@ namespace shammath { }; T sse_trial = 999.0; while (it < maxits and sham::abs(sse_trial - sse) > tolerence) { + sse = 0.0; + for (int k = 0; k < X.size(); k++) { + T r = Y[k] - f(p, X[k]); + sse += r * r; + }; // Construct the Jacobian (finite differences) shammath::mat_d J(data_size, params_nb); @@ -134,9 +139,21 @@ namespace shammath { mu *= beta; p = p_trial; } - it++; }; - return p; + T total_sum_squares = 0.0; + T mean_Y = 0.0; + for (int k = 0; k < Y.size(); k++) { + mean_Y += Y[k]; + } + mean_Y /= Y.size(); + for (int k = 0; k < Y.size(); k++) { + total_sum_squares += (Y[k] - mean_Y) * (Y[k] - mean_Y); + } + T R2 = 1 - sse / total_sum_squares; + + shamlog_debug_ln( + "least_squares", "Least squares stopped after", it, "iterations with R^2=", R2); + return {p, R2}; } } // namespace shammath From 002ea619853249fca5538b6a7a1a4adf93209d5c Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 16:14:10 +0100 Subject: [PATCH 18/33] one last thing --- src/shammath/include/shammath/solve.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 5b1aeb53db..91efd9b69f 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -130,7 +130,8 @@ namespace shammath { sse_trial = 0.0; for (int k = 0; k < X.size(); k++) { - sse_trial += (Y[k] - f(p_trial, X[k])) * (Y[k] - f(p_trial, X[k])); + T residual = Y[k] - f(p_trial, X[k]); + sse_trial += residual * residual; }; if (sse_trial > sse) { // Fail -> gradient descent mu /= beta; From ee1090f7988b84e7427c8a8fb99a7e822bf04c58 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 25 Jan 2026 16:36:43 +0100 Subject: [PATCH 19/33] typos --- src/shammath/include/shammath/matrix_op.hpp | 2 +- src/shammath/include/shammath/solve.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 761fb66a2a..cab8dd13d8 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -677,7 +677,7 @@ namespace shammath { /** * @brief This function transposes a matrix - * @param intput matrix to tranpose + * @param input matrix to transpose * @param output matrix to store the transposed matrix */ template< diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 91efd9b69f..60d5917306 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -110,7 +110,7 @@ namespace shammath { return Y[i] - f(p, X[i]); }); - shammath::mat_d J_T(params_nb, data_size); // Jacobian tranpose + shammath::mat_d J_T(params_nb, data_size); // Jacobian transposed shammath::mat_transpose(J.get_mdspan(), J_T.get_mdspan()); shammath::mat_d G(params_nb, params_nb); // left hand side From 5e6c1cad85c77df3e2efe84639f4cc3316a7fec8 Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 16:16:13 +0100 Subject: [PATCH 20/33] the number of observations needs to be greater than the number of parameters --- src/shammath/include/shammath/solve.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 60d5917306..48444a222f 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -55,7 +55,7 @@ namespace shammath { * @param p0 Initial parameters guessed * * @details The Levenberg-Marquardt method is used. Therefore, the number of observations needs - * to be greater than the number of paremeters. + * to be greater than the number of parameters. */ template std::tuple, T> least_squares( @@ -66,6 +66,7 @@ namespace shammath { int maxits = 1000, T tolerence = 1e-6) { SHAM_ASSERT(X.size() == Y.size()); + SHAM_ASSERT(X.size() > p0.size()); const int params_nb = p0.size(); const int data_size = X.size(); From 2d7614b0378575d729de1738b39013dc4e690b6a Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:03:58 +0100 Subject: [PATCH 21/33] doxygen --- src/shammath/include/shammath/matrix.hpp | 6 ++++++ src/shammath/include/shammath/solve.hpp | 25 +++++++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 2b92f3d28e..4aeaea0a57 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -130,9 +130,12 @@ namespace shammath { public: /// The matrix data std::vector data; + /// Number of rows int rows; + /// Number of columns int columns; + /// Constructor mat_d(int rows, int columns) : rows(rows), columns(columns), data(rows * columns) {} /// Get the matrix data as a mdspan @@ -154,6 +157,7 @@ namespace shammath { /// Check if this matrix is equal to another one bool operator==(const mat_d &other) const { return data == other.data; } + /// Addition operator inline mat_d &operator+=(const mat_d &other) { #pragma unroll for (size_t i = 0; i < get_mdspan().extent(0) * get_mdspan().extent(1); i++) { @@ -187,8 +191,10 @@ namespace shammath { public: /// The vector data std::vector data; + /// The vector size int size; + /// Constructor vec_d(int size) : size(size), data(size) {} /// Get the vector data as a mdspan diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 48444a222f..e0052e8682 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -46,16 +46,23 @@ namespace shammath { } /** - * @brief This function determines the best fit parameters $\vec p$ for a given function(f(\vec - * (beta), \mathbf(X)) with least squares. + * @brief This function determines the best fit parameters \f$ \vec p \f$ for a given + function \f$ f(\vec p, \mathbf X) \f$ with least squares. * - * @param f Function (1d values) - * @param X $x$ Data to fit - * @param Y $y$ Data to fit - * @param p0 Initial parameters guessed + * @param f Function (1d values) + * @param X Data to fit \f$ x \f$ + * @param Y Data to fit \f$ y \f$ + * @param p0 Initial parameters guessed + * @param maxits Maximum number of iterations in the Levenberg-Marquardt procedure. Default: + 1000 + * @param tolerance Convergence condition in the Levenberg-Marquardt procedure. Default: 1e-6 * * @details The Levenberg-Marquardt method is used. Therefore, the number of observations needs - * to be greater than the number of parameters. + * to be greater than the number of parameters. At every iteration, a new parameters array \f$ + \vec p' \f$ is estimated. The convergence condition is + \f[ + |S(\vec p') - S(\vec p))| < \epsilon + \f] where \f$ S \f$ is the residual sum of squares and \f$ \epsilon \f$ is the tolerance. */ template std::tuple, T> least_squares( @@ -64,7 +71,7 @@ namespace shammath { const std::vector &Y, const std::vector &p0, int maxits = 1000, - T tolerence = 1e-6) { + T tolerance = 1e-6) { SHAM_ASSERT(X.size() == Y.size()); SHAM_ASSERT(X.size() > p0.size()); @@ -81,7 +88,7 @@ namespace shammath { sse += r * r; }; T sse_trial = 999.0; - while (it < maxits and sham::abs(sse_trial - sse) > tolerence) { + while (it < maxits and sham::abs(sse_trial - sse) > tolerance) { sse = 0.0; for (int k = 0; k < X.size(); k++) { T r = Y[k] - f(p, X[k]); From 90d557037875b34e2cdf89ac70ede588f2d4d911 Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:09:34 +0100 Subject: [PATCH 22/33] remove magic numbers --- src/shammath/include/shammath/solve.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index e0052e8682..c73ddfc696 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -102,9 +102,13 @@ namespace shammath { f_at_p[i] = f(p, X[i]); } mat_set_vals(J.get_mdspan(), [&](auto i, auto j) -> T { - T original_p_j = p[j]; - T step_scale = (std::abs(original_p_j) < 1e-6) ? 1e-6 : original_p_j; - T dpj = step_scale * 0.001; + T original_p_j = p[j]; + const T MIN_STEP_SCALE_EPSILON = 1e-6; + T step_scale = (std::abs(original_p_j) < MIN_STEP_SCALE_EPSILON) + ? MIN_STEP_SCALE_EPSILON + : original_p_j; + const T FINITE_DIFF_STEP_FACTOR = 0.001; + T dpj = step_scale * FINITE_DIFF_STEP_FACTOR; p[j] += dpj; T f_perturbed = f(p, X[i]); From e4d071fa19b9e13fabda526a150aef535ec1a818 Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:18:14 +0100 Subject: [PATCH 23/33] == overload constant --- src/shammath/include/shammath/matrix.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 4aeaea0a57..4f1fde21cd 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -118,7 +118,7 @@ namespace shammath { inline constexpr T &operator[](int i) { return get_mdspan()(i); } /// Check if this vector is equal to another one - bool operator==(const vec &other) { return data == other.data; } + bool operator==(const vec &other) const { return data == other.data; } }; /** @@ -216,7 +216,7 @@ namespace shammath { inline constexpr T &operator[](int i) { return get_mdspan()(i); } /// Check if this vector is equal to another one - bool operator==(const vec_d &other) { return data == other.data; } + bool operator==(const vec_d &other) const { return data == other.data; } }; } // namespace shammath From 0ebcc1081fa3f2ab89e2a83dfa3da1bc36621a96 Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:19:00 +0100 Subject: [PATCH 24/33] std::sqrt instead of sycl::sqrt --- src/shammath/include/shammath/matrix_op.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index cab8dd13d8..0da9ee2355 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -733,7 +733,7 @@ namespace shammath { for (int k = 0; k < i; k++) { sum_ik += L(i, k) * L(i, k); } - L(i, i) = sycl::sqrt(M(i, i) - sum_ik); + L(i, i) = std::sqrt(M(i, i) - sum_ik); for (int j = i + 1; j < M.extent(1); j++) { T sum_ikjk = 0.0; for (int k = 0; k < i; k++) { From 49fdfb6f200f9491de5d870e9f2b1dd87fb53bc8 Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:24:08 +0100 Subject: [PATCH 25/33] last assert --- src/shammath/include/shammath/solve.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index c73ddfc696..2febef6af3 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -73,7 +73,8 @@ namespace shammath { int maxits = 1000, T tolerance = 1e-6) { SHAM_ASSERT(X.size() == Y.size()); - SHAM_ASSERT(X.size() > p0.size()); + SHAM_ASSERT(X.size() >= p0.size()); + SHAM_ASSERT(p0.size() > 0); const int params_nb = p0.size(); const int data_size = X.size(); From ef2f577488ce98560f897189305969470926350b Mon Sep 17 00:00:00 2001 From: David Fang Date: Mon, 26 Jan 2026 17:27:10 +0100 Subject: [PATCH 26/33] revert useless change --- src/shammath/include/shammath/matrix.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 4f1fde21cd..b70dc6c0e8 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -69,8 +69,8 @@ namespace shammath { /// check if this matrix is equal to another one at a given precison bool equal_at_precision(const mat &other, const T precision) const { bool res = true; - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { + for (auto i = 0; i < m; i++) { + for (auto j = 0; j < n; j++) { if (sham::abs(data[i * n + j] - other.data[i * n + j]) >= precision) { res = false; } @@ -169,8 +169,8 @@ namespace shammath { /// check if this matrix is equal to another one at a given precison bool equal_at_precision(const mat_d &other, const T precision) const { bool res = true; - for (int i = 0; i < rows; i++) { - for (int j = 0; j < columns; j++) { + for (auto i = 0; i < rows; i++) { + for (auto j = 0; j < columns; j++) { if (sham::abs(data[i * columns + j] - other.data[i * columns + j]) >= precision) { res = false; From 087764a585cc240d3b1ec2391d0a688cabfc2aed Mon Sep 17 00:00:00 2001 From: David Fang Date: Tue, 27 Jan 2026 15:26:48 +0100 Subject: [PATCH 27/33] remove last doxygen warnings --- src/shammath/include/shammath/integrator.hpp | 10 +++++----- src/shammath/include/shammath/matrix.hpp | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/shammath/include/shammath/integrator.hpp b/src/shammath/include/shammath/integrator.hpp index 4872a4696f..c71284447a 100644 --- a/src/shammath/include/shammath/integrator.hpp +++ b/src/shammath/include/shammath/integrator.hpp @@ -38,14 +38,14 @@ namespace shammath { * u'(x) &=& f(u,x) \\ * u(x_0) &=& u_0 * \f} - * and will be solved between start and end with step $\mathrm{d}t$. + * and will be solved between start and end with step \f$ \mathrm{d}t \f$. * * @param start Lower bound of integration * @param end Higher bound of integration - * @param step Step of integration $\mathrm{d}t$ - * @param ode Ode function $f$ - * @param x0 Initial coordinate $x_0$ - * @param u0 Initial value $u_0$ + * @param step Step of integration \f$ \mathrm{d}t \f$ + * @param ode Ode function \f$ f \f$ + * @param x0 Initial coordinate \f$ x_0 \f$ + * @param u0 Initial value \f$ u_0 \f$ */ template inline constexpr std::pair, std::vector> euler_ode( diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index b70dc6c0e8..92be9a679e 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -58,6 +58,7 @@ namespace shammath { /// Check if this matrix is equal to another one bool operator==(const mat &other) const { return data == other.data; } + // Addition operator for matrices inline mat &operator+=(const mat &other) { #pragma unroll for (size_t i = 0; i < m * n; i++) { @@ -157,7 +158,7 @@ namespace shammath { /// Check if this matrix is equal to another one bool operator==(const mat_d &other) const { return data == other.data; } - /// Addition operator + /// Addition operator for matrices inline mat_d &operator+=(const mat_d &other) { #pragma unroll for (size_t i = 0; i < get_mdspan().extent(0) * get_mdspan().extent(1); i++) { From 65187cef065d675b49b36e292b9a9758ce94f369 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sat, 14 Feb 2026 22:33:01 +0100 Subject: [PATCH 28/33] fixes --- src/shammath/include/shammath/matrix.hpp | 5 ++--- src/shammath/include/shammath/matrix_op.hpp | 22 ++++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/shammath/include/shammath/matrix.hpp b/src/shammath/include/shammath/matrix.hpp index 92be9a679e..3cd555617d 100644 --- a/src/shammath/include/shammath/matrix.hpp +++ b/src/shammath/include/shammath/matrix.hpp @@ -123,7 +123,7 @@ namespace shammath { }; /** - * @brief Matrix class based on std::vector storage and mdspan + * @brief Matrix class with runtime size based on std::vector storage and mdspan * @tparam T the type of the matrix entries */ template @@ -160,7 +160,6 @@ namespace shammath { /// Addition operator for matrices inline mat_d &operator+=(const mat_d &other) { -#pragma unroll for (size_t i = 0; i < get_mdspan().extent(0) * get_mdspan().extent(1); i++) { data[i] += other.data[i]; } @@ -183,7 +182,7 @@ namespace shammath { }; /** - * @brief Vector class based on std::vector storage and mdspan + * @brief Vector class with runtime size based on std::vector storage and mdspan * @tparam T the type of the vector entries * @tparam n the number of entries */ diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 0da9ee2355..6159d5b185 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -733,7 +733,7 @@ namespace shammath { for (int k = 0; k < i; k++) { sum_ik += L(i, k) * L(i, k); } - L(i, i) = std::sqrt(M(i, i) - sum_ik); + L(i, i) = sycl::sqrt(M(i, i) - sum_ik); for (int j = i + 1; j < M.extent(1); j++) { T sum_ikjk = 0.0; for (int k = 0; k < i; k++) { @@ -746,16 +746,16 @@ namespace shammath { } /** -* @brief This function solves a system of linear equations with Cholesky decomposition. The -system must have the form -\f[ - Mx = y -\f] -where $M$ is a (real) symmetric, definite-positive square matrix. -* @param M a square symmetric, definite-positive matrix -* @param y a vector, right hand side of the system -* @param x the ouput vector to store the solution of the system -*/ + * @brief This function solves a system of linear equations with Cholesky decomposition. The + system must have the form + \f[ + Mx = y + \f] + where $M$ is a (real) symmetric, definite-positive square matrix. + * @param M a square symmetric, definite-positive matrix + * @param y a vector, right hand side of the system + * @param x the ouput vector to store the solution of the system + */ template< class T, class Extents1, From fa0932991da87c893af82f70b21c79a2dd421363 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 15 Feb 2026 00:08:43 +0100 Subject: [PATCH 29/33] doxygen typo --- src/shammath/include/shammath/matrix_op.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index 6159d5b185..50910fa144 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -704,7 +704,7 @@ namespace shammath { /** * @brief This function performs Cholesky decomposition. From a (real) symmetric, - definite-positive square matrix $M$, return a lower triangular matrix $L$ such that + definite-positive square matrix \f$ M \f$, return a lower triangular matrix \f$ L \f$ such that \f[ M = L L^T \f] @@ -751,7 +751,7 @@ namespace shammath { \f[ Mx = y \f] - where $M$ is a (real) symmetric, definite-positive square matrix. + where \f$ M \f$ is a (real) symmetric, definite-positive square matrix. * @param M a square symmetric, definite-positive matrix * @param y a vector, right hand side of the system * @param x the ouput vector to store the solution of the system From d6fe7bda403809900dfee3c34cf6f7cdfa661754 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 15 Feb 2026 01:45:33 +0100 Subject: [PATCH 30/33] improve doxygen --- src/shammath/include/shammath/solve.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 2febef6af3..92ae3815ad 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -49,7 +49,8 @@ namespace shammath { * @brief This function determines the best fit parameters \f$ \vec p \f$ for a given function \f$ f(\vec p, \mathbf X) \f$ with least squares. * - * @param f Function (1d values) + * @param f Function (1d values). Must takes a vector (\f$ \vec p \f$) and a scalar (\f$ + x \f$) as parameters. * @param X Data to fit \f$ x \f$ * @param Y Data to fit \f$ y \f$ * @param p0 Initial parameters guessed @@ -57,6 +58,7 @@ namespace shammath { 1000 * @param tolerance Convergence condition in the Levenberg-Marquardt procedure. Default: 1e-6 * + * @return tuple: \f$ (\vec p, R^2) \f$ * @details The Levenberg-Marquardt method is used. Therefore, the number of observations needs * to be greater than the number of parameters. At every iteration, a new parameters array \f$ \vec p' \f$ is estimated. The convergence condition is From e94cf795f6bd4a4293dea686e356fdb42c291532 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 15 Feb 2026 01:46:15 +0100 Subject: [PATCH 31/33] tuple to pair --- src/shammath/include/shammath/solve.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 92ae3815ad..9d3af6a0f0 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -58,7 +58,7 @@ namespace shammath { 1000 * @param tolerance Convergence condition in the Levenberg-Marquardt procedure. Default: 1e-6 * - * @return tuple: \f$ (\vec p, R^2) \f$ + * @return pair: \f$ (\vec p, R^2) \f$ * @details The Levenberg-Marquardt method is used. Therefore, the number of observations needs * to be greater than the number of parameters. At every iteration, a new parameters array \f$ \vec p' \f$ is estimated. The convergence condition is @@ -67,7 +67,7 @@ namespace shammath { \f] where \f$ S \f$ is the residual sum of squares and \f$ \epsilon \f$ is the tolerance. */ template - std::tuple, T> least_squares( + std::pair, T> least_squares( const Lambda &f, const std::vector &X, const std::vector &Y, From 638e26804721748b752f3d03450e013a70c00d39 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 15 Feb 2026 03:00:57 +0100 Subject: [PATCH 32/33] minor improvements --- src/shammath/include/shammath/solve.hpp | 34 +++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/shammath/include/shammath/solve.hpp b/src/shammath/include/shammath/solve.hpp index 9d3af6a0f0..b3d5a94fdf 100644 --- a/src/shammath/include/shammath/solve.hpp +++ b/src/shammath/include/shammath/solve.hpp @@ -73,7 +73,7 @@ namespace shammath { const std::vector &Y, const std::vector &p0, int maxits = 1000, - T tolerance = 1e-6) { + T tolerance = 1e-9) { SHAM_ASSERT(X.size() == Y.size()); SHAM_ASSERT(X.size() >= p0.size()); SHAM_ASSERT(p0.size() > 0); @@ -85,18 +85,19 @@ namespace shammath { T mu = 1e-2; // damping parameter T beta = 0.1; // decay rate int it = 0; - T sse = 0.0; - for (int k = 0; k < X.size(); k++) { - T r = Y[k] - f(p, X[k]); - sse += r * r; + + auto evaluate_sse = [&](const std::vector ¶ms) -> T { + T sse = 0.0; + for (int k = 0; k < data_size; k++) { + T r = Y[k] - f(params, X[k]); + sse += r * r; + } + return sse; }; - T sse_trial = 999.0; + T sse = evaluate_sse(p); + T sse_trial = sse + 2 * tolerance; while (it < maxits and sham::abs(sse_trial - sse) > tolerance) { - sse = 0.0; - for (int k = 0; k < X.size(); k++) { - T r = Y[k] - f(p, X[k]); - sse += r * r; - }; + sse = evaluate_sse(p); // Construct the Jacobian (finite differences) shammath::mat_d J(data_size, params_nb); @@ -105,6 +106,7 @@ namespace shammath { f_at_p[i] = f(p, X[i]); } mat_set_vals(J.get_mdspan(), [&](auto i, auto j) -> T { + // This part can be improved if necessary (p is modified then restored). T original_p_j = p[j]; const T MIN_STEP_SCALE_EPSILON = 1e-6; T step_scale = (std::abs(original_p_j) < MIN_STEP_SCALE_EPSILON) @@ -144,26 +146,26 @@ namespace shammath { }; sse_trial = 0.0; - for (int k = 0; k < X.size(); k++) { + for (int k = 0; k < data_size; k++) { T residual = Y[k] - f(p_trial, X[k]); sse_trial += residual * residual; }; if (sse_trial > sse) { // Fail -> gradient descent mu /= beta; } else { // Not bad -> Gauss-Newton - it++; mu *= beta; p = p_trial; } + it++; }; T total_sum_squares = 0.0; T mean_Y = 0.0; - for (int k = 0; k < Y.size(); k++) { + for (int k = 0; k < data_size; k++) { mean_Y += Y[k]; } - mean_Y /= Y.size(); - for (int k = 0; k < Y.size(); k++) { + mean_Y /= data_size; + for (int k = 0; k < data_size; k++) { total_sum_squares += (Y[k] - mean_Y) * (Y[k] - mean_Y); } T R2 = 1 - sse / total_sum_squares; From 659e8fd819737f511c2d4eb9112e657f42bd21a4 Mon Sep 17 00:00:00 2001 From: David Fang Date: Sun, 15 Feb 2026 03:01:06 +0100 Subject: [PATCH 33/33] add test for least_squares --- src/tests/shammath/matrixTests.cpp | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/tests/shammath/matrixTests.cpp b/src/tests/shammath/matrixTests.cpp index d14a3722fc..baa85bc831 100644 --- a/src/tests/shammath/matrixTests.cpp +++ b/src/tests/shammath/matrixTests.cpp @@ -11,6 +11,7 @@ #include "shamcomm/logs.hpp" #include "shammath/matrix.hpp" #include "shammath/matrix_op.hpp" +#include "shammath/solve.hpp" #include "shamtest/details/TestResult.hpp" #include "shamtest/shamtest.hpp" @@ -561,3 +562,49 @@ TestStart(Unittest, "shammath/matrix::Cholesky_solve", test_Cholesky_solve, 1) { < 1e-9; }); } + +// This test uses Eckerle4 from NIST Standard Reference Database [Eckerle, K., NIST (1979)] +// https://www.itl.nist.gov/div898/strd/nls/data/eckerle4.shtml +TestStart(Unittest, "shammath/solve::least_squares", test_least_squares, 1) { + std::vector p0 = {1, 1e1, 5e2}; + std::vector X + = {400.000000e0, 405.000000e0, 410.000000e0, 415.000000e0, 420.000000e0, 425.000000e0, + 430.000000e0, 435.000000e0, 436.500000e0, 438.000000e0, 439.500000e0, 441.000000e0, + 442.500000e0, 444.000000e0, 445.500000e0, 447.000000e0, 448.500000e0, 450.000000e0, + 451.500000e0, 453.000000e0, 454.500000e0, 456.000000e0, 457.500000e0, 459.000000e0, + 460.500000e0, 462.000000e0, 463.500000e0, 465.000000e0, 470.000000e0, 475.000000e0, + 480.000000e0, 485.000000e0, 490.000000e0, 495.000000e0, 500.000000e0}; + + std::vector Y = { + 0.0001575e0, 0.0001699e0, 0.0002350e0, 0.0003102e0, 0.0004917e0, 0.0008710e0, 0.0017418e0, + 0.0046400e0, 0.0065895e0, 0.0097302e0, 0.0149002e0, 0.0237310e0, 0.0401683e0, 0.0712559e0, + 0.1264458e0, 0.2073413e0, 0.2902366e0, 0.3445623e0, 0.3698049e0, 0.3668534e0, 0.3106727e0, + 0.2078154e0, 0.1164354e0, 0.0616764e0, 0.0337200e0, 0.0194023e0, 0.0117831e0, 0.0074357e0, + 0.0022732e0, 0.0008800e0, 0.0004579e0, 0.0002345e0, 0.0001586e0, 0.0001143e0, 0.0000710e0}; + auto ls = shammath::least_squares( + [](std::vector p, f64 x) -> f64 { + f64 b1 = p[0]; + f64 b2 = p[1]; + f64 b3 = p[2]; + return (b1 / b2) * exp(-0.5 * ((x - b3) / b2) * ((x - b3) / b2)); + }, + X, + Y, + p0, + 1000, + 1e-9); + + auto pfit = ls.first; + auto R2 = ls.second; + std::vector res = {pfit[0], pfit[1], pfit[2], R2}; + std::vector ex_res = {1.55, 4.08, 4.5154e2}; + std::vector ex_deviation = {2e-2, 4.7e-2, 4.7e-2, 1e-2}; + + bool test_fit = true; + for (size_t i; i < 4; i++) { + if (sham::abs(res[i] - ex_res[i]) > ex_deviation[i]) { + test_fit = false; + } + }; + REQUIRE(test_fit); +}