Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions examples/interpretability/shap_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@

from __future__ import annotations

import numpy as np
import shap
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

from tabpfn_extensions import TabPFNRegressor
from tabpfn_extensions.interpretability import shapiq as tabpfn_shapiq
from tabpfn_extensions.interpretability import (
shapiq as tabpfn_shapiq,
shapiq_to_shap_explanation,
)

housing = fetch_california_housing(as_frame=False)
X, y, feature_names = housing.data, housing.target, list(housing.feature_names)

X_train, X_test, y_train, _ = train_test_split(
X, y, train_size=1000, test_size=200, random_state=0,
X,
y,
train_size=1000,
test_size=200,
random_state=0,
)
n_explain = 30
X_explain = X_test[:n_explain]
Expand All @@ -50,24 +56,16 @@
max_order=1,
)

# Compute Shapley values for n_explain rows. Each call produces an
# `InteractionValues` object; we extract the (d,) 1st-order array per row and
# stack into the (n, d) matrix that shap.Explanation expects.
# Compute first-order Shapley values for n_explain rows and wrap them in a
# shap.Explanation. `shapiq_to_shap_explanation` runs one .explain() call per
# row, stacks the (d,) arrays, averages baseline values, and packages
# everything for the SHAP plotting API. budget=256 = 2^8 is the exact-Shapley
# budget for d=8 features.
print(f"Computing Shapley values for {n_explain} rows...")
ivs = [explainer.explain(x=X_explain[i], budget=256) for i in range(n_explain)]
shap_values = np.stack([iv.get_n_order_values(1) for iv in ivs])

# baseline_value is the model's expected output when *every* feature is masked
# — i.e. the prediction on the empty coalition. We average across rows to get
# the scalar E[f(X)] that shap.Explanation wants for base_values.
base_value = float(np.mean([iv.baseline_value for iv in ivs]))

# Wrap shapiq's output in a shap.Explanation so the full shap.plots.* family
# accepts it directly.
explanation = shap.Explanation(
values=shap_values,
base_values=np.full(n_explain, base_value),
data=X_explain,
explanation = shapiq_to_shap_explanation(
explainer,
X_explain,
budget=256,
feature_names=feature_names,
)

Expand Down
31 changes: 26 additions & 5 deletions src/tabpfn_extensions/interpretability/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,32 @@ We expose two adapters:

The wrapper warns at construction time if the model isn't configured this way.

For SHAP-style plots (waterfall, beeswarm, summary, dependence) you can use shapiq's
own visualizations on the returned ``InteractionValues`` object (``iv.plot_force()``,
``iv.plot_waterfall()``, ``iv.plot_network()``, ``iv.plot_si_graph()``, etc.), or convert
the values to ``shap.Explanation`` and use ``shap.plots.*``. See
``examples/interpretability/shapiq_example.py`` and ``shap_example.py`` for both.
For SHAP-style plots (waterfall, beeswarm, summary, dependence) you have two options:

1. **Use shapiq's own visualizations** on the returned ``InteractionValues`` object:
``iv.plot_force()``, ``iv.plot_waterfall()``, ``iv.plot_network()``,
``iv.plot_si_graph()``, etc.

2. **Use the SHAP library's plotting** via the bridge helper. Run shapiq's
``.explain()`` over a batch of rows and wrap the result in a
``shap.Explanation`` in one call:

```python
from tabpfn_extensions.interpretability import shapiq_to_shap_explanation

explanation = shapiq_to_shap_explanation(
explainer, X_explain, budget=256, feature_names=feature_names,
)
shap.plots.waterfall(explanation[0])
```

``shapiq_to_shap_explanation`` extracts first-order Shapley values from
shapiq's output and wraps them in a ``shap.Explanation``. Requires
``pip install shap`` — kept out of the ``interpretability`` extra by
design (shapiq is the runtime dependency; shap is opt-in for plotting).

See ``examples/interpretability/shapiq_example.py`` and ``shap_example.py``
for both paths.

The ``shapiq`` library and the paper introducing the improved Shapley value computation
for TabPFN can be cited as follows:
Expand Down
5 changes: 3 additions & 2 deletions src/tabpfn_extensions/interpretability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
try:
from . import feature_selection, pdp, shapiq
from . import feature_selection, pdp, shap, shapiq
from .shap import shapiq_to_shap_explanation
except ImportError:
raise ImportError(
"Please install tabpfn-extensions with the 'interpretability' extra: pip install 'tabpfn-extensions[interpretability]'",
)
__all__ = ["feature_selection", "shapiq", "pdp"]
__all__ = ["feature_selection", "shapiq", "pdp", "shap", "shapiq_to_shap_explanation"]
80 changes: 80 additions & 0 deletions src/tabpfn_extensions/interpretability/shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Prior Labs GmbH 2025.
# Licensed under the Apache License, Version 2.0

"""Bridge helpers for using the SHAP library's plotting ecosystem with
Shapley values computed by shapiq.

We use shapiq for the actual Shapley-value computation — it's faster and
extension-friendly for TabPFN — but the SHAP library's plotting ecosystem
(``shap.plots.waterfall``, ``beeswarm``, ``summary``, ``dependence``, etc.)
is mature and widely used. This module bridges the two.

The ``shap`` package is **not** part of the ``interpretability`` extra. Install
it separately (``pip install shap``) if you want to use these helpers.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
import shap
from numpy.typing import ArrayLike


def shapiq_to_shap_explanation(
explainer: Any,
X: ArrayLike,
*,
budget: int,
feature_names: list[str] | None = None,
) -> shap.Explanation:
"""Compute first-order Shapley values with a shapiq explainer for each
row in ``X`` and wrap them in a ``shap.Explanation`` ready for use with
``shap.plots.*`` and ``shap.summary_plot``.

Mirrors the pattern in ``examples/interpretability/shap_example.py``:
one ``.explain(...)`` call per row, stack the first-order arrays into an
``(n, d)`` matrix, average baseline values, and pass everything to
``shap.Explanation``.

Args:
explainer: A shapiq explainer — e.g. one returned by
``get_tabpfn_imputation_explainer(..., index="SV", max_order=1)``.
X: ``(n, d)`` array of rows to explain.
budget: Number of model evaluations shapiq is allowed per row. For
small ``d`` and exact Shapley values, pass ``2**d``.
feature_names: Optional list of feature name strings (length ``d``).
Used by ``shap.plots.*`` for axis labels.

Returns:
A ``shap.Explanation`` with ``values.shape == (n, d)``.

Notes:
Only first-order Shapley values are wrapped. ``shap.Explanation``
doesn't represent higher-order interactions; for those, use
shapiq's native plots on the ``InteractionValues`` object.

Requires ``shap`` to be installed (``pip install shap``). It is
kept out of the ``interpretability`` extra by design — shapiq is
the runtime dependency, shap is opt-in for plotting.
"""
import shap

X_arr = np.asarray(X)
n = len(X_arr)
ivs = [explainer.explain(x=X_arr[i], budget=budget) for i in range(n)]
values = np.stack([iv.get_n_order_values(1) for iv in ivs])
# Pass per-row baselines through unchanged. For the imputation path the
# background is fixed and these are all the same value; for the Rundel
# remove-and-recontextualize path baselines genuinely vary per row.
# shap.Explanation accepts a 1-d (n,) array natively.
base_values = np.array([iv.baseline_value for iv in ivs])
return shap.Explanation(
values=values,
base_values=base_values,
data=X_arr,
feature_names=feature_names,
)
Comment thread
adrian-prior marked this conversation as resolved.
Loading