diff --git a/examples/interpretability/shap_example.py b/examples/interpretability/shap_example.py index 8972917f..2b7c382d 100644 --- a/examples/interpretability/shap_example.py +++ b/examples/interpretability/shap_example.py @@ -21,19 +21,25 @@ from __future__ import annotations -import numpy as np import shap from sklearn.datasets import fetch_california_housing from sklearn.model_selection import train_test_split from tabpfn_extensions import TabPFNRegressor -from tabpfn_extensions.interpretability import shapiq as tabpfn_shapiq +from tabpfn_extensions.interpretability import ( + shapiq as tabpfn_shapiq, + shapiq_to_shap_explanation, +) housing = fetch_california_housing(as_frame=False) X, y, feature_names = housing.data, housing.target, list(housing.feature_names) X_train, X_test, y_train, _ = train_test_split( - X, y, train_size=1000, test_size=200, random_state=0, + X, + y, + train_size=1000, + test_size=200, + random_state=0, ) n_explain = 30 X_explain = X_test[:n_explain] @@ -50,24 +56,16 @@ max_order=1, ) -# Compute Shapley values for n_explain rows. Each call produces an -# `InteractionValues` object; we extract the (d,) 1st-order array per row and -# stack into the (n, d) matrix that shap.Explanation expects. +# Compute first-order Shapley values for n_explain rows and wrap them in a +# shap.Explanation. `shapiq_to_shap_explanation` runs one .explain() call per +# row, stacks the (d,) arrays, averages baseline values, and packages +# everything for the SHAP plotting API. budget=256 = 2^8 is the exact-Shapley +# budget for d=8 features. print(f"Computing Shapley values for {n_explain} rows...") -ivs = [explainer.explain(x=X_explain[i], budget=256) for i in range(n_explain)] -shap_values = np.stack([iv.get_n_order_values(1) for iv in ivs]) - -# baseline_value is the model's expected output when *every* feature is masked -# — i.e. the prediction on the empty coalition. We average across rows to get -# the scalar E[f(X)] that shap.Explanation wants for base_values. -base_value = float(np.mean([iv.baseline_value for iv in ivs])) - -# Wrap shapiq's output in a shap.Explanation so the full shap.plots.* family -# accepts it directly. -explanation = shap.Explanation( - values=shap_values, - base_values=np.full(n_explain, base_value), - data=X_explain, +explanation = shapiq_to_shap_explanation( + explainer, + X_explain, + budget=256, feature_names=feature_names, ) diff --git a/src/tabpfn_extensions/interpretability/README.md b/src/tabpfn_extensions/interpretability/README.md index efcb9fa6..559c3ba4 100644 --- a/src/tabpfn_extensions/interpretability/README.md +++ b/src/tabpfn_extensions/interpretability/README.md @@ -37,11 +37,32 @@ We expose two adapters: The wrapper warns at construction time if the model isn't configured this way. -For SHAP-style plots (waterfall, beeswarm, summary, dependence) you can use shapiq's -own visualizations on the returned ``InteractionValues`` object (``iv.plot_force()``, -``iv.plot_waterfall()``, ``iv.plot_network()``, ``iv.plot_si_graph()``, etc.), or convert -the values to ``shap.Explanation`` and use ``shap.plots.*``. See -``examples/interpretability/shapiq_example.py`` and ``shap_example.py`` for both. +For SHAP-style plots (waterfall, beeswarm, summary, dependence) you have two options: + +1. **Use shapiq's own visualizations** on the returned ``InteractionValues`` object: + ``iv.plot_force()``, ``iv.plot_waterfall()``, ``iv.plot_network()``, + ``iv.plot_si_graph()``, etc. + +2. **Use the SHAP library's plotting** via the bridge helper. Run shapiq's + ``.explain()`` over a batch of rows and wrap the result in a + ``shap.Explanation`` in one call: + + ```python + from tabpfn_extensions.interpretability import shapiq_to_shap_explanation + + explanation = shapiq_to_shap_explanation( + explainer, X_explain, budget=256, feature_names=feature_names, + ) + shap.plots.waterfall(explanation[0]) + ``` + + ``shapiq_to_shap_explanation`` extracts first-order Shapley values from + shapiq's output and wraps them in a ``shap.Explanation``. Requires + ``pip install shap`` — kept out of the ``interpretability`` extra by + design (shapiq is the runtime dependency; shap is opt-in for plotting). + +See ``examples/interpretability/shapiq_example.py`` and ``shap_example.py`` +for both paths. The ``shapiq`` library and the paper introducing the improved Shapley value computation for TabPFN can be cited as follows: diff --git a/src/tabpfn_extensions/interpretability/__init__.py b/src/tabpfn_extensions/interpretability/__init__.py index 77aa1c39..80591ab3 100644 --- a/src/tabpfn_extensions/interpretability/__init__.py +++ b/src/tabpfn_extensions/interpretability/__init__.py @@ -1,7 +1,8 @@ try: - from . import feature_selection, pdp, shapiq + from . import feature_selection, pdp, shap, shapiq + from .shap import shapiq_to_shap_explanation except ImportError: raise ImportError( "Please install tabpfn-extensions with the 'interpretability' extra: pip install 'tabpfn-extensions[interpretability]'", ) -__all__ = ["feature_selection", "shapiq", "pdp"] +__all__ = ["feature_selection", "shapiq", "pdp", "shap", "shapiq_to_shap_explanation"] diff --git a/src/tabpfn_extensions/interpretability/shap.py b/src/tabpfn_extensions/interpretability/shap.py new file mode 100644 index 00000000..2aea70a7 --- /dev/null +++ b/src/tabpfn_extensions/interpretability/shap.py @@ -0,0 +1,80 @@ +# Copyright (c) Prior Labs GmbH 2025. +# Licensed under the Apache License, Version 2.0 + +"""Bridge helpers for using the SHAP library's plotting ecosystem with +Shapley values computed by shapiq. + +We use shapiq for the actual Shapley-value computation — it's faster and +extension-friendly for TabPFN — but the SHAP library's plotting ecosystem +(``shap.plots.waterfall``, ``beeswarm``, ``summary``, ``dependence``, etc.) +is mature and widely used. This module bridges the two. + +The ``shap`` package is **not** part of the ``interpretability`` extra. Install +it separately (``pip install shap``) if you want to use these helpers. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + import shap + from numpy.typing import ArrayLike + + +def shapiq_to_shap_explanation( + explainer: Any, + X: ArrayLike, + *, + budget: int, + feature_names: list[str] | None = None, +) -> shap.Explanation: + """Compute first-order Shapley values with a shapiq explainer for each + row in ``X`` and wrap them in a ``shap.Explanation`` ready for use with + ``shap.plots.*`` and ``shap.summary_plot``. + + Mirrors the pattern in ``examples/interpretability/shap_example.py``: + one ``.explain(...)`` call per row, stack the first-order arrays into an + ``(n, d)`` matrix, average baseline values, and pass everything to + ``shap.Explanation``. + + Args: + explainer: A shapiq explainer — e.g. one returned by + ``get_tabpfn_imputation_explainer(..., index="SV", max_order=1)``. + X: ``(n, d)`` array of rows to explain. + budget: Number of model evaluations shapiq is allowed per row. For + small ``d`` and exact Shapley values, pass ``2**d``. + feature_names: Optional list of feature name strings (length ``d``). + Used by ``shap.plots.*`` for axis labels. + + Returns: + A ``shap.Explanation`` with ``values.shape == (n, d)``. + + Notes: + Only first-order Shapley values are wrapped. ``shap.Explanation`` + doesn't represent higher-order interactions; for those, use + shapiq's native plots on the ``InteractionValues`` object. + + Requires ``shap`` to be installed (``pip install shap``). It is + kept out of the ``interpretability`` extra by design — shapiq is + the runtime dependency, shap is opt-in for plotting. + """ + import shap + + X_arr = np.asarray(X) + n = len(X_arr) + ivs = [explainer.explain(x=X_arr[i], budget=budget) for i in range(n)] + values = np.stack([iv.get_n_order_values(1) for iv in ivs]) + # Pass per-row baselines through unchanged. For the imputation path the + # background is fixed and these are all the same value; for the Rundel + # remove-and-recontextualize path baselines genuinely vary per row. + # shap.Explanation accepts a 1-d (n,) array natively. + base_values = np.array([iv.baseline_value for iv in ivs]) + return shap.Explanation( + values=values, + base_values=base_values, + data=X_arr, + feature_names=feature_names, + )