Skip to content

[Pallas] Fix TPU min_dot_size for matmul autotuning#1999

Merged
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-tpu-min-dot-size
Apr 14, 2026
Merged

[Pallas] Fix TPU min_dot_size for matmul autotuning#1999
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-tpu-min-dot-size

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 9, 2026

Summary

  • Re-land [pallas-tpu] fix default configs for TPU examples #1731 (reverted in Revert "[pallas-tpu] fix default configs for TPU examples (#1731)" #1740 due to test_loops.py crash, which no longer reproduces on the current codebase)
  • Set TPU _min_dot_size to (8, 128, 128) matching the Mosaic MXU tile dimensions (M>=8, K>=128, N>=128)
  • Restructure device-type checks in _min_dot_size to be explicit (tpu, xpu, cuda) instead of fallback-based
  • Add test_matmul_default to exercise autotuned matmul without hardcoded block_sizes
  • Clamp min_dot_size to tensor dimensions on Pallas backend to prevent blocks exceeding tensor size
  • Add test_matmul_smaller_than_min_dot_size to verify matmul works when dimensions are smaller than min_dot_size
  • Add direct regression test verifying config_spec min_size constraints are [8, 128, 128] on TPU
  • Remove xfail for test_reshape_input_types on Pallas (unblocked by the clamping fix)

Why correct min_dot_size matters for matmul autotuning performance

_min_dot_size feeds into default_config(), which generates the initial configuration that the autotuner uses as its starting point. It also constrains the autotuner search space via update_min_block() in matmul_ops.py. With the wrong default (16, 16, 16), the autotuner explores configs with K=16 or N=16 tiles that get zero-padded up to 128 before pl.dot, wasting autotuning time on suboptimal candidates.

With (8, 128, 128), the search space is pruned upfront so every trial uses MXU-compatible tiles, improving both autotuning speed and the quality of the resulting configs.

Clamping for small tensors

When a tensor dimension is smaller than min_dot_size (e.g., a 16x16 matmul with min N=128), update_min_block() would force the block size to 128 — larger than the tensor. Triton handles this via automatic out-of-bounds masking, but Pallas BlockSpecs can't. The fix clamps min_size to size_hint (the tensor dimension) on the Pallas backend, letting the dot-level padding in matmul_utils.py handle the MXU requirement at codegen time instead.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
@norx1991 norx1991 changed the title [Pallas] Fix default configs for TPU examples [Pallas] Fix TPU min_dot_size for matmul autotuning Apr 10, 2026
@norx1991 norx1991 marked this pull request as ready for review April 13, 2026 18:20
@norx1991 norx1991 force-pushed the yifeixu/pallas-tpu-min-dot-size branch from f614dce to adad399 Compare April 13, 2026 18:21
Re-land #1731 (reverted in #1740 due to test_loops.py crash, which
no longer reproduces on the current codebase).

- Set TPU _min_dot_size to (8, 128, 128) matching the Mosaic MXU tile
  dimensions (M>=8, K>=128, N>=128)
- Restructure device-type checks in _min_dot_size to be explicit
  (tpu, xpu, cuda) instead of fallback-based
- On Pallas, clamp min_dot_size to tensor dimensions via size_hint so
  blocks don't exceed the tensor size (Pallas BlockSpecs can't handle
  that, unlike Triton which masks out-of-bounds accesses)
- Add test_matmul_default to exercise autotuned matmul
- Add test_matmul_smaller_than_min_dot_size and direct regression test
  verifying config_spec min_size constraints are [8, 128, 128] on TPU
- Remove xfail for test_reshape_input_types on Pallas (unblocked by
  the clamping fix)
@norx1991 norx1991 force-pushed the yifeixu/pallas-tpu-min-dot-size branch from adad399 to d42909f Compare April 14, 2026 05:24
@norx1991 norx1991 merged commit 74ccae3 into main Apr 14, 2026
46 of 50 checks passed
@norx1991 norx1991 deleted the yifeixu/pallas-tpu-min-dot-size branch April 20, 2026 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants