diff --git a/src/utils/linear_algebra.rs b/src/utils/linear_algebra.rs index d95cfee..35934dc 100644 --- a/src/utils/linear_algebra.rs +++ b/src/utils/linear_algebra.rs @@ -1,9 +1,12 @@ +use num::{Float, Zero}; use rlst::dense::linalg::{lu::MatrixLu, null_space::Method}; pub use rlst::prelude::*; use serde::Deserialize; use crate::rsrs::rsrs_factors::null_and_extract::{ExtractOptions, IdOptions}; +type Real = ::Real; + fn solve_svd< Item: RlstScalar + MatrixPseudoInverse, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> @@ -90,6 +93,29 @@ pub fn add_diagonal( } } +fn normal_equation_scale(normal: &DynamicArray) -> Real { + let shape = normal.shape(); + let view = normal.r(); + let mut scale = Real::::zero(); + + for i in 0..shape[0] { + scale = scale.max(view[[i, i]].re()); + } + + scale +} + +fn normal_equation_regularization( + normal: &DynamicArray, + tol_lstq: Real, +) -> Real { + let tol_abs = Float::abs(tol_lstq); + let tol_sq = tol_abs * tol_abs; + let factor: Real = Float::max(Real::::epsilon(), tol_sq); + let scale: Real = normal_equation_scale::(normal); + factor * scale +} + impl< 'a, Item: RlstScalar + MatrixLu, @@ -115,7 +141,8 @@ impl< arr.r(), ::zero(), ); - add_diagonal(&mut normal, tol_lstq); //Regularisation + let regularization = normal_equation_regularization::(&normal, tol_lstq); + add_diagonal(&mut normal, regularization); let lu = ::into_lu_alloc(normal).unwrap(); Self { arr, normal: lu } @@ -211,7 +238,8 @@ where MatrixLuDecomposition, { pub fn solve(mut self, tol_lstq: ::Real) -> DynamicArray { - add_diagonal(&mut self.normal, tol_lstq); + let regularization = normal_equation_regularization::(&self.normal, tol_lstq); + add_diagonal(&mut self.normal, regularization); let lu = ::into_lu_alloc(self.normal).unwrap(); let _ = as MatrixLuDecomposition>::solve_mat( &lu,