Skip to content

Log space outliers#291

Merged
LennartPurucker merged 5 commits into
PriorLabs:mainfrom
ClementBourt:log-space-outliers
May 14, 2026
Merged

Log space outliers#291
LennartPurucker merged 5 commits into
PriorLabs:mainfrom
ClementBourt:log-space-outliers

Conversation

@ClementBourt
Copy link
Copy Markdown
Contributor

@ClementBourt ClementBourt commented May 12, 2026

Disclaimer

This document was drafted by AI and reviewed by a human.

Linked issue

Closes #289Overflow in diffuse density.

What changes

Replace outliers()'s combiner with the log of the Arythmetic Mean (AM)-via-logsumexp identity. Three coordinated edits in unsupervised.py, plus a matching consistency update in experiments.py.

unsupervised.py

  • L645 — drop 1.0 / inversion. The regressor branch reads the log-density directly:
    log_pred = -pred["criterion"].forward(logits, y_predict).to(log_p.device).
    The resolved L641–L644 TODO block is removed.
  • L658outliers_single_permutation_ returns log_p only.
  • L757–781outliers() averages densities via the log-sum-exp identity:
    return torch.logsumexp(torch.stack(log_densities), dim=0) - np.log(actual_n_permutations)
    No nan_to_num clamps; log-densities don't blow up the way exp(log_p) did.
  • Docstrings on outliers(), outliers_pdf, outliers_pmf, and the module-level usage example updated to reflect log-density semantics.
  • outliers_pdf and outliers_pmf rename their local variables pdf/pmflog_pdf/log_pmf to reflect the log-density semantics.

experiments.py

OutlierDetectionUnsupervisedExperiment is updated to match the new score semantics:

  • Rename: self.pself.log_p, DataFrame column "p""log_p", rank column "p_rank""log_p_rank".
  • Polarity flip in plot_two(): threshold tests inverted from x > thresh to x < thresh. Under the old 1/pdf combiner, large output meant low joint pdf meant outlier; under log(AM(pdf)), low output means outlier — opposite direction. Default percentiles are inverted to match: outlier_thresh_p 0.98 → 0.02, outlier_thresh_p_1 0.9 → 0.1, with oversampling fractions and legend labels updated accordingly.
  • Drop [self.data["p"] > 0] quantile filter, replace with -inf clamp: the old filter was a defensive measure against the old code's nan_to_num(nan=0.0) artifacts; under log-density semantics it would silently exclude all rows where the AM density is < 1, which is the dominant regime for diffuse predictive PDFs. The replacement: clamp -inf rows (zero-pmf categories) to the finite minimum for the quantile computation onlynp.quantile interpolation across -inf produces NaN (-inf - -inf). The original log_p series is preserved for bucketing, where x < thresh still classifies -inf rows as Low correctly.
  • Plot simplification: JointGrid (joint scatter + marginal histograms) replaced with a plain scatterplot.

⚠️ Breaking change — MAJOR semver bump

The return type of outliers(), outliers_pdf(), and outliers_pmf() changes from densities to log-densities. Per the project's semver policy in CONTRIBUTING.md, this is a MAJOR version bump. We recommend the strictly-breaking change because the old return type encoded the bug; consumers can call torch.exp(score) to recover the AM density if they want it.

Public attribute OutlierDetectionUnsupervisedExperiment.p is renamed to log_p, and the DataFrame columns "p" and "p_rank" (in the same class's self.data) are renamed to "log_p" and "log_p_rank". The threshold-test polarity in plot_two() is flipped (x > threshx < thresh) — any consumer that was thresholding the old p column directly must invert their comparison in addition to renaming. Same MAJOR-bump concern.

Empirical evidence

Synthetic dataset designed to reproduce overflow (1000 rows, 2 categorical × 9 classes, 5 numerical features as sign(±) * 10^uniform(-15, +15), n_permutations=10, seed=42):

Metric Pre-fix Post-fix
min(scores) 9.999e+29 -2.046e+02
max(scores) 1.000e+30 -1.346e+02
unique_values 2 999
frac_at_ceiling (≥ 1e29) 1.0 0.0
flag_rate @ 5th-pct rule 0.992 0.05

The post-fix flag_rate=0.05 confirms the 5%-percentile contract is honored. unique_values ≈ N confirms full per-row resolution (no ties from clamp saturation).

Test plan

  • pytest tests/test_unsupervised.py — all 5 pass (2 existing + 3 new)
  • pytest tests/ — full vendor suite passes
  • ruff check src/tabpfn_extensions/unsupervised/ — clean
  • mypy src/tabpfn_extensions/unsupervised/ --python-version=3.10 — clean
  • Empirical pre/post-fix reproduction on the same seed=42 synthetic dataset — see table above

TabPFNUnsupervisedModel.outliers() previously combined per-feature
densities via exp(sum(log(p))) using a 1/pdf trick to keep
intermediates small. On wider datasets this still overflows when
densities are >> 1 (flagged with an existing TODO in the source).

- outliers_single_permutation_() now returns log p directly via
  -criterion.forward(...), dropping the 1/pdf hack and the explicit
  log(pred) / clamp steps.
- outliers() combines permutations via the log-sum-exp identity:
  log(mean(p_k)) = logsumexp(log p_k) - log(K).
- Score semantics: lower scores indicate more likely outliers;
  callable signatures unchanged.
- experiments.py demo adapted to new score semantics: column
  p -> scores, percentile comparisons flipped (< instead of >),
  and quantile computation handles -inf rows.
The values returned by outliers()/outliers_pdf()/outliers_pmf() are
log-densities, not unitless scores. Rename the experiment attribute,
DataFrame columns, and local variables to match what the values are.

- experiments.py: Experiment.scores -> Experiment.log_p; DataFrame
  columns "scores"/"score_rank" -> "log_p"/"log_p_rank"; run() return
  dict key "outlier_scores" -> "log_p".
- unsupervised.py: locals scores_pdf/scores_pmf -> log_pdf/log_pmf
  inside outliers_pdf()/outliers_pmf(); docstrings updated.

Public method names (outliers, outliers_pdf, outliers_pmf) are
unchanged.
@ClementBourt ClementBourt requested a review from a team as a code owner May 12, 2026 10:37
@ClementBourt ClementBourt requested review from psinger-prior and removed request for a team May 12, 2026 10:37
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 12, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the unsupervised outlier detection logic to operate entirely in log-probability space, utilizing the log-sum-exp trick for numerical stability when averaging densities across permutations. Plotting functions in the experiment module have been updated to reflect these changes, including a shift from JointGrid to standard scatter plots. Review feedback identifies a critical bug in the feature masking logic within outliers_pdf and outliers_pmf that incorrectly indexes dimensions, and notes a breaking API change in the return type of the plot_two method.

Comment thread src/tabpfn_extensions/unsupervised/unsupervised.py
Comment thread src/tabpfn_extensions/unsupervised/experiments.py
@psinger-prior
Copy link
Copy Markdown

Hi @ClementBourt - thanks for the PR, could you please merge base branch in before review?

@ClementBourt
Copy link
Copy Markdown
Contributor Author

Just merged with base.
My pleasure, learned a lot working on it !

@LennartPurucker LennartPurucker self-requested a review May 13, 2026 14:27
@LennartPurucker
Copy link
Copy Markdown
Collaborator

LennartPurucker commented May 13, 2026

Heyho!

Looking into this now, as well, I was assigned to the related issue.

The proposed change seems reasonable to me and is a valid alternative for aggregating outlier detection. While I am not an expert in the field, it should be fine to move to the more stable version of the code (i.e., with the new aggregation). This will be a breaking change, but it is likely for a specific use case in the extensions. So I would be fine to merge as long as the test passes and the example runs fine!

Small note: please make sure to run ruff on the changes to get the linting correct.

@ClementBourt
Copy link
Copy Markdown
Contributor Author

Ruff test should pass now.
I am happy to discuss how i got to this workaround specifically if that is of any interest to you.

@LennartPurucker
Copy link
Copy Markdown
Collaborator

Do you have an overview of the differences between the two approaches (e.g., the original vs your proposal)?

What could the implications be for users? In general, all of these ways produced outlier detection, unclear if better or worse, but that might be less important than stability here.

@ClementBourt
Copy link
Copy Markdown
Contributor Author

ClementBourt commented May 13, 2026

Core explanation

The approach itself doesn't change from a theoretical standpoint — you still use the chain rule to transform a likelihood estimation into a series of model fittings. This works because TabPFN organically produces calibrated probabilities.

$p(x) = \prod_{i \in \mathcal{C}} \mathrm{pmf}\bigl(x_i \mid x_{&lt;i}\bigr) \cdot \prod_{i \in \mathcal{N}} \mathrm{pdf}\bigl(x_i \mid x_{&lt;i}\bigr)$

$\log p(x) = \sum_{i \in \mathcal{C}} \log \mathrm{pmf}\bigl(x_i \mid x_{&lt;i}\bigr) + \sum_{i \in \mathcal{N}} \log \mathrm{pdf}\bigl(x_i \mid x_{&lt;i}\bigr)$

For stability reasons, different feature permutations are averaged in the density estimation. Let π ∈ Sd denote a permutation of the d features (so Sd is the symmetric group), and let p̂π(x) denote the chain-rule density estimate produced when features are conditioned in the order π. Sampling K permutations π1, …, πK uniformly from Sd and averaging:

$\bar{p}(x) = \frac{1}{K} \sum_{k=1}^{K} \hat{p}_{\pi_k}(x)$

The previous implementation computed each individual p̂π in log space and exponentiated back into linear space. Whenever numerical features were part of the mix, overflow could — and did — happen. The workaround in the code was to flip pdf1/pdf. This was accompanied by a TODO mentioning that the best approach would be to stay in log space.

$\log p(x) = \sum_{i \in \mathcal{C}} \log \mathrm{pmf}\bigl(x_i \mid x_{&lt;i}\bigr) + \sum_{i \in \mathcal{N}} \log \frac{1}{\mathrm{pdf}\bigl(x_i \mid x_{&lt;i}\bigr)}$

This resolves the overflow for large pdfs but introduces a new one for very diffuse pdfs (the situation I encountered). This form also comes with another issue: when both pmf and pdf are small (both indicating an outlier), they now cancel each other out rather than compounding.

Long story short from here

Overflow can't be avoided in linear space — we need to stay in log space. The log-sum-exp identity lets us compute log p̄(x) without ever leaving log space, eliminating the overflow risk. Killing two birds with one stone: categorical and numerical features can now again be mixed together as we can restor 1/pdfpdf.

More details

log-sum-exp trick:

Let ℓk := log p̂πk(x) for the k-th sampled permutation πk. Then:

$\log \bar{p}(x) = \mathrm{logsumexp}(\ell_1, \ldots, \ell_K) - \log K = \ell_{\max} + \log \sum_{k} \exp\bigl(\ell_k - \ell_{\max}\bigr) - \log K, \quad \ell_{\max} := \max_{k} \ell_k$

The right-hand identity is what makes the computation numerically stable: subtracting ℓmax from every term before exponentiating guarantees the largest input to exp is exactly zero, so no overflow can occur.

Failed attempt:

I first considered staying in log space by averaging the log-likelihoods log p̂π directly. But this is equivalent to taking the geometric mean (GM) of the per-permutation densities in linear space, not the arithmetic mean (AM). A given ordering of estimated densities using AM for a dataset is not guaranteed to be preserved using GM. Thus, depending on the method, specific points could be flagged as outliers or not.
It seemed more right to stay with AM but there was the overflow issue; swapping for GM was risky at best — that's where the log-sum-exp trick saves the day.

Final nail in the coffin for GM.

Averaging in log-space means we are making the assumption that:

$$\mathbb{E}_{\pi}\bigl[\log \hat{p}_{\pi}(x)\bigr] = \log p_{\mathrm{true}}(x), \quad \forall x$$

This is equivalent to p̂π ≡ ptrue — in words, every chain-rule decomposition of the model agrees with the true joint density.
First corollary: under any imperfect model (pθ ≠ ptrue), the expected value over π must differ from log ptrue on some x in the support of ptrue.
Second corollary: under this assumption it is useless to compute the density for many permutations, as they should all give identical results.

Proof.

$$\mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\mathbb{E}_{\pi}\bigl[\log \hat{p}_{\pi}(x)\bigr]\right] = \mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\log p_{\mathrm{true}}(x)\right] \tag{1}$$

The right-hand side of (1) is the negative entropy of ptrue:

$$\mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\log p_{\mathrm{true}}(x)\right] = -H(p_{\mathrm{true}}) \tag{2}$$

The left-hand side can be written as:

$$\mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\mathbb{E}_{\pi}\bigl[\log \hat{p}_{\pi}(x)\bigr]\right] = \mathbb{E}_{\pi}\left[\mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\log \hat{p}_{\pi}(x)\right]\right] \tag{3}$$

The inner expectation is the negative cross-entropy of ptrue with respect to p̂π:

$$\mathbb{E}_{x \sim p_{\mathrm{true}}}\left[\log \hat{p}_{\pi}(x)\right] = -H(p_{\mathrm{true}}, \hat{p}_{\pi}) \tag{4}$$

Standard cross-entropy decomposition:

$$H(p_{\mathrm{true}}, \hat{p}_{\pi}) = H(p_{\mathrm{true}}) + D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr) \tag{5}$$

Substituting (2), (3), (4), (5) into (1):

$$\mathbb{E}_{\pi}\bigl[-H(p_{\mathrm{true}}) - D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr)\bigr] = -H(p_{\mathrm{true}})$$

$$-H(p_{\mathrm{true}}) - \mathbb{E}_{\pi}\left[D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} ||\hat{p}_{\pi}\bigr)\right] = -H(p_{\mathrm{true}})$$

$$\mathbb{E}_{\pi}\left[D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr)\right] = 0 \tag{6}$$

Since $$D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr) \geq 0$$ for every π (non-negativity of KL divergence), (6) forces:

$$D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr) = 0$$

$$D_{\mathrm{KL}}\bigl(p_{\mathrm{true}} || \hat{p}_{\pi}\bigr) = 0$$
iff p̂π ≡ ptrue.

@LennartPurucker
Copy link
Copy Markdown
Collaborator

A summary for future reference:

The old implementation converted values back from log space into normal space during averaging, which caused overflow problems for numerical features. A workaround inverted the numerical densities, but that introduced new instability issues and distorted the outlier behavior.

The proposed solution is to keep everything entirely in log space and use the log-sum-exp trick to average permutation estimates safely. This removes overflow problems while restoring the correct probabilistic interpretation and allowing categorical and numerical features to work together properly again.

Which sounds good to me to merge.

Thank you for your contribution and help!

Copy link
Copy Markdown
Collaborator

@LennartPurucker LennartPurucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@LennartPurucker LennartPurucker enabled auto-merge (squash) May 14, 2026 09:19
@LennartPurucker LennartPurucker merged commit eb751a8 into PriorLabs:main May 14, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Overflow in diffuse density

4 participants