Skip to content

[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy#1855

Draft
aditvenk wants to merge 1 commit intoaditvenk/stack/15from
aditvenk/stack/21
Draft

[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy#1855
aditvenk wants to merge 1 commit intoaditvenk/stack/15from
aditvenk/stack/21

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

@aditvenk aditvenk commented Mar 27, 2026

Stacked PRs:


[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy

Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a vector of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a scalar lane -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):

  • MetalBackend: override lane_index_expr/lane_offset_expr with
    tid[axis] exprs; add arange_expr(), loop_index_expr(),
    thread_in_tile_mask_expr() for Cute strategy fallback paths;
    add create_loop_strategy() with auto-capping; add
    build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):

  • Add ast.For handling to emit MSL for-loops from lane loop range()
    iterators

Tests (test_metal.py):

  • Add multi-dimensional elementwise kernels and tests (2D, 3D)
  • Add large_block_add kernel with block_sizes=[2048]
  • Add TestMetalMultiDim and TestMetalLargeBlock

Tests (test_grid.py):

  • Skip hl.grid begin/end tests for Metal (CuteNDTileStrategy
    regression — to be fixed separately)

aditvenk added a commit that referenced this pull request Mar 27, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/15 branch from e847271 to 41f5e1b Compare March 27, 2026 22:35
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from b34e44a to 1f7ea47 Compare March 27, 2026 22:35
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2026
@aditvenk aditvenk marked this pull request as draft March 27, 2026 22:48
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 27, 2026 22:48
aditvenk added a commit that referenced this pull request Mar 27, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from 1f7ea47 to 83c4c6c Compare March 27, 2026 22:48
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 27, 2026 22:48
@aditvenk aditvenk marked this pull request as ready for review March 27, 2026 22:48
@aditvenk aditvenk marked this pull request as draft March 28, 2026 00:12
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 28, 2026 00:12
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from 83c4c6c to 33f28d7 Compare March 28, 2026 00:12
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 00:14
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 00:14
@aditvenk aditvenk marked this pull request as draft March 28, 2026 01:24
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 28, 2026 01:24
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from 33f28d7 to 3c5e056 Compare March 28, 2026 01:24
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 01:26
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 01:26
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from a2fe84f to 2200607 Compare March 28, 2026 02:42
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:43
@aditvenk aditvenk marked this pull request as draft March 28, 2026 02:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 28, 2026 02:45
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from 2200607 to e24d8e5 Compare March 28, 2026 02:45
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 02:45
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 02:46
@aditvenk aditvenk marked this pull request as draft March 28, 2026 04:44
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 28, 2026 04:44
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from e24d8e5 to a6e859b Compare March 28, 2026 04:44
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from a6e859b to 063ded2 Compare March 28, 2026 04:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 04:45
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 04:45
@aditvenk aditvenk marked this pull request as draft March 28, 2026 04:47
@aditvenk aditvenk changed the base branch from aditvenk/stack/15 to main March 28, 2026 04:47
@aditvenk aditvenk force-pushed the aditvenk/stack/21 branch from 063ded2 to 40f75f9 Compare March 28, 2026 04:47
aditvenk added a commit that referenced this pull request Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped)
- Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalLargeBlock: correctness + codegen lane loop assertion

stack-info: PR: #1855, branch: aditvenk/stack/21
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/15 March 28, 2026 04:48
@aditvenk aditvenk marked this pull request as ready for review March 28, 2026 04:48
Copy link
Copy Markdown
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there test coverage for this? It would be much better if each PR included tests rather than stacking the tests.

@aditvenk
Copy link
Copy Markdown
Contributor Author

Is there test coverage for this? It would be much better if each PR included tests rather than stacking the tests.

Sorry, this one was not yet ready for review. I will amend it with tests.

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 14, 2026

Convert to draft if you don't want people to look at PRs

Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.

This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).

Why Cute strategies are the right fit for Metal:

Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a *vector* of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a *scalar lane* -- index expressions use
thread_idx()[axis] to produce a single offset per thread.

Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.

CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.

The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].

Changes:

Backend (backend.py):
- MetalBackend: override lane_index_expr/lane_offset_expr with
  tid[axis] exprs; add arange_expr(), loop_index_expr(),
  thread_in_tile_mask_expr() for Cute strategy fallback paths;
  add create_loop_strategy() with auto-capping; add
  build_launcher_args(); add num_threads to supported config keys

MSL walker (msl_ast_walker.py):
- Add ast.For handling to emit MSL for-loops from lane loop range()
  iterators

Tests (test_metal.py):
- Add multi-dimensional elementwise kernels and tests (2D, 3D)
- Add large_block_add kernel with block_sizes=[2048]
- Add TestMetalMultiDim and TestMetalLargeBlock

Tests (test_grid.py):
- Skip hl.grid begin/end tests for Metal (CuteNDTileStrategy
  regression — to be fixed separately)

stack-info: PR: #1855, branch: aditvenk/stack/21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants