Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def codegen_load(
if indexing.needs_broadcast():
output_size = SubscriptIndexing.compute_shape(fake_tensor, subscript, state)
shape_str = state.tile_strategy.shape_str(output_size)
load_expr = expr_from_string(
f"tl.broadcast_to({{load_expr}}, {shape_str})", load_expr=load_expr
)
backend = CompileEnvironment.current().backend
broadcast = backend.broadcast_to_expr("{load_expr}", shape_str)
load_expr = expr_from_string(broadcast, load_expr=load_expr)

return load_expr

Expand Down Expand Up @@ -241,24 +241,22 @@ def codegen_store(

# If pointer is scalar but output_size has dimensions, reshape value to scalar.
# Skip reshaping for scalar constants which don't have shape.
backend = CompileEnvironment.current().backend
if (
not pointer_has_block_dims
and output_size
and not isinstance(value, ast.Constant)
):
# Pointer is scalar but value may have shape - squeeze to scalar
value = expr_from_string(
"tl.reshape({value}, [])",
value=value,
)
reshape = backend.reshape_expr("{value}", "[]")
value = expr_from_string(reshape, value=value)

offset_expr = indexing.index_expr
# If dimensions need broadcasting for store, broadcast the pointer
if indexing.needs_broadcast():
shape_str = state.tile_strategy.shape_str(output_size)
offset_expr = expr_from_string(
f"tl.broadcast_to({{offset}}, {shape_str})", offset=offset_expr
)
broadcast = backend.broadcast_to_expr("{offset}", shape_str)
offset_expr = expr_from_string(broadcast, offset=offset_expr)

return expr_from_string(
f"tl.store({name} + {{offset}}, {{value}}, {{mask}})",
Expand Down Expand Up @@ -1019,7 +1017,9 @@ def handle_broadcast_tensor(
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
else:
index_values.append(f"tl.zeros([1], {dtype}){expand}")
index_values.append(
f"{env.backend.zeros_expr('[1]', dtype)}{expand}"
)
output_idx += 1
k_index += 1
elif isinstance(k, torch.Tensor):
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
index_expr.append(f"{idx} * {stride}")
if not index_expr:
shape_str = state.tile_strategy.shape_str(per_dim.output_size)
index_expr.append(f"tl.zeros({shape_str}, {dtype})")
index_expr.append(env.backend.zeros_expr(shape_str, dtype))
return SubscriptIndexing(
expr_from_string("+".join(index_expr)),
per_dim.mask_expr,
Expand Down
Loading