Skip to content

[metal] Add elementwise kernel tests#1796

Closed
aditvenk wants to merge 5 commits intoaditvenk/stack/24from
aditvenk/stack/14
Closed

[metal] Add elementwise kernel tests#1796
aditvenk wants to merge 5 commits intoaditvenk/stack/24from
aditvenk/stack/14

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

@aditvenk aditvenk commented Mar 24, 2026

Stacked PRs:


[metal] Add elementwise kernel tests

Add tests for arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
activations (relu/silu/gelu_approx), math ops (exp/log/sqrt/abs/
sin+cos/clamp), dtypes (float16/bfloat16/int32), bounds masking,
and >1D tensors (2D aligned + non-aligned, 3D).

aditvenk added a commit that referenced this pull request Mar 24, 2026
Extend MslAstWalker to handle the Triton AST patterns produced by
elementwise operations:
- libdevice.func / tl_math.func → MSL intrinsics (exp via exp2 + ln2)
- triton_helpers.maximum/minimum → max/min
- tl.* catch-all: strip _rn suffix, inline sigmoid, pass through others
- ast.Pow → sqrt(x) for **0.5, pow(x, y) otherwise
- Strip Triton broadcasting subscripts ([:, None]) as no-ops

Tests cover arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
activations (relu/silu/gelu_approx), math ops (exp/log/sqrt/abs/
sin+cos/clamp), dtypes (float16/bfloat16/int32), bounds masking,
and >1D tensors (2D aligned + non-aligned, 3D).

stack-info: PR: #1796, branch: aditvenk/stack/14
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch from 7c675d9 to 5b3c02a Compare March 24, 2026 03:51
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 24, 2026
@aditvenk aditvenk requested review from jansel, malfet and oulgen March 24, 2026 03:54
@aditvenk aditvenk marked this pull request as draft March 24, 2026 04:32
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 24, 2026 04:32
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/13 March 24, 2026 04:32
@aditvenk aditvenk marked this pull request as ready for review March 24, 2026 04:32
@aditvenk aditvenk marked this pull request as draft March 27, 2026 22:34
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 27, 2026 22:35
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch 2 times, most recently from fe08c0b to ccbd95d Compare March 27, 2026 22:35
@aditvenk aditvenk changed the title [metal] Add support for elementwise kernels [metal] Add elementwise kernel tests Mar 27, 2026
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/13 March 27, 2026 22:35
@aditvenk aditvenk marked this pull request as ready for review March 27, 2026 22:36
@aditvenk aditvenk marked this pull request as draft March 27, 2026 22:47
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 27, 2026 22:47
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch from ccbd95d to ba050a8 Compare March 27, 2026 22:48
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/13 March 27, 2026 22:48
@aditvenk aditvenk marked this pull request as ready for review March 27, 2026 22:48
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:42
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 28, 2026 02:42
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch from 7c7f971 to 805f6d6 Compare March 28, 2026 02:42
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/13 March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 28, 2026 02:45
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch from 805f6d6 to 2d8ac76 Compare March 28, 2026 02:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/13 March 28, 2026 02:45
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:46
Add tests for arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
activations (relu/silu/gelu_approx), math ops (exp/log/sqrt/abs/
sin+cos/clamp), dtypes (float16/bfloat16/int32), bounds masking,
and >1D tensors (2D aligned + non-aligned, 3D).

stack-info: PR: #1796, branch: aditvenk/stack/14
@aditvenk aditvenk marked this pull request as draft March 28, 2026 04:44
@aditvenk aditvenk changed the base branch from aditvenk/stack/13 to main March 28, 2026 04:44
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch 2 times, most recently from e5a154d to 7cf5efd Compare March 28, 2026 04:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/24 March 28, 2026 04:45
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 04:45
@aditvenk aditvenk marked this pull request as draft March 28, 2026 04:47
@aditvenk aditvenk changed the base branch from aditvenk/stack/24 to main March 28, 2026 04:47
@aditvenk aditvenk force-pushed the aditvenk/stack/14 branch from 7cf5efd to 3e36ec4 Compare March 28, 2026 04:47
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/24 March 28, 2026 04:48
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 04:48
aditvenk added 4 commits April 8, 2026 21:30
Add msl_ast_walker.py which translates Python AST to MSL C++ source.
Handles statement-level translation (assignments, if/for, etc.),
tl.load/tl.store → pointer dereferences, and C++ namespace restoration
(metal.precise.sin → metal::precise::sin).

This is a standalone library module — not yet wired into the backend.

stack-info: PR: #1794, branch: aditvenk/stack/13
Add metal_jit decorator that JIT-compiles a Python function to an MSL
Metal shader on first call. The decorator:
1. Parses the decorated function's source to recover the Python AST
2. Calls _generate_msl to translate the AST body to MSL C++ source
3. Compiles the MSL via torch.mps.compile_shader
4. Caches the compiled library for subsequent calls

Metadata (tensor arg dtypes, block sizes) is passed as decorator
arguments by Helion's codegen: @metal_jit(args=[...], block_sizes=[...])

stack-info: PR: #1991, branch: aditvenk/stack/25
- MetalBackend.function_decorator returns "metal_jit"
- Add Backend.function_decorator_expr hook; MetalBackend overrides it
  to serialize arg metadata and block sizes into the decorator call
- device_function.py calls backend.function_decorator_expr(self)
- Launcher simplified: metal_jit returns compiled lib directly,
  no more source hashing or compile_shader in the launcher
- 3D threadgroup dispatch model with _block_dims

stack-info: PR: #1992, branch: aditvenk/stack/26
Add a macOS Metal entry to matrix.json (macos-m2-26, cpu runtime,
metal backend) and update test.yml to handle macOS:
- Gate apt-get on runner.os == 'Linux'
- Add cpu runtime branch for PyTorch install
- Add MPS availability check for metal backend
- Run test/test_metal.py without xdist for metal

stack-info: PR: #1862, branch: aditvenk/stack/24
@aditvenk
Copy link
Copy Markdown
Contributor Author

aditvenk commented Apr 9, 2026

Combined into #1992

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant