Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,11 @@ def grid_index_expr(
def force_tile_mask(self) -> bool:
return True

def inductor_op_overrides(self) -> InductorOpOverrides:
from .metal.metal_overrides import MetalOverrides

return MetalOverrides()

def full_expr(
self, shape_dims: list[str], value_expr: str, dtype: torch.dtype
) -> str:
Expand All @@ -2759,7 +2764,8 @@ def zeros_expr(self, shape: str, dtype: str) -> str:
return "0"

def where_expr(self, mask: str, true_val: str, false_val: str) -> str:
return f"({mask} ? {true_val} : {false_val})"
# Must be valid Python for expr_from_string; walker converts to C++ ternary
return f"({true_val} if {mask} else {false_val})"

def minimum_expr(self, a: str, b: str) -> str:
return f"min({a}, {b})"
Expand Down
4 changes: 4 additions & 0 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,10 @@ def _default(
result_str = _unpack_opsvalue(
getattr(self.parent_handler, name)(*args, **kwargs)
)
# C++ namespace syntax (::) is not valid Python. Replace with dot
# notation so expr_from_string can parse it as attribute access.
if CompileEnvironment.current().backend_name == "metal" and "::" in result_str:
result_str = result_str.replace("::", ".")
return self._lift(expr_from_string(result_str))

def to_dtype(
Expand Down
38 changes: 38 additions & 0 deletions helion/_compiler/metal/metal_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""MetalOverrides — thin subclass of Inductor's MPS MetalOverrides.

Reuses Inductor's MetalOverrides for all expression generation (math ops,
casts, comparisons, etc.). The C++ namespace syntax (``::`` in expressions
like ``metal::precise::sin(x)``) is handled by replacing ``::`` with ``.``
before Python AST parsing, then converting back in the MSL walker.

Overrides:
- ``_special_unary`` / ``_special_binary``: skip the ``V.kernel.headers``
dependency (Helion includes the c10/metal headers unconditionally).
- ``where``: emit Python ternary (``a if cond else b``) instead of C++
ternary (``cond ? a : b``) so it can be parsed as Python AST.
"""

from __future__ import annotations

from torch._inductor.codegen.mps import MetalOverrides as _InductorMetalOverrides


class MetalOverrides(_InductorMetalOverrides):
"""Helion Metal op overrides.

Inherits all expression generation from Inductor's MetalOverrides.
"""

@staticmethod
def where(a: object, b: object, c: object) -> str:
# Inductor emits C++ ternary (cond ? a : b) which isn't valid Python.
# Use Python ternary instead; the walker converts it to C++ ternary.
return f"({b} if {a} else {c})"

def _special_unary(self, a: object, name: str) -> str:
# Skip V.kernel.headers.add() — Helion includes c10/metal headers
# unconditionally in the MSL preamble.
return f"c10::metal::{name}({a})"

def _special_binary(self, a: object, b: object, name: str) -> str:
return f"c10::metal::{name}({a}, {b})"
Loading