Skip to content

Batch predictions when test set is large #125

@LeoGrin

Description

@LeoGrin

Copying a discussion from discord:

David Holzmüller
For me, on some (larger?) datasets (but also <= 10K train samples), I get the attached error (TabPFNClassifier), tested on torch 2.4 and torch 2.5.
 File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/classifier.py", line 533, in predict_proba
    for output, config in self.executor_.iter_outputs(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/inference.py", line 192, in iter_outputs
    output = self.model(
             ^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 413, in forward
    return self._forward(x, y, style=style, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 625, in _forward
    encoder_out = self.transformer_encoder(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 74, in forward
    x = layer(x, **kwargs)
        ^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 449, in forward
    state = sublayer(state)
            ^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 334, in attn_between_features
    return self.self_attn_between_features(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 355, in forward
    output: torch.Tensor = self._compute(
                           ^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/memory.py", line 100, in method_
    return x + method(self, x, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 504, in _compute
    attention_head_outputs = MultiHeadAttention.compute_attention_heads(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 710, in compute_attention_heads
    attention_head_outputs = torch.nn.functional.scaled_dot_product_attention(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid configuration argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Jingang
Hi David, the CUDA error "invalid configuration argument" may occur if batch size is too large with flash attention and memory-efficient SDPA enabled.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

David Holzmüller
I was hoping that TabPFN's interface would automatically batch the predict() step if there are too many samples, but apparently that is not the case

Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions