Skip to content

Commit 0cda00f

Browse files
committed
[metal] Add MSL walker, 3D launcher, and copy kernel tests
Add the MSL AST walker that converts Python AST to MSL C++ source, MetalBackend.post_process_function_def to wire it into codegen, update the launcher to 3D threadgroup dispatch, and add copy kernel tests. stack-info: PR: #1794, branch: aditvenk/stack/13
1 parent 487dc06 commit 0cda00f

5 files changed

Lines changed: 639 additions & 8 deletions

File tree

helion/_compiler/backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,3 +2499,38 @@ def launcher_keyword_args(self, config: Config, *, has_barrier: bool) -> list[st
24992499

25002500
dims = tuple(DeviceFunction.current().codegen.max_thread_block_dims)
25012501
return [f"_block_dims=({dims[0]}, {dims[1]}, {dims[2]})"]
2502+
2503+
# --- MSL code generation ---
2504+
2505+
def post_process_function_def(
2506+
self, stmts: list[ast.stmt], device_fn: DeviceFunction
2507+
) -> list[ast.stmt]:
2508+
"""Post-process the standard codegen output into an MSL-returning function.
2509+
2510+
Extracts the body from the FunctionDef, creates an ``MslAstWalker``
2511+
to generate MSL, and replaces the function with a zero-arg function
2512+
that returns ``(msl_source, kernel_name)``.
2513+
"""
2514+
import ast as _ast
2515+
2516+
from .ast_extension import create
2517+
from .ast_extension import create_arguments
2518+
from .ast_extension import statement_from_string
2519+
from .metal.msl_ast_walker import MslAstWalker
2520+
2521+
fn_def = next(s for s in stmts if isinstance(s, _ast.FunctionDef))
2522+
2523+
kernel_name = device_fn.name
2524+
walker = MslAstWalker(device_fn, fn_def.body)
2525+
msl_source = walker.generate()
2526+
2527+
fn_body = statement_from_string(f"return ({msl_source!r}, {kernel_name!r})")
2528+
msl_fn = create(
2529+
_ast.FunctionDef,
2530+
name=kernel_name,
2531+
args=create_arguments([]),
2532+
body=[fn_body],
2533+
decorator_list=[],
2534+
type_params=[],
2535+
)
2536+
return [msl_fn]

helion/_compiler/metal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from __future__ import annotations

0 commit comments

Comments
 (0)