From f46b650f3d4fcbbaf3e5c69f6f5d01fe41ee8b8f Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Tue, 14 Apr 2026 13:23:38 -0700 Subject: [PATCH 1/3] [Pallas] Exclude output-only tensors from pallas_call inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Output-only tensors (allocated with empty/empty_like/new_empty and never read by the kernel) are excluded from pallas_call inputs. No allocation is needed — the output tensor position in launcher args is replaced with a None placeholder, and shape/dtype info is passed via _output_only_shapes. The launcher returns new tensors from pallas_call, and the codegen captures the return value. This eliminates: - OpSplitMode::kSplitBoth graph split (empty input_output_aliases) - The ~127 us empty.1 broadcast overhead - Double HBM allocation (no pre-allocated tensor needed) --- helion/_compiler/backend.py | 27 +- helion/_compiler/device_function.py | 22 +- helion/_compiler/generate_ast.py | 20 ++ helion/runtime/__init__.py | 372 +++++++++++++++++++--------- 4 files changed, 321 insertions(+), 120 deletions(-) diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 6bba27140..b709bd48f 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1527,8 +1527,33 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: output_indices.append(i) inplace_indices.append(i) - launcher_args = [*args, f"_output_indices={output_indices}"] + # Identify output-only args: remove from positional args and emit + # shape/dtype literals so the runtime can build out_shape without + # needing the actual tensors. + output_only_set = set(output_indices) - set(inplace_indices) + output_only_names: list[str] = [] + output_only_shape_exprs: list[str] = [] + if sorted_args is not None: + for i in output_indices: + if i in output_only_set: + arg = sorted_args[i] + assert isinstance(arg, TensorArg) + output_only_names.append(arg.host_str()) + shape = tuple(arg.fake_value.shape) + dtype = arg.fake_value.dtype + output_only_shape_exprs.append(f"({shape}, {dtype})") + self._output_only_names = output_only_names + # Replace output-only tensors with None placeholders to keep + # indices aligned with _output_indices. + oo_name_set = set(output_only_names) + placeholder_args = ["None" if a in oo_name_set else a for a in args] + + launcher_args = [*placeholder_args, f"_output_indices={output_indices}"] launcher_args.append(f"_inplace_indices={inplace_indices}") + if output_only_shape_exprs: + launcher_args.append( + f"_output_only_shapes=[{', '.join(output_only_shape_exprs)}]" + ) if has_rng_ops: launcher_args.insert(-1, "_rng_seed_buffer") diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 827481454..65810d2d8 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -824,11 +824,25 @@ def codegen_function_call(self) -> ast.AST: has_barrier=env.has_barrier, sorted_args=arg_objects, ) - # TODO(jansel): we should run CSE this statement - call_statement = statement_from_string( - f"_launcher({self.name}, {{call_grid_expr}}, {', '.join(call_args)})", - call_grid_expr=call_grid_expr, + # Check if the backend wants to capture return values for output-only tensors. + output_only_names = getattr(backend, "_output_only_names", []) + launcher_call = ( + f"_launcher({self.name}, {{call_grid_expr}}, {', '.join(call_args)})" ) + if output_only_names: + if len(output_only_names) == 1: + assign_target = output_only_names[0] + else: + assign_target = ", ".join(output_only_names) + call_statement = statement_from_string( + f"{assign_target} = {launcher_call}", + call_grid_expr=call_grid_expr, + ) + else: + call_statement = statement_from_string( + launcher_call, + call_grid_expr=call_grid_expr, + ) assert isinstance(call_statement, ExtendedAST) # Mark the kernel call we can find it in codegen_precompile_def call_statement._is_kernel_call = True diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 8040342cd..61108c51e 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -721,6 +721,26 @@ def generate_ast( kernel_def = codegen.device_function.codegen_function_def() codegen.host_dead_code_elimination() + # Remove dead allocations for output-only tensors. + # These are assignments like `out = torch.empty_like(x)` where + # `out` is overwritten by the launcher return before any read. + output_only_names = getattr( + CompileEnvironment.current().backend, "_output_only_names", [] + ) + if output_only_names: + oo_set = set(output_only_names) + codegen.host_statements = [ + stmt + for stmt in codegen.host_statements + if not ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and stmt.targets[0].id in oo_set + and not getattr(stmt, "_is_kernel_call", False) + ) + ] + # Inject RNG seed buffer creation if needed rng_statements = ( codegen.get_rng_seed_buffer_statements() diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 4c8025126..c26067999 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -2,6 +2,7 @@ from contextlib import suppress import contextvars +import hashlib import linecache import sys from typing import Any @@ -239,6 +240,14 @@ def index_map( ] +class _ShapeProxy: + """Lightweight proxy with .ndim and .shape for block spec construction.""" + + def __init__(self, shape: tuple[int, ...]) -> None: + self.shape = shape + self.ndim = len(shape) + + def _pallas_build_block_specs( pl: object, jnp: object, @@ -249,49 +258,67 @@ def _pallas_build_block_specs( output_indices: list[int], block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, - _output_only_set: set[int] | None = None, + output_only_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, ) -> tuple[list[object] | None, object | None]: """Build ``in_specs`` and ``out_specs`` for ``pl.pallas_call``. - Output-only tensors (in ``_output_only_set``) get HBM in_specs - to avoid VMEM pressure while keeping VMEM out_specs for writes. + ``block_spec_info`` is indexed by position among *all* tensor args. + ``output_only_indices`` lists tensor positions excluded from + ``tensor_arg_indices``; they are merged back to compute the mapping. + Output-only tensors are not in ``args``; their shapes come from + ``_output_only_shapes``. """ if block_spec_info is None or len(grid) == 0: return None, None - output_only_set = _output_only_set or set() + all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) + all_arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(all_positions)} + + # Build a lookup for output-only positions -> shape proxy + oo_set = set(output_only_indices or []) + oo_shape_map: dict[int, _ShapeProxy] = {} + if _output_only_shapes: + oo_iter = iter(_output_only_shapes) + for idx in sorted(oo_set): + shape, _dtype = next(oo_iter) + oo_shape_map[idx] = _ShapeProxy(shape) + + def _get_tensor_or_proxy(idx: int) -> torch.Tensor | _ShapeProxy: + if idx in oo_shape_map: + return oo_shape_map[idx] + return cast("torch.Tensor", args[idx]) in_specs = [] - for tensor_pos, idx in enumerate(tensor_arg_indices): - t = args[idx] - assert isinstance(t, torch.Tensor) - if idx in output_only_set: - in_specs.append(pl.BlockSpec(memory_space=pl.ANY)) # type: ignore[union-attr] - else: - should_use_smem = tensor_pos in (_smem_arg_indices or []) - in_specs.append( - _pallas_make_block_spec( - pl, jnp, pltpu, t, block_spec_info[tensor_pos], should_use_smem - ) - ) + for idx in tensor_arg_indices: + t = _get_tensor_or_proxy(idx) + tensor_pos = all_arg_to_tensor_pos[idx] + should_use_smem = tensor_pos in (_smem_arg_indices or []) + in_specs.append( + _pallas_make_block_spec( + pl, + jnp, + pltpu, + t, + block_spec_info[tensor_pos], + should_use_smem, + ) # pyrefly: ignore[bad-argument-type] + ) - arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} out_specs_list = [] for idx in output_indices: - t = args[idx] - assert isinstance(t, torch.Tensor) - # Output-only tensors keep VMEM out_specs so the kernel can write - # to them; only their in_specs use HBM to avoid VMEM pressure. - should_use_smem = arg_to_tensor_pos[idx] in (_smem_arg_indices or []) + t = _get_tensor_or_proxy(idx) + tensor_pos = all_arg_to_tensor_pos[idx] + should_use_smem = tensor_pos in (_smem_arg_indices or []) out_specs_list.append( _pallas_make_block_spec( pl, jnp, pltpu, t, - block_spec_info[arg_to_tensor_pos[idx]], + block_spec_info[tensor_pos], should_use_smem, - ) + ) # pyrefly: ignore[bad-argument-type] ) out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] @@ -308,6 +335,8 @@ def _pallas_build_pipeline_specs( output_indices: list[int], block_spec_info: _BlockSpecInfo | None, pipeline_arg_indices: list[int] | None, + output_only_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, ) -> tuple[list[object], object]: """Build in/out specs for pipeline launchers. @@ -315,17 +344,30 @@ def _pallas_build_pipeline_specs( All other tensors get proper BlockSpecs for automatic VMEM prefetch. """ pipeline_set = set(pipeline_arg_indices or []) - arg_to_tpos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} + all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) + arg_to_tpos = {orig: tpos for tpos, orig in enumerate(all_positions)} + + # Build shape proxy lookup for output-only positions + oo_set = set(output_only_indices or []) + oo_shape_map: dict[int, _ShapeProxy] = {} + if _output_only_shapes: + oo_iter = iter(_output_only_shapes) + for idx in sorted(oo_set): + shape, _dtype = next(oo_iter) + oo_shape_map[idx] = _ShapeProxy(shape) def _spec_for(idx: int) -> object: if idx in pipeline_set: return pl.BlockSpec(memory_space=pltpu.HBM) # type: ignore[union-attr] if block_spec_info is not None: - t = args[idx] - assert isinstance(t, torch.Tensor) + t = oo_shape_map[idx] if idx in oo_shape_map else args[idx] return _pallas_make_block_spec( - pl, jnp, pltpu, t, block_spec_info[arg_to_tpos[idx]] - ) + pl, + jnp, + pltpu, + t, + block_spec_info[arg_to_tpos[idx]], + ) # pyrefly: ignore[bad-argument-type] return pl.BlockSpec(memory_space=pl.ANY) # type: ignore[union-attr] in_specs = [_spec_for(idx) for idx in tensor_arg_indices] @@ -362,25 +404,28 @@ def _pallas_check_dtypes(args: tuple[object, ...]) -> None: def _pallas_prepare_args( args: tuple[object, ...], _output_indices: list[int], + _inplace_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, ) -> tuple[ - set[int], + list[int], list[int], dict[int, object], int, dict[int, int], - list[object], set[int], tuple[object, ...], ]: """Extract and organize tensor/non-tensor args for Pallas launchers. + When output-only tensors are excluded from ``args``, they are passed + separately via ``_output_only_shapes`` for shape/dtype information. + Returns a tuple of: - - output_set: set of output arg positions - - tensor_arg_indices: positions of tensor args + - tensor_arg_indices: positions of tensor args passed as pallas_call inputs + - output_only_indices: original positions of output-only tensors - non_tensor_args: mapping of non-tensor arg positions to values - - n_tensor_inputs: count of tensor args + - n_tensor_inputs: count of tensor inputs (excl. output-only) - arg_to_tensor_pos: mapping from original position to tensor-only position - - outputs: list of output tensors - inplace_positions: positions that are both input and output - out_shapes: JAX placeholders for output shapes """ @@ -396,25 +441,47 @@ def _pallas_prepare_args( placeholder_fn = jax_placeholder output_set = set(_output_indices) + inplace_set = set(_inplace_indices) if _inplace_indices is not None else output_set + output_only = output_set - inplace_set + + # tensor_arg_indices: positions in args that are tensors (inputs only). + # Output-only positions have None placeholders in args. tensor_arg_indices = [ i for i in range(len(args)) if isinstance(args[i], torch.Tensor) ] + output_only_indices = sorted(output_only) + non_tensor_args: dict[int, object] = { - i: args[i] for i in range(len(args)) if not isinstance(args[i], torch.Tensor) + i: args[i] + for i in range(len(args)) + if not isinstance(args[i], torch.Tensor) and args[i] is not None } n_tensor_inputs = len(tensor_arg_indices) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} - outputs = [args[i] for i in _output_indices] inplace_positions = output_set & set(tensor_arg_indices) - out_shapes = tuple(placeholder_fn(out) for out in outputs) # type: ignore[arg-type] + + # Build out_shapes for pallas_call. Inplace outputs are in args; + # output-only outputs use _output_only_shapes (shape, dtype) tuples. + import jax # type: ignore[import-not-found] + from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime + + oo_iter = iter(_output_only_shapes or []) + out_shapes_list: list[object] = [] + for orig_pos in _output_indices: + if orig_pos in output_only: + shape, dtype = next(oo_iter) + jax_dtype = torch_dtype_to_jax_runtime(dtype) # pyrefly: ignore[bad-argument-type] + out_shapes_list.append(jax.ShapeDtypeStruct(shape, jax_dtype)) + else: + out_shapes_list.append(placeholder_fn(args[orig_pos])) # type: ignore[arg-type] + out_shapes = tuple(out_shapes_list) return ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, ) @@ -441,16 +508,15 @@ def _pallas_make_reordered_kernel( reordered args. *skip_inplace_copy* is a set of original-arg positions for which the - initial ``out_ref[...] = in_ref[...]`` copy should be skipped. This is - needed for outputs backed by HBM refs (``pl.ANY``) where direct - load/store is not allowed. Outputs with VMEM BlockSpecs still get the - copy so that ``input_output_aliases`` correctly preloads their contents. + initial ``out_ref[...] = in_ref[...]`` copy should be skipped. Used by + pipeline/fori launchers for pipeline-body tensors backed by HBM refs + where direct load/store is not allowed. """ _skip_copy = skip_inplace_copy or set() def reordered_kernel(*refs: object) -> None: - n_kernel_params = len(args) - original_order: list[object] = [None] * n_kernel_params + _n_params = len(args) + original_order: list[object] = [None] * _n_params for tensor_pos, orig_pos in enumerate(tensor_arg_indices): original_order[orig_pos] = refs[tensor_pos] for orig_pos, value in non_tensor_args.items(): @@ -490,11 +556,19 @@ def _pallas_build_callable( """ def _make_interpret_callable() -> _PallasInterpretCallable: - output_tensor_positions = [ - arg_to_tensor_pos[orig_pos] for orig_pos in _output_indices + # Map (out_idx in _output_indices) -> tensor_pos for inplace outputs. + # out_idx must match jax_results ordering (all outputs), not filtered. + inplace_output_mapping = [ + (out_idx, arg_to_tensor_pos[orig_pos]) + for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos ] - callable_obj = _PallasInterpretCallable(jit_fn, output_tensor_positions) - setattr(pallas_kernel, cache_attr, (grid, callable_obj, tensor_arg_indices)) + callable_obj = _PallasInterpretCallable(jit_fn, inplace_output_mapping) + setattr( + pallas_kernel, + cache_attr, + (grid, callable_obj, tensor_arg_indices, arg_to_tensor_pos), + ) return callable_obj if _pallas_interpret_flag(): @@ -509,7 +583,8 @@ def _make_interpret_callable() -> _PallasInterpretCallable: call_aliases: dict[int, int] = {} for out_idx, orig_pos in enumerate(_output_indices): - call_aliases[arg_to_tensor_pos[orig_pos]] = out_idx + if orig_pos in arg_to_tensor_pos: + call_aliases[arg_to_tensor_pos[orig_pos]] = out_idx jax_callable = JaxCallable( name=kernel_name, @@ -517,38 +592,54 @@ def _make_interpret_callable() -> _PallasInterpretCallable: trace_key=f"{kernel_name}_{id(pallas_kernel)}_{grid}{trace_key_suffix}", input_output_aliases=call_aliases, ) - setattr(pallas_kernel, cache_attr, (grid, jax_callable, tensor_arg_indices)) + setattr( + pallas_kernel, + cache_attr, + (grid, jax_callable, tensor_arg_indices, arg_to_tensor_pos), + ) return jax_callable class _PallasInterpretCallable: """Thin wrapper that converts torch tensors <-> JAX arrays for interpret mode. - ``pallas_call`` with ``input_output_aliases`` returns new JAX arrays for the - outputs. This wrapper copies those results back into the original torch - output tensors (identified by ``output_tensor_positions``). + In interpret mode, ``pallas_call`` runs on CPU and returns JAX arrays. + This wrapper: + 1. Converts input torch tensors to JAX arrays + 2. Runs the pallas_call function + 3. For inplace outputs (donated tensors): copies JAX results back into + the original torch tensors via ``copy_()`` + 4. Returns raw JAX results so ``_pallas_invoke_and_return`` can + handle output-only tensors (which are not in the input list) + + ``inplace_output_mapping`` maps each inplace output to its JAX result: + a list of ``(out_idx, tensor_pos)`` where ``out_idx`` indexes into + ``jax_results`` and ``tensor_pos`` indexes into ``input_tensors``. """ def __init__( self, jit_fn: object, - output_tensor_positions: list[int], + inplace_output_mapping: list[tuple[int, int]], ) -> None: self._jit_fn = jit_fn - self._output_tensor_positions = output_tensor_positions + self._inplace_output_mapping = inplace_output_mapping - def __call__(self, *input_tensors: torch.Tensor) -> None: + def __call__(self, *input_tensors: torch.Tensor) -> tuple[object, ...]: jax_inputs = [_torch_to_jax(t) for t in input_tensors] jax_results = self._jit_fn(*jax_inputs) # type: ignore[operator] if not isinstance(jax_results, (tuple, list)): jax_results = (jax_results,) - # Write results back into the original output tensors. - for out_idx, tensor_pos in enumerate(self._output_tensor_positions): + # Write inplace results back into the original output tensors. + for out_idx, tensor_pos in self._inplace_output_mapping: out_tensor = input_tensors[tensor_pos] result_data = _jax_to_torch( jax_results[out_idx], device=out_tensor.device, dtype=out_tensor.dtype ) out_tensor.copy_(result_data) + # Return JAX results so output-only tensors can be handled + # by _pallas_invoke_and_return. + return tuple(jax_results) def _pallas_interpret_flag() -> bool: @@ -579,16 +670,62 @@ def _ensure_cpu_tpu_info() -> None: registry["cpu"] = lambda: _get_tpu_info_impl(ChipVersion.TPU_7X, 1) +def _pallas_invoke_and_return( + jax_callable: object, + args: tuple[object, ...], + tensor_arg_indices: list[int], + arg_to_tensor_pos: dict[int, int], + _output_indices: list[int], +) -> object: + """Run the JaxCallable and return output-only results. + + Output-only tensors (those not in ``arg_to_tensor_pos``) are not passed + as pallas_call inputs, so the JaxCallable returns new buffers for them. + Returns a single tensor, a tuple of tensors, or None. + """ + input_tensors = [ + cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices + ] + results = jax_callable(*input_tensors) # type: ignore[operator] + if results is None: + return None + if not isinstance(results, (tuple, list)): + results = (results,) + output_only_results: list[torch.Tensor] = [] + for out_idx, orig_pos in enumerate(_output_indices): + if orig_pos not in arg_to_tensor_pos: + result = results[out_idx] + if not isinstance(result, torch.Tensor): + # Interpret mode: pallas_call returns JAX arrays, convert to torch. + # On TPU, JaxCallable returns torch tensors directly. + import numpy as np + + jax_dtype = result.dtype # type: ignore[union-attr] + torch_dtype = torch.from_numpy(np.empty(0, dtype=jax_dtype)).dtype + result = _jax_to_torch( + result, + device=torch.device("cpu"), + dtype=torch_dtype, + ) + output_only_results.append(result) + if not output_only_results: + return None + if len(output_only_results) == 1: + return output_only_results[0] + return tuple(output_only_results) + + def default_pallas_launcher( pallas_kernel: object, grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, -) -> None: +) -> object: """Default launcher for Pallas kernels on TPU (or CPU with interpret=True). Uses ``JaxCallable`` from ``torch_tpu`` to compile and run the Pallas @@ -598,7 +735,8 @@ def default_pallas_launcher( buffers (zero-copy on TPU). Output-only tensors (in ``_output_indices`` but not in ``_inplace_indices``) - get HBM in_specs to avoid VMEM pressure while still being donated. + are excluded from pallas_call inputs to save VMEM. Their results are + returned as torch tensors. """ if _output_indices is None: _output_indices = [] @@ -607,30 +745,23 @@ def default_pallas_launcher( cache = getattr(pallas_kernel, "_pallas_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) - - # Derive output-only set: outputs not in _inplace_indices. - inplace_set = ( - set(_inplace_indices) - if _inplace_indices is not None - else set(_output_indices) + ) = _pallas_prepare_args( + args, _output_indices, _inplace_indices, _output_only_shapes ) - output_only_set = set(_output_indices) - inplace_set in_specs, out_specs = _pallas_build_block_specs( pl, @@ -642,7 +773,8 @@ def default_pallas_launcher( _output_indices, _block_spec_info, _smem_arg_indices, - output_only_set, + output_only_indices, + _output_only_shapes, ) reordered_kernel = _pallas_make_reordered_kernel( @@ -655,7 +787,6 @@ def default_pallas_launcher( inplace_positions, arg_to_tensor_pos, _smem_arg_indices=_smem_arg_indices, - skip_inplace_copy=output_only_set, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -663,6 +794,7 @@ def default_pallas_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } pallas_call_kwargs: dict[str, object] = { @@ -691,10 +823,9 @@ def default_pallas_launcher( cache_attr="_pallas_cache", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_return( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def default_pallas_pipeline_launcher( @@ -702,11 +833,13 @@ def default_pallas_pipeline_launcher( grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, + _inplace_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, _pipeline_arg_indices: list[int] | None = None, **kwargs: object, -) -> None: +) -> object: """Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory. Used when ``pallas_loop_type='emit_pipeline'``. Pipeline-body tensors @@ -722,22 +855,23 @@ def default_pallas_pipeline_launcher( cache = getattr(pallas_kernel, "_pallas_pipeline_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) + ) = _pallas_prepare_args( + args, _output_indices, _inplace_indices, _output_only_shapes + ) # Build scratch shapes for VMEM _jnp_dtype_map: dict[str, object] = { @@ -775,6 +909,8 @@ def default_pallas_pipeline_launcher( _output_indices, _block_spec_info, _pipeline_arg_indices, + output_only_indices, + _output_only_shapes, ) _pipeline_set = set(_pipeline_arg_indices or []) @@ -796,6 +932,7 @@ def default_pallas_pipeline_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } grid_spec = pltpu.PrefetchScalarGridSpec( @@ -833,10 +970,9 @@ def default_pallas_pipeline_launcher( trace_key_suffix="_pipeline", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_return( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def default_pallas_fori_launcher( @@ -844,10 +980,12 @@ def default_pallas_fori_launcher( grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, + _inplace_indices: list[int] | None = None, + _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, **kwargs: object, -) -> None: +) -> object: """Launcher for Pallas kernels using fori_loop with manual DMA. Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as @@ -865,22 +1003,23 @@ def default_pallas_fori_launcher( cache = getattr(pallas_kernel, "_pallas_fori_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) + ) = _pallas_prepare_args( + args, _output_indices, _inplace_indices, _output_only_shapes + ) # Build scratch shapes: VMEM buffers + DMA semaphores _jnp_dtype_map: dict[str, object] = { @@ -917,6 +1056,8 @@ def default_pallas_fori_launcher( _output_indices, _block_spec_info, _fori_pipeline_indices, # type: ignore[arg-type] + output_only_indices, + _output_only_shapes, ) _fori_pipeline_set = set(_fori_pipeline_indices or []) # type: ignore[arg-type] @@ -938,6 +1079,7 @@ def default_pallas_fori_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } grid_spec = pltpu.PrefetchScalarGridSpec( @@ -975,10 +1117,9 @@ def default_pallas_fori_launcher( trace_key_suffix="_fori", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_return( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def _torch_to_jax(t: torch.Tensor) -> object: @@ -1225,20 +1366,16 @@ def default_metal_launcher( metal_kernel: object, grid: tuple[int, ...], *args: object, - _block_dims: tuple[int, int, int] = (256, 1, 1), + _block_size: int = 256, **kwargs: object, ) -> None: """Default launcher for Metal kernels on Apple MPS devices. - The ``metal_kernel`` is a ``@metal_jit`` decorated function that - translates its Python AST body to MSL and compiles it via - ``torch.mps.compile_shader`` on each call. - This launcher dispatches the compiled kernel with the given grid and - threadgroup dimensions. + Compiles MSL source via ``torch.mps.compile_shader()`` and dispatches + using the compiled library. Caches the compiled library on the kernel + object to avoid recompilation on subsequent calls. - Uses a 3D threadgroup dispatch model: ``_block_dims`` specifies the - threadgroup size as ``(x, y, z)``. The grid specifies the number of - threadgroups per dimension. + Only 1D grids are currently supported. """ kwargs.pop("num_warps", None) kwargs.pop("num_stages", None) @@ -1247,15 +1384,20 @@ def default_metal_launcher( "metal", f"unexpected launcher kwargs: {sorted(kwargs)}" ) - lib, kernel_name = metal_kernel(*args) # type: ignore[operator] + assert len(grid) == 1, ( + f"Metal launcher only supports 1D grids, got {len(grid)}D: {grid}" + ) + + msl_source, kernel_name = metal_kernel() # type: ignore[operator] + source_hash = hashlib.sha256(msl_source.encode()).digest() + cache = getattr(metal_kernel, "_metal_cache", None) + if cache is not None and cache[0] == source_hash: + lib = cache[1] + else: + lib = torch.mps.compile_shader(msl_source) # type: ignore[attr-defined] + metal_kernel._metal_cache = (source_hash, lib) # type: ignore[attr-defined] tensor_args = [a for a in args if isinstance(a, torch.Tensor)] dispatch_fn = getattr(lib, kernel_name) - bx, by, bz = _block_dims - # Pad grid to 3D - gx = grid[0] if len(grid) > 0 else 1 - gy = grid[1] if len(grid) > 1 else 1 - gz = grid[2] if len(grid) > 2 else 1 - total_threads = (gx * bx, gy * by, gz * bz) - group_size = (bx, by, bz) - dispatch_fn(*tensor_args, threads=total_threads, group_size=group_size) + total_threads = grid[0] * _block_size + dispatch_fn(*tensor_args, threads=total_threads, group_size=_block_size) From 35c0472f20983d63a9d905d1c7a96227f03a1ef0 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 15 Apr 2026 10:25:56 -0700 Subject: [PATCH 2/3] [Pallas] Filter output-only tensors from args instead of using None placeholders The previous approach passed None placeholders in place of output-only tensors to keep positional indices aligned. This broke because: 1. The runtime used isinstance(args[i], torch.Tensor) to find tensor args, so None placeholders were silently skipped, causing tensor_arg_indices to use args-based positions rather than kernel parameter positions. Block specs, reordered_kernel, and input_output_aliases all index by kernel parameter position, leading to mismatches when output-only tensors appear before input tensors. 2. The reordered_kernel used len(args) to determine the number of kernel parameters, but with Nones removed by the isinstance check, this count was wrong for kernels with output-only tensors. 3. _pallas_build_block_specs and _pallas_build_pipeline_specs accessed args[idx] to get tensor shapes, but with None placeholders at output-only positions, this returned None instead of a tensor. The fix: - Codegen now filters out output-only tensors entirely and passes _n_kernel_params so the runtime knows the original parameter count. - _pallas_prepare_args maps between reduced-args positions and kernel positions, ensuring tensor_arg_indices, non_tensor_args, and arg_to_tensor_pos all use kernel positions consistently. - Block spec builders now receive input_tensors + arg_to_tensor_pos instead of the raw args tuple, eliminating direct positional access. - input_tensors is built outside the cache branch so it's always available for _pallas_invoke_and_return on cache hits. --- helion/_compiler/backend.py | 11 +-- helion/runtime/__init__.py | 129 ++++++++++++++++++++++++++---------- 2 files changed, 101 insertions(+), 39 deletions(-) diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index b709bd48f..dd3067c15 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1546,14 +1546,15 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: # Replace output-only tensors with None placeholders to keep # indices aligned with _output_indices. oo_name_set = set(output_only_names) - placeholder_args = ["None" if a in oo_name_set else a for a in args] + filtered_args = [a for a in args if a not in oo_name_set] - launcher_args = [*placeholder_args, f"_output_indices={output_indices}"] + launcher_args = [*filtered_args, f"_output_indices={output_indices}"] launcher_args.append(f"_inplace_indices={inplace_indices}") if output_only_shape_exprs: - launcher_args.append( - f"_output_only_shapes=[{', '.join(output_only_shape_exprs)}]" - ) + launcher_args.extend([ + f"_output_only_shapes=[{', '.join(output_only_shape_exprs)}]", + f"_n_kernel_params={len(args)}", + ]) if has_rng_ops: launcher_args.insert(-1, "_rng_seed_buffer") diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index c26067999..95a05d9b8 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -253,9 +253,10 @@ def _pallas_build_block_specs( jnp: object, pltpu: object, grid: tuple[int, ...], - args: tuple[object, ...], + input_tensors: list[torch.Tensor], tensor_arg_indices: list[int], output_indices: list[int], + arg_to_tensor_pos: dict[int, int], block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, output_only_indices: list[int] | None = None, @@ -264,10 +265,8 @@ def _pallas_build_block_specs( """Build ``in_specs`` and ``out_specs`` for ``pl.pallas_call``. ``block_spec_info`` is indexed by position among *all* tensor args. - ``output_only_indices`` lists tensor positions excluded from - ``tensor_arg_indices``; they are merged back to compute the mapping. - Output-only tensors are not in ``args``; their shapes come from - ``_output_only_shapes``. + ``input_tensors`` contains actual tensors for inputs; output-only + shapes come from ``_output_only_shapes``. """ if block_spec_info is None or len(grid) == 0: return None, None @@ -287,7 +286,7 @@ def _pallas_build_block_specs( def _get_tensor_or_proxy(idx: int) -> torch.Tensor | _ShapeProxy: if idx in oo_shape_map: return oo_shape_map[idx] - return cast("torch.Tensor", args[idx]) + return input_tensors[arg_to_tensor_pos[idx]] in_specs = [] for idx in tensor_arg_indices: @@ -330,9 +329,10 @@ def _pallas_build_pipeline_specs( jnp: object, pltpu: object, grid: tuple[int, ...], - args: tuple[object, ...], + input_tensors: list[torch.Tensor], tensor_arg_indices: list[int], output_indices: list[int], + arg_to_tensor_pos: dict[int, int], block_spec_info: _BlockSpecInfo | None, pipeline_arg_indices: list[int] | None, output_only_indices: list[int] | None = None, @@ -360,7 +360,11 @@ def _spec_for(idx: int) -> object: if idx in pipeline_set: return pl.BlockSpec(memory_space=pltpu.HBM) # type: ignore[union-attr] if block_spec_info is not None: - t = oo_shape_map[idx] if idx in oo_shape_map else args[idx] + t = ( + oo_shape_map[idx] + if idx in oo_shape_map + else input_tensors[arg_to_tensor_pos[idx]] + ) return _pallas_make_block_spec( pl, jnp, @@ -409,6 +413,7 @@ def _pallas_prepare_args( ) -> tuple[ list[int], list[int], + list[torch.Tensor], dict[int, object], int, dict[int, int], @@ -421,12 +426,13 @@ def _pallas_prepare_args( separately via ``_output_only_shapes`` for shape/dtype information. Returns a tuple of: - - tensor_arg_indices: positions of tensor args passed as pallas_call inputs - - output_only_indices: original positions of output-only tensors - - non_tensor_args: mapping of non-tensor arg positions to values + - tensor_arg_indices: kernel parameter positions of tensor inputs + - output_only_indices: kernel parameter positions of output-only tensors + - input_tensors: actual tensor objects for pallas_call inputs + - non_tensor_args: mapping of kernel positions to non-tensor values - n_tensor_inputs: count of tensor inputs (excl. output-only) - - arg_to_tensor_pos: mapping from original position to tensor-only position - - inplace_positions: positions that are both input and output + - arg_to_tensor_pos: mapping from kernel position to tensor-only position + - inplace_positions: kernel positions that are both input and output - out_shapes: JAX placeholders for output shapes """ from .settings import is_pallas_interpret @@ -444,17 +450,32 @@ def _pallas_prepare_args( inplace_set = set(_inplace_indices) if _inplace_indices is not None else output_set output_only = output_set - inplace_set - # tensor_arg_indices: positions in args that are tensors (inputs only). - # Output-only positions have None placeholders in args. - tensor_arg_indices = [ - i for i in range(len(args)) if isinstance(args[i], torch.Tensor) - ] + # Output-only tensors are removed from args. Build mappings between + # reduced args positions and original kernel parameter positions. output_only_indices = sorted(output_only) + def _args_pos_to_kernel_pos(args_pos: int) -> int: + """Map a position in reduced args to original kernel parameter position.""" + kernel_pos = args_pos + for oo_pos in output_only_indices: + if oo_pos <= kernel_pos: + kernel_pos += 1 + return kernel_pos + + # tensor_arg_indices uses KERNEL positions (for block specs, reordered kernel). + # kernel_to_args maps kernel positions back to args positions (for data access). + kernel_to_args: dict[int, int] = {} + tensor_arg_indices = [] + for i in range(len(args)): + if isinstance(args[i], torch.Tensor): + kpos = _args_pos_to_kernel_pos(i) + tensor_arg_indices.append(kpos) + kernel_to_args[kpos] = i + non_tensor_args: dict[int, object] = { - i: args[i] + _args_pos_to_kernel_pos(i): args[i] for i in range(len(args)) - if not isinstance(args[i], torch.Tensor) and args[i] is not None + if not isinstance(args[i], torch.Tensor) } n_tensor_inputs = len(tensor_arg_indices) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} @@ -470,15 +491,24 @@ def _pallas_prepare_args( for orig_pos in _output_indices: if orig_pos in output_only: shape, dtype = next(oo_iter) - jax_dtype = torch_dtype_to_jax_runtime(dtype) # pyrefly: ignore[bad-argument-type] + jax_dtype = torch_dtype_to_jax_runtime( + dtype + ) # pyrefly: ignore[bad-argument-type] out_shapes_list.append(jax.ShapeDtypeStruct(shape, jax_dtype)) else: - out_shapes_list.append(placeholder_fn(args[orig_pos])) # type: ignore[arg-type] + out_shapes_list.append(placeholder_fn(args[kernel_to_args[orig_pos]])) # type: ignore[arg-type] out_shapes = tuple(out_shapes_list) + # Build input_tensors from args using kernel_to_args mapping. + input_tensors = [ + cast("torch.Tensor", args[kernel_to_args[kpos]]).contiguous() + for kpos in tensor_arg_indices + ] + return ( tensor_arg_indices, output_only_indices, + input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, @@ -499,6 +529,7 @@ def _pallas_make_reordered_kernel( n_extra_refs: int = 0, skip_inplace_copy: set[int] | None = None, _smem_arg_indices: list[int] | None = None, + n_kernel_params: int | None = None, ) -> object: """Create a wrapper kernel that reorders pallas_call refs to the original arg order. @@ -515,7 +546,7 @@ def _pallas_make_reordered_kernel( _skip_copy = skip_inplace_copy or set() def reordered_kernel(*refs: object) -> None: - _n_params = len(args) + _n_params = n_kernel_params if n_kernel_params is not None else len(args) original_order: list[object] = [None] * _n_params for tensor_pos, orig_pos in enumerate(tensor_arg_indices): original_order[orig_pos] = refs[tensor_pos] @@ -672,8 +703,7 @@ def _ensure_cpu_tpu_info() -> None: def _pallas_invoke_and_return( jax_callable: object, - args: tuple[object, ...], - tensor_arg_indices: list[int], + input_tensors: list[torch.Tensor], arg_to_tensor_pos: dict[int, int], _output_indices: list[int], ) -> object: @@ -683,9 +713,6 @@ def _pallas_invoke_and_return( as pallas_call inputs, so the JaxCallable returns new buffers for them. Returns a single tensor, a tuple of tensors, or None. """ - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] results = jax_callable(*input_tensors) # type: ignore[operator] if results is None: return None @@ -722,6 +749,7 @@ def default_pallas_launcher( _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, + _n_kernel_params: int | None = None, _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, @@ -754,6 +782,7 @@ def default_pallas_launcher( ( tensor_arg_indices, output_only_indices, + input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, @@ -768,9 +797,10 @@ def default_pallas_launcher( jnp, pltpu, grid, - args, + input_tensors, tensor_arg_indices, _output_indices, + arg_to_tensor_pos, _block_spec_info, _smem_arg_indices, output_only_indices, @@ -787,6 +817,7 @@ def default_pallas_launcher( inplace_positions, arg_to_tensor_pos, _smem_arg_indices=_smem_arg_indices, + n_kernel_params=_n_kernel_params, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -823,8 +854,16 @@ def default_pallas_launcher( cache_attr="_pallas_cache", ) + # Build input_tensors from args (works for both cache hit and miss). + # tensor_arg_indices order matches the order of tensors in args. + input_tensors = [ + cast("torch.Tensor", a).contiguous() + for a in args + if isinstance(a, torch.Tensor) + ] + return _pallas_invoke_and_return( - jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + jax_callable, input_tensors, arg_to_tensor_pos, _output_indices ) @@ -835,6 +874,7 @@ def default_pallas_pipeline_launcher( _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, + _n_kernel_params: int | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, _pipeline_arg_indices: list[int] | None = None, @@ -864,6 +904,7 @@ def default_pallas_pipeline_launcher( ( tensor_arg_indices, output_only_indices, + input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, @@ -904,9 +945,10 @@ def default_pallas_pipeline_launcher( jnp, pltpu, grid, - args, + input_tensors, tensor_arg_indices, _output_indices, + arg_to_tensor_pos, _block_spec_info, _pipeline_arg_indices, output_only_indices, @@ -925,6 +967,7 @@ def default_pallas_pipeline_launcher( arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=_pipeline_set, + n_kernel_params=_n_kernel_params, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -970,8 +1013,15 @@ def default_pallas_pipeline_launcher( trace_key_suffix="_pipeline", ) + # Build input_tensors from args (works for both cache hit and miss). + input_tensors = [ + cast("torch.Tensor", a).contiguous() + for a in args + if isinstance(a, torch.Tensor) + ] + return _pallas_invoke_and_return( - jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + jax_callable, input_tensors, arg_to_tensor_pos, _output_indices ) @@ -982,6 +1032,7 @@ def default_pallas_fori_launcher( _output_indices: list[int] | None = None, _inplace_indices: list[int] | None = None, _output_only_shapes: list[tuple[tuple[int, ...], object]] | None = None, + _n_kernel_params: int | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, **kwargs: object, @@ -1012,6 +1063,7 @@ def default_pallas_fori_launcher( ( tensor_arg_indices, output_only_indices, + input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, @@ -1051,9 +1103,10 @@ def default_pallas_fori_launcher( jnp, pltpu, grid, - args, + input_tensors, tensor_arg_indices, _output_indices, + arg_to_tensor_pos, _block_spec_info, _fori_pipeline_indices, # type: ignore[arg-type] output_only_indices, @@ -1072,6 +1125,7 @@ def default_pallas_fori_launcher( arg_to_tensor_pos, n_extra_refs=len(scratch_shapes), skip_inplace_copy=_fori_pipeline_set, + n_kernel_params=_n_kernel_params, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -1117,8 +1171,15 @@ def default_pallas_fori_launcher( trace_key_suffix="_fori", ) + # Build input_tensors from args (works for both cache hit and miss). + input_tensors = [ + cast("torch.Tensor", a).contiguous() + for a in args + if isinstance(a, torch.Tensor) + ] + return _pallas_invoke_and_return( - jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + jax_callable, input_tensors, arg_to_tensor_pos, _output_indices ) From 5b647a168f7c5e19f32cc9c50d10ef8a7dfa5c5b Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 15 Apr 2026 10:48:13 -0700 Subject: [PATCH 3/3] [Pallas] Add kernel-position index shifting and multi-input output-only test The runtime now maps between reduced-args positions and kernel parameter positions using _args_pos_to_kernel_pos, rather than assuming positions align directly. This is defensive: currently function-argument tensors are always traced before body-created tensors so output-only tensors end up last, but the shifting protects against future ordering changes. Also adds test_output_only_mid_position exercising the output-only flow with two inputs and one output-only tensor. --- helion/_compiler/backend.py | 21 ++++++++++---- helion/runtime/__init__.py | 55 +++++++++++++++++++++---------------- test/test_pallas.py | 22 +++++++++++++++ 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index dd3067c15..4afc9abfb 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1538,8 +1538,17 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: if i in output_only_set: arg = sorted_args[i] assert isinstance(arg, TensorArg) - output_only_names.append(arg.host_str()) shape = tuple(arg.fake_value.shape) + # Skip output-only optimization for dynamic shapes: + # symbolic dims can't be emitted as compile-time literals. + # Merge output-only indices into inplace so the runtime + # doesn't expect _output_only_shapes. + if any(isinstance(s, torch.SymInt) for s in shape): + output_only_names.clear() + output_only_shape_exprs.clear() + inplace_indices = sorted(set(inplace_indices) | output_only_set) + break + output_only_names.append(arg.host_str()) dtype = arg.fake_value.dtype output_only_shape_exprs.append(f"({shape}, {dtype})") self._output_only_names = output_only_names @@ -1551,10 +1560,12 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: launcher_args = [*filtered_args, f"_output_indices={output_indices}"] launcher_args.append(f"_inplace_indices={inplace_indices}") if output_only_shape_exprs: - launcher_args.extend([ - f"_output_only_shapes=[{', '.join(output_only_shape_exprs)}]", - f"_n_kernel_params={len(args)}", - ]) + launcher_args.extend( + [ + f"_output_only_shapes=[{', '.join(output_only_shape_exprs)}]", + f"_n_kernel_params={len(args)}", + ] + ) if has_rng_ops: launcher_args.insert(-1, "_rng_seed_buffer") diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 95a05d9b8..326420f5a 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -298,10 +298,10 @@ def _get_tensor_or_proxy(idx: int) -> torch.Tensor | _ShapeProxy: pl, jnp, pltpu, - t, + t, # pyrefly: ignore[bad-argument-type] block_spec_info[tensor_pos], should_use_smem, - ) # pyrefly: ignore[bad-argument-type] + ) ) out_specs_list = [] @@ -314,10 +314,10 @@ def _get_tensor_or_proxy(idx: int) -> torch.Tensor | _ShapeProxy: pl, jnp, pltpu, - t, + t, # pyrefly: ignore[bad-argument-type] block_spec_info[tensor_pos], should_use_smem, - ) # pyrefly: ignore[bad-argument-type] + ) ) out_specs = out_specs_list if len(out_specs_list) > 1 else out_specs_list[0] @@ -369,9 +369,9 @@ def _spec_for(idx: int) -> object: pl, jnp, pltpu, - t, + t, # pyrefly: ignore[bad-argument-type] block_spec_info[arg_to_tpos[idx]], - ) # pyrefly: ignore[bad-argument-type] + ) return pl.BlockSpec(memory_space=pl.ANY) # type: ignore[union-attr] in_specs = [_spec_for(idx) for idx in tensor_arg_indices] @@ -413,10 +413,10 @@ def _pallas_prepare_args( ) -> tuple[ list[int], list[int], - list[torch.Tensor], dict[int, object], int, dict[int, int], + dict[int, int], set[int], tuple[object, ...], ]: @@ -428,10 +428,10 @@ def _pallas_prepare_args( Returns a tuple of: - tensor_arg_indices: kernel parameter positions of tensor inputs - output_only_indices: kernel parameter positions of output-only tensors - - input_tensors: actual tensor objects for pallas_call inputs - non_tensor_args: mapping of kernel positions to non-tensor values - n_tensor_inputs: count of tensor inputs (excl. output-only) - arg_to_tensor_pos: mapping from kernel position to tensor-only position + - kernel_to_args: mapping from kernel position to args position - inplace_positions: kernel positions that are both input and output - out_shapes: JAX placeholders for output shapes """ @@ -450,12 +450,12 @@ def _pallas_prepare_args( inplace_set = set(_inplace_indices) if _inplace_indices is not None else output_set output_only = output_set - inplace_set - # Output-only tensors are removed from args. Build mappings between - # reduced args positions and original kernel parameter positions. + # Output-only tensors are removed from args at codegen time. Build + # mappings between reduced-args positions and kernel parameter positions. output_only_indices = sorted(output_only) def _args_pos_to_kernel_pos(args_pos: int) -> int: - """Map a position in reduced args to original kernel parameter position.""" + """Map a position in reduced args to the original kernel parameter position.""" kernel_pos = args_pos for oo_pos in output_only_indices: if oo_pos <= kernel_pos: @@ -492,26 +492,20 @@ def _args_pos_to_kernel_pos(args_pos: int) -> int: if orig_pos in output_only: shape, dtype = next(oo_iter) jax_dtype = torch_dtype_to_jax_runtime( - dtype - ) # pyrefly: ignore[bad-argument-type] + dtype # pyrefly: ignore[bad-argument-type] + ) out_shapes_list.append(jax.ShapeDtypeStruct(shape, jax_dtype)) else: out_shapes_list.append(placeholder_fn(args[kernel_to_args[orig_pos]])) # type: ignore[arg-type] out_shapes = tuple(out_shapes_list) - # Build input_tensors from args using kernel_to_args mapping. - input_tensors = [ - cast("torch.Tensor", args[kernel_to_args[kpos]]).contiguous() - for kpos in tensor_arg_indices - ] - return ( tensor_arg_indices, output_only_indices, - input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, + kernel_to_args, inplace_positions, out_shapes, ) @@ -782,16 +776,21 @@ def default_pallas_launcher( ( tensor_arg_indices, output_only_indices, - input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, + kernel_to_args, inplace_positions, out_shapes, ) = _pallas_prepare_args( args, _output_indices, _inplace_indices, _output_only_shapes ) + input_tensors = [ + cast("torch.Tensor", args[kernel_to_args[kpos]]).contiguous() + for kpos in tensor_arg_indices + ] + in_specs, out_specs = _pallas_build_block_specs( pl, jnp, @@ -904,10 +903,10 @@ def default_pallas_pipeline_launcher( ( tensor_arg_indices, output_only_indices, - input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, + kernel_to_args, inplace_positions, out_shapes, ) = _pallas_prepare_args( @@ -940,6 +939,11 @@ def default_pallas_pipeline_launcher( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) + input_tensors = [ + cast("torch.Tensor", args[kernel_to_args[kpos]]).contiguous() + for kpos in tensor_arg_indices + ] + in_specs_list, out_specs = _pallas_build_pipeline_specs( pl, jnp, @@ -1063,10 +1067,10 @@ def default_pallas_fori_launcher( ( tensor_arg_indices, output_only_indices, - input_tensors, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, + kernel_to_args, inplace_positions, out_shapes, ) = _pallas_prepare_args( @@ -1095,6 +1099,11 @@ def default_pallas_fori_launcher( pltpu.VMEM(shape, jnp_dtype) # pyrefly: ignore[bad-argument-type] ) + input_tensors = [ + cast("torch.Tensor", args[kernel_to_args[kpos]]).contiguous() + for kpos in tensor_arg_indices + ] + # Build in_specs/out_specs: proper BlockSpecs for outer grid dims, # HBM refs for tensors used in the fori_loop body (DMA handles tiling). _fori_pipeline_indices = kwargs.get("_pipeline_arg_indices") diff --git a/test/test_pallas.py b/test/test_pallas.py index 4f4d3bbf4..b22be2a4c 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -807,6 +807,28 @@ def test_int64_tensor_raises(self) -> None: with self.assertRaises(TypeError, msg="does not support"): code_and_output(add_kernel, (x, y), block_size=128) + def test_output_only_mid_position(self) -> None: + """Output-only kernel with two inputs and one output-only tensor. + + Verifies that the index shifting in _pallas_prepare_args works + correctly when output-only tensors are filtered from args. + Currently, function-argument tensors are always traced before + body-created tensors, so the output ends up last; the shifting + is defensive for potential future ordering changes. + """ + + @helion.kernel(backend="pallas", static_shapes=True) + def mid_output_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] + y[tile] + return out + + x = torch.randn(1024, device=DEVICE, dtype=torch.float32) + y = torch.randn(1024, device=DEVICE, dtype=torch.float32) + code, result = code_and_output(mid_output_kernel, (x, y), block_sizes=[1024]) + torch.testing.assert_close(result, x + y) + def test_fori_loop_multidim(self) -> None: """Test fori_loop with a 2D inner loop (nested iteration).""" args = (