From bdf085f99f456f93874b4ad289c7dc9aee7225ad Mon Sep 17 00:00:00 2001 From: AndreSlavescu Date: Wed, 18 Mar 2026 02:01:11 -0400 Subject: [PATCH 1/3] mlp fused kernel + compiler improvements --- benchmarks/mlp.py | 188 +++++++++++++++++++++++++++++++++ kernels/mlp.py | 53 ++++++++++ metile/codegen/msl_emitter.py | 164 +++++++++++++++++++++++----- metile/compiler/lowering.py | 101 +++++++++++++++--- metile/frontend/autotune.py | 2 +- metile/frontend/kernel.py | 22 ++++ metile/runtime/metal_device.py | 5 + 7 files changed, 492 insertions(+), 43 deletions(-) create mode 100644 benchmarks/mlp.py create mode 100644 kernels/mlp.py diff --git a/benchmarks/mlp.py b/benchmarks/mlp.py new file mode 100644 index 0000000..8fc98a4 --- /dev/null +++ b/benchmarks/mlp.py @@ -0,0 +1,188 @@ +"""Fused MLP benchmark: meTile fused GEMM+activation vs MLX compile.""" + +import sys +import time +from pathlib import Path + +_root = str(Path(__file__).resolve().parent.parent) +sys.path.insert(0, _root) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +import mlx.core as mx +import numpy as np +from benchutils import bench_interleaved + +import metile +from kernels.gemm import matmul +from kernels.mlp import matmul_gelu, matmul_silu +from metile.runtime.metal_device import MetalDevice + +# --- Autotune configs --- + +GEMM_CONFIGS = [ + metile.Config(BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, WM=2, WN=2, K_UNROLL=1), + metile.Config(BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WM=2, WN=2, K_UNROLL=1), + metile.Config(BLOCK_M=128, BLOCK_N=64, BLOCK_K=64, WM=2, WN=4, K_UNROLL=1), + metile.Config(BLOCK_M=128, BLOCK_N=64, BLOCK_K=128, WM=2, WN=4, K_UNROLL=1), + metile.Config(BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, WM=4, WN=4, K_UNROLL=1), + metile.Config(BLOCK_M=128, BLOCK_N=128, BLOCK_K=128, WM=4, WN=4, K_UNROLL=1), +] + +autotuned_gelu = metile.autotune(configs=GEMM_CONFIGS, key=["M", "N", "K"], verbose=True)( + matmul_gelu +) +autotuned_silu = metile.autotune(configs=GEMM_CONFIGS, key=["M", "N", "K"], verbose=True)( + matmul_silu +) +autotuned_matmul = metile.autotune(configs=GEMM_CONFIGS, key=["M", "N", "K"], verbose=True)(matmul) + +COOLDOWN = 3.0 + +COL_SIZE = 20 +COL_T = 12 + + +def _print_table(title, rows): + print(f"\n {title}") + hdr = f" {'size':>{COL_SIZE}} {'metile (ms)':>{COL_T}} {'MLX (ms)':>{COL_T}}" + print(hdr) + print(" " + "-" * (len(hdr) - 4)) + for size_str, dt_mtile, dt_mlx in rows: + print(f" {size_str:>{COL_SIZE}} {dt_mtile:>{COL_T}.2f} {dt_mlx:>{COL_T}.2f}") + + +def _gelu_ref(x): + """GELU (sigmoid approx) matching meTile: x / (1 + exp(-1.702 * x)).""" + return x / (1.0 + mx.exp(-1.702 * x)) + + +def _silu_ref(x): + """SiLU: x / (1 + exp(-x)).""" + return x / (1.0 + mx.exp(-x)) + + +def main(): + # Fused GEMM+activation sizes (batch*seq, hidden, intermediate) + fused_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + # LLM-typical shapes + (32, 4096, 4096), + (128, 4096, 4096), + (512, 4096, 4096), + (1024, 4096, 4096), + ] + + # Full MLP sizes: (batch*seq, model_dim, ffn_dim) + mlp_sizes = [ + (128, 1024, 4096), + (256, 2048, 8192), + (512, 4096, 4096), + (1024, 4096, 4096), + (32, 4096, 4096), + ] + + if len(sys.argv) > 1 and sys.argv[1] == "--silu": + act = "silu" + else: + act = "gelu" + + autotuned_act = autotuned_gelu if act == "gelu" else autotuned_silu + act_fn = _gelu_ref if act == "gelu" else _silu_ref + + dev = MetalDevice.get() + + # --- Fused GEMM+activation --- + + print(f"=== Fused GEMM+{act.upper()} (autotuned) ===\n") + + rows = [] + for M, N, K in fused_sizes: + A_np = np.random.randn(M, K).astype(np.float32) + B_np = np.random.randn(K, N).astype(np.float32) + A_buf = metile.Buffer(data=A_np.ravel()) + B_buf = metile.Buffer(data=B_np.ravel()) + C_buf = metile.Buffer.zeros((M * N,)) + + def grid_fn(cfg, M=M, N=N): + return (metile.cdiv(M, cfg["BLOCK_M"]), metile.cdiv(N, cfg["BLOCK_N"])) + + dispatch = autotuned_act[grid_fn].prepare(A_buf, B_buf, C_buf, M, N, K) + dev.sync() + + A_mx, B_mx = mx.array(A_np), mx.array(B_np) + + @mx.compile + def mlx_fused(a, b): + return act_fn(a @ b) + + mx.eval(mlx_fused(A_mx, B_mx)) + + def mlx_fn(a=A_mx, b=B_mx): + mx.eval(mlx_fused(a, b)) + + time.sleep(COOLDOWN) + dt_mtile, dt_mlx = bench_interleaved(dispatch, mlx_fn, dev.sync) + rows.append((f"{M}x{N}x{K}", dt_mtile * 1000, dt_mlx * 1000)) + + _print_table(f"matmul_{act} (metile fused vs MLX compile)", rows) + print() + + # --- Full MLP pipeline --- + + print(f"=== Full MLP: {act.upper()}(x @ W1) @ W2 (autotuned) ===\n") + + rows = [] + for M, D, H in mlp_sizes: + X_np = np.random.randn(M, D).astype(np.float32) + W1_np = np.random.randn(D, H).astype(np.float32) + W2_np = np.random.randn(H, D).astype(np.float32) + + X_buf = metile.Buffer(data=X_np.ravel()) + W1_buf = metile.Buffer(data=W1_np.ravel()) + W2_buf = metile.Buffer(data=W2_np.ravel()) + H_buf = metile.Buffer.zeros((M * H,)) + Y_buf = metile.Buffer.zeros((M * D,)) + + def grid_up(cfg, M=M, H=H): + return (metile.cdiv(M, cfg["BLOCK_M"]), metile.cdiv(H, cfg["BLOCK_N"])) + + dispatch_up = autotuned_act[grid_up].prepare(X_buf, W1_buf, H_buf, M, H, D) + dev.sync() + + def grid_down(cfg, M=M, D=D): + return (metile.cdiv(M, cfg["BLOCK_M"]), metile.cdiv(D, cfg["BLOCK_N"])) + + dispatch_down = autotuned_matmul[grid_down].prepare(H_buf, W2_buf, Y_buf, M, D, H) + dev.sync() + + def mtile_mlp(up=dispatch_up, down=dispatch_down): + up() + down() + + X_mx = mx.array(X_np) + W1_mx = mx.array(W1_np) + W2_mx = mx.array(W2_np) + + @mx.compile + def mlx_mlp(x, w1, w2): + return act_fn(x @ w1) @ w2 + + mx.eval(mlx_mlp(X_mx, W1_mx, W2_mx)) + + def mlx_fn(x=X_mx, w1=W1_mx, w2=W2_mx): + mx.eval(mlx_mlp(x, w1, w2)) + + time.sleep(COOLDOWN) + dt_mtile, dt_mlx = bench_interleaved(mtile_mlp, mlx_fn, dev.sync) + rows.append((f"{M}x{D}x{H}", dt_mtile * 1000, dt_mlx * 1000)) + + _print_table(f"MLP {act} (metile fused vs MLX compile)", rows) + print() + + +if __name__ == "__main__": + main() diff --git a/kernels/mlp.py b/kernels/mlp.py new file mode 100644 index 0000000..93b137f --- /dev/null +++ b/kernels/mlp.py @@ -0,0 +1,53 @@ +import metile + + +@metile.kernel +def matmul_gelu( + A, + B, + C, + M, + N, + K, + BLOCK_M: metile.constexpr, + BLOCK_N: metile.constexpr, + BLOCK_K: metile.constexpr, +): + """ + Fused GEMM + GELU epilogue: C = GELU(A @ B) + """ + pid_m = metile.program_id(0) + pid_n = metile.program_id(1) + acc = metile.zeros((BLOCK_M, BLOCK_N), dtype="f32") + for k in metile.tile_range(0, K, BLOCK_K): + a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K)) + b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N)) + acc = metile.dot(a, b, acc) + acc = acc / (1.0 + metile.exp(0.0 - 1.702 * acc)) + metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N)) + + +@metile.kernel +def matmul_silu( + A, + B, + C, + M, + N, + K, + BLOCK_M: metile.constexpr, + BLOCK_N: metile.constexpr, + BLOCK_K: metile.constexpr, +): + """ + Fused GEMM + SiLU epilogue: C = SiLU(A @ B) + """ + pid_m = metile.program_id(0) + pid_n = metile.program_id(1) + acc = metile.zeros((BLOCK_M, BLOCK_N), dtype="f32") + for k in metile.tile_range(0, K, BLOCK_K): + a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K)) + b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N)) + acc = metile.dot(a, b, acc) + acc = acc / (1.0 + metile.exp(0.0 - acc)) + metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N)) diff --git a/metile/codegen/msl_emitter.py b/metile/codegen/msl_emitter.py index 1efe075..6094cb8 100644 --- a/metile/codegen/msl_emitter.py +++ b/metile/codegen/msl_emitter.py @@ -34,6 +34,77 @@ "tanh": "tanh", } +_BINOP_SYMBOLS_EPILOGUE = { + "add": "+", + "sub": "-", + "mul": "*", + "div": "/", +} + + +def _format_float_literal(v: float) -> str: + """Format a float constant for MSL.""" + s = f"{v}f" + if "." not in s and "e" not in s.lower(): + s = f"{v}.0f" + return s + + +def _emit_epilogue_chain(operations: list, elem_expr: str, lines: list, pad: str): + """Emit a chain of element-wise epilogue ops on a single element. + + Handles both simple (relu, unary, scale) and compound (binop with + constants, binop referencing original accumulator) epilogue patterns. + Operates on elem_expr (e.g. "acc[0][0].thread_elements()[0]" or "ct[i]"). + """ + # Check if the chain needs save_orig / binop_orig + has_chain = any(e[0] in ("save_orig", "binop", "binop_orig") for e in operations) + + if has_chain: + # Use temporaries for the chain + lines.append(f"{pad}{{") + lines.append(f"{pad} float _v = {elem_expr};") + has_orig = any(e[0] == "save_orig" for e in operations) + if has_orig: + lines.append(f"{pad} float _orig = _v;") + for epi in operations: + if epi[0] == "save_orig": + continue + elif epi[0] == "relu": + lines.append(f"{pad} _v = max(_v, 0.0f);") + elif epi[0] == "unary": + fn = _UNARY_MSL.get(epi[1], epi[1]) + lines.append(f"{pad} _v = {fn}(_v);") + elif epi[0] == "scale": + lines.append(f"{pad} _v *= _scale;") + elif epi[0] == "binop": + _, op_name, const_side, const_val = epi + sym = _BINOP_SYMBOLS_EPILOGUE.get(op_name, "+") + lit = _format_float_literal(const_val) + if const_side == "lhs": + lines.append(f"{pad} _v = {lit} {sym} _v;") + else: + lines.append(f"{pad} _v = _v {sym} {lit};") + elif epi[0] == "binop_orig": + _, op_name, orig_side = epi + sym = _BINOP_SYMBOLS_EPILOGUE.get(op_name, "+") + if orig_side == "lhs": + lines.append(f"{pad} _v = _orig {sym} _v;") + else: + lines.append(f"{pad} _v = _v {sym} _orig;") + lines.append(f"{pad} {elem_expr} = _v;") + lines.append(f"{pad}}}") + else: + # Simple ops — apply directly (backward compatible) + for epi in operations: + if epi[0] == "relu": + lines.append(f"{pad}{elem_expr} = max({elem_expr}, 0.0f);") + elif epi[0] == "unary": + fn = _UNARY_MSL.get(epi[1], epi[1]) + lines.append(f"{pad}{elem_expr} = {fn}({elem_expr});") + elif epi[0] == "scale": + lines.append(f"{pad}{elem_expr} *= _scale;") + def emit(func: mir.MFunction) -> str: """Generate MSL source code from a Metal IR function.""" @@ -85,9 +156,16 @@ def _emit_tensor_ops_kernel(func: mir.MFunction) -> str: lines.append(params_str) lines.append(") {") + # Check if preemptive mode (needs bounds guards for OOB simdgroups) + _preemptive = any(isinstance(op, mir.MMatmul2dSetup) and not op.cooperative for op in func.ops) + # Emit body by walking ops for op in func.ops: - _emit_gemm_op(op, lines, indent=1, func=func) + if _preemptive and isinstance(op, mir.MCoopTensorStore): + op._needs_bounds_guard = True + if _preemptive and isinstance(op, mir.MCoopTensorEpilogue): + op._needs_bounds_guard = True + _emit_gemm_op(op, lines, indent=1, func=func, _tensor_ops_preemptive=_preemptive) lines.append("}") return "\n".join(lines) @@ -235,7 +313,12 @@ def _uses_op_type(ops: list[mir.MOp], op_type) -> bool: def _emit_gemm_op( - op: mir.MOp, lines: list[str], indent: int, func: mir.MFunction, has_swizzle: bool = False + op: mir.MOp, + lines: list[str], + indent: int, + func: mir.MFunction, + has_swizzle: bool = False, + _tensor_ops_preemptive: bool = False, ): # Skip ops folded to constants by the fold pass if ( @@ -366,7 +449,10 @@ def _emit_gemm_op( _emit_cooperative_load(op, lines, indent, func) elif isinstance(op, mir.MForLoop): - _emit_for_loop(op, lines, indent, func, has_swizzle) + if _tensor_ops_preemptive and op.iv_name in ("k", "k0"): + _emit_for_loop_guarded(op, lines, indent, func) + else: + _emit_for_loop(op, lines, indent, func, has_swizzle) elif isinstance(op, mir.MSimdgroupRoleBlock): sgid_name = _val_name_gemm(op.sgid, func) @@ -659,23 +745,9 @@ def _emit_acc_elem_apply(op, lines, indent, func): ) for mi in range(op.num_8m): for ni in range(op.num_8n): - for epi in op.operations: - if epi[0] == "relu": - for e in (0, 1): - lines.append( - f"{pad}{acc}[{mi}][{ni}].thread_elements()[{e}] = " - f"max({acc}[{mi}][{ni}].thread_elements()[{e}], 0.0f);" - ) - elif epi[0] == "unary": - fn = epi[1] - for e in (0, 1): - lines.append( - f"{pad}{acc}[{mi}][{ni}].thread_elements()[{e}] = " - f"{fn}({acc}[{mi}][{ni}].thread_elements()[{e}]);" - ) - elif epi[0] == "scale": - for e in (0, 1): - lines.append(f"{pad}{acc}[{mi}][{ni}].thread_elements()[{e}] *= _scale;") + for e in (0, 1): + elem = f"{acc}[{mi}][{ni}].thread_elements()[{e}]" + _emit_epilogue_chain(op.operations, elem, lines, pad) def _emit_tensor_view_decl(op, lines, indent, func): @@ -744,6 +816,8 @@ def _emit_matmul2d_setup(op, lines, indent, func): lines.append(f"{pad}const uint sg_col = sgid % {WN}u;") lines.append(f"{pad}const uint tile_row = pid_m * {BM}u + sg_row * {SM}u;") lines.append(f"{pad}const uint tile_col = pid_n * {BN}u + sg_col * {SN}u;") + # Guard: skip OOB simdgroups when M or N < BLOCK_M or BLOCK_N + lines.append(f"{pad}const bool _valid_tile = (tile_row < uint(M)) && (tile_col < uint(N));") lines.append("") desc_bk = min(32, BK) if op.use_separated else BK @@ -827,26 +901,30 @@ def _emit_coop_tensor_epilogue(op, lines, indent): """Emit element-wise epilogue on cooperative_tensor.""" pad = " " * indent ct = op.ct_name + needs_guard = getattr(op, "_needs_bounds_guard", False) + if needs_guard: + lines.append(f"{pad}if (_valid_tile) {{") + indent += 1 + pad = " " * indent lines.append(f"{pad}// Fused epilogue on cooperative_tensor registers") lines.append(f"{pad}#pragma clang loop unroll(full)") lines.append(f"{pad}for (uint16_t i = 0; i < {ct}.get_capacity(); ++i) {{") lines.append(f"{pad} if ({ct}.is_valid_element(i)) {{") - for epi in op.operations: - if epi[0] == "relu": - lines.append(f"{pad} {ct}[i] = max({ct}[i], 0.0f);") - elif epi[0] == "unary": - lines.append(f"{pad} {ct}[i] = {epi[1]}({ct}[i]);") - elif epi[0] == "scale": - lines.append(f"{pad} {ct}[i] *= _scale;") + _emit_epilogue_chain(op.operations, f"{ct}[i]", lines, f"{pad} ") lines.append(f"{pad} }}") lines.append(f"{pad}}}") + if needs_guard: + lines.append(f"{' ' * (indent - 1)}}}") lines.append("") def _emit_coop_tensor_store(op, lines, indent): """Emit cooperative_tensor store to output slice.""" pad = " " * indent - lines.append(f"{pad}{op.ct_name}.store({op.output_slice});") + if getattr(op, "_needs_bounds_guard", False): + lines.append(f"{pad}if (_valid_tile) {op.ct_name}.store({op.output_slice});") + else: + lines.append(f"{pad}{op.ct_name}.store({op.output_slice});") def _emit_persistent_grab( @@ -1088,6 +1166,36 @@ def _emit_specialized_db_k_loop( lines.append(f"{pad}}}") +def _emit_for_loop_guarded(op: mir.MForLoop, lines: list[str], indent: int, func: mir.MFunction): + """Emit a tensor_ops K-loop with _valid_tile bounds guard. + + Barriers remain outside the guard (all threads must reach them), + while loads/compute/inner loops are wrapped in if (_valid_tile). + """ + pad = " " * indent + end = _val_name_gemm(op.end, func) if isinstance(op.end, mir.MValue) else str(op.end) + lines.append( + f"{pad}for (int {op.iv_name} = {op.start}; {op.iv_name} < {end}; {op.iv_name} += {op.step}) {{" + ) + + # Separate barrier ops from compute ops + barriers = [b for b in op.body if isinstance(b, mir.MBarrier)] + compute_ops = [b for b in op.body if not isinstance(b, mir.MBarrier)] + + # Emit barriers first (outside guard — all threads must participate) + for b_op in barriers: + _emit_gemm_op(b_op, lines, indent + 1, func) + + # Emit compute inside guard + if compute_ops: + lines.append(f"{pad} if (_valid_tile) {{") + for body_op in compute_ops: + _emit_gemm_op(body_op, lines, indent + 2, func, _tensor_ops_preemptive=True) + lines.append(f"{pad} }}") + + lines.append(f"{pad}}}") + + def _emit_for_loop( op: mir.MForLoop, lines: list[str], indent: int, func: mir.MFunction, has_swizzle: bool = False ): diff --git a/metile/compiler/lowering.py b/metile/compiler/lowering.py index c4eb188..fa6c3e4 100644 --- a/metile/compiler/lowering.py +++ b/metile/compiler/lowering.py @@ -25,7 +25,15 @@ def lower(func: tir.Function) -> mir.MFunction: dtype = p.type.dtype break if MetalDevice.get().supports_tensor_ops and dtype == "f32": - return _lower_tensor_ops_gemm(func) + # tensor_ops matmul2d requires SM,SN <= 32 for valid descriptor + constexprs = func.constexprs + BM = constexprs.get("BLOCK_M", 128) + BN = constexprs.get("BLOCK_N", 64) + WM = constexprs.get("WM", 2) + WN = constexprs.get("WN", 2) + SM, SN = BM // WM, BN // WN + if SM <= 32 and SN <= 32: + return _lower_tensor_ops_gemm(func) return _lower_gemm(func) ctx = _ElementwiseLoweringContext(func) return ctx.lower() @@ -874,14 +882,19 @@ def _make_coop_load( def _detect_epilogue(ops: list) -> list[tuple]: """Detect element-wise epilogue ops between GEMM dot loop and tile store. - Finds ops that produce TileType results (operating on the accumulator) - between the ForRange containing Dot and the TileStore. Distinguishes - these from offset computations (which produce ScalarType/I32). - - Supported patterns: - - Select with Compare(gt, acc, 0) → ReLU: ("relu",) - - Unary(fn, acc) → element-wise math: ("unary", fn_name) - - BinOp(mul, acc, scalar) → scale: ("scale",) + Traces the chain of element-wise ops applied to the accumulator after + the GEMM loop and before the tile store. Handles arbitrary compositions + of unary, binary-with-constant, and binary-with-original-accumulator ops. + + Returns a list of epilogue tuples: + - ("relu",) — max(val, 0) + - ("unary", fn_name) — fn(val) + - ("scale",) — val *= _scale (non-constant scalar) + - ("binop", op, "lhs"|"rhs", float) — binary op with a constant + "lhs"/"rhs" indicates which side the CONSTANT is on + - ("binop_orig", op, "lhs"|"rhs") — binary op referencing original acc + "lhs"/"rhs" indicates which side the ORIGINAL acc is on + - ("save_orig",) — prepended when binop_orig is used """ for_idx = store_idx = None for i, op in enumerate(ops): @@ -893,30 +906,90 @@ def _detect_epilogue(ops: list) -> list[tuple]: if for_idx is None or store_idx is None or store_idx <= for_idx + 1: return [] + # Find the accumulator value name (last Dot result in the loop body) + acc_name = _find_dot_result_name(ops[for_idx].body) + if acc_name is None: + return [] + epilogue = [] + chain_name = acc_name + needs_orig = False + for op in ops[for_idx + 1 : store_idx]: if not hasattr(op, "result") or op.result is None: continue rt = op.result.type if not isinstance(rt, TileType): continue - # This op produces a TileType result → epilogue op + if isinstance(op, tir.Select): - # ReLU: where(acc > 0, acc, 0) cond_op = op.condition.defining_op if cond_op and isinstance(cond_op, tir.Compare) and cond_op.predicate == "gt": epilogue.append(("relu",)) else: - # Generic clamp/select — treat as relu-like for now epilogue.append(("relu",)) + chain_name = op.result.name + elif isinstance(op, tir.Unary): epilogue.append(("unary", op.op)) - elif isinstance(op, tir.BinOp) and op.op == "mul": - epilogue.append(("scale",)) + chain_name = op.result.name + + elif isinstance(op, tir.BinOp): + lhs_tile = isinstance(op.lhs.type, TileType) + rhs_tile = isinstance(op.rhs.type, TileType) + + if lhs_tile and not rhs_tile: + # chain OP scalar_const + const_val = _extract_constant(op.rhs) + if const_val is not None: + epilogue.append(("binop", op.op, "rhs", const_val)) + else: + epilogue.append(("scale",)) + elif rhs_tile and not lhs_tile: + # scalar_const OP chain + const_val = _extract_constant(op.lhs) + if const_val is not None: + epilogue.append(("binop", op.op, "lhs", const_val)) + else: + return [] # non-constant scalar on lhs, can't fuse + elif lhs_tile and rhs_tile: + # Both TileType: one must be original acc, other is chain + lhs_is_orig = op.lhs.name == acc_name and op.lhs.name != chain_name + rhs_is_orig = op.rhs.name == acc_name and op.rhs.name != chain_name + if lhs_is_orig: + epilogue.append(("binop_orig", op.op, "lhs")) + needs_orig = True + elif rhs_is_orig: + epilogue.append(("binop_orig", op.op, "rhs")) + needs_orig = True + else: + return [] # can't fuse: two non-acc tile operands + else: + return [] + chain_name = op.result.name + + if needs_orig: + epilogue.insert(0, ("save_orig",)) return epilogue +def _find_dot_result_name(body_ops: list) -> str | None: + """Find the name of the last Dot result in a loop body.""" + name = None + for op in body_ops: + if isinstance(op, tir.Dot) and op.result: + name = op.result.name + return name + + +def _extract_constant(val) -> float | None: + """Extract a numeric literal from a Value, or None.""" + if val.defining_op and isinstance(val.defining_op, tir.Constant): + return float(val.defining_op.value) + return None + + def _lower_tensor_ops_gemm(func: tir.Function) -> mir.MFunction: """Lower a GEMM to Metal 4 tensor_ops matmul2d. diff --git a/metile/frontend/autotune.py b/metile/frontend/autotune.py index f5897a4..ca8c98f 100644 --- a/metile/frontend/autotune.py +++ b/metile/frontend/autotune.py @@ -150,7 +150,7 @@ def _bench(self, config, args, kwargs, dev): t0 = time.perf_counter() for _ in range(at.rep): launcher(*args, **merged) - dev.sync() + dev.sync() return (time.perf_counter() - t0) / at.rep diff --git a/metile/frontend/kernel.py b/metile/frontend/kernel.py index 4098b34..25233e9 100644 --- a/metile/frontend/kernel.py +++ b/metile/frontend/kernel.py @@ -23,6 +23,7 @@ vectorize_loads, ) from metile.frontend.tracing import TracingContext, TracingProxy, constexpr +from metile.ir import metal_ir as mir from metile.ir import tile_ir as tir from metile.ir.types import I32, PtrType, ScalarType from metile.runtime.buffer import MtileBuffer @@ -33,6 +34,24 @@ # Scalar buffer cache: (value, format_char) -> metal_buffer _scalar_buffer_cache: dict = {} +_ELEM_SIZES = {"float": 4, "half": 2, "int": 4, "uint": 4} + + +def _validate_threadgroup_memory(metal_ir: mir.MFunction): + """Raise RuntimeError if threadgroup memory exceeds hardware limit.""" + total_bytes = 0 + for op in metal_ir.ops: + if isinstance(op, mir.MThreadgroupAlloc): + total_bytes += op.size * _ELEM_SIZES.get(op.elem_type, 4) + if total_bytes == 0: + return + limit = MetalDevice.get().max_threadgroup_memory + if total_bytes > limit: + raise RuntimeError( + f"Kernel '{metal_ir.name}' requires {total_bytes} bytes threadgroup memory " + f"but device limit is {limit} bytes. Reduce tile sizes." + ) + def _dump(path: str, content: str): """Write debug output to a file, creating directories as needed.""" @@ -411,6 +430,9 @@ def _compile(self, args, constexprs: dict, param_names: list[str]) -> CompiledKe os.path.join(_debug_dir, "metal_ir", f"{metal_ir.name}.post_opt.txt"), ir_text ) + # Validate threadgroup memory fits within hardware limit + _validate_threadgroup_memory(metal_ir) + # Step 4: Generate MSL msl_source = emit(metal_ir) diff --git a/metile/runtime/metal_device.py b/metile/runtime/metal_device.py index b64f8b7..aef4326 100644 --- a/metile/runtime/metal_device.py +++ b/metile/runtime/metal_device.py @@ -326,6 +326,11 @@ def has_metal_compiler(self) -> bool: except (subprocess.TimeoutExpired, FileNotFoundError): return False + @cached_property + def max_threadgroup_memory(self) -> int: + """Max threadgroup memory in bytes (MTLDevice.maxThreadgroupMemoryLength).""" + return _send_uint64(self.device, "maxThreadgroupMemoryLength") + @cached_property def supports_tensor_ops(self) -> bool: """Check if device supports Metal 4 tensor_ops (M5+ and Xcode required).""" From 5863f594249df86d144fe14173a81f792c857d75 Mon Sep 17 00:00:00 2001 From: AndreSlavescu Date: Wed, 18 Mar 2026 02:22:13 -0400 Subject: [PATCH 2/3] remove header --- benchmarks/mlp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/mlp.py b/benchmarks/mlp.py index 8fc98a4..f24bf2d 100644 --- a/benchmarks/mlp.py +++ b/benchmarks/mlp.py @@ -1,5 +1,3 @@ -"""Fused MLP benchmark: meTile fused GEMM+activation vs MLX compile.""" - import sys import time from pathlib import Path From 5427dd12aca0335cff0c5c697a02ccfe878c07f4 Mon Sep 17 00:00:00 2001 From: AndreSlavescu Date: Wed, 18 Mar 2026 02:39:10 -0400 Subject: [PATCH 3/3] non-constant scalar coercion + max / min epilogue emission --- metile/codegen/msl_emitter.py | 26 ++++++++++++++++---------- metile/compiler/lowering.py | 4 +++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/metile/codegen/msl_emitter.py b/metile/codegen/msl_emitter.py index 6094cb8..c8329a1 100644 --- a/metile/codegen/msl_emitter.py +++ b/metile/codegen/msl_emitter.py @@ -79,19 +79,25 @@ def _emit_epilogue_chain(operations: list, elem_expr: str, lines: list, pad: str lines.append(f"{pad} _v *= _scale;") elif epi[0] == "binop": _, op_name, const_side, const_val = epi - sym = _BINOP_SYMBOLS_EPILOGUE.get(op_name, "+") lit = _format_float_literal(const_val) - if const_side == "lhs": - lines.append(f"{pad} _v = {lit} {sym} _v;") - else: - lines.append(f"{pad} _v = _v {sym} {lit};") + if op_name in _BINOP_SYMBOLS_EPILOGUE: + sym = _BINOP_SYMBOLS_EPILOGUE[op_name] + if const_side == "lhs": + lines.append(f"{pad} _v = {lit} {sym} _v;") + else: + lines.append(f"{pad} _v = _v {sym} {lit};") + elif op_name in ("max", "min"): + lines.append(f"{pad} _v = {op_name}(_v, {lit});") elif epi[0] == "binop_orig": _, op_name, orig_side = epi - sym = _BINOP_SYMBOLS_EPILOGUE.get(op_name, "+") - if orig_side == "lhs": - lines.append(f"{pad} _v = _orig {sym} _v;") - else: - lines.append(f"{pad} _v = _v {sym} _orig;") + if op_name in _BINOP_SYMBOLS_EPILOGUE: + sym = _BINOP_SYMBOLS_EPILOGUE[op_name] + if orig_side == "lhs": + lines.append(f"{pad} _v = _orig {sym} _v;") + else: + lines.append(f"{pad} _v = _v {sym} _orig;") + elif op_name in ("max", "min"): + lines.append(f"{pad} _v = {op_name}(_v, _orig);") lines.append(f"{pad} {elem_expr} = _v;") lines.append(f"{pad}}}") else: diff --git a/metile/compiler/lowering.py b/metile/compiler/lowering.py index fa6c3e4..44cc5b1 100644 --- a/metile/compiler/lowering.py +++ b/metile/compiler/lowering.py @@ -943,8 +943,10 @@ def _detect_epilogue(ops: list) -> list[tuple]: const_val = _extract_constant(op.rhs) if const_val is not None: epilogue.append(("binop", op.op, "rhs", const_val)) - else: + elif op.op == "mul": epilogue.append(("scale",)) + else: + return [] # non-constant scalar for non-mul op, can't fuse elif rhs_tile and not lhs_tile: # scalar_const OP chain const_val = _extract_constant(op.lhs)