[metal] Reuse Inductor's MetalOverrides for MSL expression emission#1853
[metal] Reuse Inductor's MetalOverrides for MSL expression emission#1853
Conversation
b3e1da8 to
e8fe296
Compare
e8fe296 to
8ef44cd
Compare
8ef44cd to
6bd2521
Compare
6bd2521 to
73bed39
Compare
2c95aa5 to
20ba09a
Compare
20ba09a to
849c5c3
Compare
849c5c3 to
53876fa
Compare
jansel
left a comment
There was a problem hiding this comment.
Can you reuse the version of this in inductor:
https://github.com/pytorch/pytorch/blob/47fa87a257f72a30c2028e1200c1673ed36125b8/torch/_inductor/codegen/mps.py#L185
Will take a look at it. From skimming it seems like it pulls in Inductor dependencies (V.kernel, CSEVariable) and also emits MSL differently. |
Refactored to use Inductor's override for 99%. A couple of methods need to be handled slightly differently because Helion parses override strings as Python AST. While this works completely for Triton/Cute/Pallas because those are valid python syntax. For Metal, since it is C++, some things like |
Subclass Inductor's torch._inductor.codegen.mps.MetalOverrides instead of reimplementing ~170 lines of math/cast/comparison overrides. This brings NaN-safe math (c10::metal::max/min), IEEE-compliant precision (metal::precise::*), and type-safe casts (decltype) for free. The C++ namespace syntax (::) in Inductor's expression output is replaced with . before Python AST parsing, then restored to :: by the MSL walker. Override where() to use Python ternary (parseable by AST), and _special_unary/_special_binary to skip V.kernel.headers dependency. stack-info: PR: #1853, branch: aditvenk/stack/19
Stacked PRs:
[metal] Reuse Inductor's MetalOverrides for MSL expression emission
Subclass Inductor's torch._inductor.codegen.mps.MetalOverrides instead
of reimplementing ~170 lines of math/cast/comparison overrides. This
brings NaN-safe math (c10::metal::max/min), IEEE-compliant precision
(metal::precise::*), and type-safe casts (decltype) for free.
The C++ namespace syntax (::) in Inductor's expression output is
replaced with . before Python AST parsing, then restored to :: by the
MSL walker. Override where() to use Python ternary (parseable by AST),
and _special_unary/_special_binary to skip V.kernel.headers dependency.