Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,36 @@ def _(state: CodegenState) -> ast.AST:
)


@_decorators.codegen(_mask_to, "metal")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
assert isinstance(tensor, torch.Tensor)
other = state.proxy_arg(1)
assert isinstance(other, (int, float, bool))
mask_exprs: list[str] = []
input_sizes = [*tensor.size()]
for size in input_sizes:
if (
index := CompileEnvironment.current().resolve_block_id(size)
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
if mask_var not in mask_exprs:
mask_exprs.append(mask_var)
if not mask_exprs:
return state.ast_arg(0)
mask_expr = " and ".join(mask_exprs)
input_dtype = tensor.dtype
other_typed = CompileEnvironment.current().backend.cast_ast(
expr_from_string(constant_repr(other)),
input_dtype,
)
return expr_from_string(
"({expr} if {mask} else {other})",
expr=state.ast_arg(0),
mask=expr_from_string(mask_expr),
other=other_typed,
)


@_decorators.codegen(_mask_to, "cute")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
Expand Down
49 changes: 49 additions & 0 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,28 @@ def _codegen_cute_store_permute_lane_loops(
)


@_decorators.codegen(store, "metal")
def _(state: CodegenState) -> ast.AST:
# Metal delegates to the same PointerIndexingStrategy as Triton.
# This produces tl.store(ptr + offset, val, mask) in the AST;
# the MSL walker translates it to Metal.
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
extra_mask = state.ast_args[3]
assert isinstance(extra_mask, (type(None), ast.AST))

if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
device_fn.device_store_index += 1
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_store(state, tensor, [*subscript], value, extra_mask)
raise exc.BackendUnsupported("metal", f"store target type: {type(tensor)}")


@_decorators.codegen(store, "cute")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
Expand Down Expand Up @@ -1377,6 +1399,33 @@ def _(state: CodegenState) -> ast.AST:
return result


@_decorators.codegen(load, "metal")
def _(state: CodegenState) -> ast.AST:
# Metal delegates to the same PointerIndexingStrategy as Triton.
# This produces tl.load(ptr + offset, mask, other=0) in the AST;
# the MSL walker translates it to Metal.
tensor = state.proxy_arg(0)
subscript = state.proxy_arg(1)
assert isinstance(subscript, (list, tuple))
ast_subscript = state.ast_args[1]
assert isinstance(ast_subscript, (list, tuple))
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None
assert isinstance(eviction_policy, (type(None), ast.AST))

if isinstance(tensor, torch.Tensor):
device_fn = state.device_function
device_fn.device_load_index += 1
indexing_idx = device_fn.device_memory_op_index
device_fn.device_memory_op_index += 1
strategy = device_fn.get_indexing_strategy(indexing_idx)
return strategy.codegen_load(
state, tensor, [*subscript], extra_mask, eviction_policy
)
raise exc.BackendUnsupported("metal", f"load tensor type: {type(tensor)}")


@_decorators.codegen(load, "cute")
def _(state: CodegenState) -> object:
tensor = state.proxy_arg(0)
Expand Down
Loading