From a49e278f9aa7c759f5679ab3c3385bda0ea800ae Mon Sep 17 00:00:00 2001 From: Marco Inacio Date: Thu, 14 May 2026 08:00:21 +0100 Subject: [PATCH] Add ensemble_batch_size for single-device inference On devices with large amounts of RAM like Strix Halo, this can greatly speed up results --- changelog/906.added.md | 1 + src/tabpfn/base.py | 24 +- src/tabpfn/classifier.py | 35 ++- src/tabpfn/finetuning/finetuned_regressor.py | 9 +- src/tabpfn/inference.py | 287 ++++++++++++++++++- src/tabpfn/regressor.py | 42 ++- tests/test_classifier_interface.py | 42 +++ tests/test_inference.py | 232 ++++++++++++++- 8 files changed, 638 insertions(+), 34 deletions(-) create mode 100644 changelog/906.added.md diff --git a/changelog/906.added.md b/changelog/906.added.md new file mode 100644 index 000000000..62d52a3fb --- /dev/null +++ b/changelog/906.added.md @@ -0,0 +1 @@ +Add `ensemble_batch_size` to batch compatible ensemble members during single-device inference, reducing the number of forward passes required when using many estimators. diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index 0803f0b35..a4ed599ec 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -286,6 +286,7 @@ def create_inference_engine( # noqa: PLR0913 forced_inference_dtype_: torch.dtype | None, memory_saving_mode: MemorySavingMode, use_autocast_: bool, + ensemble_batch_size: int | None = None, inference_mode: bool = True, ) -> InferenceEngine: """Create the appropriate TabPFN inference engine based on `fit_mode`. @@ -307,9 +308,14 @@ def create_inference_engine( # noqa: PLR0913 forced_inference_dtype_: If not None, the forced dtype for inference. memory_saving_mode: GPU/CPU memory saving settings. use_autocast_: Whether we use torch.autocast for inference. + ensemble_batch_size: Maximum number of compatible ensemble members to batch + together during single-device low_memory or fit_preprocessors inference. inference_mode: Whether to use torch.inference_mode (set False if backprop is needed) """ + if ensemble_batch_size is not None and ensemble_batch_size < 1: + raise ValueError("ensemble_batch_size must be at least 1 or None.") + if fit_mode == "low_memory": return InferenceEngineOnDemand( X_train=X_train, @@ -320,6 +326,7 @@ def create_inference_engine( # noqa: PLR0913 dtype_byte_size=byte_size, force_inference_dtype=forced_inference_dtype_, save_peak_mem=memory_saving_mode, + ensemble_batch_size=ensemble_batch_size, ) if fit_mode == "fit_preprocessors": return InferenceEngineCachePreprocessing( @@ -332,6 +339,7 @@ def create_inference_engine( # noqa: PLR0913 force_inference_dtype=forced_inference_dtype_, save_peak_mem=memory_saving_mode, inference_mode=inference_mode, + ensemble_batch_size=ensemble_batch_size, ) if fit_mode == "fit_with_cache": # Use explicit KV cache engine for models that support it (e.g. v3), @@ -500,9 +508,21 @@ def get_embeddings( ): # Cast output to Any to allow dict-like access output_dict = typing.cast("dict[str, torch.Tensor]", output) - embed = output_dict[selected_data].squeeze(1) + embed = output_dict[selected_data] + + if isinstance(config, list): + assert embed.ndim == 3 + for batch_index, batch_config in enumerate(config): + assert isinstance( + batch_config, + (ClassifierEnsembleConfig, RegressorEnsembleConfig), + ) + embeddings.append(embed[:, batch_index].cpu().numpy()) + continue + assert isinstance(config, (ClassifierEnsembleConfig, RegressorEnsembleConfig)) + embed = embed.squeeze(1) assert embed.ndim == 2 - embeddings.append(embed.squeeze().cpu().numpy()) + embeddings.append(embed.cpu().numpy()) return np.array(embeddings) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 157559184..6db4fe96b 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -202,6 +202,7 @@ def __init__( # noqa: PLR0913 self, *, n_estimators: int = 8, + ensemble_batch_size: int | None = None, categorical_features_indices: Sequence[int] | None = None, softmax_temperature: float = 0.9, balance_probabilities: bool = False, @@ -250,6 +251,15 @@ def __init__( # noqa: PLR0913 If provided, we might ignore some of the suggestion to better fit the data seen during pre-training. + ensemble_batch_size: + Batch compatible ensemble members together during single-device + prediction. This reduces the number of forward passes needed for + `n_estimators > 1` in `low_memory` or `fit_preprocessors` mode. + + - If `None`, estimators are evaluated one-by-one. + - If an int, up to that many compatible ensemble members are evaluated + in one forward pass on a single device. + !!! note The indices are 0-based and should represent the data passed to `.fit()`. If the data changes between the initializations of the @@ -456,6 +466,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this """ super().__init__() self.n_estimators = n_estimators + self.ensemble_batch_size = ensemble_batch_size self.categorical_features_indices = categorical_features_indices self.softmax_temperature = softmax_temperature self.balance_probabilities = balance_probabilities @@ -807,6 +818,7 @@ def fit(self, X: XType, y: YType) -> Self: forced_inference_dtype_=self.forced_inference_dtype_, memory_saving_mode=self.memory_saving_mode, use_autocast_=self.use_autocast_, + ensemble_batch_size=self.ensemble_batch_size, inference_mode=True, ) @@ -1453,7 +1465,7 @@ def forward( # noqa: C901, PLR0912 if original_ndim == 2: # Shape is [Nsamples, NClasses] -> [Nsamples, 1, NClasses] processed_output = output.unsqueeze(1) - config_list = [config] + config_list = config if isinstance(config, list) else [config] elif original_ndim == 3: # Shape is [Nsamples, batch_size, NClasses] processed_output = output @@ -1484,7 +1496,7 @@ def forward( # noqa: C901, PLR0912 output_batch.append(processed_output[:, i, use_perm]) - outputs.append(torch.stack(output_batch, dim=1)) + outputs.extend(output_batch) # --- Post-processing --- stacked_outputs = torch.stack(outputs) @@ -1501,11 +1513,20 @@ def forward( # noqa: C901, PLR0912 if output.ndim > 2 and use_inference_mode: output = output.squeeze(1) if not return_raw_logits else output.squeeze(2) - if not use_inference_mode: - # This case is primarily for fine-tuning where NLLLoss expects [B, C, N] - if output.ndim == 2: # was likely [N, C] - output = output.unsqueeze(0) # [1, N, C] - output = output.transpose(0, 1).transpose(1, 2) + if not use_inference_mode and return_raw_logits: + # Fine-tuning consumes raw logits as [Q, B, E, L]. The batched + # engine currently emits [E, Q, L] after selecting B=1. + if output.ndim == 3: + output = output.permute(1, 0, 2).unsqueeze(1) + elif output.ndim == 4: + output = output.permute(1, 2, 0, 3) + elif not use_inference_mode: + # This case is primarily for fine-tuning where NLLLoss expects + # [B, C, N]. The batched engine emits averaged [N, C]. + if output.ndim == 2: + output = output.T.unsqueeze(0) + elif output.ndim == 3: + output = output.permute(1, 2, 0) return output diff --git a/src/tabpfn/finetuning/finetuned_regressor.py b/src/tabpfn/finetuning/finetuned_regressor.py index 34acc40d8..bbb995a55 100644 --- a/src/tabpfn/finetuning/finetuned_regressor.py +++ b/src/tabpfn/finetuning/finetuned_regressor.py @@ -440,10 +440,11 @@ def _forward_with_loss(self, batch: RegressorBatch) -> torch.Tensor: # type: ig bardist_loss_fn = self._bardist_loss _, per_estim_logits, _ = self._training_forward(X_query_batch) - # per_estim_logits is a list (per estimator) of tensors with shape [Q, B(=1), L] - - # shape suffix: Q=n_queries, B=batch(=1), E=estimators, L=logits - logits_QBEL = torch.stack(per_estim_logits, dim=2) + # per_estim_logits is a list (per estimator) of tensors. The batched + # inference path selects the only supported meta-batch item, so each + # tensor is [Q, L]. Restore B=1 for the shared loss layout. + logits_QLE = torch.stack(per_estim_logits, dim=2) + logits_QBEL = logits_QLE.permute(0, 2, 1).unsqueeze(1) Q, B, E, L = logits_QBEL.shape num_bars = bardist_loss_fn.num_bars diff --git a/src/tabpfn/inference.py b/src/tabpfn/inference.py index 8597369af..6225ece32 100644 --- a/src/tabpfn/inference.py +++ b/src/tabpfn/inference.py @@ -66,6 +66,14 @@ def __iter__(self) -> _TimedIterator[_T]: return self +@dataclasses.dataclass +class _PreparedEnsembleForward: + X_full: torch.Tensor + y_train: torch.Tensor + categorical_inds: list[int] + config: EnsembleConfig + + def _model_expectes_task_type_arg(model: Architecture) -> bool: """Check if the model's forward function expects a task_type argument. @@ -340,6 +348,7 @@ def __init__( dtype_byte_size: int, force_inference_dtype: torch.dtype | None, save_peak_mem: MemorySavingMode, + ensemble_batch_size: int | None = None, ) -> None: """Initialize the on-demand inference engine. @@ -354,6 +363,8 @@ def __init__( dtype_byte_size: The byte size of the dtype. force_inference_dtype: The dtype to force inference to. save_peak_mem: Whether to save peak memory usage. + ensemble_batch_size: Maximum number of compatible ensemble members to + batch together during single-device inference. """ super().__init__( model_caches=[_PerDeviceModelCache(model) for model in models], @@ -365,6 +376,7 @@ def __init__( self.X_train = X_train self.y_train = y_train self.ensemble_preprocessor = ensemble_preprocessor + self.ensemble_batch_size = ensemble_batch_size self.to(devices, self.force_inference_dtype, self.dtype_byte_size) @@ -391,6 +403,21 @@ def iter_outputs( for model_cache in self.model_caches: model_cache.set_dtype(self.force_inference_dtype) + if ( + len(devices) == 1 + and self.ensemble_batch_size is not None + and self.ensemble_batch_size > 1 + ): + yield from self._iter_outputs_batched_single_device( + X, + autocast=autocast, + task_type=task_type, + only_return_standard_out=only_return_standard_out, + save_peak_mem=save_peak_mem, + device=devices[0], + ) + return + ensemble_members_iterator = ( self.ensemble_preprocessor.fit_transform_ensemble_members_iterator( X_train=self.X_train, @@ -427,6 +454,124 @@ def iter_outputs( timed_outputs.elapsed_seconds ) + def _iter_outputs_batched_single_device( + self, + X: np.ndarray, + *, + autocast: bool, + task_type: str, + only_return_standard_out: bool, + save_peak_mem: bool, + device: torch.device, + ) -> Iterator[tuple[torch.Tensor | dict, list[EnsembleConfig]]]: + assert self.ensemble_batch_size is not None + forward_time = 0.0 + open_batches: list[list[_PreparedEnsembleForward]] = [] + + def _prepare_member( + ensemble_member: TabPFNEnsembleMember, + ) -> _PreparedEnsembleForward: + X_full, y_train = _prepare_model_inputs( + device, + self.force_inference_dtype, + ensemble_member.X_train, + ensemble_member.transform_X_test(X), + ensemble_member.y_train, + ) + X_full, feature_schema = _maybe_run_gpu_preprocessing( + X_full, + gpu_preprocessor=ensemble_member.gpu_preprocessor, + num_train_rows=ensemble_member.X_train.shape[0], + feature_schema=ensemble_member.feature_schema, + ) + return _PreparedEnsembleForward( + X_full=X_full, + y_train=y_train, + categorical_inds=feature_schema.indices_for( + FeatureModality.CATEGORICAL + ), + config=ensemble_member.config, + ) + + def _is_compatible_with_batch( + reference: _PreparedEnsembleForward, + candidate: _PreparedEnsembleForward, + ) -> bool: + return ( + candidate.config._model_index == reference.config._model_index + and candidate.X_full.shape == reference.X_full.shape + and candidate.y_train.shape == reference.y_train.shape + ) + + def _forward_batch( + batch: list[_PreparedEnsembleForward], + ) -> tuple[torch.Tensor | dict, list[EnsembleConfig]]: + nonlocal forward_time + model_index = batch[0].config._model_index + model = self.model_caches[model_index].get(device) + X_full = torch.cat([prepared.X_full for prepared in batch], dim=1) + y_train = torch.stack([prepared.y_train for prepared in batch], dim=1) + batched_cat_ix = [prepared.categorical_inds for prepared in batch] + + performance_options = model.get_default_performance_options() + performance_options = dataclasses.replace( + performance_options, + save_peak_memory_factor=DEFAULT_SAVE_PEAK_MEMORY_FACTOR + if save_peak_mem + else None, + ) + + kwargs = {} + if _model_expectes_task_type_arg(model): + kwargs["task_type"] = task_type + + forward_start = time.perf_counter() + with get_autocast_context(device, enabled=autocast), torch.inference_mode(): + output = model( + X_full, + y_train, + only_return_standard_out=only_return_standard_out, + categorical_inds=batched_cat_ix, + performance_options=performance_options, + **kwargs, + ) + forward_time += time.perf_counter() - forward_start + + return ( + _move_and_squeeze_output(output, device), + [prepared.config for prepared in batch], + ) + + ensemble_members_iterator = ( + self.ensemble_preprocessor.fit_transform_ensemble_members_iterator( + X_train=self.X_train, + y_train=self.y_train, + parallel_mode="in-order", + ) + ) + for ensemble_member in ensemble_members_iterator: + prepare_start = time.perf_counter() + prepared = _prepare_member(ensemble_member) + forward_time += time.perf_counter() - prepare_start + + for batch in open_batches: + if _is_compatible_with_batch(batch[0], prepared): + batch.append(prepared) + break + else: + open_batches.append([prepared]) + + while open_batches and ( + len(open_batches[0]) == self.ensemble_batch_size + or len(open_batches) > self.ensemble_batch_size + ): + yield _forward_batch(open_batches.pop(0)) + + while open_batches: + yield _forward_batch(open_batches.pop(0)) + + self._speed_metrics["predict_model_forward_seconds"] = forward_time + def _call_model( # noqa: PLR0913 self, *, @@ -611,7 +756,7 @@ class InferenceEngineCachePreprocessing(MultiDeviceInferenceEngine): forward pass through the model which is currently done sequentially. """ - def __init__( + def __init__( # noqa: PLR0913 self, X_train: np.ndarray | torch.Tensor, y_train: np.ndarray | torch.Tensor, @@ -624,6 +769,7 @@ def __init__( save_peak_mem: MemorySavingMode, inference_mode: bool, no_preprocessing: bool = False, + ensemble_batch_size: int | None = None, ) -> None: """Initialize the cache preprocessing inference engine. @@ -642,6 +788,8 @@ def __init__( (this is quicker but disables backpropagation) no_preprocessing: If True, skip preprocessing on test data. Used for differentiability. + ensemble_batch_size: Maximum number of compatible ensemble members to + batch together during single-device inference. """ super().__init__( model_caches=[_PerDeviceModelCache(model) for model in models], @@ -652,6 +800,7 @@ def __init__( self.inference_mode = inference_mode self.no_preprocessing = no_preprocessing + self.ensemble_batch_size = ensemble_batch_size self.X_train_shape_before_preprocessing = X_train.shape fit_preprocess_start = time.perf_counter() @@ -675,7 +824,7 @@ def iter_outputs( autocast: bool, task_type: str, only_return_standard_out: bool = True, - ) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]: + ) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig | list[EnsembleConfig]]]: devices = self.get_devices() if self.force_inference_dtype is not None: @@ -693,6 +842,21 @@ def iter_outputs( else: save_peak_mem = False + if ( + len(devices) == 1 + and self.ensemble_batch_size is not None + and self.ensemble_batch_size > 1 + ): + yield from self._iter_outputs_batched_single_device( + X, + autocast=autocast, + task_type=task_type, + only_return_standard_out=only_return_standard_out, + save_peak_mem=save_peak_mem, + device=devices[0], + ) + return + def _transform_X_test( ensemble_member: TabPFNEnsembleMember, ) -> np.ndarray | torch.Tensor: @@ -726,6 +890,125 @@ def _transform_X_test( timed_outputs.elapsed_seconds ) + def _iter_outputs_batched_single_device( # noqa: C901 + self, + X: np.ndarray | torch.Tensor, + *, + autocast: bool, + task_type: str, + only_return_standard_out: bool, + save_peak_mem: bool, + device: torch.device, + ) -> Iterator[tuple[torch.Tensor | dict, list[EnsembleConfig]]]: + assert self.ensemble_batch_size is not None + forward_time = 0.0 + open_batches: list[list[_PreparedEnsembleForward]] = [] + + def _transform_X_test( + ensemble_member: TabPFNEnsembleMember, + ) -> np.ndarray | torch.Tensor: + return X if self.no_preprocessing else ensemble_member.transform_X_test(X) + + def _prepare_member( + ensemble_member: TabPFNEnsembleMember, + ) -> _PreparedEnsembleForward: + X_full, y_train = _prepare_model_inputs( + device, + self.force_inference_dtype, + ensemble_member.X_train, + _transform_X_test(ensemble_member), + ensemble_member.y_train, + ) + X_full, feature_schema = _maybe_run_gpu_preprocessing( + X_full, + gpu_preprocessor=ensemble_member.gpu_preprocessor, + num_train_rows=ensemble_member.X_train.shape[0], + feature_schema=ensemble_member.feature_schema, + ) + return _PreparedEnsembleForward( + X_full=X_full, + y_train=y_train, + categorical_inds=feature_schema.indices_for( + FeatureModality.CATEGORICAL + ), + config=ensemble_member.config, + ) + + def _is_compatible_with_batch( + reference: _PreparedEnsembleForward, + candidate: _PreparedEnsembleForward, + ) -> bool: + return ( + candidate.config._model_index == reference.config._model_index + and candidate.X_full.shape == reference.X_full.shape + and candidate.y_train.shape == reference.y_train.shape + ) + + def _forward_batch( + batch: list[_PreparedEnsembleForward], + ) -> tuple[torch.Tensor | dict, list[EnsembleConfig]]: + nonlocal forward_time + model_index = batch[0].config._model_index + model = self.model_caches[model_index].get(device) + X_full = torch.cat([prepared.X_full for prepared in batch], dim=1) + y_train = torch.stack([prepared.y_train for prepared in batch], dim=1) + batched_cat_ix = [prepared.categorical_inds for prepared in batch] + + performance_options = model.get_default_performance_options() + performance_options = dataclasses.replace( + performance_options, + save_peak_memory_factor=DEFAULT_SAVE_PEAK_MEMORY_FACTOR + if save_peak_mem + else None, + ) + + kwargs = {} + if _model_expectes_task_type_arg(model): + kwargs["task_type"] = task_type + + forward_start = time.perf_counter() + with ( + get_autocast_context(device, enabled=autocast), + torch.inference_mode(self.inference_mode), + ): + output = model( + X_full, + y_train, + only_return_standard_out=only_return_standard_out, + categorical_inds=batched_cat_ix, + performance_options=performance_options, + **kwargs, + ) + forward_time += time.perf_counter() - forward_start + + return ( + _move_and_squeeze_output(output, device), + [prepared.config for prepared in batch], + ) + + for ensemble_member in self.ensemble_members: + prepare_start = time.perf_counter() + prepared = _prepare_member(ensemble_member) + forward_time += time.perf_counter() - prepare_start + + for batch in open_batches: + if _is_compatible_with_batch(batch[0], prepared): + batch.append(prepared) + break + else: + open_batches.append([prepared]) + + while open_batches and ( + len(open_batches[0]) == self.ensemble_batch_size + or len(open_batches) > self.ensemble_batch_size + ): + yield _forward_batch(open_batches.pop(0)) + + while open_batches: + yield _forward_batch(open_batches.pop(0)) + + self._speed_metrics["predict_model_forward_seconds"] = forward_time + def _call_model( # noqa: PLR0913 self, *, diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 58e492ee7..7c5d19d77 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -211,6 +211,7 @@ def __init__( # noqa: PLR0913 self, *, n_estimators: int = 8, + ensemble_batch_size: int | None = None, categorical_features_indices: Sequence[int] | None = None, softmax_temperature: float = 0.9, average_before_softmax: bool = False, @@ -256,6 +257,15 @@ def __init__( # noqa: PLR0913 If provided, we might ignore some of the suggestion to better fit the data seen during pre-training. + ensemble_batch_size: + Batch compatible ensemble members together during single-device + prediction. This reduces the number of forward passes needed for + `n_estimators > 1` in `low_memory` or `fit_preprocessors` mode. + + - If `None`, estimators are evaluated one-by-one. + - If an int, up to that many compatible ensemble members are evaluated + in one forward pass on a single device. + !!! note The indices are 0-based and should represent the data passed to `.fit()`. If the data changes between the initializations of the @@ -439,6 +449,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this """ super().__init__() self.n_estimators = n_estimators + self.ensemble_batch_size = ensemble_batch_size self.categorical_features_indices = categorical_features_indices self.softmax_temperature = softmax_temperature self.average_before_softmax = average_before_softmax @@ -857,6 +868,7 @@ def fit(self, X: XType, y: YType) -> Self: forced_inference_dtype_=self.forced_inference_dtype_, memory_saving_mode=self.memory_saving_mode, use_autocast_=self.use_autocast_, + ensemble_batch_size=self.ensemble_batch_size, # TODO: Standard fit usually uses inference_mode=True, before it was enabled ) @@ -1090,14 +1102,18 @@ def _iter_forward_executor( if self.softmax_temperature != 1: output = output / self.softmax_temperature # noqa: PLW2901 - # BSz.= 1 Scenario, the same as normal predict() function - # Handled by first if-statement - config_for_ensemble = config - if isinstance(config, list) and len(config) == 1: - single_config = config[0] - config_for_ensemble = single_config + config_list = config if isinstance(config, list) else [config] + output_batch = output.unsqueeze(1) if output.ndim == 2 else output + + if output_batch.ndim != 3 or output_batch.shape[1] != len(config_list): + raise ValueError( + "Unexpected regression output/config batch shape combination." + ) + + for batch_index, config_for_ensemble in enumerate(config_list): + if not isinstance(config_for_ensemble, RegressorEnsembleConfig): + raise ValueError("Unexpected config format for regression output.") - if isinstance(config_for_ensemble, RegressorEnsembleConfig): borders_t: np.ndarray logit_cancel_mask: np.ndarray | None descending_borders: bool @@ -1126,15 +1142,11 @@ def _iter_forward_executor( if descending_borders: borders_t = borders_t.flip(-1) # type: ignore + batch_output = output_batch[:, batch_index] if logit_cancel_mask is not None: - output = output.clone() # noqa: PLW2901 - output[..., logit_cancel_mask] = float("-inf") - yield borders_t, output - else: - raise ValueError( - "Unexpected config format " - "and Batch prediction is not supported yet!" - ) + batch_output = batch_output.clone() + batch_output[..., logit_cancel_mask] = float("-inf") + yield borders_t, batch_output def forward( self, diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 37e702aaf..69597146c 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -373,6 +373,48 @@ def test_predict_raw_logits( ) +def test_predict_outputs_match_with_ensemble_batching( + X_y: tuple[np.ndarray, np.ndarray], +) -> None: + X, y = X_y + y = y.astype(np.int64) + + sequential = TabPFNClassifier( + n_estimators=5, + fit_mode="fit_preprocessors", + random_state=42, + ) + batched = TabPFNClassifier( + n_estimators=5, + ensemble_batch_size=2, + fit_mode="fit_preprocessors", + random_state=42, + ) + + sequential.fit(X, y) + batched.fit(X, y) + + np.testing.assert_allclose( + sequential.predict_proba(X), + batched.predict_proba(X), + atol=1e-5, + rtol=1e-5, + ) + np.testing.assert_allclose( + sequential.predict_logits(X), + batched.predict_logits(X), + atol=1e-5, + rtol=1e-5, + ) + np.testing.assert_allclose( + sequential.predict_raw_logits(X), + batched.predict_raw_logits(X), + atol=1e-3, + rtol=1e-5, + ) + np.testing.assert_array_equal(sequential.predict(X), batched.predict(X)) + + def test_multiple_models_predict_different_logits(X_y: tuple[np.ndarray, np.ndarray]): """Tests the predict_raw_logits method.""" X, y = X_y diff --git a/tests/test_inference.py b/tests/test_inference.py index b273a47df..406d77482 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -2,6 +2,7 @@ from __future__ import annotations +from copy import deepcopy from typing import Literal, overload from typing_extensions import override @@ -28,6 +29,7 @@ def __init__(self) -> None: super().__init__() self.parameter = torch.nn.Parameter(torch.tensor(1.0)) self.received_task_type: str | None = None + self.seen_batch_sizes: list[int] = [] @overload def forward( @@ -68,10 +70,12 @@ def forward( assert isinstance(x, Tensor) assert isinstance(y, Tensor) self.received_task_type = task_type + self.seen_batch_sizes.append(x.shape[1]) n_train_test, _, _ = x.shape - n_train, _ = y.shape + n_train = y.shape[0] test_rows = n_train_test - n_train - return x.sum(-2, keepdim=True).sum(-1, keepdim=True).reshape(-1, test_rows) + train_summary = x[:n_train].mean(0, keepdim=True).sum(-1, keepdim=True) + return x[-test_rows:].sum(-1, keepdim=True) + train_summary @property def ninp(self) -> int: @@ -111,9 +115,10 @@ def forward( assert isinstance(x, Tensor) assert isinstance(y, Tensor) n_train_test, _, _ = x.shape - n_train, _ = y.shape + n_train = y.shape[0] test_rows = n_train_test - n_train - return x.sum(-2, keepdim=True).sum(-1, keepdim=True).reshape(-1, test_rows) + train_summary = x[:n_train].mean(0, keepdim=True).sum(-1, keepdim=True) + return x[-test_rows:].sum(-1, keepdim=True) + train_summary @property def ninp(self) -> int: @@ -234,6 +239,225 @@ def get_outputs( ) +def test__cache_preprocessing__single_device_ensemble_batching() -> None: + rng = default_rng(seed=0) + n_train = 100 + n_features = 4 + n_classes = 3 + X_train = rng.standard_normal(size=(n_train, n_features)) + y_train = rng.integers(low=0, high=n_classes - 1, size=(n_train, 1)) + X_test = rng.standard_normal(size=(2, n_features)) + + sequential_model = _TestModel() + batched_model = _TestModel() + ensemble_configs = _create_test_ensemble_configs( + n_configs=5, + n_classes=n_classes, + num_models=1, + ) + + def _make_engine( + model: _TestModel, + *, + ensemble_batch_size: int | None, + ) -> InferenceEngineCachePreprocessing: + ensemble_preprocessor = TabPFNEnsemblePreprocessor( + configs=ensemble_configs, + n_samples=X_train.shape[0], + feature_schema=FeatureSchema.from_only_categorical_indices([], n_features), + random_state=default_rng(seed=0), + n_preprocessing_jobs=1, + ) + return InferenceEngineCachePreprocessing( + X_train, + y_train, + ensemble_preprocessor=ensemble_preprocessor, + models=[model], + devices=[torch.device("cpu")], + dtype_byte_size=4, + force_inference_dtype=None, + save_peak_mem=True, + inference_mode=True, + ensemble_batch_size=ensemble_batch_size, + ) + + sequential_engine = _make_engine(sequential_model, ensemble_batch_size=None) + sequential_outputs = list( + sequential_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + batched_engine = _make_engine(batched_model, ensemble_batch_size=2) + batched_outputs = list( + batched_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + assert sequential_model.seen_batch_sizes == [1, 1, 1, 1, 1] + assert batched_model.seen_batch_sizes == [2, 2, 1] + + flattened_batched_outputs: list[tuple[Tensor, EnsembleConfig]] = [] + for output, configs in batched_outputs: + assert isinstance(output, Tensor) + assert isinstance(configs, list) + output_batch = output.unsqueeze(1) if output.ndim == 2 else output + assert output_batch.shape[1] == len(configs) + for batch_index, config in enumerate(configs): + flattened_batched_outputs.append((output_batch[:, batch_index], config)) + + assert len(sequential_outputs) == len(flattened_batched_outputs) + for batched_output, batched_config in flattened_batched_outputs: + sequential_output = _find_seq_output(batched_config, sequential_outputs) + assert isinstance(sequential_output, Tensor) + assert torch.allclose(sequential_output, batched_output) + + +def test__on_demand__single_device_ensemble_batching() -> None: + rng = default_rng(seed=0) + n_train = 100 + n_features = 4 + n_classes = 3 + X_train = rng.standard_normal(size=(n_train, n_features)) + y_train = rng.integers(low=0, high=n_classes - 1, size=(n_train, 1)) + X_test = rng.standard_normal(size=(2, n_features)) + + sequential_model = _TestModel() + batched_model = _TestModel() + ensemble_configs = _create_test_ensemble_configs( + n_configs=5, + n_classes=n_classes, + num_models=1, + ) + + def _make_engine( + model: _TestModel, + *, + ensemble_batch_size: int | None, + ) -> InferenceEngineOnDemand: + ensemble_preprocessor = TabPFNEnsemblePreprocessor( + configs=ensemble_configs, + n_samples=X_train.shape[0], + feature_schema=FeatureSchema.from_only_categorical_indices([], n_features), + random_state=default_rng(seed=0), + n_preprocessing_jobs=1, + ) + return InferenceEngineOnDemand( + X_train, + y_train, + ensemble_preprocessor=ensemble_preprocessor, + models=[model], + devices=[torch.device("cpu")], + dtype_byte_size=4, + force_inference_dtype=None, + save_peak_mem=True, + ensemble_batch_size=ensemble_batch_size, + ) + + sequential_engine = _make_engine(sequential_model, ensemble_batch_size=None) + sequential_outputs = list( + sequential_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + batched_engine = _make_engine(batched_model, ensemble_batch_size=2) + batched_outputs = list( + batched_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + assert sequential_model.seen_batch_sizes == [1, 1, 1, 1, 1] + assert batched_model.seen_batch_sizes == [2, 2, 1] + + flattened_batched_outputs: list[tuple[Tensor, EnsembleConfig]] = [] + for output, configs in batched_outputs: + assert isinstance(output, Tensor) + assert isinstance(configs, list) + output_batch = output.unsqueeze(1) if output.ndim == 2 else output + assert output_batch.shape[1] == len(configs) + for batch_index, config in enumerate(configs): + flattened_batched_outputs.append((output_batch[:, batch_index], config)) + + assert len(sequential_outputs) == len(flattened_batched_outputs) + for batched_output, batched_config in flattened_batched_outputs: + sequential_output = _find_seq_output(batched_config, sequential_outputs) + assert isinstance(sequential_output, Tensor) + assert torch.allclose(sequential_output, batched_output) + + +def test__cache_preprocessing__ensemble_batching_groups_alternating_configs() -> None: + rng = default_rng(seed=0) + n_train = 100 + n_features = 4 + n_classes = 3 + X_train = rng.standard_normal(size=(n_train, n_features)) + y_train = rng.integers(low=0, high=n_classes - 1, size=(n_train, 1)) + X_test = rng.standard_normal(size=(2, n_features)) + + sequential_models = [_TestModel(), _TestModel()] + batched_models = [_TestModel(), _TestModel()] + base_config = _create_test_ensemble_configs( + n_configs=6, + n_classes=n_classes, + num_models=1, + )[0] + ensemble_configs = [deepcopy(base_config) for _ in range(6)] + for index, config in enumerate(ensemble_configs): + config._model_index = index % 2 + config.feature_shift_count = index + + def _make_engine( + models: list[_TestModel], + *, + ensemble_batch_size: int | None, + ) -> InferenceEngineCachePreprocessing: + ensemble_preprocessor = TabPFNEnsemblePreprocessor( + configs=ensemble_configs, + n_samples=X_train.shape[0], + feature_schema=FeatureSchema.from_only_categorical_indices([], n_features), + random_state=default_rng(seed=0), + n_preprocessing_jobs=1, + ) + return InferenceEngineCachePreprocessing( + X_train, + y_train, + ensemble_preprocessor=ensemble_preprocessor, + models=models, + devices=[torch.device("cpu")], + dtype_byte_size=4, + force_inference_dtype=None, + save_peak_mem=True, + inference_mode=True, + ensemble_batch_size=ensemble_batch_size, + ) + + sequential_engine = _make_engine(sequential_models, ensemble_batch_size=None) + sequential_outputs = list( + sequential_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + batched_engine = _make_engine(batched_models, ensemble_batch_size=3) + batched_outputs = list( + batched_engine.iter_outputs(X_test, autocast=False, task_type="multiclass") + ) + + assert [model.seen_batch_sizes for model in sequential_models] == [ + [1, 1, 1], + [1, 1, 1], + ] + assert [model.seen_batch_sizes for model in batched_models] == [[3], [3]] + + flattened_batched_outputs: list[tuple[Tensor, EnsembleConfig]] = [] + for output, configs in batched_outputs: + assert isinstance(output, Tensor) + assert isinstance(configs, list) + output_batch = output.unsqueeze(1) if output.ndim == 2 else output + assert output_batch.shape[1] == len(configs) + for batch_index, config in enumerate(configs): + flattened_batched_outputs.append((output_batch[:, batch_index], config)) + + assert len(sequential_outputs) == len(flattened_batched_outputs) + for batched_output, batched_config in flattened_batched_outputs: + sequential_output = _find_seq_output(batched_config, sequential_outputs) + assert isinstance(sequential_output, Tensor) + assert torch.allclose(sequential_output, batched_output) + + def test__on_demand__result_equal_in_serial_and_in_parallel() -> None: rng = default_rng(seed=0) n_train = 100