Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@
"pytorch-version": "pytorch-nightly",
"alias": "b200",
"backend": "cute"
},
{
"runner": "macos-26-xlarge",
"python-version": "3.12",
"ref-eager": false,
"image": "",
"runtime-version": "cpu",
"container-options": "",
"pytorch-version": "pytorch-2.9",
"alias": "m2-metal",
"backend": "metal"
}
]
}
9 changes: 8 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
uses: actions/checkout@v6

- name: Install system dependencies
if: runner.os == 'Linux'
run: |
set -eux
SUDO=$(command -v sudo 2>/dev/null || true)
Expand Down Expand Up @@ -234,6 +235,12 @@ jobs:
print(f'All {n} devices healthy')
"

- name: MPS Availability Check
if: matrix.backend == 'metal'
run: |
source .venv/bin/activate
python -c "import torch; assert torch.backends.mps.is_available(), 'MPS not available'"

- name: Run Tests
run: |
set -o pipefail
Expand All @@ -251,7 +258,7 @@ jobs:
if [[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]]; then
TEST_PATH="test/test_examples_dist.py"
EXTRA_FLAGS="-rs"
elif [[ "${{ matrix.alias }}" == "tpu" ]]; then
elif [[ "${{ matrix.alias }}" == "tpu" || "$HELION_BACKEND" == "metal" ]]; then
TEST_PATH="."
EXTRA_FLAGS="--ignore=test/test_examples_dist.py"
PARALLEL=""
Expand Down
7 changes: 7 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def _init_tpu_device() -> bool:
DEVICE = torch.device("xpu")
elif _has_mtia_runtime():
DEVICE = torch.device("mtia")
elif _get_backend() == "metal" and torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cuda")

Expand Down Expand Up @@ -289,6 +291,11 @@ def skipIfTileIR(reason: str) -> Callable[[Callable], Callable]:
return skipIfFn(lambda: _get_backend() == "tileir", reason)


def skipIfMetal(reason: str) -> Callable[[Callable], Callable]:
"""Skip test if running with metal"""
return skipIfFn(lambda: _get_backend() == "metal", reason)


def skipIfPallas(reason: str) -> Callable[[Callable], Callable]:
"""Skip test if running with pallas"""
# Defers check to test execution time to avoid CUDA init during pytest-xdist collection.
Expand Down
6 changes: 6 additions & 0 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfMetal
from helion._testing import skipUnlessTensorDescriptor
from helion._testing import xfailIfPallas
import helion.language as hl
Expand Down Expand Up @@ -120,6 +121,7 @@ def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
)
torch.testing.assert_close(result, grid_2d_pytorch(args[0], args[1]))

@skipIfMetal("aten.addmm not yet registered for Metal backend")
def test_grid_2d_idx_nested(self):
@helion.kernel(static_shapes=True)
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -281,6 +283,7 @@ def tile_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor:
code, result = code_and_output(tile_begin_end, (x,), block_size=4)
torch.testing.assert_close(result, tile_begin_end_pytorch(x))

@skipIfMetal("Metal does not support loop_index_expr for grid loops")
def test_range_as_grid_basic(self):
"""Test that range() works as an alias for hl.grid() in device code."""

Expand All @@ -301,6 +304,7 @@ def range_kernel(x: torch.Tensor) -> torch.Tensor:
code, result = code_and_output(range_kernel, (x,))
torch.testing.assert_close(result, expected)

@skipIfMetal("Metal does not support loop_index_expr for grid loops")
def test_range_with_begin_end(self):
"""Test that range(begin, end) works as alias for hl.grid(begin, end)."""

Expand All @@ -321,6 +325,7 @@ def range_begin_end_kernel(x: torch.Tensor) -> torch.Tensor:
code, result = code_and_output(range_begin_end_kernel, (x,))
torch.testing.assert_close(result, expected)

@skipIfMetal("Metal does not support loop_index_expr for grid loops")
def test_range_with_step(self):
"""Test that range(begin, end, step) works as alias for hl.grid(begin, end, step)."""

Expand All @@ -343,6 +348,7 @@ def range_step_kernel(x: torch.Tensor) -> torch.Tensor:
code, result = code_and_output(range_step_kernel, (x,))
torch.testing.assert_close(result, expected)

@skipIfMetal("Metal does not support loop_index_expr for grid loops")
def test_range_with_tensor_size(self):
"""Test that range(tensor.size(dim)) works with dynamic tensor dimensions."""

Expand Down
Loading