Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/906.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `ensemble_batch_size` to batch compatible ensemble members during single-device inference, reducing the number of forward passes required when using many estimators.
24 changes: 22 additions & 2 deletions src/tabpfn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def create_inference_engine( # noqa: PLR0913
forced_inference_dtype_: torch.dtype | None,
memory_saving_mode: MemorySavingMode,
use_autocast_: bool,
ensemble_batch_size: int | None = None,
inference_mode: bool = True,
) -> InferenceEngine:
"""Create the appropriate TabPFN inference engine based on `fit_mode`.
Expand All @@ -307,9 +308,14 @@ def create_inference_engine( # noqa: PLR0913
forced_inference_dtype_: If not None, the forced dtype for inference.
memory_saving_mode: GPU/CPU memory saving settings.
use_autocast_: Whether we use torch.autocast for inference.
ensemble_batch_size: Maximum number of compatible ensemble members to batch
together during single-device low_memory or fit_preprocessors inference.
inference_mode: Whether to use torch.inference_mode (set False if
backprop is needed)
"""
if ensemble_batch_size is not None and ensemble_batch_size < 1:
raise ValueError("ensemble_batch_size must be at least 1 or None.")

if fit_mode == "low_memory":
return InferenceEngineOnDemand(
X_train=X_train,
Expand All @@ -320,6 +326,7 @@ def create_inference_engine( # noqa: PLR0913
dtype_byte_size=byte_size,
force_inference_dtype=forced_inference_dtype_,
save_peak_mem=memory_saving_mode,
ensemble_batch_size=ensemble_batch_size,
)
if fit_mode == "fit_preprocessors":
return InferenceEngineCachePreprocessing(
Expand All @@ -332,6 +339,7 @@ def create_inference_engine( # noqa: PLR0913
force_inference_dtype=forced_inference_dtype_,
save_peak_mem=memory_saving_mode,
inference_mode=inference_mode,
ensemble_batch_size=ensemble_batch_size,
Comment thread
randommm marked this conversation as resolved.
)
if fit_mode == "fit_with_cache":
# Use explicit KV cache engine for models that support it (e.g. v3),
Expand Down Expand Up @@ -500,9 +508,21 @@ def get_embeddings(
):
# Cast output to Any to allow dict-like access
output_dict = typing.cast("dict[str, torch.Tensor]", output)
embed = output_dict[selected_data].squeeze(1)
embed = output_dict[selected_data]

if isinstance(config, list):
assert embed.ndim == 3
for batch_index, batch_config in enumerate(config):
assert isinstance(
batch_config,
(ClassifierEnsembleConfig, RegressorEnsembleConfig),
)
embeddings.append(embed[:, batch_index].cpu().numpy())
continue

assert isinstance(config, (ClassifierEnsembleConfig, RegressorEnsembleConfig))
embed = embed.squeeze(1)
assert embed.ndim == 2
embeddings.append(embed.squeeze().cpu().numpy())
embeddings.append(embed.cpu().numpy())

return np.array(embeddings)
35 changes: 28 additions & 7 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__( # noqa: PLR0913
self,
*,
n_estimators: int = 8,
ensemble_batch_size: int | None = None,
categorical_features_indices: Sequence[int] | None = None,
softmax_temperature: float = 0.9,
balance_probabilities: bool = False,
Expand Down Expand Up @@ -250,6 +251,15 @@ def __init__( # noqa: PLR0913
If provided, we might ignore some of the suggestion to better fit the
data seen during pre-training.

ensemble_batch_size:
Batch compatible ensemble members together during single-device
prediction. This reduces the number of forward passes needed for
`n_estimators > 1` in `low_memory` or `fit_preprocessors` mode.

- If `None`, estimators are evaluated one-by-one.
- If an int, up to that many compatible ensemble members are evaluated
in one forward pass on a single device.

!!! note
The indices are 0-based and should represent the data passed to
`.fit()`. If the data changes between the initializations of the
Expand Down Expand Up @@ -456,6 +466,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
"""
super().__init__()
self.n_estimators = n_estimators
self.ensemble_batch_size = ensemble_batch_size
self.categorical_features_indices = categorical_features_indices
self.softmax_temperature = softmax_temperature
self.balance_probabilities = balance_probabilities
Expand Down Expand Up @@ -807,6 +818,7 @@ def fit(self, X: XType, y: YType) -> Self:
forced_inference_dtype_=self.forced_inference_dtype_,
memory_saving_mode=self.memory_saving_mode,
use_autocast_=self.use_autocast_,
ensemble_batch_size=self.ensemble_batch_size,
inference_mode=True,
)

Expand Down Expand Up @@ -1453,7 +1465,7 @@ def forward( # noqa: C901, PLR0912
if original_ndim == 2:
# Shape is [Nsamples, NClasses] -> [Nsamples, 1, NClasses]
processed_output = output.unsqueeze(1)
config_list = [config]
config_list = config if isinstance(config, list) else [config]
elif original_ndim == 3:
# Shape is [Nsamples, batch_size, NClasses]
processed_output = output
Expand Down Expand Up @@ -1484,7 +1496,7 @@ def forward( # noqa: C901, PLR0912

output_batch.append(processed_output[:, i, use_perm])

outputs.append(torch.stack(output_batch, dim=1))
outputs.extend(output_batch)

# --- Post-processing ---
stacked_outputs = torch.stack(outputs)
Expand All @@ -1501,11 +1513,20 @@ def forward( # noqa: C901, PLR0912
if output.ndim > 2 and use_inference_mode:
output = output.squeeze(1) if not return_raw_logits else output.squeeze(2)

if not use_inference_mode:
# This case is primarily for fine-tuning where NLLLoss expects [B, C, N]
if output.ndim == 2: # was likely [N, C]
output = output.unsqueeze(0) # [1, N, C]
output = output.transpose(0, 1).transpose(1, 2)
if not use_inference_mode and return_raw_logits:
# Fine-tuning consumes raw logits as [Q, B, E, L]. The batched
# engine currently emits [E, Q, L] after selecting B=1.
if output.ndim == 3:
output = output.permute(1, 0, 2).unsqueeze(1)
elif output.ndim == 4:
output = output.permute(1, 2, 0, 3)
elif not use_inference_mode:
# This case is primarily for fine-tuning where NLLLoss expects
# [B, C, N]. The batched engine emits averaged [N, C].
if output.ndim == 2:
output = output.T.unsqueeze(0)
elif output.ndim == 3:
output = output.permute(1, 2, 0)

return output

Expand Down
9 changes: 5 additions & 4 deletions src/tabpfn/finetuning/finetuned_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,11 @@ def _forward_with_loss(self, batch: RegressorBatch) -> torch.Tensor: # type: ig
bardist_loss_fn = self._bardist_loss

_, per_estim_logits, _ = self._training_forward(X_query_batch)
# per_estim_logits is a list (per estimator) of tensors with shape [Q, B(=1), L]

# shape suffix: Q=n_queries, B=batch(=1), E=estimators, L=logits
logits_QBEL = torch.stack(per_estim_logits, dim=2)
# per_estim_logits is a list (per estimator) of tensors. The batched
# inference path selects the only supported meta-batch item, so each
# tensor is [Q, L]. Restore B=1 for the shared loss layout.
logits_QLE = torch.stack(per_estim_logits, dim=2)
logits_QBEL = logits_QLE.permute(0, 2, 1).unsqueeze(1)

Q, B, E, L = logits_QBEL.shape
num_bars = bardist_loss_fn.num_bars
Expand Down
Loading
Loading