[Pallas] Don't record block_id in dim_map for hl.grid program_id dimensions#2001
[Pallas] Don't record block_id in dim_map for hl.grid program_id dimensions#2001
Conversation
…nsions Fix _pallas_index_str to only record dim_map entries for dimensions that emit `:` (BlockSpec-tiled), not for `pl.program_id()` (scalar grid access) dimensions. Previously, hl.grid() scalar accesses would record a block_id in pallas_tensor_dim_block_ids, causing _compute_block_spec_info to attempt tiling with bs=1, which then failed TPU alignment checks and triggered the _no_tiling_block_spec_info fallback.
Patch _no_tiling_block_spec_info to raise if called during test_scalar_access_hl_grid, ensuring hl.grid() dims don't leak block_ids into pallas_tensor_dim_block_ids.
Exercise autotuner defaults (no explicit block_sizes) for kernels that are now working on TPU with the corrected _min_dot_size.
The dim_map fix makes nested hl.grid + hl.tile work correctly on Pallas, so this test no longer needs the xfail.
These kernels don't use hl.dot, so _min_dot_size has no effect on them. Keep only test_matmul_default which exercises the actual fix.
| ) | ||
| torch.testing.assert_close(result, grid_2d_pytorch(args[0], args[1])) | ||
|
|
||
| @xfailIfPallas("2D nested grids not working correctly Pallas") |
There was a problem hiding this comment.
I had a look at the compiled output for this test before versus after the PR, it looks like the main difference is that we changed from tiling the 0th dimension (dfe = 3) with block size 1 to not tiling it at all. This makes the kernel run correctly because we currently have an assumption that hl.grid() accesses can never be tiled, since its block size is 1, but this assumption is incorrect because for high dfes (3 in this case), 1 can still be a valid block size.
While I think this fix makes sense for now, I think the long term better fix is to be able to recognize whether or not this tl.grid index can be tiled or not, and generate indexing exprs accordingly. This way we use a lot less VMEM. This is something I'm hoping to address as part of analysis part we discussed offline.
There was a problem hiding this comment.
Thanks for the comment! This is a good point.
Currently the kernel code is
def _helion_grid_2d_idx_nested(x, y, out):
for offset_1 in range(0, 4):
for offset_2 in range(0, 64, _BLOCK_SIZE_2):
for offset_3 in range(0, 16, _BLOCK_SIZE_3):
acc = jnp.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3],
0.0, jnp.float32)
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
acc_copy = acc
acc_copy_0 = acc_copy
load = x[pl.program_id(0), offset_1,
pl.ds(offset_2, _BLOCK_SIZE_2), pl.ds(offset_4, _BLOCK_SIZE_4)]
load_1 = y[pl.ds(offset_4, _BLOCK_SIZE_4),
pl.ds(offset_3, _BLOCK_SIZE_3)]
acc = acc_copy_0 + jnp.matmul(load, load_1,
preferred_element_type=jnp.float32)
v_0 = lax.convert_element_type(acc,
jnp.bfloat16)
out[pl.program_id(0), offset_1, pl.ds(offset_2,
_BLOCK_SIZE_2), pl.ds(offset_3, _BLOCK_SIZE_3)] = v_0 It seems to me that we are using two mechanisms when handling hl.grid:
- a blockspec with
bs=1in pl.pallas_call +pl.program_id. These two mechanisms are conflicting with each other - an outer for loop in the kernel
This PR removes blockspec from (1), so it is not conflicting with pl.program_id.
IIUC, ideally, we want to use either
- a blockspec with
bs=1(when tiling is allowed) pl.program_id(when tiling is not allowed)
So approach (2) as in this PR will still be needed eventually when (1) is not possible due to tiling constraints?
Then when tiling is allowed as in this test, the ideal kernel should be
_BLOCK_SIZE_2 = int(16) # tile_m
_BLOCK_SIZE_3 = int(16) # tile_n
_BLOCK_SIZE_4 = int(32) # tile_k
def _helion_grid_2d_idx_nested(x, y, out):
# x BlockRef shape: [1, 1, 64, 32]
# y BlockRef shape: [32, 16]
# out BlockRef shape: [1, 1, 64, 16]
for offset_2 in range(0, 64, _BLOCK_SIZE_2):
for offset_3 in range(0, 16, _BLOCK_SIZE_3):
acc = jnp.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0,
jnp.float32)
for offset_4 in range(0, 32, _BLOCK_SIZE_4):
load = x[:, :, pl.ds(offset_2, _BLOCK_SIZE_2),
pl.ds(offset_4, _BLOCK_SIZE_4)]
load_1 = y[pl.ds(offset_4, _BLOCK_SIZE_4),
pl.ds(offset_3, _BLOCK_SIZE_3)]
acc = acc + jnp.matmul(load, load_1,
preferred_element_type=jnp.float32)
v_0 = lax.convert_element_type(acc, jnp.bfloat16)
out[:, :, pl.ds(offset_2, _BLOCK_SIZE_2),
pl.ds(offset_3, _BLOCK_SIZE_3)] = v_0
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor, *,
_launcher=_default_pallas_launcher):
bi, bj, m, k = x.size()
k2, n = y.size()
assert k == k2, f'size mismatch {k} != {k2}'
out = torch.empty(bi, bj, m, n,
dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_launcher(
_helion_grid_2d_idx_nested,
(3, 4), # grid = (bi, bj)
x, y, out,
_output_indices=[2],
_inplace_indices=[],
_block_spec_info=[
((1, 1, None, None), (0, 1, None, None)), # x:
dim0=bs1/grid0, dim1=bs1/grid1
((None, None), (None, None)), # y:
no tiling
((1, 1, None, None), (0, 1, None, None)), # out:
dim0=bs1/grid0, dim1=bs1/grid1
],
)
return outBTW, I do not know why the two hl.grid loops are treated differently... it seems the outer for loop is never needed.
Background
While investigating when
_no_tiling_block_spec_info(the fallback for untileable dims) actually gets triggered, we found thathl.grid()scalar accesses were incorrectly recording block_ids inpallas_tensor_dim_block_ids, causing the fallback to fire unnecessarily.Summary
_pallas_index_strto only recorddim_mapentries for dimensions that emit:(BlockSpec-tiled), not forpl.program_id()(scalar grid access) dimensionshl.grid()scalar accesses would record a block_id inpallas_tensor_dim_block_ids._compute_block_spec_infowould then see this block_id, look up its block size from the config (which is 1 forhl.grid), and fail the TPU alignment check (1 % 128 != 0), triggering the_no_tiling_block_spec_infofallback_no_tiling_block_spec_infoto raise if called duringtest_scalar_access_hl_grid, ensuring this codepath is not triggered forhl.grid()dims