Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,11 +2400,85 @@ def library_imports(self) -> dict[str, str]:
),
}

def index_type_str(self, index_dtype: torch.dtype) -> str:
return "uint"

def inline_constexpr(self, name: str, value: str) -> str:
return f"{name} = {value}"

def cast_expr(self, expr_str: str, dtype_str: str) -> str:
return f"static_cast<{dtype_str}>({expr_str})"

def program_id_expr(self, dim: int, *, index_dtype: str) -> str:
return f"tgid[{dim}]"

def grid_index_expr(
self, offset_var: str, block_size_var: str, dtype: str, *, axis: int
) -> str:
return f"{offset_var} + tid[{axis}]"

def force_tile_mask(self) -> bool:
return True

def full_expr(
self, shape_dims: list[str], value_expr: str, dtype: torch.dtype
) -> str:
metal_type = self.dtype_str(dtype)
return f"{metal_type}({value_expr})"

def reshape_expr(self, expr: str, shape: str) -> str:
return expr

def broadcast_to_expr(self, expr: str, shape: str) -> str:
return expr
Comment on lines +2429 to +2433
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal uses per-thread scalar dispatch — each thread processes one element, so there are no block dimensions. Hence, these are identity operations. Tests coming up in #1794 and #1796.

I also have matmul tests in an unpublished PRs on top of this current stack. Matmuls will use a cooperative thread model, similar to cute


def zeros_expr(self, shape: str, dtype: str) -> str:
return "0"

def where_expr(self, mask: str, true_val: str, false_val: str) -> str:
return f"({mask} ? {true_val} : {false_val})"

def minimum_expr(self, a: str, b: str) -> str:
return f"min({a}, {b})"

def supports_config_key(self, key: str) -> bool:
return key in self._SUPPORTED_CONFIG_KEYS

def supports_precompile(self) -> bool:
return False

def autotune(
self,
bound_kernel: BoundKernel[Any],
args: Sequence[object],
*,
force: bool = True,
**kwargs: object,
) -> Config:
return bound_kernel.config_spec.default_config()

def transform_host_arg(
self,
arg: Argument,
host_str: str,
tensor_host_args: list[str],
) -> str:
"""Wrap scalar SymbolArguments as 1-element tensors for buffer passing."""
from .device_function import SymbolArgument

if isinstance(arg, SymbolArgument):
device_expr = (
f"{tensor_host_args[0]}.device" if tensor_host_args else "'mps'"
)
return (
f"torch.scalar_tensor(float({host_str}), "
f"dtype=torch.float32, "
f"device={device_expr})"
)
return host_str

def launcher_keyword_args(self, config: Config, *, has_barrier: bool) -> list[str]:
from .device_function import DeviceFunction

dims = tuple(DeviceFunction.current().codegen.max_thread_block_dims)
return [f"_block_dims=({dims[0]}, {dims[1]}, {dims[2]})"]
Loading