Skip to content

TabPFNRegressor lacks differentiable_input support that TabPFNClassifier has #922

@lujiazho

Description

@lujiazho

Describe the workflow you want to enable

Summary

TabPFNClassifier accepts differentiable_input=True and exposes
fit_with_differentiable_input(X, y) so a downstream loss can backprop
through the model into upstream torch modules feeding X. The
constructor of TabPFNRegressor accepts the same differentiable_input
argument, but fit() raises and there is no fit_with_differentiable_input
counterpart.

src/tabpfn/regressor.py:790-793 (current main):

if self.differentiable_input:
    raise ValueError(
        "Differentiable input is not supported for regressors yet."
    )

Minimal repro

import torch, torch.nn as nn
from tabpfn import TabPFNRegressor

device = "cuda" if torch.cuda.is_available() else "cpu"
linear = nn.Linear(8, 8).to(device)
X = linear(torch.randn(30, 8, device=device))
y = torch.randn(30, device=device)

reg = TabPFNRegressor(
    n_estimators=1, ignore_pretraining_limits=True,
    device=device, differentiable_input=True,
)
reg.fit(X, y)
# ValueError: Differentiable input is not supported for regressors yet.

The classifier-side equivalent works:

import torch, torch.nn as nn
import torch.nn.functional as F
from tabpfn import TabPFNClassifier

device = "cuda" if torch.cuda.is_available() else "cpu"
linear = nn.Linear(8, 8).to(device)
X = linear(torch.randn(30, 8, device=device))
y = torch.randint(0, 10, (30,), device=device, dtype=torch.long)

clf = TabPFNClassifier(
    n_estimators=1, ignore_pretraining_limits=True,
    device=device, differentiable_input=True,
)
clf.fit_with_differentiable_input(X, y)
logits = clf.forward(X, use_inference_mode=True, return_logits=True)
loss = F.cross_entropy(logits, y)
loss.backward()
print(linear.weight.grad)              # → linear.weight.grad is finite, non-zero

Why we need this

Several common use cases need a differentiable head:

  • Prompt tuning over learned support embeddings.
  • ICL adapter training where a feature encoder feeds into TabPFN and the
    loss on TabPFN's prediction must update the encoder.
  • MoE / multi-head architectures that want TabPFN as one of several
    differentiable heads.

Currently the only way to use TabPFN regression in these settings is to
bypass the official executor and rebuild
InferenceEngineBatchedNoPreprocessing from scratch on every forward. This can help a lot of people and potentially make TabPFN more convenient for users.

Environment

  • Python 3.12, torch 2.x, tabpfn main (current 7.1.1)
  • Reproduced against PriorLabs/TabPFN main @ 275e06a (current HEAD).
  • Reproducible on both CPU and CUDA.

Describe your proposed solution

Proposed fix

Mirror the classifier-side prompt-tuning path on the regressor:

  1. Add _initialize_for_differentiable_input(X, y, rng).
  2. Add fit_with_differentiable_input(X, y) building an
    InferenceEngineCachePreprocessing with inference_mode=False.
  3. Gate use_inference_mode on differentiable_input inside
    _iter_forward_executor (parallel to classifier's existing
    actual_inference_mode = use_inference_mode and not self.differentiable_input).
  4. Replace the ValueError in fit() with one that points users at the
    new method.

I have a working patch on top of main and am happy to open a PR.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

Impact

High (Major improvement)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions