Skip to content

fix(gmm): use log-sum-exp in estimate_log_prob_resp to prevent overflow#443

Open
schliffen wants to merge 1 commit into
rust-ml:masterfrom
schliffen:fix/issue-442-log-sum-exp
Open

fix(gmm): use log-sum-exp in estimate_log_prob_resp to prevent overflow#443
schliffen wants to merge 1 commit into
rust-ml:masterfrom
schliffen:fix/issue-442-log-sum-exp

Conversation

@schliffen
Copy link
Copy Markdown

Fixes #442

Problem

estimate_log_prob_resp computed the per-sample log-normaliser as:

ln(Σ exp(xᵢ))

For tight Gaussians (small covariance, observation near the mean) the weighted
log-probabilities xᵢ can exceed ~709, causing f64::exp() to overflow to inf.
The subsequent ln(inf) produces an infinite log_prob_norm, and every log_resp
becomes -inf, collapsing all responsibilities to NaN / zero and destabilising EM.

Fix

Replace with the log-sum-exp trick:

max + ln(Σ exp(xᵢ − max))

where max = max_i(xᵢ) is computed independently per sample row via
map_axis(Axis(1), ...). After the shift the largest exponent is 0, so
exp() never exceeds 1.

Proof of equivalence:

max + ln(Σ exp(xᵢ − max))
= max + ln(Σ exp(xᵢ) · exp(−max))     [exp(a−b) = exp(a)·exp(−b)]
= max + ln(exp(−max) · Σ exp(xᵢ))     [factor out constant exp(−max)]
= max + ln(exp(−max)) + ln(Σ exp(xᵢ)) [ln(a·b) = ln(a)+ln(b)]
= max + (−max) + ln(Σ exp(xᵢ))        [ln(exp(−max)) = −max]
= ln(Σ exp(xᵢ))                        ∎

Both the old and new code operate per-row (Axis(1)); only the intermediate
numerical range changes.

Regression test

test_large_weighted_log_prob_numerical_stability builds a 70-feature,
1-component GMM with covariance 1e-10 · I. A point at the cluster mean
produces a log-Gaussian probability of ~741 — above the f64 overflow
threshold of ~709.78. The test asserts log_prob_norm is finite and
responsibilities sum to 1.

@schliffen schliffen force-pushed the fix/issue-442-log-sum-exp branch from e3fe6dd to 58a03d4 Compare May 28, 2026 00:45
@codecov
Copy link
Copy Markdown

codecov Bot commented May 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.56%. Comparing base (60382a3) to head (5407c0b).

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #443   +/-   ##
=======================================
  Coverage   77.56%   77.56%           
=======================================
  Files         106      106           
  Lines        7585     7585           
=======================================
  Hits         5883     5883           
  Misses       1702     1702           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

…ow (rust-ml#442)

The previous implementation computed the per-sample log-normaliser as:

    ln(Σ exp(xᵢ))

where xᵢ are the weighted log-probabilities for each mixture component.
For tight Gaussians (small covariance, observation near the mean) xᵢ can
exceed ~709, causing f64::exp() to overflow to inf.  The subsequent ln(inf)
produces an infinite log_prob_norm, and every log_resp becomes -inf, which
collapses all responsibilities to NaN / zero.

Fix: replace with the log-sum-exp (LSE) trick —

    max + ln(Σ exp(xᵢ − max))

where max = max_i(xᵢ) is computed independently for each sample row.
After the shift, the largest exponent is 0, so exp() never exceeds 1.

Proof of equivalence:

    max + ln(Σ exp(xᵢ − max))
    = max + ln(Σ exp(xᵢ) · exp(−max))   [exp(a−b) = exp(a)·exp(−b)]
    = max + ln(exp(−max) · Σ exp(xᵢ))   [factor out constant exp(−max)]
    = max + ln(exp(−max)) + ln(Σ exp(xᵢ)) [ln(a·b) = ln(a)+ln(b)]
    = max + (−max) + ln(Σ exp(xᵢ))      [ln(exp(−max)) = −max]
    = ln(Σ exp(xᵢ))                      ∎

Both the original and the new code operate per-row (Axis(1)); the only
change is the intermediate numerical range.

Adds a regression test (test_large_weighted_log_prob_numerical_stability)
that constructs a 70-feature, 1-component GMM with covariance 1e-10·I,
producing a log-Gaussian probability of ~741 at the cluster mean — well
above the f64 overflow threshold of ~709.78.  The test asserts that
log_prob_norm is finite and that responsibilities sum to 1.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@schliffen schliffen force-pushed the fix/issue-442-log-sum-exp branch from 58a03d4 to 5407c0b Compare May 28, 2026 00:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Numerical overflow in linfa-clustering GMM E-step: estimate_log_prob_resp uses unstable exp-sum-ln normalization

1 participant