fix(gmm): use log-sum-exp in estimate_log_prob_resp to prevent overflow#443
Open
schliffen wants to merge 1 commit into
Open
fix(gmm): use log-sum-exp in estimate_log_prob_resp to prevent overflow#443schliffen wants to merge 1 commit into
schliffen wants to merge 1 commit into
Conversation
e3fe6dd to
58a03d4
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #443 +/- ##
=======================================
Coverage 77.56% 77.56%
=======================================
Files 106 106
Lines 7585 7585
=======================================
Hits 5883 5883
Misses 1702 1702 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ow (rust-ml#442) The previous implementation computed the per-sample log-normaliser as: ln(Σ exp(xᵢ)) where xᵢ are the weighted log-probabilities for each mixture component. For tight Gaussians (small covariance, observation near the mean) xᵢ can exceed ~709, causing f64::exp() to overflow to inf. The subsequent ln(inf) produces an infinite log_prob_norm, and every log_resp becomes -inf, which collapses all responsibilities to NaN / zero. Fix: replace with the log-sum-exp (LSE) trick — max + ln(Σ exp(xᵢ − max)) where max = max_i(xᵢ) is computed independently for each sample row. After the shift, the largest exponent is 0, so exp() never exceeds 1. Proof of equivalence: max + ln(Σ exp(xᵢ − max)) = max + ln(Σ exp(xᵢ) · exp(−max)) [exp(a−b) = exp(a)·exp(−b)] = max + ln(exp(−max) · Σ exp(xᵢ)) [factor out constant exp(−max)] = max + ln(exp(−max)) + ln(Σ exp(xᵢ)) [ln(a·b) = ln(a)+ln(b)] = max + (−max) + ln(Σ exp(xᵢ)) [ln(exp(−max)) = −max] = ln(Σ exp(xᵢ)) ∎ Both the original and the new code operate per-row (Axis(1)); the only change is the intermediate numerical range. Adds a regression test (test_large_weighted_log_prob_numerical_stability) that constructs a 70-feature, 1-component GMM with covariance 1e-10·I, producing a log-Gaussian probability of ~741 at the cluster mean — well above the f64 overflow threshold of ~709.78. The test asserts that log_prob_norm is finite and that responsibilities sum to 1. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
58a03d4 to
5407c0b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #442
Problem
estimate_log_prob_respcomputed the per-sample log-normaliser as:For tight Gaussians (small covariance, observation near the mean) the weighted
log-probabilities xᵢ can exceed ~709, causing
f64::exp()to overflow toinf.The subsequent
ln(inf)produces an infinitelog_prob_norm, and everylog_respbecomes
-inf, collapsing all responsibilities to NaN / zero and destabilising EM.Fix
Replace with the log-sum-exp trick:
where
max = max_i(xᵢ)is computed independently per sample row viamap_axis(Axis(1), ...). After the shift the largest exponent is 0, soexp()never exceeds 1.Proof of equivalence:
Both the old and new code operate per-row (
Axis(1)); only the intermediatenumerical range changes.
Regression test
test_large_weighted_log_prob_numerical_stabilitybuilds a 70-feature,1-component GMM with covariance
1e-10 · I. A point at the cluster meanproduces a log-Gaussian probability of ~741 — above the f64 overflow
threshold of ~709.78. The test asserts
log_prob_normis finite andresponsibilities sum to 1.