diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 8801bf393..72d600bc9 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -1436,8 +1436,8 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: output_indices: list[int] = [] # Indices of output tensors that are also read by the kernel # (inplace-mutated params or body-created tensors the kernel reads). - # These must use VMEM BlockSpecs. Output-only tensors (written but - # never read) get HBM in_specs to avoid VMEM pressure. + # Output-only tensors (written but never read) are excluded from + # pallas_call inputs and returned by the launcher instead. inplace_indices: list[int] = [] if sorted_args is not None: env = CompileEnvironment.current() @@ -1455,9 +1455,9 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: # # Only tensors allocated with torch.empty/empty_like/new_empty can be # output-only — their initial values are undefined, so it's safe - # to use HBM BlockSpecs. Tensors allocated with torch.zeros_like, - # torch.full, etc. have meaningful initial values that must be - # preserved via VMEM BlockSpecs. + # to skip passing them as pallas_call inputs. Tensors allocated + # with torch.zeros_like, torch.full, etc. have meaningful initial + # values that must be preserved via input_output_aliases (donation). empty_vars = _empty_allocated_vars(host_fn.body) kernel_reads: set[str] = set() for stmt in host_fn.body: @@ -1481,6 +1481,16 @@ def _empty_allocated_vars(body: list[ast.stmt]) -> set[str]: output_indices.append(i) inplace_indices.append(i) + # Identify output-only arg names for codegen to capture return values. + output_only_set = set(output_indices) - set(inplace_indices) + output_only_names: list[str] = [] + if sorted_args is not None: + for i in output_indices: + if i in output_only_set: + output_only_names.append(sorted_args[i].host_str()) + # Store for codegen_function_call to use. + self._output_only_names = output_only_names + launcher_args = [*args, f"_output_indices={output_indices}"] launcher_args.append(f"_inplace_indices={inplace_indices}") diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index e0bd2c99a..b6feb953c 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -791,11 +791,39 @@ def codegen_function_call(self) -> ast.AST: has_barrier=env.has_barrier, sorted_args=arg_objects, ) - # TODO(jansel): we should run CSE this statement - call_statement = statement_from_string( - f"_launcher({self.name}, {{call_grid_expr}}, {', '.join(call_args)})", - call_grid_expr=call_grid_expr, + # Check if the backend wants to capture return values for output-only tensors. + output_only_names = getattr(backend, "_output_only_names", []) + launcher_call = ( + f"_launcher({self.name}, {{call_grid_expr}}, {', '.join(call_args)})" ) + if output_only_names: + if len(output_only_names) == 1: + assign_target = output_only_names[0] + else: + assign_target = ", ".join(output_only_names) + # Assign launcher result to output-only vars. The `if` guard + # handles custom launchers that return None (e.g. fake-tensor + # inference mocks). + if len(output_only_names) == 1: + fallback = output_only_names[0] + else: + fallback = f"({', '.join(output_only_names)})" + # Emit two statements: the launcher call, then the conditional assign. + launcher_stmt = statement_from_string( + f"_launcher_result = {launcher_call}", + call_grid_expr=call_grid_expr, + ) + assert isinstance(launcher_stmt, ExtendedAST) + launcher_stmt._is_kernel_call = True + self.codegen.host_statements.append(launcher_stmt) + call_statement = statement_from_string( + f"{assign_target} = _launcher_result if _launcher_result is not None else {fallback}", + ) + else: + call_statement = statement_from_string( + launcher_call, + call_grid_expr=call_grid_expr, + ) assert isinstance(call_statement, ExtendedAST) # Mark the kernel call we can find it in codegen_precompile_def call_statement._is_kernel_call = True diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 64e337cee..83346c3c0 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -250,47 +250,45 @@ def _pallas_build_block_specs( output_indices: list[int], block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, - _output_only_set: set[int] | None = None, + output_only_indices: list[int] | None = None, ) -> tuple[list[object] | None, object | None]: """Build ``in_specs`` and ``out_specs`` for ``pl.pallas_call``. - Output-only tensors (in ``_output_only_set``) get HBM in_specs - to avoid VMEM pressure while keeping VMEM out_specs for writes. + ``block_spec_info`` is indexed by position among *all* tensor args. + ``output_only_indices`` lists tensor positions excluded from + ``tensor_arg_indices``; they are merged back to compute the mapping. """ if block_spec_info is None or len(grid) == 0: return None, None - output_only_set = _output_only_set or set() + all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) + all_arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(all_positions)} in_specs = [] - for tensor_pos, idx in enumerate(tensor_arg_indices): + for idx in tensor_arg_indices: t = args[idx] assert isinstance(t, torch.Tensor) - if idx in output_only_set: - in_specs.append(pl.BlockSpec(memory_space=pl.ANY)) # type: ignore[union-attr] - else: - should_use_smem = tensor_pos in (_smem_arg_indices or []) - in_specs.append( - _pallas_make_block_spec( - pl, jnp, pltpu, t, block_spec_info[tensor_pos], should_use_smem - ) + tensor_pos = all_arg_to_tensor_pos[idx] + should_use_smem = tensor_pos in (_smem_arg_indices or []) + in_specs.append( + _pallas_make_block_spec( + pl, jnp, pltpu, t, block_spec_info[tensor_pos], should_use_smem ) + ) - arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} out_specs_list = [] for idx in output_indices: t = args[idx] assert isinstance(t, torch.Tensor) - # Output-only tensors keep VMEM out_specs so the kernel can write - # to them; only their in_specs use HBM to avoid VMEM pressure. - should_use_smem = arg_to_tensor_pos[idx] in (_smem_arg_indices or []) + tensor_pos = all_arg_to_tensor_pos[idx] + should_use_smem = tensor_pos in (_smem_arg_indices or []) out_specs_list.append( _pallas_make_block_spec( pl, jnp, pltpu, t, - block_spec_info[arg_to_tensor_pos[idx]], + block_spec_info[tensor_pos], should_use_smem, ) ) @@ -309,6 +307,7 @@ def _pallas_build_pipeline_specs( output_indices: list[int], block_spec_info: _BlockSpecInfo | None, pipeline_arg_indices: list[int] | None, + output_only_indices: list[int] | None = None, ) -> tuple[list[object], object]: """Build in/out specs for pipeline launchers. @@ -316,7 +315,8 @@ def _pallas_build_pipeline_specs( All other tensors get proper BlockSpecs for automatic VMEM prefetch. """ pipeline_set = set(pipeline_arg_indices or []) - arg_to_tpos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} + all_positions = sorted(set(tensor_arg_indices) | set(output_only_indices or [])) + arg_to_tpos = {orig: tpos for tpos, orig in enumerate(all_positions)} def _spec_for(idx: int) -> object: if idx in pipeline_set: @@ -351,25 +351,24 @@ def _jax_placeholder_for_tensor(t: torch.Tensor) -> object: def _pallas_prepare_args( args: tuple[object, ...], _output_indices: list[int], + _inplace_indices: list[int] | None = None, ) -> tuple[ - set[int], + list[int], list[int], dict[int, object], int, dict[int, int], - list[object], set[int], tuple[object, ...], ]: """Extract and organize tensor/non-tensor args for Pallas launchers. Returns a tuple of: - - output_set: set of output arg positions - - tensor_arg_indices: positions of tensor args + - tensor_arg_indices: positions of tensor args passed as pallas_call inputs + - output_only_indices: positions of output-only tensors (excluded from inputs) - non_tensor_args: mapping of non-tensor arg positions to values - - n_tensor_inputs: count of tensor args + - n_tensor_inputs: count of tensor inputs (excl. output-only) - arg_to_tensor_pos: mapping from original position to tensor-only position - - outputs: list of output tensors - inplace_positions: positions that are both input and output - out_shapes: JAX placeholders for output shapes """ @@ -385,25 +384,29 @@ def _pallas_prepare_args( placeholder_fn = jax_placeholder output_set = set(_output_indices) - tensor_arg_indices = [ + inplace_set = set(_inplace_indices) if _inplace_indices is not None else output_set + output_only = output_set - inplace_set + + all_tensor_positions = [ i for i in range(len(args)) if isinstance(args[i], torch.Tensor) ] + output_only_indices = [i for i in all_tensor_positions if i in output_only] + tensor_arg_indices = [i for i in all_tensor_positions if i not in output_only] + non_tensor_args: dict[int, object] = { i: args[i] for i in range(len(args)) if not isinstance(args[i], torch.Tensor) } n_tensor_inputs = len(tensor_arg_indices) arg_to_tensor_pos = {orig: tpos for tpos, orig in enumerate(tensor_arg_indices)} - outputs = [args[i] for i in _output_indices] inplace_positions = output_set & set(tensor_arg_indices) - out_shapes = tuple(placeholder_fn(out) for out in outputs) # type: ignore[arg-type] + out_shapes = tuple(placeholder_fn(args[i]) for i in _output_indices) # type: ignore[arg-type] return ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, ) @@ -430,10 +433,9 @@ def _pallas_make_reordered_kernel( reordered args. *skip_inplace_copy* is a set of original-arg positions for which the - initial ``out_ref[...] = in_ref[...]`` copy should be skipped. This is - needed for outputs backed by HBM refs (``pl.ANY``) where direct - load/store is not allowed. Outputs with VMEM BlockSpecs still get the - copy so that ``input_output_aliases`` correctly preloads their contents. + initial ``out_ref[...] = in_ref[...]`` copy should be skipped. Used by + pipeline/fori launchers for pipeline-body tensors backed by HBM refs + where direct load/store is not allowed. """ _skip_copy = skip_inplace_copy or set() @@ -479,11 +481,19 @@ def _pallas_build_callable( """ def _make_interpret_callable() -> _PallasInterpretCallable: - output_tensor_positions = [ - arg_to_tensor_pos[orig_pos] for orig_pos in _output_indices + # Map (out_idx in _output_indices) -> tensor_pos for inplace outputs. + # out_idx must match jax_results ordering (all outputs), not filtered. + inplace_output_mapping = [ + (out_idx, arg_to_tensor_pos[orig_pos]) + for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos ] - callable_obj = _PallasInterpretCallable(jit_fn, output_tensor_positions) - setattr(pallas_kernel, cache_attr, (grid, callable_obj, tensor_arg_indices)) + callable_obj = _PallasInterpretCallable(jit_fn, inplace_output_mapping) + setattr( + pallas_kernel, + cache_attr, + (grid, callable_obj, tensor_arg_indices, arg_to_tensor_pos), + ) return callable_obj if _pallas_interpret_flag(): @@ -498,7 +508,8 @@ def _make_interpret_callable() -> _PallasInterpretCallable: call_aliases: dict[int, int] = {} for out_idx, orig_pos in enumerate(_output_indices): - call_aliases[arg_to_tensor_pos[orig_pos]] = out_idx + if orig_pos in arg_to_tensor_pos: + call_aliases[arg_to_tensor_pos[orig_pos]] = out_idx jax_callable = JaxCallable( name=kernel_name, @@ -506,7 +517,11 @@ def _make_interpret_callable() -> _PallasInterpretCallable: trace_key=f"{kernel_name}_{id(pallas_kernel)}_{grid}{trace_key_suffix}", input_output_aliases=call_aliases, ) - setattr(pallas_kernel, cache_attr, (grid, jax_callable, tensor_arg_indices)) + setattr( + pallas_kernel, + cache_attr, + (grid, jax_callable, tensor_arg_indices, arg_to_tensor_pos), + ) return jax_callable @@ -514,30 +529,38 @@ class _PallasInterpretCallable: """Thin wrapper that converts torch tensors <-> JAX arrays for interpret mode. ``pallas_call`` with ``input_output_aliases`` returns new JAX arrays for the - outputs. This wrapper copies those results back into the original torch - output tensors (identified by ``output_tensor_positions``). + outputs. This wrapper copies inplace results back into the original torch + output tensors and returns all JAX results for output-only handling. + + ``inplace_output_mapping`` is a list of ``(out_idx, tensor_pos)`` tuples + where ``out_idx`` is the position in ``jax_results`` (matching + ``_output_indices`` order) and ``tensor_pos`` is the position in + ``input_tensors``. """ def __init__( self, jit_fn: object, - output_tensor_positions: list[int], + inplace_output_mapping: list[tuple[int, int]], ) -> None: self._jit_fn = jit_fn - self._output_tensor_positions = output_tensor_positions + self._inplace_output_mapping = inplace_output_mapping - def __call__(self, *input_tensors: torch.Tensor) -> None: + def __call__(self, *input_tensors: torch.Tensor) -> tuple[object, ...]: jax_inputs = [_torch_to_jax(t) for t in input_tensors] jax_results = self._jit_fn(*jax_inputs) # type: ignore[operator] if not isinstance(jax_results, (tuple, list)): jax_results = (jax_results,) - # Write results back into the original output tensors. - for out_idx, tensor_pos in enumerate(self._output_tensor_positions): + # Write inplace results back into the original output tensors. + for out_idx, tensor_pos in self._inplace_output_mapping: out_tensor = input_tensors[tensor_pos] result_data = _jax_to_torch( jax_results[out_idx], device=out_tensor.device, dtype=out_tensor.dtype ) out_tensor.copy_(result_data) + # Return JAX results so output-only tensors can be handled + # by _pallas_invoke_and_copy_back. + return tuple(jax_results) def _pallas_interpret_flag() -> bool: @@ -568,6 +591,46 @@ def _ensure_cpu_tpu_info() -> None: registry["cpu"] = lambda: _get_tpu_info_impl(ChipVersion.TPU_7X, 1) +def _pallas_invoke_and_copy_back( + jax_callable: object, + args: tuple[object, ...], + tensor_arg_indices: list[int], + arg_to_tensor_pos: dict[int, int], + _output_indices: list[int], +) -> object: + """Run the JaxCallable and return output-only results. + + Output-only tensors (those not in ``arg_to_tensor_pos``) are not passed + as pallas_call inputs, so the JaxCallable returns them. + Returns a single tensor, a tuple of tensors, or None. + """ + input_tensors = [ + cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices + ] + results = jax_callable(*input_tensors) # type: ignore[operator] + if results is None: + return None + if not isinstance(results, (tuple, list)): + results = (results,) + output_only_results: list[torch.Tensor] = [] + for out_idx, orig_pos in enumerate(_output_indices): + if orig_pos not in arg_to_tensor_pos: + result = results[out_idx] + if not isinstance(result, torch.Tensor): + out_tensor = cast("torch.Tensor", args[orig_pos]) + result = _jax_to_torch( + result, + device=out_tensor.device, + dtype=out_tensor.dtype, + ) + output_only_results.append(result) + if not output_only_results: + return None + if len(output_only_results) == 1: + return output_only_results[0] + return tuple(output_only_results) + + def default_pallas_launcher( pallas_kernel: object, grid: tuple[int, ...], @@ -577,7 +640,7 @@ def default_pallas_launcher( _block_spec_info: _BlockSpecInfo | None = None, _smem_arg_indices: list[int] | None = None, **kwargs: object, -) -> None: +) -> object: """Default launcher for Pallas kernels on TPU (or CPU with interpret=True). Uses ``JaxCallable`` from ``torch_tpu`` to compile and run the Pallas @@ -587,37 +650,29 @@ def default_pallas_launcher( buffers (zero-copy on TPU). Output-only tensors (in ``_output_indices`` but not in ``_inplace_indices``) - get HBM in_specs to avoid VMEM pressure while still being donated. + are excluded from pallas_call inputs to save VMEM. Their results are + returned as torch tensors. """ if _output_indices is None: _output_indices = [] cache = getattr(pallas_kernel, "_pallas_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) - - # Derive output-only set: outputs not in _inplace_indices. - inplace_set = ( - set(_inplace_indices) - if _inplace_indices is not None - else set(_output_indices) - ) - output_only_set = set(_output_indices) - inplace_set + ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) in_specs, out_specs = _pallas_build_block_specs( pl, @@ -629,7 +684,7 @@ def default_pallas_launcher( _output_indices, _block_spec_info, _smem_arg_indices, - output_only_set, + output_only_indices, ) reordered_kernel = _pallas_make_reordered_kernel( @@ -642,7 +697,6 @@ def default_pallas_launcher( inplace_positions, arg_to_tensor_pos, _smem_arg_indices=_smem_arg_indices, - skip_inplace_copy=output_only_set, ) out_shape_arg = out_shapes if len(out_shapes) > 1 else out_shapes[0] @@ -650,6 +704,7 @@ def default_pallas_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } pallas_call_kwargs: dict[str, object] = { @@ -678,10 +733,9 @@ def default_pallas_launcher( cache_attr="_pallas_cache", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_copy_back( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def default_pallas_pipeline_launcher( @@ -689,11 +743,12 @@ def default_pallas_pipeline_launcher( grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, + _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str]] | None = None, _pipeline_arg_indices: list[int] | None = None, **kwargs: object, -) -> None: +) -> object: """Launcher for Pallas kernels using PrefetchScalarGridSpec with scratch memory. Used when ``pallas_loop_type='emit_pipeline'``. Pipeline-body tensors @@ -707,22 +762,21 @@ def default_pallas_pipeline_launcher( cache = getattr(pallas_kernel, "_pallas_pipeline_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) + ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) # Build scratch shapes for VMEM _jnp_dtype_map: dict[str, object] = { @@ -760,6 +814,7 @@ def default_pallas_pipeline_launcher( _output_indices, _block_spec_info, _pipeline_arg_indices, + output_only_indices, ) _pipeline_set = set(_pipeline_arg_indices or []) @@ -781,6 +836,7 @@ def default_pallas_pipeline_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } grid_spec = pltpu.PrefetchScalarGridSpec( @@ -818,10 +874,9 @@ def default_pallas_pipeline_launcher( trace_key_suffix="_pipeline", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_copy_back( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def default_pallas_fori_launcher( @@ -829,10 +884,11 @@ def default_pallas_fori_launcher( grid: tuple[int, ...], *args: object, _output_indices: list[int] | None = None, + _inplace_indices: list[int] | None = None, _block_spec_info: _BlockSpecInfo | None = None, _scratch_shapes: list[tuple[tuple[int, ...], str | None, str]] | None = None, **kwargs: object, -) -> None: +) -> object: """Launcher for Pallas kernels using fori_loop with manual DMA. Used when ``pallas_loop_type="fori_loop"``. Passes all tensors as @@ -848,22 +904,21 @@ def default_pallas_fori_launcher( cache = getattr(pallas_kernel, "_pallas_fori_cache", None) if cache is not None and cache[0] == grid: - _, jax_callable, tensor_arg_indices = cache + _, jax_callable, tensor_arg_indices, arg_to_tensor_pos = cache else: from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp ( - output_set, tensor_arg_indices, + output_only_indices, non_tensor_args, n_tensor_inputs, arg_to_tensor_pos, - outputs, inplace_positions, out_shapes, - ) = _pallas_prepare_args(args, _output_indices) + ) = _pallas_prepare_args(args, _output_indices, _inplace_indices) # Build scratch shapes: VMEM buffers + DMA semaphores _jnp_dtype_map: dict[str, object] = { @@ -900,6 +955,7 @@ def default_pallas_fori_launcher( _output_indices, _block_spec_info, _fori_pipeline_indices, # type: ignore[arg-type] + output_only_indices, ) _fori_pipeline_set = set(_fori_pipeline_indices or []) # type: ignore[arg-type] @@ -921,6 +977,7 @@ def default_pallas_fori_launcher( pallas_aliases = { arg_to_tensor_pos[orig_pos]: out_idx for out_idx, orig_pos in enumerate(_output_indices) + if orig_pos in arg_to_tensor_pos } grid_spec = pltpu.PrefetchScalarGridSpec( @@ -958,10 +1015,9 @@ def default_pallas_fori_launcher( trace_key_suffix="_fori", ) - input_tensors = [ - cast("torch.Tensor", args[i]).contiguous() for i in tensor_arg_indices - ] - jax_callable(*input_tensors) # type: ignore[operator] + return _pallas_invoke_and_copy_back( + jax_callable, args, tensor_arg_indices, arg_to_tensor_pos, _output_indices + ) def _torch_to_jax(t: torch.Tensor) -> object: diff --git a/test/test_pallas.py b/test/test_pallas.py index b502445fb..b20187314 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -730,6 +730,8 @@ def test_output_only_not_inplace(self) -> None: code, result = code_and_output(pallas_relu, (x,), block_sizes=[1024]) torch.testing.assert_close(result, torch.relu(x)) self.assertIn("_inplace_indices=[]", code) + # Codegen should capture the launcher return value + self.assertIn("_launcher_result = _launcher(", code) def test_new_empty_output_only(self) -> None: """new_empty allocations should also be recognized as output-only.""" @@ -745,6 +747,7 @@ def new_empty_relu(x: torch.Tensor) -> torch.Tensor: code, result = code_and_output(new_empty_relu, (x,), block_sizes=[1024]) torch.testing.assert_close(result, torch.relu(x)) self.assertIn("_inplace_indices=[]", code) + self.assertIn("_launcher_result = _launcher(", code) def test_mixed_inplace_and_output_only(self) -> None: """Kernel with both an inplace-mutated input and an output-only tensor. @@ -767,6 +770,8 @@ def inplace_and_output(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected_out) # x is inplace-mutated (index 0), out is output-only (not in inplace) self.assertIn("_inplace_indices=[0]", code) + # out is output-only so codegen captures the return value + self.assertIn("_launcher_result = _launcher(", code) def test_empty_like_read_stays_inplace(self) -> None: """An empty_like output that is also read should stay in _inplace_indices.""" @@ -784,6 +789,30 @@ def read_write_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, x + 1.0) # out is read after write, so it must be in _inplace_indices self.assertIn("_inplace_indices=[1]", code) + # No return capture — out is donated, not returned by launcher + self.assertNotIn("_launcher_result = _launcher(", code) + + def test_multiple_output_only(self) -> None: + """Kernel returning two output-only tensors.""" + + @helion.kernel(backend="pallas", static_shapes=True) + def two_outputs(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + out1 = torch.empty_like(x) + out2 = torch.empty_like(x) + for tile in hl.tile(x.size()): + out1[tile] = x[tile] + 1.0 + out2[tile] = x[tile] * 2.0 + return out1, out2 + + x = torch.randn(1024, device=DEVICE, dtype=torch.float32) + code, (result1, result2) = code_and_output( + two_outputs, (x,), block_sizes=[1024] + ) + torch.testing.assert_close(result1, x + 1.0) + torch.testing.assert_close(result2, x * 2.0) + self.assertIn("_inplace_indices=[]", code) + # Codegen should capture both return values + self.assertIn("out1, out2 = _launcher_result", code) if __name__ == "__main__":