diff --git a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs index 27341a9f5..4be34a2b0 100644 --- a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs +++ b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs @@ -341,10 +341,11 @@ impl GaussianMixtureModel { observations: &ArrayBase, ) -> (Array1, Array2) { let weighted_log_prob = self.estimate_weighted_log_prob(observations); - let log_prob_norm = weighted_log_prob - .mapv(|x| x.exp()) - .sum_axis(Axis(1)) - .mapv(|x| x.ln()); + // Log-sum-exp trick to avoid overflow: ln(Σ exp(xᵢ)) = max + ln(Σ exp(xᵢ - max)) + let log_max = weighted_log_prob + .map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b))); + let shifted = &weighted_log_prob - &log_max.clone().insert_axis(Axis(1)); + let log_prob_norm = shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max; let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1)); (log_prob_norm, log_resp) }