Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,13 @@ impl<F: Float> GaussianMixtureModel<F> {
observations: &ArrayBase<D, Ix2>,
) -> (Array1<F>, Array2<F>) {
let weighted_log_prob = self.estimate_weighted_log_prob(observations);
let log_prob_norm = weighted_log_prob
.mapv(|x| x.exp())
.sum_axis(Axis(1))
.mapv(|x| x.ln());
// Log-sum-exp trick: shift by per-row max before exponentiating to prevent
// overflow when weighted log-probabilities are large (e.g. tight Gaussians).
// Mathematically: ln(Σ exp(xᵢ)) = max + ln(Σ exp(xᵢ - max))
let log_max = weighted_log_prob
.map_axis(Axis(1), |row| row.fold(F::neg_infinity(), |a, &b| a.max(b)));
let shifted = &weighted_log_prob - &log_max.clone().insert_axis(Axis(1));
let log_prob_norm = shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max;
let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
(log_prob_norm, log_resp)
}
Expand Down Expand Up @@ -778,4 +781,63 @@ mod tests {
let ones = ndarray::Array1::ones(n_samples);
assert_abs_diff_eq!(row_sums, ones, epsilon = 1e-6);
}

// Regression test for issue #442: naive exp-sum-ln overflows for large weighted log-probs.
//
// With 70 features and covariance = 1e-10 * I, a point exactly at the cluster mean produces
// a log-gaussian-probability of roughly 741, which causes f64::exp() to overflow (threshold
// ~709.78). The log-sum-exp fix must keep log_prob_norm finite and responsibilities summing
// to 1.
#[test]
fn test_large_weighted_log_prob_numerical_stability() {
use ndarray::Array2;

const N_FEATURES: usize = 70;
const EPS: f64 = 1e-10; // tight covariance → log-prob at mean ≈ 741 >> 709

// Build a single-component, N_FEATURES-dimensional GMM with covariance = EPS * I.
// We use the private precision helpers available to child modules.
let cov_2d = Array2::<f64>::eye(N_FEATURES) * EPS;
// insert_axis creates shape (1, N_FEATURES, N_FEATURES)
let covariances = cov_2d.insert_axis(Axis(0));

let precisions_chol = GaussianMixtureModel::compute_precisions_cholesky_full(&covariances)
.expect("Cholesky should succeed for positive-definite covariance");
let precisions = GaussianMixtureModel::compute_precisions_full(&precisions_chol);

let gmm = GaussianMixtureModel {
covar_type: GmmCovarType::Full,
weights: array![1.0_f64],
means: Array2::zeros((1, N_FEATURES)),
covariances,
precisions,
precisions_chol,
};

// Three observations all sitting at the cluster mean → worst-case large log-prob.
let observations = Array2::zeros((3, N_FEATURES));
let (log_prob_norm, log_resp) = gmm.estimate_log_prob_resp(&observations.view());

// All log_prob_norm values must be finite (not inf/nan).
for (i, &v) in log_prob_norm.iter().enumerate() {
assert!(
v.is_finite(),
"log_prob_norm[{}] is not finite ({}): log-sum-exp fix missing or broken",
i,
v
);
}

// Responsibilities (exp(log_resp)) must sum to 1 for every sample.
let resp = log_resp.mapv(f64::exp);
let row_sums = resp.sum_axis(Axis(1));
for (i, &s) in row_sums.iter().enumerate() {
assert!(
(s - 1.0).abs() < 1e-9,
"responsibilities for sample {} sum to {}, expected 1.0",
i,
s
);
}
}
}
Loading