Conversation
d6cc98c to
f8a28ad
Compare
f8a28ad to
52412bb
Compare
thcmbs
left a comment
There was a problem hiding this comment.
Thanks for sharing this! Makes sense to me. Do we have a larger set of benchmark we can run to make sure this does not introduce regressions / or whether it improves other examples?
|
|
||
| from .ast_read_writes import ReadWrites | ||
|
|
||
| def _empty_allocated_vars(body: list[_ast.stmt]) -> set[str]: |
There was a problem hiding this comment.
Fme: will this catch tuples made of empty tensors? Should we cover it? (There is no correctness issue afaict, simply wondering about a potential missed opportunity)
There was a problem hiding this comment.
Thank you! This is a good point. I think we should cover it. Let me try finalizing this PR and see if it is easy to add. Otherwise, it can be a follow-up.
@thcmbs Thanks for taking a look! Yeah that's what I am trying to get in #1913. We currently have 9 examples that can run through autotuning. Let me patch this and see if this provides meaningful improvement for other examples. |
f42a4f1 to
5b53b6c
Compare
|
Did some initial testing by running autotune:in #1913 .
|
| registry["cpu"] = lambda: _get_tpu_info_impl(ChipVersion.TPU_7X, 1) | ||
|
|
||
|
|
||
| def _pallas_invoke_and_copy_back( |
There was a problem hiding this comment.
Does this mean that even though we're not duplicating VMEMs, we end-up creating a new HMB copy of the result, which we copy back into our original tensor?
For a helion kernel like
@helion.kernel
def fn(input):
output = empty_like(input)
for t in hl.tile(N):
output[t] = f(input[t])
return outputIIUC, Prior to this PR, we were generating something like this
def _helion_fn(input, output):
output[:] = f(input[:])
def reordered_kernel(input, output, output_alias):
output_alias[...] = output[...]
_helion_fn(input, output_alias)
def fn(input):
output = empty_like(input)
pallas_call(reordered_kernel, input, output, input_output_aliases={1:2} )
return outputAnd with this PR, we generate something like
def _helion_fn(input, output):
output[:] = f(input[:])
def reordered_kernel(input, output):
_helion_fn(input, output)
def fn(input):
output = empty_like(input)
tmp_result = pallas_call(reordered_kernel, input,)
output = tmp_result # copy_back
return outputBut given that you've already analyzed which are the output-only tensors, perhaps we can avoid the additional HBM copy as well, and just do
def _helion_fn(input, output):
output[:] = f(input[:])
def reordered_kernel(input, output):
_helion_fn(input, output)
def fn(input):
output = pallas_call(reordered_kernel, input,)
return output?
There was a problem hiding this comment.
This is a very good point! For now the launcher has to copy the pallas_call result back into it. Eliminating this would require non-trivial codegen change. Let's think about it as a follow-up optimization.
There was a problem hiding this comment.
@norx1991 One idea to try, should be a simple change: after identifying the tensors which are "output only", lets still donate them to the kernel via input_output_aliases, but, for the input argument, lets change its memory space to HBM, and lets skip the output_alias[...] = output[...] step, here. If this works, this should allow us to duplicate neither VMEM nor HBM. Would you mind giving this a try?
There was a problem hiding this comment.
Sure thing. I can give it a try. Some clarification is probably needed: since "torch_tpu splits the XLA graph before AND after the custom kernel when input_output_aliases is non-empty, inserting an empty.1 broadcast op that adds overhead", if we still use input_output_aliases, it defeats the purpose for device time improvement?
There was a problem hiding this comment.
I see, thanks for the clarification. In that case perhaps we should take a two-step approach:
- this PR: for output-only indices, donate the tensor, but mark it as HBM to save VMEM
- follow-up: avoid creating the output tensor manually, instead take the return value from JAX and convert that to a tensor (as discussed above).
Would that be a reasonable plan?
There was a problem hiding this comment.
Thank you! Step 1 seems like a good starting point. Created #1984 for this. Will investigate on Step 2.
There was a problem hiding this comment.
Hi @AmesingFlank ,
I quickly tried the approach in Step 2 in #1998 (with the help of Claude Code of course). It is still a draft but we can get some idea. The change involves many places and we need to deal with mock launcher as well.
I later found that torch.Tensor.set_ can also avoid the HBM copy, and the change is simpler, so I think taking the approach in this PR and use torch.Tensor.set_ will be a simpler path forward. WDYT?
There was a problem hiding this comment.
Switched to use set_ and rebased.
There was a problem hiding this comment.
Hi @AmesingFlank , this PR and #2022 are ready for review. Thanks!
43ff6c4 to
2e5e2ed
Compare
bc8ed3d to
c945401
Compare
c945401 to
e8f8863
Compare
e8f8863 to
557db9d
Compare
557db9d to
aa6ec67
Compare
Output-only tensors (allocated with empty/empty_like/new_empty and never read by the kernel) are excluded from pallas_call inputs. Uses set_() to swap the pre-allocated tensor's storage with the pallas_call result (zero-copy). This eliminates the OpSplitMode::kSplitBoth graph split in torch_tpu (empty input_output_aliases for output-only kernels), removing the ~127 us empty.1 broadcast overhead.
aa6ec67 to
18f1e06
Compare
Summary
When a Helion kernel allocates an output tensor that is only written to (never read), e.g.:
the Pallas backend previously passed
outas both a pallas_call input (for donation viainput_output_aliases) and output. This triggeredOpSplitMode::kSplitBothin torch_tpu (pallas_py.cc:110), which splits the XLA graph before and after the custom kernel. This forced thetorch.empty_like()allocation — which would normally be elided or fused by XLA — to materialize as a separate device op (empty.1broadcast), adding ~127us overhead per kernel call at large tensor sizes.Root cause investigation
We investigated the
empty.1overhead by isolating each factor:kSplitBothalone doesn't insertempty.1: Passing a pre-existing tensor withinput_output_aliases(without allocating a new one per call) shows noempty.1in xprof traces —kSplitBothjust splits the graph.torch.empty()alone doesn't create a device op: Without aliasing,torch.empty()is either fused away or never materialized on device — noempty.1appears.torch.empty()+kSplitBothtogether createempty.1: Whentorch.empty()is called per iteration ANDinput_output_aliasesis non-empty,kSplitBothforces the graph to materialize before the kernel, which forces theempty()to execute as a separateempty.1broadcast op on the device.The overhead scales linearly with tensor size: ~3us at 4MB, ~127us at 400MB.
Fix
This PR excludes output-only tensors from pallas_call inputs entirely. Since output-only tensors are never read by the kernel, they don't need to be donated — pallas_call returns new buffers for them, and
set_()swaps the pre-allocated tensor's storage with the result (zero-copy).With no output-only tensors in inputs,
input_output_aliasesbecomes empty,kSplitBothis not triggered, and thetorch.empty_like()allocation is never forced to materialize on device.Builds on #1984 which already eliminated VMEM pressure for output-only tensors via HBM
in_specs. This PR goes further by removing them from inputs entirely.Benchmark Results
exp_fwdkernel on TPU v7, N=104,857,600. See #1773 for benchmark methodology.Device time for helion
exp_fwdis 0.260 ms — noempty.1broadcast op. Matches nativetorch.exp(0.259 ms) and pure JAX Pallas (0.260 ms).Benchmark script: https://gist.github.com/norx1991/091794b532b2a86650befc449f868c68