David Holzmüller
For me, on some (larger?) datasets (but also <= 10K train samples), I get the attached error (TabPFNClassifier), tested on torch 2.4 and torch 2.5.
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/classifier.py", line 533, in predict_proba
for output, config in self.executor_.iter_outputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/inference.py", line 192, in iter_outputs
output = self.model(
^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 413, in forward
return self._forward(x, y, style=style, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 625, in _forward
encoder_out = self.transformer_encoder(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 74, in forward
x = layer(x, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 449, in forward
state = sublayer(state)
^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 334, in attn_between_features
return self.self_attn_between_features(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 355, in forward
output: torch.Tensor = self._compute(
^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/memory.py", line 100, in method_
return x + method(self, x, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 504, in _compute
attention_head_outputs = MultiHeadAttention.compute_attention_heads(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 710, in compute_attention_heads
attention_head_outputs = torch.nn.functional.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid configuration argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Jingang
Hi David, the CUDA error "invalid configuration argument" may occur if batch size is too large with flash attention and memory-efficient SDPA enabled.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
David Holzmüller
I was hoping that TabPFN's interface would automatically batch the predict() step if there are too many samples, but apparently that is not the case
Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.
Copying a discussion from discord:
Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.