[Pallas] Fix TPU min_dot_size for matmul autotuning#1999
Merged
Conversation
f614dce to
adad399
Compare
AmesingFlank
approved these changes
Apr 13, 2026
Re-land #1731 (reverted in #1740 due to test_loops.py crash, which no longer reproduces on the current codebase). - Set TPU _min_dot_size to (8, 128, 128) matching the Mosaic MXU tile dimensions (M>=8, K>=128, N>=128) - Restructure device-type checks in _min_dot_size to be explicit (tpu, xpu, cuda) instead of fallback-based - On Pallas, clamp min_dot_size to tensor dimensions via size_hint so blocks don't exceed the tensor size (Pallas BlockSpecs can't handle that, unlike Triton which masks out-of-bounds accesses) - Add test_matmul_default to exercise autotuned matmul - Add test_matmul_smaller_than_min_dot_size and direct regression test verifying config_spec min_size constraints are [8, 128, 128] on TPU - Remove xfail for test_reshape_input_types on Pallas (unblocked by the clamping fix)
adad399 to
d42909f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
test_loops.pycrash, which no longer reproduces on the current codebase)_min_dot_sizeto(8, 128, 128)matching the Mosaic MXU tile dimensions (M>=8, K>=128, N>=128)_min_dot_sizeto be explicit (tpu,xpu,cuda) instead of fallback-basedtest_matmul_defaultto exercise autotuned matmul without hardcodedblock_sizesmin_dot_sizeto tensor dimensions on Pallas backend to prevent blocks exceeding tensor sizetest_matmul_smaller_than_min_dot_sizeto verify matmul works when dimensions are smaller thanmin_dot_sizeconfig_specmin_size constraints are[8, 128, 128]on TPUxfailfortest_reshape_input_typeson Pallas (unblocked by the clamping fix)Why correct
min_dot_sizematters for matmul autotuning performance_min_dot_sizefeeds intodefault_config(), which generates the initial configuration that the autotuner uses as its starting point. It also constrains the autotuner search space viaupdate_min_block()inmatmul_ops.py. With the wrong default(16, 16, 16), the autotuner explores configs with K=16 or N=16 tiles that get zero-padded up to 128 beforepl.dot, wasting autotuning time on suboptimal candidates.With
(8, 128, 128), the search space is pruned upfront so every trial uses MXU-compatible tiles, improving both autotuning speed and the quality of the resulting configs.Clamping for small tensors
When a tensor dimension is smaller than
min_dot_size(e.g., a 16x16 matmul with min N=128),update_min_block()would force the block size to 128 — larger than the tensor. Triton handles this via automatic out-of-bounds masking, but Pallas BlockSpecs can't. The fix clampsmin_sizetosize_hint(the tensor dimension) on the Pallas backend, letting the dot-level padding inmatmul_utils.pyhandle the MXU requirement at codegen time instead.