Skip to content

[Pallas] Exclude output-only tensors from Pallas pallas_call inputs to improve performance#1849

Merged
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-vmem-fix
Apr 21, 2026
Merged

[Pallas] Exclude output-only tensors from Pallas pallas_call inputs to improve performance#1849
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-vmem-fix

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Mar 27, 2026

Summary

When a Helion kernel allocates an output tensor that is only written to (never read), e.g.:

out = torch.empty_like(x)
for tile in hl.tile(x.size()):
    out[tile] = torch.exp(x[tile])
return out

the Pallas backend previously passed out as both a pallas_call input (for donation via input_output_aliases) and output. This triggered OpSplitMode::kSplitBoth in torch_tpu (pallas_py.cc:110), which splits the XLA graph before and after the custom kernel. This forced the torch.empty_like() allocation — which would normally be elided or fused by XLA — to materialize as a separate device op (empty.1 broadcast), adding ~127us overhead per kernel call at large tensor sizes.

Root cause investigation

We investigated the empty.1 overhead by isolating each factor:

  1. kSplitBoth alone doesn't insert empty.1: Passing a pre-existing tensor with input_output_aliases (without allocating a new one per call) shows no empty.1 in xprof traces — kSplitBoth just splits the graph.

  2. torch.empty() alone doesn't create a device op: Without aliasing, torch.empty() is either fused away or never materialized on device — no empty.1 appears.

  3. torch.empty() + kSplitBoth together create empty.1: When torch.empty() is called per iteration AND input_output_aliases is non-empty, kSplitBoth forces the graph to materialize before the kernel, which forces the empty() to execute as a separate empty.1 broadcast op on the device.

The overhead scales linearly with tensor size: ~3us at 4MB, ~127us at 400MB.

Fix

This PR excludes output-only tensors from pallas_call inputs entirely. Since output-only tensors are never read by the kernel, they don't need to be donated — pallas_call returns new buffers for them, and set_() swaps the pre-allocated tensor's storage with the result (zero-copy).

With no output-only tensors in inputs, input_output_aliases becomes empty, kSplitBoth is not triggered, and the torch.empty_like() allocation is never forced to materialize on device.

Builds on #1984 which already eliminated VMEM pressure for output-only tensors via HBM in_specs. This PR goes further by removing them from inputs entirely.

Benchmark Results

exp_fwd kernel on TPU v7, N=104,857,600. See #1773 for benchmark methodology.

Implementation Wall (ms) Throughput (ms) Device (ms, xprof)
torch.exp 0.508 0.281 0.259
helion exp_fwd (g=50) 0.507 0.293 0.260
jax.numpy.exp 0.445 0.274 0.263
pallas exp jax (g=50) 0.453 0.271 0.260
pallas exp torch_tpu (g=50) 0.523 0.280 0.260

Device time for helion exp_fwd is 0.260 ms — no empty.1 broadcast op. Matches native torch.exp (0.259 ms) and pure JAX Pallas (0.260 ms).

Benchmark script: https://gist.github.com/norx1991/091794b532b2a86650befc449f868c68

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch 3 times, most recently from d6cc98c to f8a28ad Compare March 27, 2026 23:14
@norx1991 norx1991 changed the title Exclude output-only tensors from Pallas pallas_call inputs to save VMEM [Pallas] Exclude output-only tensors from Pallas pallas_call inputs to save VMEM and improve performance Mar 27, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch from f8a28ad to 52412bb Compare April 2, 2026 20:58
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing this! Makes sense to me. Do we have a larger set of benchmark we can run to make sure this does not introduce regressions / or whether it improves other examples?

Comment thread helion/_compiler/backend.py Outdated

from .ast_read_writes import ReadWrites

def _empty_allocated_vars(body: list[_ast.stmt]) -> set[str]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fme: will this catch tuples made of empty tensors? Should we cover it? (There is no correctness issue afaict, simply wondering about a potential missed opportunity)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This is a good point. I think we should cover it. Let me try finalizing this PR and see if it is easy to add. Otherwise, it can be a follow-up.

@norx1991
Copy link
Copy Markdown
Contributor Author

norx1991 commented Apr 7, 2026

Thanks for sharing this! Makes sense to me. Do we have a larger set of benchmark we can run to make sure this does not introduce regressions / or whether it improves other examples?

@thcmbs Thanks for taking a look! Yeah that's what I am trying to get in #1913. We currently have 9 examples that can run through autotuning. Let me patch this and see if this provides meaningful improvement for other examples.

@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch 2 times, most recently from f42a4f1 to 5b53b6c Compare April 7, 2026 23:04
@norx1991 norx1991 marked this pull request as ready for review April 7, 2026 23:09
@norx1991
Copy link
Copy Markdown
Contributor Author

norx1991 commented Apr 8, 2026

Did some initial testing by running autotune:in #1913 .

Kernel Without vmem-fix With vmem-fix Delta
exp 0.62x 0.78x +26%
add 0.62x 0.79x +27%
softmax_two_pass 0.62x 0.79x +27%
welford 0.71x 0.60x -15%
attention 0.82x 0.80x -2%
bmm 0.66x 0.74x +12%
geglu 0.17x 0.16x ~same
low_mem_dropout 0.60x 0.61x ~same
swiglu 0.16x 0.17x ~same

Comment thread helion/_compiler/backend.py Outdated
Comment thread helion/runtime/__init__.py Outdated
registry["cpu"] = lambda: _get_tpu_info_impl(ChipVersion.TPU_7X, 1)


def _pallas_invoke_and_copy_back(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that even though we're not duplicating VMEMs, we end-up creating a new HMB copy of the result, which we copy back into our original tensor?

For a helion kernel like

@helion.kernel
def fn(input):
    output = empty_like(input)
    for t in hl.tile(N):
        output[t] = f(input[t])
    return output

IIUC, Prior to this PR, we were generating something like this

def _helion_fn(input, output):
    output[:] = f(input[:])

def reordered_kernel(input, output, output_alias):
    output_alias[...] = output[...]
    _helion_fn(input, output_alias)

def fn(input):
   output = empty_like(input)
   pallas_call(reordered_kernel, input, output, input_output_aliases={1:2} )
   return output

And with this PR, we generate something like

def _helion_fn(input, output):
    output[:] = f(input[:])

def reordered_kernel(input, output):
    _helion_fn(input, output)

def fn(input):
   output = empty_like(input)
   tmp_result = pallas_call(reordered_kernel, input,)
   output = tmp_result # copy_back
   return output

But given that you've already analyzed which are the output-only tensors, perhaps we can avoid the additional HBM copy as well, and just do

def _helion_fn(input, output):
    output[:] = f(input[:])

def reordered_kernel(input, output):
    _helion_fn(input, output)

def fn(input):
   output = pallas_call(reordered_kernel, input,)
   return output

?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good point! For now the launcher has to copy the pallas_call result back into it. Eliminating this would require non-trivial codegen change. Let's think about it as a follow-up optimization.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@norx1991 One idea to try, should be a simple change: after identifying the tensors which are "output only", lets still donate them to the kernel via input_output_aliases, but, for the input argument, lets change its memory space to HBM, and lets skip the output_alias[...] = output[...] step, here. If this works, this should allow us to duplicate neither VMEM nor HBM. Would you mind giving this a try?

Copy link
Copy Markdown
Contributor Author

@norx1991 norx1991 Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. I can give it a try. Some clarification is probably needed: since "torch_tpu splits the XLA graph before AND after the custom kernel when input_output_aliases is non-empty, inserting an empty.1 broadcast op that adds overhead", if we still use input_output_aliases, it defeats the purpose for device time improvement?

Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the clarification. In that case perhaps we should take a two-step approach:

  1. this PR: for output-only indices, donate the tensor, but mark it as HBM to save VMEM
  2. follow-up: avoid creating the output tensor manually, instead take the return value from JAX and convert that to a tensor (as discussed above).

Would that be a reasonable plan?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Step 1 seems like a good starting point. Created #1984 for this. Will investigate on Step 2.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AmesingFlank ,

I quickly tried the approach in Step 2 in #1998 (with the help of Claude Code of course). It is still a draft but we can get some idea. The change involves many places and we need to deal with mock launcher as well.

I later found that torch.Tensor.set_ can also avoid the HBM copy, and the change is simpler, so I think taking the approach in this PR and use torch.Tensor.set_ will be a simpler path forward. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to use set_ and rebased.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AmesingFlank , this PR and #2022 are ready for review. Thanks!

@norx1991 norx1991 requested a review from AmesingFlank April 8, 2026 06:32
@norx1991 norx1991 marked this pull request as draft April 8, 2026 17:55
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch 4 times, most recently from 43ff6c4 to 2e5e2ed Compare April 10, 2026 21:09
@norx1991 norx1991 changed the title [Pallas] Exclude output-only tensors from Pallas pallas_call inputs to save VMEM and improve performance [Pallas] Exclude output-only tensors from Pallas pallas_call inputs to improve performance Apr 10, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch 5 times, most recently from bc8ed3d to c945401 Compare April 10, 2026 21:33
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch from c945401 to e8f8863 Compare April 10, 2026 21:36
@norx1991 norx1991 marked this pull request as ready for review April 10, 2026 21:40
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch from e8f8863 to 557db9d Compare April 13, 2026 17:23
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch from 557db9d to aa6ec67 Compare April 16, 2026 17:32
Output-only tensors (allocated with empty/empty_like/new_empty and
never read by the kernel) are excluded from pallas_call inputs.
Uses set_() to swap the pre-allocated tensor's storage with the
pallas_call result (zero-copy).

This eliminates the OpSplitMode::kSplitBoth graph split in torch_tpu
(empty input_output_aliases for output-only kernels), removing the
~127 us empty.1 broadcast overhead.
@norx1991 norx1991 force-pushed the yifeixu/pallas-vmem-fix branch from aa6ec67 to 18f1e06 Compare April 20, 2026 17:57
@norx1991 norx1991 merged commit e1eaf5c into main Apr 21, 2026
22 checks passed
@norx1991 norx1991 deleted the yifeixu/pallas-vmem-fix branch April 21, 2026 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants