diff --git a/CHANGELOG.md b/CHANGELOG.md index 718f4f97..7f1589ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +## [0.4.0] - 2026-05-11 + +### Added +- KV-cache support across extensions where it materially helps (improved with TabPFN-3 — earlier `tabpfn` versions degrade gracefully with a `UserWarning`): + - New `tabpfn_extensions.utils.warn_if_no_kv_cache(model, *, context)` helper that warns when a TabPFN model isn't configured with `fit_mode="fit_with_cache"` or when `executor_.keep_cache_on_device` is `False` post-fit. Recognises the `tabpfn-client` backend and emits an alternative recommendation instead of misleading guidance (#284). + - `pval_crt.tabpfn_crt`: new `use_kv_cache: bool = True` keyword argument. Enabling it sets `fit_mode="fit_with_cache"` on the internal predictive model and `keep_cache_on_device=True` after fit, with a graceful fallback when the installed `tabpfn` doesn't support the cache (#284). + - `interpretability/pdp.partial_dependence_plots`: warns via `warn_if_no_kv_cache` when handed a TabPFN estimator (#284). + - `interpretability/get_tabpfn_imputation_explainer`: same warn helper applied; defaults `imputer="baseline"` (one forward pass per coalition, ~50× faster than the previous `marginal` imputer with no faithfulness loss in practice) (#283). +- `examples/interpretability/{shap,shapiq}_example.py` rewritten around California housing (regression, d=8, exact budget 2^8 = 256) and both engage the KV cache (#283). + +### Changed +- `ManyClassClassifier(estimator=TabPFNClassifier(), ...).fit(X, y)` no longer requires an explicit `alphabet_size`. The wrapper now reads `MAX_NUMBER_OF_CLASSES` via the base estimator's `get_inference_config()` when available (v3 → 160). For older `tabpfn` releases that don't have `get_inference_config()`, it falls back to the historical hardcoded default of 10 if the base estimator's class lives in a `tabpfn`-prefixed module; non-TabPFN estimators still get the explicit-alphabet `ValueError`. The previous `estimator.max_num_classes_` fallback has been dead since this repo's initial commit — removed (#282). +- `README.md` overhauled: dropped the workflow mermaid graph, removed RF-PFN / large-datasets references, retagged deprecated extensions, and reworded the many-class entry so it no longer hardcodes "10 classes" (#284). +- `examples/README.md` refreshed to match the on-disk tree; `hpo/` and `phe/` example directories explicitly tagged as deprecated (#284). +- `interpretability/README.md` updated to drop the SHAP section and document the KV-cache + baseline-imputer expectations (#283). + +### Deprecated +- `AutoTabPFNClassifier` and `AutoTabPFNRegressor` (`post_hoc_ensembles/sklearn_interface.py`). Construction now emits a `DeprecationWarning`; a banner comment marks the module deprecated. Scheduled for removal in a future release (#284). +- `TunedTabPFNClassifier` and `TunedTabPFNRegressor` (`hpo/tuned_tabpfn.py`). Same treatment (#284). + +### Removed +- The `rf_pfn` package (`RandomForestTabPFN*` / `DecisionTreeTabPFN*`), its tests (`tests/test_rf_pfn.py`, `tests/test_dt_pfn.py`), the `rf_pfn = []` optional-dependency extra, and the `examples/rf_pfn/` and `examples/large_datasets/` example directories (#284). +- The `sklearn_ensembles` package (`tabpfn_extensions.sklearn_ensembles`) and its module-level imports. No external callers in the repo (#284). +- The legacy `interpretability/shap.py` adapter and the `shap` runtime dependency. The `shapiq`-based path supersedes it, including for plotting (`shap.Explanation` wrappers still work with `shap.plots.*`) (#283). +- `interpretability/experiments.py` — dead code. No callers anywhere in the repo, and internally referenced a `tabpfn.scripts.estimator.interpretability` path that no longer exists in current `tabpfn` (#284). +- The `dt_pfn` `model_type` branch from `hpo/search_space.py` / `hpo/tuned_tabpfn.py` and the matching test arm; depended on the removed `rf_pfn` package (#284). + +### Fixed +- KV-cache wiring degrades gracefully when the installed `tabpfn` doesn't support it: + - The PDP example (`examples/interpretability/pdp_example.py`) wraps construct + fit in a `try/except (TypeError, ValueError, NotImplementedError)`, so it works on both older local `tabpfn` (which raises `ValueError` / `NotImplementedError` at fit time) and the `tabpfn-client` backend (which raises `TypeError` on the constructor kwarg). The fallback emits a `UserWarning` recommending an upgrade (#284). + - `pval_crt.tabpfn_crt` is local-only by design (it raises `ImportError` at import time if `tabpfn` isn't installed). Its fallback catches `ValueError` / `NotImplementedError` from older local `tabpfn`; it doesn't need to catch `TypeError` because the client backend can't reach this code path (#284). + +### Notes +- `tabpfn` dependency pin remains `>=7.0.0` for this release. KV-cache examples need TabPFN-3 to actually engage the cache; on earlier `tabpfn` they emit a `UserWarning` and run without the speedup. +- TabEBM remains broken against any released `tabpfn` ≥ 7.x because `tabpfn.config` was removed upstream. This is a pre-existing latent bug, tracked separately ([RES-1541](https://linear.app/priorlabs/issue/RES-1541)). + ## [0.3.0] - 2026-04-24 ### Added diff --git a/pyproject.toml b/pyproject.toml index 36b4a6cb..3cf8453b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "tabpfn-extensions" -version = "0.3.0" +version = "0.4.0" dependencies = [ "torch>=2.1", "pandas>=2.2.2", diff --git a/src/tabpfn_extensions/utils.py b/src/tabpfn_extensions/utils.py index 66f46875..f637fd4a 100644 --- a/src/tabpfn_extensions/utils.py +++ b/src/tabpfn_extensions/utils.py @@ -57,8 +57,11 @@ def warn_if_no_kv_cache(model: Any, *, context: str = "This operation") -> None: # tabpfn-client doesn't expose fit_mode and doesn't (yet) support the KV # cache — recommend switching to the local tabpfn package instead of # firing the generic "set fit_mode='fit_with_cache'" message that would - # TypeError on the client. - if "tabpfn_client" in str(getattr(model, "__class__", type(model)).__module__): + # TypeError on the client. Walk the MRO because tabpfn-extensions wraps + # the client base classes in tabpfn_extensions.utils, so the *immediate* + # class' __module__ is "tabpfn_extensions.utils" and won't match. + mro_modules = (getattr(cls, "__module__", "") for cls in type(model).__mro__) + if any("tabpfn_client" in m for m in mro_modules): warnings.warn( f"{context} would benefit substantially from the KV cache, but " "the tabpfn-client backend does not currently support it. "