Support differentiable_input on TabPFNRegressor#923
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for differentiable inputs in the TabPFNRegressor, enabling gradients to flow from a loss back to upstream torch modules. Key changes include the addition of fit_with_differentiable_input, a specialized initialization path for torch tensors, and logic to bypass standard non-differentiable preprocessing. Review feedback identified several issues: the fit_with_differentiable_input method incorrectly skips normalization and validation on subsequent calls, and the use of .item() and .detach() on normalization parameters breaks the gradient flow for target scaling. Suggestions were also made to use list comprehensions for feature schema initialization to avoid shared references and to improve robustness for single-sample or zero-variance inputs.
There was a problem hiding this comment.
Pull request overview
This PR adds first-class support for differentiable_input=True on TabPFNRegressor by introducing a dedicated fit_with_differentiable_input() pathway (mirroring the existing classifier capability) so gradients can flow from a downstream loss back into upstream torch modules that produce X.
Changes:
- Added a differentiable initialization + fitting path for
TabPFNRegressorusingInferenceEngineCachePreprocessing(inference_mode=False)and a differentiable preprocessor config. - Updated
fit()to raise a more actionableValueErrorwhendifferentiable_input=True, pointing users tofit_with_differentiable_input. - Added new tests verifying gradient flow, helpful error messaging, and categorical-feature rejection for the differentiable path.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/tabpfn/regressor.py |
Adds _initialize_for_differentiable_input, fit_with_differentiable_input, and inference-mode gating to keep autograd enabled. |
tests/test_regressor_interface.py |
Adds regression interface tests covering differentiable-input behavior and gradients. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Address gemini-code-assist review on PR PriorLabs#923: the second fit call previously skipped re-normalising y, leaving y_train_mean_, y_train_std_, raw_space_bardist_ stuck on the first fit's stats — silently miscaling predictions when the new target distribution differed. Split _initialize_for_differentiable_input into: - _initialize_for_differentiable_input: first-call-only setup (categorical check, feature schema, ensemble configs). Cached in self.ensemble_configs_. - _refresh_targets_for_differentiable_input: per-call setup (validate_dataset_size, z-normalise y, rebuild raw_space_bardist_, update n_train_samples_). Runs on every fit. fit_with_differentiable_input's else branch now calls the per-call helper so subsequent fits track the current target distribution while still reusing the loaded model and ensemble configs. Add test__fit_with_differentiable_input__second_call_refreshes_target_stats that fits twice with very different y distributions and checks y_train_mean_, y_train_std_, and raw_space_bardist_.borders all move.
Fixes the medium-severity comments raised on the differentiable_input regressor path: 1. Feature instances per column: replace `[Feature(...)] * n_features` with a list comprehension so each column has its own dataclass and a later in-place update on one column does not leak across all columns. 2. y stats numerical robustness: switch `y_float.std()` (PyTorch's default `correction=1`, which differs from `np.std` and returns NaN for N=1) to `clamp(y_float.std(correction=0), min=1e-20)`. This matches the standard `fit()` path's `np.std` semantics and stays finite for single-sample input. 3. Constant-target guard: a constant y collapses the bardist borders to a single point and trips `FullSupportBarDistribution`'s strictly-increasing assertion. `fit()` short-circuits this with `is_constant_target_`; the differentiable path has no analogue, so reject up front with a clear ValueError pointing users at `fit()`. 4. Sequential preprocessing for diff input: force `n_preprocessing_jobs=1` inside `fit_with_differentiable_input`. When X carries an autograd graph, joblib's process-boundary pickling breaks the graph; sequential execution preserves it. The detach-then-`.item()` of `y_train_mean_/std_` is intentional and not changed: `raw_space_bardist_` is a frozen lookup buffer that should not hold a y-grad graph; users wanting fully differentiable target scaling should z-normalise y externally so mean/std become constants here. Documented inline. New tests: - feature_schema_columns_are_independent: catches the alias bug. - std_matches_population_definition: locks in `np.std` semantics. - constant_target_rejected: locks in the explicit guard. - single_sample_y_does_not_nan: confirms N=1 hits the guard cleanly rather than producing NaN deep in the bardist. All 9 differentiable_input tests pass on CPU and CUDA.
|
This seems like a really useful addition for consistency between the classifier and regressor implementations. I also appreciated the clear breakdown of the motivation and changes in the PR description. |
Mirrors the classifier-side prompt-tuning path so gradients can flow from
a downstream loss back through TabPFNRegressor to upstream torch modules
feeding X (and y, when it carries grads). Previously, TabPFNRegressor.fit
raised ValueError("Differentiable input is not supported for regressors
yet.") and there was no fit_with_differentiable_input.
What this changes:
- _initialize_for_differentiable_input(X, y, rng): minimal preprocessing
that uses PreprocessorConfig("none", differentiable=True), z-normalises
y as a torch op (preserves grads), and rebuilds raw_space_bardist_ in
the caller's target scale. Polynomial features are forced to "no" since
the polynomial step relies on sklearn StandardScaler on numpy.
- fit_with_differentiable_input(X, y): mirrors the classifier method;
builds an InferenceEngineCachePreprocessing with inference_mode=False.
- _iter_forward_executor: gates use_inference_mode on differentiable_input
so a user calling forward(X, use_inference_mode=True) after
fit_with_differentiable_input still gets gradients (parallel to the
classifier's existing actual_inference_mode gate).
- fit() now raises a clearer ValueError pointing users to the new method
when differentiable_input=True, instead of silently converting torch
tensors to numpy.
Tests:
- end-to-end gradient-flow test (CPU + CUDA): a loss computed from
forward output produces a finite, non-zero gradient on an upstream
nn.Linear's weight.
- guard tests for fit() with differentiable_input=True and for
categorical features under the differentiable path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Address gemini-code-assist review on PR PriorLabs#923: the second fit call previously skipped re-normalising y, leaving y_train_mean_, y_train_std_, raw_space_bardist_ stuck on the first fit's stats — silently miscaling predictions when the new target distribution differed. Split _initialize_for_differentiable_input into: - _initialize_for_differentiable_input: first-call-only setup (categorical check, feature schema, ensemble configs). Cached in self.ensemble_configs_. - _refresh_targets_for_differentiable_input: per-call setup (validate_dataset_size, z-normalise y, rebuild raw_space_bardist_, update n_train_samples_). Runs on every fit. fit_with_differentiable_input's else branch now calls the per-call helper so subsequent fits track the current target distribution while still reusing the loaded model and ensemble configs. Add test__fit_with_differentiable_input__second_call_refreshes_target_stats that fits twice with very different y distributions and checks y_train_mean_, y_train_std_, and raw_space_bardist_.borders all move.
Fixes the medium-severity comments raised on the differentiable_input regressor path: 1. Feature instances per column: replace `[Feature(...)] * n_features` with a list comprehension so each column has its own dataclass and a later in-place update on one column does not leak across all columns. 2. y stats numerical robustness: switch `y_float.std()` (PyTorch's default `correction=1`, which differs from `np.std` and returns NaN for N=1) to `clamp(y_float.std(correction=0), min=1e-20)`. This matches the standard `fit()` path's `np.std` semantics and stays finite for single-sample input. 3. Constant-target guard: a constant y collapses the bardist borders to a single point and trips `FullSupportBarDistribution`'s strictly-increasing assertion. `fit()` short-circuits this with `is_constant_target_`; the differentiable path has no analogue, so reject up front with a clear ValueError pointing users at `fit()`. 4. Sequential preprocessing for diff input: force `n_preprocessing_jobs=1` inside `fit_with_differentiable_input`. When X carries an autograd graph, joblib's process-boundary pickling breaks the graph; sequential execution preserves it. The detach-then-`.item()` of `y_train_mean_/std_` is intentional and not changed: `raw_space_bardist_` is a frozen lookup buffer that should not hold a y-grad graph; users wanting fully differentiable target scaling should z-normalise y externally so mean/std become constants here. Documented inline. New tests: - feature_schema_columns_are_independent: catches the alias bug. - std_matches_population_definition: locks in `np.std` semantics. - constant_target_rejected: locks in the explicit guard. - single_sample_y_does_not_nan: confirms N=1 hits the guard cleanly rather than producing NaN deep in the bardist. All 9 differentiable_input tests pass on CPU and CUDA.
|
Hi @klemens-floege, just wanted to gently follow up on this PR. Would appreciate a review when you have time! I think this would make the regressor/classifier APIs much more consistent for differentiable workflows. Thanks! |
|
@lujiazho thank you very much for opening a PR into our repository. I will not be able to review the PR before tmw afternoon, I apologize for the delay! |
klemens-floege
left a comment
There was a problem hiding this comment.
Sorry for the delay we just released V3. Thanks for the PR -> the symmetric API is the right shape + allows prompt tuning etc.
Two things before this lands:
Blocking:
- initialize_for_differentiable_input never sets self.n_estimators (uses self.n_estimators without the underscore). forward() / predict() then crash on tqdm(total=self.n_estimators_, …). Locally, test__fit_with_differentiable_input__grad_flows_to_upstream_module fails on both cpu and mps with AttributeError: 'TabPFNRegressor' object has no attribute 'n_estimators_' — contrary to the "9/9 pass" in the description. Fix: add self.n_estimators_ = len(ensemble_configs) in the first-call helper, and mirror classifier.py:946 in the else branch.
A more high level comment, you are PR is adding a lot of code, we should make sure the reduce code duplication as much as possible. Here some starting pointers, but will take a closer look once the tests pass:
- Extract self._rebuild_raw_space_bardist() — same three lines appear in the standard fit path (line ~860) and the new diff path.
- Extract a shared _build_cache_preprocessing_executor(...) — the bottom 25 lines of fit_with_differentiable_input are 90% the standard executor build; deltas are n_preprocessing_jobs and inference_mode.
- Inline _refresh_targets_for_differentiable_input into fit_with_differentiable_input. It has two callers and the split makes the lifecycle harder to follow than it needs to be.
- The categorical-features guard is duplicated verbatim with classifier.py:638 644 — push to a shared helper.
- Consolidate the three "bad input raises ValueError" tests into one parametrized test.
Once the tests pass I'm happy to re-review.
The differentiable-input fit path on TabPFNRegressor never set self.n_estimators_, so forward() / predict() crashed on tqdm(total=...) with AttributeError. Two call sites were missing the assignment: 1. _initialize_for_differentiable_input now sets n_estimators_ via scale_n_estimators_for_feature_coverage, mirroring classifier.py:650. 2. fit_with_differentiable_input's else branch (subsequent fits) now re-asserts n_estimators_ from cached ensemble configs, mirroring classifier.py:948. The stale assert len(...) == self.n_estimators (missing underscore) is fixed at the same time.
Per klemens-floege review on PR PriorLabs#923. No behaviour change — same differentiable-input semantics, just less code duplication. - Share the categorical-features guard. New reject_categoricals_for_differentiable_input() in base.py replaces the identical inline checks in TabPFNClassifier and TabPFNRegressor. - Extract _rebuild_raw_space_bardist() on TabPFNRegressor. The same three-line construction (borders * std + mean as a FullSupportBarDistribution) appears in the standard fit path and the differentiable path; the helper detaches borders unconditionally so the buffer never holds a y autograd graph (no-op for the standard path). - Extract _build_ensemble_preprocessor_and_executor() on TabPFNRegressor. The two paths' executor-build blocks now share one method; deltas are only n_preprocessing_jobs (1 in the differentiable path so the autograd graph survives joblib's process-boundary pickling) and inference_mode (False under differentiable input). - Inline _refresh_targets_for_differentiable_input back into fit_with_differentiable_input. Lifecycle is clearer with the y-target validation, normalisation, and bardist rebuild laid out linearly after the first-call / cached-state branch. - Consolidate three bad-input ValueError tests into one pytest.parametrize covering categorical_features, constant_target, and single_sample cases.
|
@klemens-floege Thanks a lot for the detailed review! I’ve addressed the blocking issue by setting self.n_estimators_ properly in the differentiable-input initialization path and mirrored the classifier logic. I also tried reducing duplication based on your suggestions. The fixes/refactors are included in the latest commits. I’d really appreciate another look when you have time! |
Issue
Closes #922 and #702
Motivation and Context
TabPFNClassifieralready supportsdifferentiable_input=Trueviafit_with_differentiable_input, allowing a downstream loss to backpropthrough the model into upstream torch modules.
TabPFNRegressorexposesthe same
differentiable_inputconstructor argument butfit()raisesValueError("Differentiable input is not supported for regressors yet."),and there is no
fit_with_differentiable_inputcounterpart. Therefore, theregressor cannot be used for prompt tuning, ICL adapter training, or any
setting where gradients must flow through the regression head.
This PR mirrors the classifier-side path on the regressor so the
two estimators have symmetric APIs.
What changes:
_initialize_for_differentiable_input(X, y, rng): minimal,differentiable preprocessing using
PreprocessorConfig("none", differentiable=True). z-normalisesyas a torch op so grads survive,rebuilds
raw_space_bardist_in the caller's target scale, and forcespolynomial_features="no"since the polynomial step relies on anumpy-only sklearn
StandardScaler(the regressor's runtime configdefaults to a non-zero value, so this had to be explicit).
fit_with_differentiable_input(X, y): parallel toTabPFNClassifier.fit_with_differentiable_input. Builds anInferenceEngineCachePreprocessingwithinference_mode=False._iter_forward_executor: gatesuse_inference_modeondifferentiable_input(parallel to classifier line 1459) so a usercalling
forward(X, use_inference_mode=True)afterfit_with_differentiable_inputstill gets gradients.fit(): now raises a clearerValueErrorpointing atfit_with_differentiable_inputwhendifferentiable_input=True,instead of silently failing once the numpy-only path hits a torch
tensor.
Public API Changes
Adds one new public method on
TabPFNRegressor:fit_with_differentiable_input(X: torch.Tensor, y: torch.Tensor) -> SelfTightens one error message:
TabPFNRegressor.fit(...)withdifferentiable_input=Truenow raises aValueErrorwhose messagepoints users at
fit_with_differentiable_input(instead of the previousgeneric "not supported" message). Same exception type, more actionable
text.
No breaking changes to existing call sites.
How Has This Been Tested?
New tests in
tests/test_regressor_interface.py:test__fit_with_differentiable_input__grad_flows_to_upstream_module[cpu|cuda]end-to-end:
nn.Linear → TabPFNRegressor → MSE loss → backward()produces a finite, non-zero gradient on the upstream
Linear'sweight. Runs on both CPU and CUDA.
test__fit__differentiable_input_true__raises_helpful_errorguard that calling
.fit()withdifferentiable_input=Trueraisesa
ValueErrorreferencingfit_with_differentiable_input.test__fit_with_differentiable_input__categorical_features_rejectedguard that the differentiable path rejects categorical features.
test__fit_with_differentiable_input__second_call_refreshes_target_statsfits twice with very different y distributions and asserts
y_train_mean_,y_train_std_, andraw_space_bardist_.bordersallmove on the second fit. Added in response to the gemini-code-assist
review pointing out that the original
elsebranch reused staletarget stats.
Full local results:
tests/test_regressor_interface.py: 155 passed, 1 pre-existingfailure (
test_onnx_exportable_cpu, which fails identically onunmodified
main, unrelated to this change).tests/test_classifier_interface.py+test_finetuning_regressor.py+test_finetuning_classifier.py: 246 passed, 1 pre-existing ONNXfailure. No problems caused by this PR.
Checklist
changelog/README.md), or "no changelog needed" label requested.