Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,11 +1201,18 @@ def adjust_block_size_constraints(
tensor_ndim,
)

for i, spec in enumerate(block_specs):
for spec in block_specs:
if not isinstance(spec, BlockSizeSpec):
continue
bid = spec.block_ids[0]
dfe = min_dim_from_end.get(bid, ndim - 1 - i)
dfe = min_dim_from_end.get(bid)
if dfe is None:
# No tensor dim mapping found — skip alignment
# constraints. The fallback (ndim - 1 - i) assumes
# spec ordering matches tensor dim ordering which is
# not always true (e.g. hl.tile([M, N, B]) where B
# is the first tensor dim but last spec).
continue
if dfe == 0:
tndim = min_tensor_ndim.get(bid, ndim)
alignment = tiling_1d if tndim <= 1 else 128
Expand Down
37 changes: 37 additions & 0 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,19 @@ def _pallas_index_str(
tile_with_offset_info = _get_tile_with_offset_info(idx, state, i)
if tile_with_offset_info is not None:
block_id = tile_with_offset_info.block_id
scalar_block_id = _scalar_begin_block_id(idx, state)
if scalar_block_id is not None:
# Scalar .begin index with block_size=1 — emit literal 0 to
# collapse this dimension, mirroring Triton's scalar SymInt
# handling (indexing_strategy.py:1024-1033). Record in
# dim_map so the BlockSpec tiles this dim; the kernel
# receives a size-1 slice and index 0 eliminates it.
# Don't increment out_pos — the dim is collapsed from the
# output, so subsequent None positions stay correct.
parts.append("0")
dim_map.setdefault(tensor_dim, scalar_block_id)
tensor_dim += 1
continue
if block_id is not None:
offset_expr = ""
if tile_with_offset_info is not None:
Expand Down Expand Up @@ -293,6 +306,30 @@ def _resolve_block_id(
return None


def _scalar_begin_block_id(idx: object, state: CodegenState) -> int | None:
"""Return block_id if *idx* is a tile.begin with block_size=1.

When block_size=1, tile.begin is a scalar subscript that should
collapse the dimension (like Triton's scalar SymInt handling).
Returns None for non-.begin indices or when block_size > 1.
"""
if not isinstance(idx, torch.SymInt):
return None
import sympy

expr = _symint_expr(idx)
if not isinstance(expr, sympy.Symbol):
return None
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info is None or not isinstance(origin_info.origin, TileBeginOrigin):
return None
block_id = origin_info.origin.block_id
env = CompileEnvironment.current()
if env.block_sizes[block_id].from_config(state.device_function.config) != 1:
return None
return block_id


def _pallas_ds_expr(state: CodegenState, block_id: int, tile_offset: str) -> str:
"""Return a ``pl.ds(offset, block_size)`` expression for *block_id*, offset by *tile_offset*"""
offset = state.codegen.offset_var(block_id)
Expand Down
1 change: 0 additions & 1 deletion test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
expected[42, 79] = x[42, 79]
torch.testing.assert_close(result, expected)

@xfailIfPallas("Result mismatch due to incorrect tiling")
def test_scalar_index_transpose(self) -> None:
"""Scalar .begin index should collapse the dimension.

Expand Down
Loading