[Pallas] Add non-DMA fori_loop fallback for DMA-unaligned inner blocks#1969
[Pallas] Add non-DMA fori_loop fallback for DMA-unaligned inner blocks#1969norx1991 merged 5 commits intopytorch:mainfrom
Conversation
bbb0ff2 to
540944e
Compare
| if self.has_pallas_symbolic_bounds: | ||
| choices = tuple(c for c in choices if c != "default") | ||
| config.setdefault("pallas_loop_type", choices[0]) | ||
| # "default" uses Python range() which can't handle traced bounds; |
There was a problem hiding this comment.
The original logic is already excluding "default" if there are symbolic bounds. Do I understand correctly that the motivation here is to exclude "emit_pipeline" as well? If so consider updating the comments to that explicit
There was a problem hiding this comment.
Thanks for having a look, the comment was unclear! I updated it.
Here the goal is simply to update the default choice.
- Before: exclude "default" from all choices, then set the first element of the array -> emit_pipeline
- After: explicitly set "fori_loop" as default
We don't remove emit_pipeline from the potential configs (the relevant code for that is below, where we simply re-order them).
PTAL
There was a problem hiding this comment.
For my own understanding, with this PR, does the emit_pipeline strategy work with traced loop bounds, or is it only fori_loop that works?
There was a problem hiding this comment.
Here we only unlock fori_loop when the shapes are not friendly to pipelining (i.e. last dim 64 in the example)
There was a problem hiding this comment.
Here we only unlock fori_loop
If we're compiling a kernel where only fori_loop works and emit_pipeline doesn't, then I think we should explicitly exclude emit_pipeline from the config choices, instead of merely setting fori_loop as the default, no?
There was a problem hiding this comment.
I realize the title of the PR is outdated and might be confusing!
I should clarify: has_pallas_symbolic_bounds and "inner block is DMA-unaligned" are orthogonal conditions.
- emit_pipeline fails specifically on unaligned inner blocks (last dim not divisible by 128, etc.), regardless of whether bounds are symbolic.
- fori_loop now handles both the aligned (DMA) and unaligned (no-DMA pl.ds() slicing) paths.
So for a kernel with symbolic bounds and DMA-aligned inner blocks, emit_pipeline still works fine. The autotuner will naturally discard emit_pipeline at benchmark time if it fails for a given kernel. That said, you're right that it's a wasted effort in the current scenario.
As a follow up, we could add a flag upon loop registration has_pallas_dma_unaligned that allows us to discard emit_pipeline from the choices?
There was a problem hiding this comment.
Yes I guess we can follow what if self.has_pallas_symbolic_bounds does. It seems to be the most natural way.
There was a problem hiding this comment.
Done, added a TODO at the config choices site referencing this discussion. The clean fix is a has_pallas_dma_unaligned flag set at loop registration (like has_pallas_symbolic_bounds), but that needs plumbing VMEM shape info earlier in the pipeline since the alignment check currently happens at codegen time in _check_dma_alignment(). Will do as follow-up.
7c88625 to
1e92c95
Compare
| """State for fori_loop-based loops on TPU (Pallas backend). | ||
|
|
||
| Uses jax.lax.fori_loop with pltpu.make_async_copy for manual DMA control. | ||
| When ``use_dma=False``, skips DMA and accesses HBM refs directly via |
There was a problem hiding this comment.
For my education, even for the case where we can use make_async_copy, will this pl.ds path be potentially faster? I wonder if this should be a tunable config.
There was a problem hiding this comment.
That's a good point! I was envisioning this as a fallback so that we can emit something, but it could eventually be faster in some cases. So it could be an autotuner dimension, but perhaps as a follow up? And with supporting benchmarks?
There was a problem hiding this comment.
Yes, let's treat it as a follow-up. As long as we find a case where it is faster, we can make it tunable.
BTW, the fori loop is currently doing sync copy, so we need to fix that before we can really compare.
|
|
||
| DMA requires last dim % 128 == 0 and second-to-last dim % 8 == 0 | ||
| for 2D+ tensors. Unlike outer BlockSpecs, emit_pipeline/fori_loop | ||
| inner DMA does NOT have a ``block == tensor_dim`` exception. |
There was a problem hiding this comment.
How about the 128 * (32 / bitwidth(dtype)) part? Does it depend on data type? Also I wonder if any documentation is available that we can link here.
There was a problem hiding this comment.
AFAIK: Our proposed check (last_dim %128, 2nd-to-last %8) is correct for bf16 sublanes, overly conservative for f32 sublanes (which has no constraint), and too lenient for 1D (should be %1024). No public docs exist yet for DMA constraints, as I believe it's very much under development, and the rules are quite dynamic...
Note these differ from BlockSpec constraints where 1D is dtype-dependent: 128*(32/bitwidth).
So as a follow up, we'd need to plumb the dtype and refine + update the rule as libtpu improves.
There was a problem hiding this comment.
Sounds good. I was hoping these rules can be clear from the beginning... Can you also add a dtype-related comment just like this "Unlike outer BlockSpecs, ..."
957ddaf to
3904913
Compare
3904913 to
1f95162
Compare
1f95162 to
f37dddf
Compare
| # When using DMA: register VMEM scratch buffers + DMA semaphores | ||
| tensor_to_vmem: dict[str, str] = {} | ||
| tensor_to_sem: dict[str, str] = {} | ||
| if use_dma: |
There was a problem hiding this comment.
This if should be merged with the above?
|
Let's update the title of this PR? Also, is there already a test exercising the new pl.ds path? |
9dc754b to
e81370b
Compare
e81370b to
e862feb
Compare
Unfortunately, not yet (I edited the PR description to make it clearer).
Both done, PTAL. |
|
Let's update this branch and I will merge it. |
# Conflicts: # helion/language/memory_ops.py # test/test_pallas.py
Done! Can you please retry the CI? I think it's transient? or infra issue? |
When an inner hl.tile block shape violates TPU DMA alignment (last dim % 128 != 0 or second-to-last % 8 != 0), the existing fori_loop codegen fails because pltpu.make_async_copy rejects unaligned refs. Matmul split-K with N=64 hits exactly this case, so today no loop type works:
Changes:
Together with #1966 , this makes examples/matmul_split_k.py work end-to-end on TPU with autotuning (only for split_k==1)