diff --git a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs index 27341a9f5..759176fc8 100644 --- a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs +++ b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs @@ -341,10 +341,13 @@ impl GaussianMixtureModel { observations: &ArrayBase, ) -> (Array1, Array2) { let weighted_log_prob = self.estimate_weighted_log_prob(observations); - let log_prob_norm = weighted_log_prob - .mapv(|x| x.exp()) - .sum_axis(Axis(1)) - .mapv(|x| x.ln()); + // Log-sum-exp trick: shift by per-row max before exponentiating to prevent + // overflow when weighted log-probabilities are large (e.g. tight Gaussians). + // Mathematically: ln(Σ exp(xᵢ)) = max + ln(Σ exp(xᵢ - max)) + let log_max = weighted_log_prob + .map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b))); + let shifted = &weighted_log_prob - &log_max.clone().insert_axis(Axis(1)); + let log_prob_norm = shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max; let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1)); (log_prob_norm, log_resp) } @@ -778,4 +781,63 @@ mod tests { let ones = ndarray::Array1::ones(n_samples); assert_abs_diff_eq!(row_sums, ones, epsilon = 1e-6); } + + // Regression test for issue #442: naive exp-sum-ln overflows for large weighted log-probs. + // + // With 70 features and covariance = 1e-10 * I, a point exactly at the cluster mean produces + // a log-gaussian-probability of roughly 741, which causes f64::exp() to overflow (threshold + // ~709.78). The log-sum-exp fix must keep log_prob_norm finite and responsibilities summing + // to 1. + #[test] + fn test_large_weighted_log_prob_numerical_stability() { + use ndarray::Array2; + + const N_FEATURES: usize = 70; + const EPS: f64 = 1e-10; // tight covariance → log-prob at mean ≈ 741 >> 709 + + // Build a single-component, N_FEATURES-dimensional GMM with covariance = EPS * I. + // We use the private precision helpers available to child modules. + let cov_2d = Array2::::eye(N_FEATURES) * EPS; + // insert_axis creates shape (1, N_FEATURES, N_FEATURES) + let covariances = cov_2d.insert_axis(Axis(0)); + + let precisions_chol = GaussianMixtureModel::compute_precisions_cholesky_full(&covariances) + .expect("Cholesky should succeed for positive-definite covariance"); + let precisions = GaussianMixtureModel::compute_precisions_full(&precisions_chol); + + let gmm = GaussianMixtureModel { + covar_type: GmmCovarType::Full, + weights: array![1.0_f64], + means: Array2::zeros((1, N_FEATURES)), + covariances, + precisions, + precisions_chol, + }; + + // Three observations all sitting at the cluster mean → worst-case large log-prob. + let observations = Array2::zeros((3, N_FEATURES)); + let (log_prob_norm, log_resp) = gmm.estimate_log_prob_resp(&observations.view()); + + // All log_prob_norm values must be finite (not inf/nan). + for (i, &v) in log_prob_norm.iter().enumerate() { + assert!( + v.is_finite(), + "log_prob_norm[{}] is not finite ({}): log-sum-exp fix missing or broken", + i, + v + ); + } + + // Responsibilities (exp(log_resp)) must sum to 1 for every sample. + let resp = log_resp.mapv(f64::exp); + let row_sums = resp.sum_axis(Axis(1)); + for (i, &s) in row_sums.iter().enumerate() { + assert!( + (s - 1.0).abs() < 1e-9, + "responsibilities for sample {} sum to {}, expected 1.0", + i, + s + ); + } + } }