Skip to content

[Pallas] Don't record block_id in dim_map for hl.grid program_id dimensions#2001

Open
norx1991 wants to merge 5 commits intomainfrom
yifeixu/pallas-fix-grid-dim-map
Open

[Pallas] Don't record block_id in dim_map for hl.grid program_id dimensions#2001
norx1991 wants to merge 5 commits intomainfrom
yifeixu/pallas-fix-grid-dim-map

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 10, 2026

Background

While investigating when _no_tiling_block_spec_info (the fallback for untileable dims) actually gets triggered, we found that hl.grid() scalar accesses were incorrectly recording block_ids in pallas_tensor_dim_block_ids, causing the fallback to fire unnecessarily.

Summary

  • Fix _pallas_index_str to only record dim_map entries for dimensions that emit : (BlockSpec-tiled), not for pl.program_id() (scalar grid access) dimensions
  • Previously, hl.grid() scalar accesses would record a block_id in pallas_tensor_dim_block_ids. _compute_block_spec_info would then see this block_id, look up its block size from the config (which is 1 for hl.grid), and fail the TPU alignment check (1 % 128 != 0), triggering the _no_tiling_block_spec_info fallback
  • Add regression test that patches _no_tiling_block_spec_info to raise if called during test_scalar_access_hl_grid, ensuring this codepath is not triggered for hl.grid() dims

…nsions

Fix _pallas_index_str to only record dim_map entries for dimensions
that emit `:` (BlockSpec-tiled), not for `pl.program_id()` (scalar
grid access) dimensions. Previously, hl.grid() scalar accesses would
record a block_id in pallas_tensor_dim_block_ids, causing
_compute_block_spec_info to attempt tiling with bs=1, which then
failed TPU alignment checks and triggered the _no_tiling_block_spec_info
fallback.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
Patch _no_tiling_block_spec_info to raise if called during
test_scalar_access_hl_grid, ensuring hl.grid() dims don't
leak block_ids into pallas_tensor_dim_block_ids.
Exercise autotuner defaults (no explicit block_sizes) for kernels
that are now working on TPU with the corrected _min_dot_size.
The dim_map fix makes nested hl.grid + hl.tile work correctly
on Pallas, so this test no longer needs the xfail.
These kernels don't use hl.dot, so _min_dot_size has no effect on them.
Keep only test_matmul_default which exercises the actual fix.
@norx1991 norx1991 marked this pull request as ready for review April 10, 2026 22:54
AmesingFlank
AmesingFlank approved these changes Apr 11, 2026
)
torch.testing.assert_close(result, grid_2d_pytorch(args[0], args[1]))

@xfailIfPallas("2D nested grids not working correctly Pallas")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look at the compiled output for this test before versus after the PR, it looks like the main difference is that we changed from tiling the 0th dimension (dfe = 3) with block size 1 to not tiling it at all. This makes the kernel run correctly because we currently have an assumption that hl.grid() accesses can never be tiled, since its block size is 1, but this assumption is incorrect because for high dfes (3 in this case), 1 can still be a valid block size.

While I think this fix makes sense for now, I think the long term better fix is to be able to recognize whether or not this tl.grid index can be tiled or not, and generate indexing exprs accordingly. This way we use a lot less VMEM. This is something I'm hoping to address as part of analysis part we discussed offline.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! This is a good point.

Currently the kernel code is

  def _helion_grid_2d_idx_nested(x, y, out):                     
      for offset_1 in range(0, 4):                               
          for offset_2 in range(0, 64, _BLOCK_SIZE_2):           
              for offset_3 in range(0, 16, _BLOCK_SIZE_3):       
                  acc = jnp.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 
  0.0, jnp.float32)
                  for offset_4 in range(0, 32, _BLOCK_SIZE_4):   
                      acc_copy = acc                             
                      acc_copy_0 = acc_copy
                      load = x[pl.program_id(0), offset_1,       
  pl.ds(offset_2, _BLOCK_SIZE_2), pl.ds(offset_4, _BLOCK_SIZE_4)]
                      load_1 = y[pl.ds(offset_4, _BLOCK_SIZE_4),
  pl.ds(offset_3, _BLOCK_SIZE_3)]                                
                      acc = acc_copy_0 + jnp.matmul(load, load_1,
   preferred_element_type=jnp.float32)                           
                  v_0 = lax.convert_element_type(acc,
  jnp.bfloat16)                                                  
                  out[pl.program_id(0), offset_1, pl.ds(offset_2,
   _BLOCK_SIZE_2), pl.ds(offset_3, _BLOCK_SIZE_3)] = v_0 

It seems to me that we are using two mechanisms when handling hl.grid:

  1. a blockspec with bs=1 in pl.pallas_call + pl.program_id. These two mechanisms are conflicting with each other
  2. an outer for loop in the kernel

This PR removes blockspec from (1), so it is not conflicting with pl.program_id.

IIUC, ideally, we want to use either

  1. a blockspec with bs=1 (when tiling is allowed)
  2. pl.program_id (when tiling is not allowed)

So approach (2) as in this PR will still be needed eventually when (1) is not possible due to tiling constraints?

Then when tiling is allowed as in this test, the ideal kernel should be

_BLOCK_SIZE_2 = int(16)  # tile_m
  _BLOCK_SIZE_3 = int(16)  # tile_n
  _BLOCK_SIZE_4 = int(32)  # tile_k

  def _helion_grid_2d_idx_nested(x, y, out):
      # x BlockRef shape: [1, 1, 64, 32]
      # y BlockRef shape: [32, 16]
      # out BlockRef shape: [1, 1, 64, 16]
      for offset_2 in range(0, 64, _BLOCK_SIZE_2):
          for offset_3 in range(0, 16, _BLOCK_SIZE_3):
              acc = jnp.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0,
   jnp.float32)
              for offset_4 in range(0, 32, _BLOCK_SIZE_4):
                  load = x[:, :, pl.ds(offset_2, _BLOCK_SIZE_2),
  pl.ds(offset_4, _BLOCK_SIZE_4)]
                  load_1 = y[pl.ds(offset_4, _BLOCK_SIZE_4),
  pl.ds(offset_3, _BLOCK_SIZE_3)]
                  acc = acc + jnp.matmul(load, load_1,
  preferred_element_type=jnp.float32)
              v_0 = lax.convert_element_type(acc, jnp.bfloat16)
              out[:, :, pl.ds(offset_2, _BLOCK_SIZE_2),
  pl.ds(offset_3, _BLOCK_SIZE_3)] = v_0

  def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor, *,
  _launcher=_default_pallas_launcher):
      bi, bj, m, k = x.size()
      k2, n = y.size()
      assert k == k2, f'size mismatch {k} != {k2}'
      out = torch.empty(bi, bj, m, n,
  dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
      _launcher(
          _helion_grid_2d_idx_nested,
          (3, 4),  # grid = (bi, bj)
          x, y, out,
          _output_indices=[2],
          _inplace_indices=[],
          _block_spec_info=[
              ((1, 1, None, None), (0, 1, None, None)),  # x:
  dim0=bs1/grid0, dim1=bs1/grid1
              ((None, None), (None, None)),                # y:
  no tiling
              ((1, 1, None, None), (0, 1, None, None)),   # out:
  dim0=bs1/grid0, dim1=bs1/grid1
          ],
      )
      return out

BTW, I do not know why the two hl.grid loops are treated differently... it seems the outer for loop is never needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants