Skip to content

[Pallas] Add non-DMA fori_loop fallback for DMA-unaligned inner blocks#1969

Merged
norx1991 merged 5 commits intopytorch:mainfrom
thcmbs:pallas-matmul-split-k-2
Apr 17, 2026
Merged

[Pallas] Add non-DMA fori_loop fallback for DMA-unaligned inner blocks#1969
norx1991 merged 5 commits intopytorch:mainfrom
thcmbs:pallas-matmul-split-k-2

Conversation

@thcmbs
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs commented Apr 7, 2026

When an inner hl.tile block shape violates TPU DMA alignment (last dim % 128 != 0 or second-to-last % 8 != 0), the existing fori_loop codegen fails because pltpu.make_async_copy rejects unaligned refs. Matmul split-K with N=64 hits exactly this case, so today no loop type works:

  • default: can't handle symbolic (traced) loop bounds produced by nested hl.tile.
  • emit_pipeline / fori_loop (DMA): reject unaligned inner blocks.

Changes:

  • fori_loop non-DMA path: when _check_dma_alignment fails, skip pltpu.make_async_copy and instead emit offset assignments that index the outer BlockSpec refs directly via pl.ds(begin + j*step, size).
  • Autotuner routing: for kernels with symbolic bounds, set fori_loop as the default choice (small inner dims typically can't satisfy DMA alignment, so emit_pipeline usually loses at benchmark time).
  • memory_ops: teach load/store lowering to recognise ForiLoopState with use_dma=False and route through the pl.ds() path.

Together with #1966 , this makes examples/matmul_split_k.py work end-to-end on TPU with autotuning (only for split_k==1)

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 7, 2026
@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch 2 times, most recently from bbb0ff2 to 540944e Compare April 8, 2026 07:07
@thcmbs thcmbs marked this pull request as ready for review April 8, 2026 15:09
Comment thread helion/autotuner/config_spec.py Outdated
if self.has_pallas_symbolic_bounds:
choices = tuple(c for c in choices if c != "default")
config.setdefault("pallas_loop_type", choices[0])
# "default" uses Python range() which can't handle traced bounds;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original logic is already excluding "default" if there are symbolic bounds. Do I understand correctly that the motivation here is to exclude "emit_pipeline" as well? If so consider updating the comments to that explicit

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for having a look, the comment was unclear! I updated it.

Here the goal is simply to update the default choice.

  • Before: exclude "default" from all choices, then set the first element of the array -> emit_pipeline
  • After: explicitly set "fori_loop" as default

We don't remove emit_pipeline from the potential configs (the relevant code for that is below, where we simply re-order them).

PTAL

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, with this PR, does the emit_pipeline strategy work with traced loop bounds, or is it only fori_loop that works?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we only unlock fori_loop when the shapes are not friendly to pipelining (i.e. last dim 64 in the example)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we only unlock fori_loop

If we're compiling a kernel where only fori_loop works and emit_pipeline doesn't, then I think we should explicitly exclude emit_pipeline from the config choices, instead of merely setting fori_loop as the default, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the title of the PR is outdated and might be confusing!

I should clarify: has_pallas_symbolic_bounds and "inner block is DMA-unaligned" are orthogonal conditions.

  • emit_pipeline fails specifically on unaligned inner blocks (last dim not divisible by 128, etc.), regardless of whether bounds are symbolic.
  • fori_loop now handles both the aligned (DMA) and unaligned (no-DMA pl.ds() slicing) paths.

So for a kernel with symbolic bounds and DMA-aligned inner blocks, emit_pipeline still works fine. The autotuner will naturally discard emit_pipeline at benchmark time if it fails for a given kernel. That said, you're right that it's a wasted effort in the current scenario.

As a follow up, we could add a flag upon loop registration has_pallas_dma_unaligned that allows us to discard emit_pipeline from the choices?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I guess we can follow what if self.has_pallas_symbolic_bounds does. It seems to be the most natural way.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a TODO at the config choices site referencing this discussion. The clean fix is a has_pallas_dma_unaligned flag set at loop registration (like has_pallas_symbolic_bounds), but that needs plumbing VMEM shape info earlier in the pipeline since the alignment check currently happens at codegen time in _check_dma_alignment(). Will do as follow-up.

@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch 3 times, most recently from 7c88625 to 1e92c95 Compare April 10, 2026 07:41
@thcmbs thcmbs requested review from AmesingFlank and norx1991 April 10, 2026 14:27
"""State for fori_loop-based loops on TPU (Pallas backend).

Uses jax.lax.fori_loop with pltpu.make_async_copy for manual DMA control.
When ``use_dma=False``, skips DMA and accesses HBM refs directly via
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, even for the case where we can use make_async_copy, will this pl.ds path be potentially faster? I wonder if this should be a tunable config.

Copy link
Copy Markdown
Collaborator Author

@thcmbs thcmbs Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point! I was envisioning this as a fallback so that we can emit something, but it could eventually be faster in some cases. So it could be an autotuner dimension, but perhaps as a follow up? And with supporting benchmarks?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's treat it as a follow-up. As long as we find a case where it is faster, we can make it tunable.

BTW, the fori loop is currently doing sync copy, so we need to fix that before we can really compare.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

Comment thread helion/language/_tracing_ops.py Outdated

DMA requires last dim % 128 == 0 and second-to-last dim % 8 == 0
for 2D+ tensors. Unlike outer BlockSpecs, emit_pipeline/fori_loop
inner DMA does NOT have a ``block == tensor_dim`` exception.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the 128 * (32 / bitwidth(dtype)) part? Does it depend on data type? Also I wonder if any documentation is available that we can link here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK: Our proposed check (last_dim %128, 2nd-to-last %8) is correct for bf16 sublanes, overly conservative for f32 sublanes (which has no constraint), and too lenient for 1D (should be %1024). No public docs exist yet for DMA constraints, as I believe it's very much under development, and the rules are quite dynamic...

Note these differ from BlockSpec constraints where 1D is dtype-dependent: 128*(32/bitwidth).

So as a follow up, we'd need to plumb the dtype and refine + update the rule as libtpu improves.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I was hoping these rules can be clear from the beginning... Can you also add a dtype-related comment just like this "Unlike outer BlockSpecs, ..."

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch from 957ddaf to 3904913 Compare April 13, 2026 08:00
@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch from 3904913 to 1f95162 Compare April 13, 2026 13:24
@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch from 1f95162 to f37dddf Compare April 13, 2026 13:28
@thcmbs thcmbs requested a review from norx1991 April 13, 2026 17:33
@AmesingFlank
Copy link
Copy Markdown
Contributor

Together with #1966 , this makes examples/matmul_split_k.py work end-to-end on TPU with autotuning.

@thcmbs Out of curiosity, does this PR allow us to enable test_matmul_split_k in test/test_examples.py?

Comment thread helion/language/_tracing_ops.py Outdated
# When using DMA: register VMEM scratch buffers + DMA semaphores
tensor_to_vmem: dict[str, str] = {}
tensor_to_sem: dict[str, str] = {}
if use_dma:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if should be merged with the above?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, done.

@norx1991
Copy link
Copy Markdown
Contributor

Let's update the title of this PR? Also, is there already a test exercising the new pl.ds path?

@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch from 9dc754b to e81370b Compare April 15, 2026 03:08
@thcmbs thcmbs changed the title [Pallas] Fix traced loop bounds and autotuner routing for nested tiling [Pallas] Add non-DMA fori_loop fallback for DMA-unaligned inner blocks Apr 15, 2026
@thcmbs thcmbs force-pushed the pallas-matmul-split-k-2 branch from e81370b to e862feb Compare April 15, 2026 03:23
@thcmbs
Copy link
Copy Markdown
Collaborator Author

thcmbs commented Apr 15, 2026

Together with #1966 , this makes examples/matmul_split_k.py work end-to-end on TPU with autotuning.

@thcmbs Out of curiosity, does this PR allow us to enable test_matmul_split_k in test/test_examples.py?

Unfortunately, not yet (I edited the PR description to make it clearer). examples/matmul_split_k.py can now run because it defaults to split_k=1 and allows split_k==1 in autotuning.
In test_example.py, we force split_k=8 which require another PR (WIP) to work correctly! (about atomic add)

Let's update the title of this PR? Also, is there already a test exercising the new pl.ds path?

Both done, PTAL.

@thcmbs thcmbs requested a review from norx1991 April 15, 2026 14:22
@norx1991
Copy link
Copy Markdown
Contributor

Let's update this branch and I will merge it.

# Conflicts:
#	helion/language/memory_ops.py
#	test/test_pallas.py
@thcmbs
Copy link
Copy Markdown
Collaborator Author

thcmbs commented Apr 17, 2026

Let's update this branch and I will merge it.

Done! Can you please retry the CI? I think it's transient? or infra issue?

@norx1991 norx1991 merged commit b5f6f60 into pytorch:main Apr 17, 2026
37 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants