diff --git a/changelog/923.fixed.md b/changelog/923.fixed.md new file mode 100644 index 000000000..09d861805 --- /dev/null +++ b/changelog/923.fixed.md @@ -0,0 +1 @@ +Add `TabPFNRegressor.fit_with_differentiable_input(X, y)` so gradients can flow from a downstream loss back through the regressor into upstream torch modules feeding `X` (and `y`, when it carries grads). Mirrors the existing classifier-side path — previously `TabPFNRegressor.fit` raised `ValueError("Differentiable input is not supported for regressors yet.")` and there was no differentiable counterpart. diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index 01cd6ef32..fc34f1cf7 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -370,6 +370,24 @@ def create_inference_engine( # noqa: PLR0913 raise ValueError(f"Invalid fit_mode: {fit_mode}") +def reject_categoricals_for_differentiable_input( + categorical_features_indices: Sequence[int] | None, +) -> None: + """Reject categorical features in the differentiable-input fit path. + + The differentiable path uses an identity preprocessor (no + ordinal-encoding step), so categorical columns have no valid handling + and would corrupt the prompt-tuning signal. + """ + if ( + categorical_features_indices is not None + and len(categorical_features_indices) > 0 + ): + raise ValueError( + "Categorical features are not supported for differentiable input." + ) + + def initialize_model_variables_helper( calling_instance: TabPFNRegressor | TabPFNClassifier, model_type: Literal["regressor", "classifier"], diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 21b21d33d..f3dc426ad 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -41,6 +41,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + reject_categoricals_for_differentiable_input, ) from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, @@ -635,13 +636,7 @@ def _initialize_for_differentiable_input( ) # Minimal preprocessing for prompt tuning - if ( - self.categorical_features_indices is not None - and len(self.categorical_features_indices) > 0 - ): - raise ValueError( - "Categorical features are not supported for differentiable input." - ) + reject_categoricals_for_differentiable_input(self.categorical_features_indices) n_features = X.shape[1] features = [Feature(name=None, modality=FeatureModality.NUMERICAL)] * n_features self.inferred_feature_schema_ = FeatureSchema(features=features) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index ae5daddf3..bd3fed727 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -48,13 +48,17 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + reject_categoricals_for_differentiable_input, ) from tabpfn.constants import ( REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion, ) from tabpfn.errors import TabPFNValidationError, handle_oom_errors -from tabpfn.inference import InferenceEngine, InferenceEngineBatchedNoPreprocessing +from tabpfn.inference import ( + InferenceEngine, + InferenceEngineBatchedNoPreprocessing, +) from tabpfn.model_loading import ( ModelSource, load_fitted_tabpfn_model, @@ -65,12 +69,13 @@ from tabpfn.preprocessing import ( EnsembleConfig, FeatureSubsamplingMethod, + PreprocessorConfig, RegressorEnsembleConfig, clean_data, generate_regression_ensemble_configs, ) from tabpfn.preprocessing.clean import fix_dtypes, process_text_na_dataframe -from tabpfn.preprocessing.datamodel import FeatureModality, FeatureSchema +from tabpfn.preprocessing.datamodel import Feature, FeatureModality, FeatureSchema from tabpfn.preprocessing.ensemble import ( TabPFNEnsemblePreprocessor, scale_n_estimators_for_feature_coverage, @@ -83,12 +88,14 @@ DevicesSpecification, convert_batch_of_cat_ix_to_schema, infer_random_state, + remove_non_differentiable_preprocessing_from_models, transform_borders_one, translate_probs_across_borders, ) from tabpfn.validation import ( ensure_compatible_fit_inputs, ensure_compatible_predict_input_sklearn, + validate_dataset_size, ) if TYPE_CHECKING: @@ -640,6 +647,127 @@ def _initialize_model_variables(self) -> int: """ return initialize_model_variables_helper(self, self.estimator_type) + def _rebuild_raw_space_bardist(self) -> None: + """Rebuild ``raw_space_bardist_`` from current ``y_train_mean_``/std_. + + Detaches the znorm-space borders so the rebuilt buffer never holds a + y autograd graph — required for the differentiable-input path and a + no-op for the standard path. Both ``y_train_mean_`` and + ``y_train_std_`` must already be set as Python floats. + """ + borders = self.znorm_space_bardist_.borders.detach() + self.raw_space_bardist_ = FullSupportBarDistribution( + borders * self.y_train_std_ + self.y_train_mean_, + ).float() + + def _build_ensemble_preprocessor_and_executor( + self, + *, + X: Any, + y: Any, + ensemble_configs: list[RegressorEnsembleConfig], + static_seed: int, + byte_size: int, + n_preprocessing_jobs: int, + inference_mode: bool, + ) -> None: + """Build ``self.ensemble_preprocessor_`` and ``self.executor_``. + + Shared between the standard fit path and the differentiable-input + path. The two paths differ only in ``n_preprocessing_jobs`` + (forced to 1 in the differentiable path so the autograd graph on + ``X`` survives joblib's process-boundary pickling) and + ``inference_mode`` (False under differentiable input so backprop + works through the executor). + """ + self.ensemble_preprocessor_ = TabPFNEnsemblePreprocessor( + configs=ensemble_configs, + n_samples=X.shape[0], + feature_schema=self.inferred_feature_schema_, + # Use static_seed so we're independent of any random generation + # inside the initialize functions above. + random_state=static_seed, + n_preprocessing_jobs=n_preprocessing_jobs, + keep_fitted_cache=(self.fit_mode == "fit_with_cache"), + enable_gpu_preprocessing=self.inference_config_.ENABLE_GPU_PREPROCESSING, + feature_subsampling_method=FeatureSubsamplingMethod( + self.inference_config_.FEATURE_SUBSAMPLING_METHOD + ), + constant_feature_count=self.inference_config_.FEATURE_SUBSAMPLING_CONSTANT_FEATURE_COUNT, + subsample_samples=self.inference_config_.SUBSAMPLE_SAMPLES, + importance_top_k_count=self.inference_config_.FEATURE_SUBSAMPLING_IMPORTANCE_TOP_K_COUNT, + X_train=X, + y_train=y, + task_type=self.estimator_type, + ) + self.executor_ = create_inference_engine( + fit_mode=self.fit_mode, + X_train=X, + y_train=y, + ensemble_preprocessor=self.ensemble_preprocessor_, + models=self.models_, + devices_=self.devices_, + byte_size=byte_size, + forced_inference_dtype_=self.forced_inference_dtype_, + memory_saving_mode=self.memory_saving_mode, + use_autocast_=self.use_autocast_, + inference_mode=inference_mode, + ) + + def _initialize_for_differentiable_input( + self, + X: torch.Tensor, + rng: np.random.Generator, + ) -> tuple[list[RegressorEnsembleConfig], torch.Tensor]: + """First-call setup for the differentiable path. + + Mirrors the classifier-side helper so that gradients can flow from a + loss back to upstream torch modules feeding ``X`` (and optionally + ``y``). Skips the standard numpy preprocessing path and uses a + differentiable identity preprocessor. y-target normalization happens + every call inside ``fit_with_differentiable_input``; this helper is + only for the cached feature-schema and ensemble-config setup. + """ + # Minimal preprocessing for prompt tuning: no categorical features, + # all-numerical schema, identity preprocessor that preserves grads. + reject_categoricals_for_differentiable_input(self.categorical_features_indices) + n_features = X.shape[1] + # One Feature instance per column — list multiplication would share + # the same dataclass and any later in-place update would leak across + # columns. + features = [ + Feature(name=None, modality=FeatureModality.NUMERICAL) + for _ in range(n_features) + ] + self.inferred_feature_schema_ = FeatureSchema(features=features) + self.n_features_in_ = n_features + + preprocessor_configs = [PreprocessorConfig("none", differentiable=True)] + self.n_estimators_ = scale_n_estimators_for_feature_coverage( + n_estimators=self.n_estimators, + n_total_features=n_features, + preprocessor_configs=preprocessor_configs, + ) + # Polynomial features go through sklearn StandardScaler on numpy and + # are not differentiable; force "no" regardless of the runtime default + # (the regressor config defaults to a non-zero value). + ensemble_configs = generate_regression_ensemble_configs( + num_estimators=self.n_estimators_, + add_fingerprint_feature=self.inference_config_.FINGERPRINT_FEATURE, + feature_shift_decoder=self.inference_config_.FEATURE_SHIFT_METHOD, + polynomial_features="no", + preprocessor_configs=preprocessor_configs, + target_transforms=[None], + random_state=rng, + num_models=len(self.models_), + outlier_removal_std=self.inference_config_.get_resolved_outlier_removal_std( + estimator_type=self.estimator_type + ), + ) + assert len(ensemble_configs) == self.n_estimators_ + + return ensemble_configs, X + def _initialize_dataset_preprocessing( self, X: XType, @@ -793,6 +921,103 @@ def fit_from_preprocessed( return self + @track_model_call(model_method="fit", param_names=["X", "y"]) + def fit_with_differentiable_input(self, X: torch.Tensor, y: torch.Tensor) -> Self: + """Fit the model with differentiable input. + + Mirror of ``TabPFNClassifier.fit_with_differentiable_input``. Lets + gradients flow from a downstream loss back through ``X`` (and ``y``, + if it carries grads) into upstream torch modules. Use this instead + of ``fit`` when ``differentiable_input=True``. + + Args: + X: The input data as a torch tensor. + y: The target variable as a torch tensor. + + Returns: + self + """ + if self.fit_mode != "fit_preprocessors": + logging.warning( + "The model was not in 'fit_preprocessors' mode. " + "Automatically switching to 'fit_preprocessors' mode for differentiable" + " input." + ) + self.fit_mode = "fit_preprocessors" + + static_seed, rng = infer_random_state(self.random_state) + + is_first_fit_call = not hasattr(self, "models_") + if is_first_fit_call: + byte_size = self._initialize_model_variables() + ensemble_configs, X = self._initialize_for_differentiable_input( + X=X, rng=rng + ) + self.ensemble_configs_ = ensemble_configs # Store for prompt tuning reuse + remove_non_differentiable_preprocessing_from_models(models=self.models_) + else: + _, _, byte_size = determine_precision( + self.inference_precision, self.devices_ + ) + ensemble_configs = self.ensemble_configs_ # Reuse from first fit + # Mirror classifier.py: re-assert n_estimators_ from cached + # configs so a subsequent call after pickling restores it. + self.n_estimators_ = len(ensemble_configs) + + # Refresh target stats and rebuild the raw-space bardist on every + # call so they track the current fit data; cached state is only the + # model load, feature schema, and ensemble configs above. + validate_dataset_size( + X=X, + y=y, + max_num_samples=self.inference_config_.MAX_NUMBER_OF_SAMPLES, + max_num_features=self.inference_config_.MAX_NUMBER_OF_FEATURES, + devices=self.devices_, + ignore_pretraining_limits=self.ignore_pretraining_limits, + ) + self.n_train_samples_ = int(X.shape[0]) + + y_float = ( + y.float() + if isinstance(y, torch.Tensor) + else torch.as_tensor(y, dtype=torch.float32) + ) + y_mean = y_float.mean() + # Match the standard fit's np.std (population std, ddof=0). torch.std + # defaults to correction=1 and returns NaN for N=1; clamp keeps the + # divisor non-zero. The constant-target guard below catches the + # remaining bardist-collapse case. + y_std = torch.clamp(y_float.std(correction=0), min=1e-20) + if y_std.detach().item() <= 1e-12: + raise ValueError( + "Constant or near-constant target (std≈0) is not supported " + "by fit_with_differentiable_input; there is no signal to " + "predict differentiably. Use fit() for constant-target data." + ) + # Detach when storing as Python floats — raw_space_bardist_ is a + # frozen lookup and must not hold a y autograd graph. Users who need + # fully differentiable target scaling should z-normalise y themselves + # before calling so the mean/std are constants here. + self.y_train_mean_ = y_mean.detach().item() + self.y_train_std_ = y_std.detach().item() + y = (y_float - y_mean) / y_std + self._rebuild_raw_space_bardist() + + # Force sequential preprocessing: with differentiable input X carries + # an autograd graph that does not survive joblib's process-boundary + # pickling. Sequential execution preserves the graph in-process. + self._build_ensemble_preprocessor_and_executor( + X=X, + y=y, + ensemble_configs=ensemble_configs, + static_seed=static_seed, + byte_size=byte_size, + n_preprocessing_jobs=1, + inference_mode=False, + ) + + return self + @config_context(transform_output="default") # type: ignore @track_model_call(model_method="fit", param_names=["X", "y"]) def fit(self, X: XType, y: YType) -> Self: @@ -807,7 +1032,8 @@ def fit(self, X: XType, y: YType) -> Self: """ if self.differentiable_input: raise ValueError( - "Differentiable input is not supported for regressors yet." + "differentiable_input=True requires fit_with_differentiable_input " + "with torch tensor X and y, not fit()." ) if self.fit_mode == "batched": @@ -859,43 +1085,17 @@ def fit(self, X: XType, y: YType) -> Self: self.y_train_std_ = std.item() + 1e-20 self.y_train_mean_ = mean.item() y = (y - self.y_train_mean_) / self.y_train_std_ - self.raw_space_bardist_ = FullSupportBarDistribution( - self.znorm_space_bardist_.borders * self.y_train_std_ + self.y_train_mean_, - ).float() + self._rebuild_raw_space_bardist() - ensemble_preprocessor = TabPFNEnsemblePreprocessor( - configs=ensemble_configs, - n_samples=X.shape[0], - feature_schema=self.inferred_feature_schema_, - # Note: we use the static_seed so we're independent of the random generation - # inside the initialize function above - random_state=static_seed, - n_preprocessing_jobs=self.n_preprocessing_jobs, - keep_fitted_cache=(self.fit_mode == "fit_with_cache"), - enable_gpu_preprocessing=self.inference_config_.ENABLE_GPU_PREPROCESSING, - feature_subsampling_method=FeatureSubsamplingMethod( - self.inference_config_.FEATURE_SUBSAMPLING_METHOD - ), - constant_feature_count=self.inference_config_.FEATURE_SUBSAMPLING_CONSTANT_FEATURE_COUNT, - subsample_samples=self.inference_config_.SUBSAMPLE_SAMPLES, - importance_top_k_count=self.inference_config_.FEATURE_SUBSAMPLING_IMPORTANCE_TOP_K_COUNT, - X_train=X, - y_train=y, - task_type=self.estimator_type, - ) - - self.executor_ = create_inference_engine( - fit_mode=self.fit_mode, - X_train=X, - y_train=y, - ensemble_preprocessor=ensemble_preprocessor, - models=self.models_, - devices_=self.devices_, + self._build_ensemble_preprocessor_and_executor( + X=X, + y=y, + ensemble_configs=ensemble_configs, + static_seed=static_seed, byte_size=byte_size, - forced_inference_dtype_=self.forced_inference_dtype_, - memory_saving_mode=self.memory_saving_mode, - use_autocast_=self.use_autocast_, + n_preprocessing_jobs=self.n_preprocessing_jobs, # TODO: Standard fit usually uses inference_mode=True, before it was enabled + inference_mode=True, ) return self @@ -1122,8 +1322,11 @@ def _iter_forward_executor( check_is_fitted(self) # Ensure torch.inference_mode is OFF to allow gradients if self.fit_mode in ["fit_preprocessors", "batched"]: - # only these two modes support this option - self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) + # only these two modes support this option. + # Don't enable inference mode when differentiable_input=True (prompt + # tuning) to allow gradients to flow through. + actual_inference_mode = use_inference_mode and not self.differentiable_input + self.executor_.use_torch_inference_mode(use_inference=actual_inference_mode) std_borders = self.znorm_space_bardist_.borders.cpu().numpy() for output, config in self.executor_.iter_outputs( X, autocast=self.use_autocast_, task_type="regression" diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 0e3c046a0..54bd44450 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -976,3 +976,209 @@ def test__create_default_for_version__passes_through_overrides() -> None: assert estimator.n_estimators == 16 assert estimator.softmax_temperature == 0.9 + + +# --------------------------------------------------------------------------- +# differentiable_input +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", devices) +def test__fit_with_differentiable_input__grad_flows_to_upstream_module( + device: str, +) -> None: + """End-to-end: a loss computed from forward(use_inference_mode=True) after + fit_with_differentiable_input must produce a non-zero, finite gradient on + an upstream torch module's weights. + """ + torch.manual_seed(0) + D, N_train, N_test = 8, 30, 10 + linear = nn.Linear(D, D).to(device) + + X_train = linear(torch.randn(N_train, D, device=device)) + X_test = linear(torch.randn(N_test, D, device=device)) + y_train = torch.randn(N_train, device=device) + y_test = torch.randn(N_test, device=device) + + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device=device, + differentiable_input=True, + ) + reg.fit_with_differentiable_input(X_train, y_train) + + averaged_logits, _outputs, borders = reg.forward(X_test, use_inference_mode=True) + + # averaged_logits is [N_borders, N_samples] after the transpose in + # forward(); reduce to a scalar per sample via softmax over bin centers. + per_sample_logits = averaged_logits.transpose(0, 1) # [N_test, N_borders] + border_t = torch.as_tensor( + borders[0], + device=per_sample_logits.device, + dtype=per_sample_logits.dtype, + ) + n_logits = per_sample_logits.shape[-1] + if border_t.numel() == n_logits + 1: + bin_centers = (border_t[:-1] + border_t[1:]) / 2.0 + else: + bin_centers = border_t + probs = torch.softmax(per_sample_logits.float(), dim=-1) + pred_z = (probs * bin_centers).sum(dim=-1) + pred = pred_z * float(reg.y_train_std_) + float(reg.y_train_mean_) + + loss = torch.nn.functional.mse_loss(pred.float(), y_test.float()) + assert loss.requires_grad + loss.backward() + + grad = linear.weight.grad + assert grad is not None, "gradient did not reach upstream nn.Linear" + assert torch.isfinite(grad).all(), "gradient contained NaN/Inf" + assert grad.norm().item() > 0, "gradient norm is zero — graph was detached" + + +def test__fit__differentiable_input_true__raises_helpful_error() -> None: + """Calling .fit() (instead of fit_with_differentiable_input) when + differentiable_input=True must raise a clear error pointing users to the + correct API rather than silently running a non-differentiable path. + """ + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device="cpu", + differentiable_input=True, + ) + X = np.random.default_rng(0).standard_normal((20, 4)).astype(np.float32) + y = np.random.default_rng(0).standard_normal(20).astype(np.float32) + with pytest.raises(ValueError, match="fit_with_differentiable_input"): + reg.fit(X, y) + + +@pytest.mark.parametrize( + ("case_id", "extra_kwargs", "X", "y", "match"), + [ + # The differentiable path uses an identity preprocessor and has no + # ordinal-encoding step, so categorical columns have no valid handling. + ( + "categorical_features", + {"categorical_features_indices": [0]}, + torch.randn(20, 4), + torch.randn(20), + "Categorical features", + ), + # Constant target collapses the bardist borders to a single point. + ( + "constant_target", + {}, + torch.randn(5, 4), + torch.full((5,), 3.14), + "Constant or near-constant target", + ), + # torch.std defaults to correction=1 and returns NaN for N=1; our path + # uses correction=0 so std collapses to 0 and trips the constant-target + # guard instead of a downstream NaN. + ( + "single_sample", + {}, + torch.randn(1, 4), + torch.tensor([2.0]), + "Constant or near-constant target", + ), + ], +) +def test__fit_with_differentiable_input__bad_input_raises_value_error( + case_id: str, + extra_kwargs: dict[str, object], + X: torch.Tensor, + y: torch.Tensor, + match: str, +) -> None: + """Bad inputs to the differentiable fit path must raise ValueError with a + clear message rather than producing NaNs or crashing downstream. + """ + del case_id # Only used for parametrize ids. + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device="cpu", + differentiable_input=True, + **extra_kwargs, # type: ignore[arg-type] + ) + with pytest.raises(ValueError, match=match): + reg.fit_with_differentiable_input(X, y) + + +def test__fit_with_differentiable_input__std_matches_population_definition() -> None: + """The differentiable path's y_train_std_ should match np.std (population + std, ddof=0), not torch's default sample std (correction=1), so it lines + up with the standard fit() path. + """ + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device="cpu", + differentiable_input=True, + ) + X = torch.randn(20, 4) + y_np = np.random.default_rng(0).standard_normal(20).astype(np.float32) + y = torch.from_numpy(y_np) + reg.fit_with_differentiable_input(X, y) + expected = float(np.std(y_np)) # ddof=0 + assert abs(reg.y_train_std_ - expected) < 1e-5, ( + f"y_train_std_ should equal np.std(y) (population std); " + f"got {reg.y_train_std_}, expected {expected}" + ) + + +def test__fit_with_differentiable_input__feature_schema_cols_independent() -> None: + """Each column's Feature must be a distinct instance — list multiplication + `[Feature(...)] * n` would alias all columns to one mutable dataclass. + """ + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device="cpu", + differentiable_input=True, + ) + X = torch.randn(10, 4) + y = torch.randn(10) + reg.fit_with_differentiable_input(X, y) + feats = reg.inferred_feature_schema_.features + assert len(feats) == 4 + # Distinct instances, not aliases. + ids = {id(f) for f in feats} + assert len(ids) == 4, "feature columns share the same Feature instance" + + +def test__fit_with_differentiable_input__second_call_refreshes_target_stats() -> None: + """A second call with different y must update y_train_mean_/std_ and the + raw_space_bardist_; only the model load and ensemble configs are cached. + """ + torch.manual_seed(0) + reg = TabPFNRegressor( + n_estimators=1, + ignore_pretraining_limits=True, + device="cpu", + differentiable_input=True, + ) + X1 = torch.randn(20, 4) + y1 = torch.randn(20) * 10.0 + 100.0 # mean ~100, std ~10 + reg.fit_with_differentiable_input(X1, y1) + mean1, std1 = reg.y_train_mean_, reg.y_train_std_ + bardist_borders1 = reg.raw_space_bardist_.borders.clone() + + X2 = torch.randn(20, 4) + y2 = torch.randn(20) * 0.5 - 5.0 # mean ~-5, std ~0.5 + reg.fit_with_differentiable_input(X2, y2) + mean2, std2 = reg.y_train_mean_, reg.y_train_std_ + + assert abs(mean2 - mean1) > 1.0, ( + f"y_train_mean_ should reflect new y; got {mean1} -> {mean2}" + ) + assert abs(std2 - std1) > 1.0, ( + f"y_train_std_ should reflect new y; got {std1} -> {std2}" + ) + # raw_space_bardist_ borders are derived from y stats; they must move. + assert not torch.allclose(reg.raw_space_bardist_.borders, bardist_borders1), ( + "raw_space_bardist_ must be rebuilt to the new target scale" + )