From ace53a88deabb4342fa98f9cab41ba15701b8047 Mon Sep 17 00:00:00 2001 From: Ali Nehrani Date: Sat, 30 May 2026 18:35:07 +0900 Subject: [PATCH 1/2] Implement log-sum-exp trick in Gaussian Mixture Model to prevent overflow --- .../linfa-clustering/src/gaussian_mixture/algorithm.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs index 27341a9f5..34f23f1af 100644 --- a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs +++ b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs @@ -341,10 +341,11 @@ impl GaussianMixtureModel { observations: &ArrayBase, ) -> (Array1, Array2) { let weighted_log_prob = self.estimate_weighted_log_prob(observations); - let log_prob_norm = weighted_log_prob - .mapv(|x| x.exp()) - .sum_axis(Axis(1)) - .mapv(|x| x.ln()); + // Log-sum-exp trick to avoid overflow: ln(Σ exp(xᵢ)) = max + ln(Σ exp(xᵢ - max)) + let log_max = weighted_log_prob + .map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b))); + let shifted = &weighted_log_prob - &log_max.clone().insert_axis(Axis(1)); + let log_prob_norm = shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max; let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1)); (log_prob_norm, log_resp) } From 238b776be0ceccc251b2ab3f5bc71c8af1532ddc Mon Sep 17 00:00:00 2001 From: Ali Nehrani Date: Sat, 30 May 2026 18:57:26 +0900 Subject: [PATCH 2/2] fix lint problem --- algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs index 34f23f1af..4be34a2b0 100644 --- a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs +++ b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs @@ -343,7 +343,7 @@ impl GaussianMixtureModel { let weighted_log_prob = self.estimate_weighted_log_prob(observations); // Log-sum-exp trick to avoid overflow: ln(Σ exp(xᵢ)) = max + ln(Σ exp(xᵢ - max)) let log_max = weighted_log_prob - .map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b))); + .map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b))); let shifted = &weighted_log_prob - &log_max.clone().insert_axis(Axis(1)); let log_prob_norm = shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max; let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));