Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions examples/interpretability/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
X, y = data.data, data.target
feature_names = data.feature_names

# Initialize model
clf = TabPFNClassifier(n_estimators=3)
# Initialize model. Single estimator keeps the runtime manageable — feature
# selection runs many TabPFN fits per round.
clf = TabPFNClassifier(n_estimators=1)

# Feature selection
sfs = interpretability.feature_selection.feature_selection(
# Feature selection. With verbose=True (the default) the wrapper prints the
# baseline CV score on all features, the per-round picks, and the selected
# names + CV score on the subset. The same numbers are also available on
# the returned FeatureSelectionResult for programmatic use.
result = interpretability.feature_selection.feature_selection(
estimator=clf,
X=X,
y=y,
n_features_to_select=5, # How many features to select
feature_names=feature_names,
n_features_to_select=4,
feature_names=list(feature_names),
)

# Print selected features
selected_features = [
feature_names[i] for i in range(len(feature_names)) if sfs.get_support()[i]
]
print("\nSelected features:")
for feature in selected_features:
print(f"- {feature}")
# `result.selected_names` is populated because we passed `feature_names`.
# `result.selector.transform(X)` would project to just those columns;
# `result.support_mask` / `result.selected_indices` are also available.
print("\nProgrammatic summary:")
print(f"Selected features: {result.selected_names}")
print(
f"CV score before / after: "
f"{result.baseline_score_mean:.4f} -> {result.selected_score_mean:.4f}"
)
10 changes: 9 additions & 1 deletion src/tabpfn_extensions/interpretability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
try:
from . import feature_selection, pdp, shap, shapiq
from .feature_selection import FeatureSelectionResult
from .shap import shapiq_to_shap_explanation
except ImportError:
raise ImportError(
"Please install tabpfn-extensions with the 'interpretability' extra: pip install 'tabpfn-extensions[interpretability]'",
)
__all__ = ["feature_selection", "shapiq", "pdp", "shap", "shapiq_to_shap_explanation"]
__all__ = [
"feature_selection",
"shapiq",
"pdp",
"shap",
"FeatureSelectionResult",
"shapiq_to_shap_explanation",
]
Loading
Loading