[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy#1855
Draft
aditvenk wants to merge 1 commit intoaditvenk/stack/15from
Draft
[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy#1855aditvenk wants to merge 1 commit intoaditvenk/stack/15from
aditvenk wants to merge 1 commit intoaditvenk/stack/15from
Conversation
aditvenk
added a commit
that referenced
this pull request
Mar 27, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
e847271 to
41f5e1b
Compare
b34e44a to
1f7ea47
Compare
This was referenced Mar 27, 2026
aditvenk
added a commit
that referenced
this pull request
Mar 27, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
1f7ea47 to
83c4c6c
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
83c4c6c to
33f28d7
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
33f28d7 to
3c5e056
Compare
a2fe84f to
2200607
Compare
2200607 to
e24d8e5
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
e24d8e5 to
a6e859b
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
a6e859b to
063ded2
Compare
063ded2 to
40f75f9
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 28, 2026
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Bump 2D block_sizes to [64, 64] (4096 threads, auto-capped) - Bump 3D block_sizes to [16, 16, 16] (4096 threads, auto-capped) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalLargeBlock: correctness + codegen lane loop assertion stack-info: PR: #1855, branch: aditvenk/stack/21
This was referenced Apr 9, 2026
jansel
requested changes
Apr 13, 2026
Contributor
jansel
left a comment
There was a problem hiding this comment.
Is there test coverage for this? It would be much better if each PR included tests rather than stacking the tests.
Contributor
Author
Sorry, this one was not yet ready for review. I will amend it with tests. |
Contributor
|
Convert to draft if you don't want people to look at PRs |
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64], 3D [16, 16, 16]) crash at runtime. Users shouldn't need to know about hardware limits when picking block sizes. This commit switches the Metal backend from the base NDTileStrategy / FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy, and adds auto-capping logic so that when block_size products exceed 1024, num_threads is reduced and each thread processes multiple elements via lane loops (elements_per_thread = block_size / num_threads). Why Cute strategies are the right fit for Metal: Helion has two families of tile strategies. The base strategies (NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas, where each thread holds a *vector* of elements -- index expressions use tl.arange(0, block_size) / jnp.arange(0, block_size) to produce a tile of offsets per program. The Cute strategies (CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends where each thread is a *scalar lane* -- index expressions use thread_idx()[axis] to produce a single offset per thread. Metal's execution model matches CuTe's scalar-thread model. Both dispatch one scalar element per thread for elementwise work, and both rely on cooperative hardware primitives (Metal's simdgroup MPP / CuTe's warpgroup MMA) for matmul -- where the hardware internally distributes tile elements across threads and the strategy must not generate per-thread indices. CuteNDTileStrategy supports both patterns: num_threads + lane loops for scalar dispatch (when block_size > thread_count, each thread iterates over multiple elements), and mma_mode to suppress thread-index generation for cooperative matmul. The base NDTileStrategy can express neither. By adopting Cute strategies now, matmul support can later add mma_mode=True without changing strategy selection. The lane_index_expr and lane_offset_expr backend hooks make the generated index expressions backend-agnostic: CuTe emits cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits tid[axis]. Changes: Backend (backend.py): - MetalBackend: override lane_index_expr/lane_offset_expr with tid[axis] exprs; add arange_expr(), loop_index_expr(), thread_in_tile_mask_expr() for Cute strategy fallback paths; add create_loop_strategy() with auto-capping; add build_launcher_args(); add num_threads to supported config keys MSL walker (msl_ast_walker.py): - Add ast.For handling to emit MSL for-loops from lane loop range() iterators Tests (test_metal.py): - Add multi-dimensional elementwise kernels and tests (2D, 3D) - Add large_block_add kernel with block_sizes=[2048] - Add TestMetalMultiDim and TestMetalLargeBlock Tests (test_grid.py): - Skip hl.grid begin/end tests for Metal (CuteNDTileStrategy regression — to be fixed separately) stack-info: PR: #1855, branch: aditvenk/stack/21
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[metal] Auto-cap threadgroup size via CuteND/CuteFlattenedTileStrategy
Metal has a hard 1024-thread-per-threadgroup limit. Kernels with
block_size products exceeding 1024 (e.g. 1D [2048], 2D [64, 64],
3D [16, 16, 16]) crash at runtime. Users shouldn't need to know
about hardware limits when picking block sizes.
This commit switches the Metal backend from the base NDTileStrategy /
FlattenedTileStrategy to CuteNDTileStrategy / CuteFlattenedTileStrategy,
and adds auto-capping logic so that when block_size products exceed
1024, num_threads is reduced and each thread processes multiple elements
via lane loops (elements_per_thread = block_size / num_threads).
Why Cute strategies are the right fit for Metal:
Helion has two families of tile strategies. The base strategies
(NDTileStrategy / FlattenedTileStrategy) target Triton and Pallas,
where each thread holds a vector of elements -- index expressions
use tl.arange(0, block_size) / jnp.arange(0, block_size) to
produce a tile of offsets per program. The Cute strategies
(CuteNDTileStrategy / CuteFlattenedTileStrategy) target backends
where each thread is a scalar lane -- index expressions use
thread_idx()[axis] to produce a single offset per thread.
Metal's execution model matches CuTe's scalar-thread model. Both
dispatch one scalar element per thread for elementwise work, and
both rely on cooperative hardware primitives (Metal's simdgroup MPP /
CuTe's warpgroup MMA) for matmul -- where the hardware internally
distributes tile elements across threads and the strategy must not
generate per-thread indices.
CuteNDTileStrategy supports both patterns: num_threads + lane
loops for scalar dispatch (when block_size > thread_count, each
thread iterates over multiple elements), and mma_mode to suppress
thread-index generation for cooperative matmul. The base
NDTileStrategy can express neither. By adopting Cute strategies now,
matmul support can later add mma_mode=True without changing
strategy selection.
The lane_index_expr and lane_offset_expr backend hooks make
the generated index expressions backend-agnostic: CuTe emits
cutlass.Int32(cute.arch.thread_idx()[axis]), Metal emits
tid[axis].
Changes:
Backend (backend.py):
tid[axis] exprs; add arange_expr(), loop_index_expr(),
thread_in_tile_mask_expr() for Cute strategy fallback paths;
add create_loop_strategy() with auto-capping; add
build_launcher_args(); add num_threads to supported config keys
MSL walker (msl_ast_walker.py):
iterators
Tests (test_metal.py):
Tests (test_grid.py):
regression — to be fixed separately)