Skip to content

[metal] Add MSL AST walker for Python-to-C++ translation#1794

Merged
aditvenk merged 1 commit intomainfrom
aditvenk/stack/13
Apr 13, 2026
Merged

[metal] Add MSL AST walker for Python-to-C++ translation#1794
aditvenk merged 1 commit intomainfrom
aditvenk/stack/13

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

@aditvenk aditvenk commented Mar 24, 2026

Stacked PRs:


[metal] Add MSL AST walker for Python-to-C++ translation

Add msl_ast_walker.py which translates Python AST to MSL C++ source.
Handles statement-level translation (assignments, if/for, etc.),
tl.load/tl.store → pointer dereferences, and C++ namespace restoration
(metal.precise.sin → metal::precise::sin).

This is a standalone library module — not yet wired into the backend.

aditvenk added a commit that referenced this pull request Mar 24, 2026
End-to-end Metal codegen for copy kernels: a Helion kernel like
``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU.

Metal load/store delegates to the same PointerIndexingStrategy as
Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker
(MslAstWalker) then translates this Triton-flavored AST into Metal
C++. This reuse is safe because the Triton AST is never executed —
it is purely an intermediate representation. The walker also handles
``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to.

All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3
tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the
final form needed for matmul. Autotuning is disabled for now.

Tests:
- copy aligned (1024) and non-aligned (1000)
- masked copy via torch.where, aligned and non-aligned
- OOB sentinel checks (size 999 and size 1)
- codegen assertions: mask always present (force_tile_mask)

stack-info: PR: #1794, branch: aditvenk/stack/13
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 4c7b81e to 5a843eb Compare March 24, 2026 00:25
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 24, 2026
@aditvenk aditvenk marked this pull request as draft March 24, 2026 00:30
aditvenk added a commit that referenced this pull request Mar 24, 2026
End-to-end Metal codegen for copy kernels: a Helion kernel like
``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU.

Metal load/store delegates to the same PointerIndexingStrategy as
Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker
(MslAstWalker) then translates this Triton-flavored AST into Metal
C++. This reuse is safe because the Triton AST is never executed —
it is purely an intermediate representation. The walker also handles
``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to.

All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3
tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the
final form needed for matmul. Autotuning is disabled for now.

Tests:
- copy aligned (1024) and non-aligned (1000)
- masked copy via torch.where, aligned and non-aligned
- OOB sentinel checks (size 999 and size 1)
- codegen assertions: mask always present (force_tile_mask)

stack-info: PR: #1794, branch: aditvenk/stack/13
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 5a843eb to 20b7d22 Compare March 24, 2026 00:30
@aditvenk aditvenk marked this pull request as ready for review March 24, 2026 00:30
@aditvenk aditvenk marked this pull request as draft March 24, 2026 00:34
aditvenk added a commit that referenced this pull request Mar 24, 2026
End-to-end Metal codegen for copy kernels: a Helion kernel like
``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU.

Metal load/store delegates to the same PointerIndexingStrategy as
Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker
(MslAstWalker) then translates this Triton-flavored AST into Metal
C++. This reuse is safe because the Triton AST is never executed —
it is purely an intermediate representation. The walker also handles
``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to.

All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3
tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the
final form needed for matmul. Autotuning is disabled for now.

Tests:
- copy aligned (1024) and non-aligned (1000)
- masked copy via torch.where, aligned and non-aligned
- OOB sentinel checks (size 999 and size 1)
- codegen assertions: mask always present (force_tile_mask)

stack-info: PR: #1794, branch: aditvenk/stack/13
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 20b7d22 to 087162c Compare March 24, 2026 00:34
@aditvenk aditvenk marked this pull request as ready for review March 24, 2026 00:34
@aditvenk
Copy link
Copy Markdown
Contributor Author

Example:

@helion.kernel(backend="metal", configs=[helion.Config(block_sizes=[256], num_warps=4)])
  def masked_copy(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
      out = torch.zeros_like(x)
      for tile in hl.tile(x.size(0)):
          out[tile] = torch.where(mask[tile], x[tile], 0.0)
      return out

translated to

 kernel void _helion_masked_copy(device bool* mask [[buffer(0)]], device float* x [[buffer(1)]], device float* out [[buffer(2)]], uint3 tgid [[threadgroup_position_in_grid]], uint3 tid [[thread_position_in_threadgroup]]) {
      constexpr int _BLOCK_SIZE_0 = 256;
      auto pid_0 = tgid[0];
      auto offset_0 = (pid_0 * _BLOCK_SIZE_0);
      auto indices_0 = (offset_0 + tid[0]);
      auto mask_0 = (indices_0 < 1000);
      auto load = (mask_0 ? *((mask + (indices_0 * 1))) : (0));
      auto load_1 = (mask_0 ? *((x + (indices_0 * 1))) : (0));
      auto v_0 = ((float)(0.0));
      auto v_1 = select(v_0, load_1, load);
      if (mask_0) {
          *((out + (indices_0 * 1))) = v_1;
      }
  }

@aditvenk aditvenk marked this pull request as draft March 24, 2026 03:51
@aditvenk aditvenk marked this pull request as ready for review March 24, 2026 03:51
@aditvenk aditvenk requested review from jansel, malfet and oulgen and removed request for jansel and oulgen March 24, 2026 03:53
@aditvenk aditvenk marked this pull request as draft March 24, 2026 04:32
@aditvenk aditvenk marked this pull request as ready for review March 24, 2026 04:32
@aditvenk aditvenk marked this pull request as draft March 27, 2026 22:34
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch 2 times, most recently from fcbd308 to cfc9c34 Compare March 27, 2026 22:35
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 87cda76 to 0cda00f Compare March 28, 2026 02:30
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/17 March 28, 2026 02:30
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:30
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:35
@aditvenk aditvenk changed the base branch from aditvenk/stack/17 to main March 28, 2026 02:35
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 0cda00f to 6d4065e Compare March 28, 2026 02:35
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/17 March 28, 2026 02:36
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:36
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:42
@aditvenk aditvenk changed the base branch from aditvenk/stack/17 to main March 28, 2026 02:42
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 6d4065e to 71f53e7 Compare March 28, 2026 02:42
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/17 March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/17 to main March 28, 2026 02:45
@aditvenk aditvenk force-pushed the aditvenk/stack/13 branch from 71f53e7 to afdcfb2 Compare March 28, 2026 02:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/17 March 28, 2026 02:45
Copy link
Copy Markdown
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, so this design still doesn't seem ideal to me. Some alternate ideas:

  1. We could extend Helion to be able to codegen C++ directly and not force things to go via a fake-python-AST. This seems the cleanest, though I would require more work.

  2. Alternately, we could split the Python -> C++ translation into its own separate layer -- so you could do something like:

@metal_python_kernel
def my_metal_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
   ...

The @metal_python_kernel decorator would handle:

  1. Translating the fake AST into C++
  2. Compiling that C++
  3. Generating the launcher wrapper calls C++ from Python

This might be useful as its own standalone thing. I'm sure there are people out there who would rather write metal using a Python interface.

It will also be easier to debug, since less is happening in a single step.

@aditvenk
Copy link
Copy Markdown
Contributor Author

aditvenk commented Apr 1, 2026

Hrm, so this design still doesn't seem ideal to me. Some alternate ideas:

  1. We could extend Helion to be able to codegen C++ directly and not force things to go via a fake-python-AST. This seems the cleanest, though I would require more work.
  2. Alternately, we could split the Python -> C++ translation into its own separate layer -- so you could do something like:
@metal_python_kernel
def my_metal_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
   ...

The @metal_python_kernel decorator would handle:

  1. Translating the fake AST into C++
  2. Compiling that C++
  3. Generating the launcher wrapper calls C++ from Python

This might be useful as its own standalone thing. I'm sure there are people out there who would rather write metal using a Python interface.

It will also be easier to debug, since less is happening in a single step.

Thanks, this is an interesting idea. Let me play around a bit

@aditvenk
Copy link
Copy Markdown
Contributor Author

aditvenk commented Apr 2, 2026

@jansel -- I prototyped Option 2. Helion AST is lowered (via MetalOverrides and MetalBackend) into a fake Metal Python AST + a decorator. The decorator does the last mile AST walking to convert to C++, compile MSL, and generate the launcher call.

Here's an example of the fake Metal Python

  @metal_python_kernel
  def _helion_silu(x, out):
      pid_flat = tgid[0]
      offsets_0 = pid_flat * _BLOCK_SIZE_0 + tid[0]
      indices_0 = offsets_0
      mask_0 = tid[0] < _BLOCK_SIZE_0 and offsets_0 < 1024
      load = tl.load(x + indices_0 * 1, mask_0, other=0)
      v_0 = static_cast < decltype(load) > -load
      v_1 = metal.precise.exp(v_0)
      v_2 = 1
      v_3 = v_1 + v_2
      v_4 = load / v_3
      tl.store(out + indices_0 * 1, v_4, mask_0)

This is closer to "pseudo-code" and can help with understanding/debugging, but I will not consider this as a useful way to write Metal kernel from Python. (PS: The tl.load/tl.store comes from PointerIndexingStrategy which currently hardcodes tl.load/store)

The decorator internally uses the AST walker in this PR to lower to pure C++.

Wdyt about this approach?

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 8, 2026

@aditvenk let's go with that option, that seems cleaner.

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 9, 2026

Re-request review once you want me to look at this again.

@aditvenk
Copy link
Copy Markdown
Contributor Author

aditvenk commented Apr 9, 2026

Re-request review once you want me to look at this again.

Absolutely. I am still refining the metal_jit decorator (follow-up PR) -- will re-request review once I am happy with it.

Add msl_ast_walker.py which translates Python AST to MSL C++ source.
Handles statement-level translation (assignments, if/for, etc.),
tl.load/tl.store → pointer dereferences, and C++ namespace restoration
(metal.precise.sin → metal::precise::sin).

This is a standalone library module — not yet wired into the backend.

stack-info: PR: #1794, branch: aditvenk/stack/13
@aditvenk
Copy link
Copy Markdown
Contributor Author

aditvenk commented Apr 9, 2026

@jansel -- ready for re-review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants