[metal] Add MSL AST walker for Python-to-C++ translation#1794
Conversation
End-to-end Metal codegen for copy kernels: a Helion kernel like ``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU. Metal load/store delegates to the same PointerIndexingStrategy as Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker (MslAstWalker) then translates this Triton-flavored AST into Metal C++. This reuse is safe because the Triton AST is never executed — it is purely an intermediate representation. The walker also handles ``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to. All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3 tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the final form needed for matmul. Autotuning is disabled for now. Tests: - copy aligned (1024) and non-aligned (1000) - masked copy via torch.where, aligned and non-aligned - OOB sentinel checks (size 999 and size 1) - codegen assertions: mask always present (force_tile_mask) stack-info: PR: #1794, branch: aditvenk/stack/13
4c7b81e to
5a843eb
Compare
End-to-end Metal codegen for copy kernels: a Helion kernel like ``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU. Metal load/store delegates to the same PointerIndexingStrategy as Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker (MslAstWalker) then translates this Triton-flavored AST into Metal C++. This reuse is safe because the Triton AST is never executed — it is purely an intermediate representation. The walker also handles ``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to. All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3 tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the final form needed for matmul. Autotuning is disabled for now. Tests: - copy aligned (1024) and non-aligned (1000) - masked copy via torch.where, aligned and non-aligned - OOB sentinel checks (size 999 and size 1) - codegen assertions: mask always present (force_tile_mask) stack-info: PR: #1794, branch: aditvenk/stack/13
5a843eb to
20b7d22
Compare
End-to-end Metal codegen for copy kernels: a Helion kernel like ``out[tile] = x[tile]`` now compiles to MSL and runs on Apple GPU. Metal load/store delegates to the same PointerIndexingStrategy as Triton, producing ``tl.load``/``tl.store`` AST nodes. The MSL walker (MslAstWalker) then translates this Triton-flavored AST into Metal C++. This reuse is safe because the Triton AST is never executed — it is purely an intermediate representation. The walker also handles ``tl.where``/``tl.full``/``tl.cast`` for masked copies via _mask_to. All kernels use a unified 3D dispatch model (``uint3 tgid`` / ``uint3 tid``) and pass ``_block_dims=(x,y,z)`` to the launcher, matching the final form needed for matmul. Autotuning is disabled for now. Tests: - copy aligned (1024) and non-aligned (1000) - masked copy via torch.where, aligned and non-aligned - OOB sentinel checks (size 999 and size 1) - codegen assertions: mask always present (force_tile_mask) stack-info: PR: #1794, branch: aditvenk/stack/13
20b7d22 to
087162c
Compare
|
Example: @helion.kernel(backend="metal", configs=[helion.Config(block_sizes=[256], num_warps=4)])
def masked_copy(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x)
for tile in hl.tile(x.size(0)):
out[tile] = torch.where(mask[tile], x[tile], 0.0)
return outtranslated to kernel void _helion_masked_copy(device bool* mask [[buffer(0)]], device float* x [[buffer(1)]], device float* out [[buffer(2)]], uint3 tgid [[threadgroup_position_in_grid]], uint3 tid [[thread_position_in_threadgroup]]) {
constexpr int _BLOCK_SIZE_0 = 256;
auto pid_0 = tgid[0];
auto offset_0 = (pid_0 * _BLOCK_SIZE_0);
auto indices_0 = (offset_0 + tid[0]);
auto mask_0 = (indices_0 < 1000);
auto load = (mask_0 ? *((mask + (indices_0 * 1))) : (0));
auto load_1 = (mask_0 ? *((x + (indices_0 * 1))) : (0));
auto v_0 = ((float)(0.0));
auto v_1 = select(v_0, load_1, load);
if (mask_0) {
*((out + (indices_0 * 1))) = v_1;
}
} |
fcbd308 to
cfc9c34
Compare
87cda76 to
0cda00f
Compare
0cda00f to
6d4065e
Compare
6d4065e to
71f53e7
Compare
71f53e7 to
afdcfb2
Compare
jansel
left a comment
There was a problem hiding this comment.
Hrm, so this design still doesn't seem ideal to me. Some alternate ideas:
-
We could extend Helion to be able to codegen C++ directly and not force things to go via a fake-python-AST. This seems the cleanest, though I would require more work.
-
Alternately, we could split the Python -> C++ translation into its own separate layer -- so you could do something like:
@metal_python_kernel
def my_metal_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
...
The @metal_python_kernel decorator would handle:
- Translating the fake AST into C++
- Compiling that C++
- Generating the launcher wrapper calls C++ from Python
This might be useful as its own standalone thing. I'm sure there are people out there who would rather write metal using a Python interface.
It will also be easier to debug, since less is happening in a single step.
Thanks, this is an interesting idea. Let me play around a bit |
|
@jansel -- I prototyped Option 2. Helion AST is lowered (via MetalOverrides and MetalBackend) into a fake Metal Python AST + a decorator. The decorator does the last mile AST walking to convert to C++, compile MSL, and generate the launcher call. Here's an example of the fake Metal Python This is closer to "pseudo-code" and can help with understanding/debugging, but I will not consider this as a useful way to write Metal kernel from Python. (PS: The tl.load/tl.store comes from PointerIndexingStrategy which currently hardcodes tl.load/store) The decorator internally uses the AST walker in this PR to lower to pure C++. Wdyt about this approach? |
|
@aditvenk let's go with that option, that seems cleaner. |
|
Re-request review once you want me to look at this again. |
Absolutely. I am still refining the metal_jit decorator (follow-up PR) -- will re-request review once I am happy with it. |
Add msl_ast_walker.py which translates Python AST to MSL C++ source. Handles statement-level translation (assignments, if/for, etc.), tl.load/tl.store → pointer dereferences, and C++ namespace restoration (metal.precise.sin → metal::precise::sin). This is a standalone library module — not yet wired into the backend. stack-info: PR: #1794, branch: aditvenk/stack/13
|
@jansel -- ready for re-review.
|
Stacked PRs:
[metal] Add MSL AST walker for Python-to-C++ translation
Add msl_ast_walker.py which translates Python AST to MSL C++ source.
Handles statement-level translation (assignments, if/for, etc.),
tl.load/tl.store → pointer dereferences, and C++ namespace restoration
(metal.precise.sin → metal::precise::sin).
This is a standalone library module — not yet wired into the backend.