From 302a972ca1c4e69b657ac90a69fe98b88cef7663 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Mar 2026 16:18:15 -0700 Subject: [PATCH 1/5] [transform_ext] Move populate_pattern op to dialects dir Introduces op `transform_ext.populate_pattern TARGET_OP_KIND PAT_NAME PRIORITY` where patterns can be registered on `PopulatePatternOp` via its `name_to_pattern_rewrite` class member. Fixes #80. --- examples/workload/example.py | 3 ++ examples/workload/example_mlir.py | 3 ++ examples/xegpu/matmul.py | 3 ++ examples/xegpu/mlp.py | 3 ++ lighthouse/dialects/__init__.py | 4 ++ lighthouse/dialects/transform_ext.py | 60 +++++++++++++++++++++++++ lighthouse/schedule/pattern_schedule.py | 54 ---------------------- lighthouse/workload/runner.py | 29 +++++++----- pyproject.toml | 2 +- 9 files changed, 96 insertions(+), 65 deletions(-) create mode 100644 lighthouse/dialects/__init__.py create mode 100644 lighthouse/dialects/transform_ext.py delete mode 100644 lighthouse/schedule/pattern_schedule.py diff --git a/examples/workload/example.py b/examples/workload/example.py index d9eb7f2a..c540cc9a 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -18,6 +18,7 @@ from mlir.dialects import transform from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.pipeline.helper import match from lighthouse.pipeline.opt import PassBundles, apply_bundle @@ -159,6 +160,8 @@ def schedule_modules( if __name__ == "__main__": with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = ElementwiseSum(400, 400) print(" Dump kernel ".center(60, "-")) diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index 376993e2..2126d34b 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -21,6 +21,7 @@ from mlir.dialects import func, linalg, arith, memref from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import execute, benchmark import lighthouse.utils as lh_utils @@ -195,6 +196,8 @@ def payload_module(self): if __name__ == "__main__": with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = ElementwiseSumMLIRAlloc(400, 400) print(" Dump kernel ".center(60, "-")) diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 72bb5c6d..a5395b18 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -19,6 +19,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -360,6 +361,8 @@ def parse_cli(): c_type = "f32" with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMatMul( M=M, N=N, diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 59d26ccc..d8f4caa6 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -24,6 +24,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -375,6 +376,8 @@ def parse_cli(): identity_weights = args.check_result with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMLP( batch_size=args.batch_size, input_size=args.input_size, diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py new file mode 100644 index 00000000..ea42a530 --- /dev/null +++ b/lighthouse/dialects/__init__.py @@ -0,0 +1,4 @@ +def register_and_load(): + from . import transform_ext + + transform_ext.register_and_load() diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py new file mode 100644 index 00000000..a3e9ca60 --- /dev/null +++ b/lighthouse/dialects/transform_ext.py @@ -0,0 +1,60 @@ +from mlir import rewrite, ir +from mlir.dialects import ext, transform + + +def register_and_load(context=None): + TransformExtensionDialect.load() + + +class TransformExtensionDialect(ext.Dialect, name="transform_ext"): + @classmethod + def load(cls, *args, **kwargs): + super().load(*args, **kwargs) + for op_cls in cls.operations: + if hasattr(op_cls, "attach_interface_impls"): + op_cls.attach_interface_impls() + + +class PopulatePatternOp(TransformExtensionDialect.Operation, name="populate_pattern"): + """An operation to populate a pattern set with a specific pattern. + + To be used in the region of `transform.apply_patterns`.""" + + op_kind: ir.StringAttr + pattern_name: ir.StringAttr + priority: ir.IntegerAttr + + # A mapping from pattern names to their corresponding rewrite functions. + # This should be populated by the users of this operation. In effect serves + # as a registry for rewrite patterns that can be applied by this operation. + name_to_rewrite_pattern = {} + + @classmethod + def attach_interface_impls(cls, context=None): + cls.PatternDescriptorOpInterfaceModel.attach( + cls.OPERATION_NAME, context=context + ) + + class PatternDescriptorOpInterfaceModel(transform.PatternDescriptorOpInterface): + @staticmethod + def populate_patterns( + op: "PopulatePatternOp", + patternset: rewrite.RewritePatternSet, + ) -> None: + patternset.add( + op.op_kind.value, + op.name_to_rewrite_pattern[op.pattern_name.value], + benefit=op.priority.value, + ) + + +def populate_pattern( + op_kind: str, pattern_name: str, priority: int +) -> PopulatePatternOp: + """Camelcase constructor for PopulatePatternOp.""" + priority_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), priority) + return PopulatePatternOp( + op_kind=ir.StringAttr.get(op_kind), + pattern_name=ir.StringAttr.get(pattern_name), + priority=priority_attr, + ) diff --git a/lighthouse/schedule/pattern_schedule.py b/lighthouse/schedule/pattern_schedule.py deleted file mode 100644 index 1ea6709e..00000000 --- a/lighthouse/schedule/pattern_schedule.py +++ /dev/null @@ -1,54 +0,0 @@ -from mlir import rewrite, ir -from mlir.dialects import ext, transform -from lighthouse.schedule.utils import schedule_boilerplate - - -@ext.register_dialect -class PatternDialect(ext.Dialect, name="lighthouse"): - pass - - -def rewrite_pattern(patterns: dict, pname: str): - """Return a rewrite pattern class that can be registered with MLIR. - The patterns dict should map op names to their corresponding match and rewrite functions.""" - - @ext.register_operation(PatternDialect, replace=True) - class RewritePattern(PatternDialect.Operation, name=pname): - @classmethod - def attach_interface_impls(cls, ctx=None): - cls.PatternDescriptorOpInterfaceFallbackModel.attach( - cls.OPERATION_NAME, context=ctx - ) - - class PatternDescriptorOpInterfaceFallbackModel( - transform.PatternDescriptorOpInterface - ): - @staticmethod - def populate_patterns( - op: "RewritePattern", - patternset: rewrite.RewritePatternSet, - ) -> None: - for op_name, match_and_rewrite in patterns.items(): - patternset.add(op_name, match_and_rewrite, benefit=1) - - return RewritePattern - - -def pattern_rewrite_schedule(patterns: dict, pname: str = "rewrite_pattern"): - """Return a transform module that applies the given rewrite patterns. - patterns: dict mapping op names to match-and-rewrite functions. - pname: name for the generated rewrite pattern operation.""" - - rw_pattern = rewrite_pattern(patterns, pname) - PatternDialect.load(register=False, reload=False) - rw_pattern.attach_interface_impls() - - with schedule_boilerplate() as (schedule, named_seq): - apply_patterns_op = transform.ApplyPatternsOp(named_seq.bodyTarget) - with ir.InsertionPoint(apply_patterns_op.patterns): - rw_pattern() - transform.yield_([named_seq.bodyTarget]) - named_seq.verify() - - schedule.body.operations[0].verify() - return schedule diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 70bcd3f0..1954edb7 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -5,10 +5,11 @@ import numpy as np import os from mlir import ir -from mlir.dialects import func, arith, scf, memref +from mlir.dialects import func, arith, scf, memref, transform from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from lighthouse.schedule.pattern_schedule import pattern_rewrite_schedule +from lighthouse.dialects.transform_ext import PopulatePatternOp, populate_pattern +from lighthouse.schedule.utils import schedule_boilerplate from lighthouse.utils.mlir import func_cif, get_mlir_library_path from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload @@ -121,16 +122,24 @@ def bench(*args): def get_bench_wrapper_schedule(workload: Workload): - return pattern_rewrite_schedule( - { - "func.func": bench_wrapper_pattern( - workload.payload_function_name, - workload.benchmark_function_name, - ) - }, - "add_bench_pattern", + PopulatePatternOp.name_to_rewrite_pattern["bench_wrapper"] = bench_wrapper_pattern( + workload.payload_function_name, + workload.benchmark_function_name, ) + with schedule_boilerplate() as (schedule, named_seq): + apply_patterns_op = transform.apply_patterns(named_seq.bodyTarget) + with ir.InsertionPoint(apply_patterns_op.patterns): + populate_pattern( + op_kind="func.func", + pattern_name="bench_wrapper", + priority=1, + ) + transform.yield_([named_seq.bodyTarget]) + + schedule.body.operations[0].verify() + return schedule + def benchmark( workload: Workload, diff --git a/pyproject.toml b/pyproject.toml index 7f963e04..9c79ac48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "lighthouse" dynamic = ["version"] requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging dependencies = [ - "mlir-python-bindings==20260315+69780be1d" + "mlir-python-bindings==20260316+f46a51538" ] [dependency-groups] From 3fdefb9e0bb07b78380df89c0b61f35721febdde Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Mar 2026 16:31:18 -0700 Subject: [PATCH 2/5] Register op as well for MPI example --- examples/feed-forward-mpi/feed-forward-mpi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/feed-forward-mpi/feed-forward-mpi.py b/examples/feed-forward-mpi/feed-forward-mpi.py index 1373bf04..66bf078f 100644 --- a/examples/feed-forward-mpi/feed-forward-mpi.py +++ b/examples/feed-forward-mpi/feed-forward-mpi.py @@ -26,6 +26,7 @@ as_ctype, ) +from lighthouse import dialects as lh_dialects from lighthouse.utils.memref import ( to_ctype as memref_to_ctype, deallocate_memrefs_on_exit, @@ -412,6 +413,8 @@ def schedule_modules( R = MPI.COMM_WORLD.Get_rank() with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = DistFF(args, P, R) # execute(wload, verbose=args.verbose) From b18dad11e624571bd7d8304d521b156d3ce5cf68 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Mar 2026 16:34:17 -0700 Subject: [PATCH 3/5] Change arg order --- lighthouse/dialects/transform_ext.py | 4 ++-- lighthouse/workload/runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py index a3e9ca60..fe730056 100644 --- a/lighthouse/dialects/transform_ext.py +++ b/lighthouse/dialects/transform_ext.py @@ -20,8 +20,8 @@ class PopulatePatternOp(TransformExtensionDialect.Operation, name="populate_patt To be used in the region of `transform.apply_patterns`.""" - op_kind: ir.StringAttr pattern_name: ir.StringAttr + op_kind: ir.StringAttr priority: ir.IntegerAttr # A mapping from pattern names to their corresponding rewrite functions. @@ -49,7 +49,7 @@ def populate_patterns( def populate_pattern( - op_kind: str, pattern_name: str, priority: int + pattern_name: str, op_kind: str, priority: int ) -> PopulatePatternOp: """Camelcase constructor for PopulatePatternOp.""" priority_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), priority) diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 1954edb7..dec5036a 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -131,8 +131,8 @@ def get_bench_wrapper_schedule(workload: Workload): apply_patterns_op = transform.apply_patterns(named_seq.bodyTarget) with ir.InsertionPoint(apply_patterns_op.patterns): populate_pattern( - op_kind="func.func", pattern_name="bench_wrapper", + op_kind="func.func", priority=1, ) transform.yield_([named_seq.bodyTarget]) From d395035ea5c5dd0538f05f9a21a2ef321587afc7 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 19 Mar 2026 08:52:43 -0700 Subject: [PATCH 4/5] Get rid of rewrite pattern and just make into a normal transform op --- lighthouse/dialects/transform_ext.py | 145 ++++++++++++++++++++------- lighthouse/workload/__init__.py | 2 - lighthouse/workload/runner.py | 85 +++------------- 3 files changed, 122 insertions(+), 110 deletions(-) diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py index fe730056..f9b358e1 100644 --- a/lighthouse/dialects/transform_ext.py +++ b/lighthouse/dialects/transform_ext.py @@ -1,5 +1,8 @@ -from mlir import rewrite, ir -from mlir.dialects import ext, transform +from mlir import ir +from mlir.dialects import ext, transform, func, arith, scf, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.utils.mlir import func_cif def register_and_load(context=None): @@ -15,46 +18,112 @@ def load(cls, *args, **kwargs): op_cls.attach_interface_impls() -class PopulatePatternOp(TransformExtensionDialect.Operation, name="populate_pattern"): - """An operation to populate a pattern set with a specific pattern. - - To be used in the region of `transform.apply_patterns`.""" +class WrapInBenchingFuncOp( + TransformExtensionDialect.Operation, name="wrap_in_benching_func" +): + """Create a function that calls `target` function in a benchmarking loop. - pattern_name: ir.StringAttr - op_kind: ir.StringAttr - priority: ir.IntegerAttr + The new function has the same arguments as `target` plus three additional ones: + - A memref to store the timing results (one element per iteration). + - The number of timed iterations. + - The number of warmup iterations. + """ - # A mapping from pattern names to their corresponding rewrite functions. - # This should be populated by the users of this operation. In effect serves - # as a registry for rewrite patterns that can be applied by this operation. - name_to_rewrite_pattern = {} + target: ext.Operand[transform.AnyOpType] + bench_func: ext.Result[transform.AnyOpType[()]] @classmethod def attach_interface_impls(cls, context=None): - cls.PatternDescriptorOpInterfaceModel.attach( - cls.OPERATION_NAME, context=context - ) + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + @staticmethod + def wrap_in_benching_func(target: func.FuncOp, bench_name: str): + """Create a function that calls `target` in a benchmarking loop. + + Each call to `target` is timed separately, and the times (in seconds) + are stored in a memref that is passed as an additional argument to the + benchmark function. It also takes two additional arguments for the + number of runs and warmup iterations. + """ + + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) + index_t = ir.IndexType.get() + args = target.type.inputs + [time_memref_t, index_t, index_t] - class PatternDescriptorOpInterfaceModel(transform.PatternDescriptorOpInterface): + @func_cif(*args, name=bench_name) + def bench(*args): + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + func_args = list(args[: len(target.type.inputs)]) + times_memref, num_times, num_warmup = args[-3:] + for i in scf.for_(zero, num_warmup, one): + # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. + func.CallOp(target, func_args) + scf.yield_(()) + # TODO: get `num_times` from the `times_memref`. + for i in scf.for_(zero, num_times, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(target, func_args) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, times_memref, [i]) + scf.yield_(()) + + return bench.func_op + + class TransformOpInterfaceModel(transform.TransformOpInterface): @staticmethod - def populate_patterns( - op: "PopulatePatternOp", - patternset: rewrite.RewritePatternSet, - ) -> None: - patternset.add( - op.op_kind.value, - op.name_to_rewrite_pattern[op.pattern_name.value], - benefit=op.priority.value, - ) - - -def populate_pattern( - pattern_name: str, op_kind: str, priority: int -) -> PopulatePatternOp: - """Camelcase constructor for PopulatePatternOp.""" - priority_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), priority) - return PopulatePatternOp( - op_kind=ir.StringAttr.get(op_kind), - pattern_name=ir.StringAttr.get(pattern_name), - priority=priority_attr, - ) + def apply( + op: "WrapInBenchingFuncOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + targets = state.get_payload_ops(op.target) + if bench_name_attr := op.attributes.get("bench_name"): + bench_name = bench_name_attr.value + if len(targets) != 1: + return DiagnosedSilenceableFailure.SilenceableFailure + else: + bench_name = None + + bench_funcs = [] + for target in targets: + if not isinstance(target, func.FuncOp): + return DiagnosedSilenceableFailure.SilenceableFailure + + with ir.InsertionPoint(target), target.location: + bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( + target, bench_name or f"bench_{target.name.value}" + ) + bench_funcs.append(bench_func) + + results.set_ops(op.bench_func, bench_funcs) + + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "WrapInBenchingFuncOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def wrap_in_benching_func( + target: ir.Value[transform.AnyOpType], bench_name: str | None = None +) -> ir.Value[transform.AnyOpType]: + """snake_case wrapper to create a WrapInBenchingFuncOp.""" + op = WrapInBenchingFuncOp(target=target) + if bench_name is not None: + op.attributes["bench_name"] = ir.StringAttr.get(bench_name) + return op.bench_func diff --git a/lighthouse/workload/__init__.py b/lighthouse/workload/__init__.py index 44d8b3e2..d5435261 100644 --- a/lighthouse/workload/__init__.py +++ b/lighthouse/workload/__init__.py @@ -2,13 +2,11 @@ from .runner import ( execute, benchmark, - bench_wrapper_pattern, get_bench_wrapper_schedule, ) __all__ = [ "Workload", - "bench_wrapper_pattern", "benchmark", "execute", "get_bench_wrapper_schedule", diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index dec5036a..8798e068 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -5,12 +5,14 @@ import numpy as np import os from mlir import ir -from mlir.dialects import func, arith, scf, memref, transform +from mlir.dialects import transform +from mlir.dialects.transform import structured from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from lighthouse.dialects.transform_ext import PopulatePatternOp, populate_pattern + +from lighthouse.dialects import transform_ext from lighthouse.schedule.utils import schedule_boilerplate -from lighthouse.utils.mlir import func_cif, get_mlir_library_path +from lighthouse.utils.mlir import get_mlir_library_path from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload from typing import Optional @@ -67,75 +69,18 @@ def execute( raise ValueError("Benchmark verification failed.") -def bench_wrapper_pattern(funcname: str, benchname=None): - """Returns a rewrite pattern that matches a function named `funcname` and clones it - as a new function with name given by `benchname` (default: "bench_" + funcname). - The new function is a benchmark wrapper that calls the payload function and times it. - Every function call is timed separately. Returns the times (seconds) in a memref, - which is passed as an additional argument to the benchmark function. - It also takes two additional arguments for the number of runs and warmup iterations. - """ - marker = "__bench_wrapped__" - if benchname is None: - benchname = f"bench_{funcname}" - - def match_and_rewrite(op, rewriter): - if op.name.value != funcname: - return True # Failed match, return truthy value - if marker in op.attributes: - return True # Already wrapped, skip - payload_arguments = op.type.inputs - - with rewriter.ip, op.location: - # define rtclock function - f64_t = ir.F64Type.get() - func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), f64_t - ) - index_t = ir.IndexType.get() - args = payload_arguments + [time_memref_t, index_t, index_t] - - @func_cif(*args, name=benchname) - def bench(*args): - index_t = ir.IndexType.get() - zero = arith.constant(index_t, 0) - one = arith.constant(index_t, 1) - for i in scf.for_(zero, args[-1], one): - # FIXME(upstream): func.call is broken for this use case? - func.CallOp(op, list(args[: len(payload_arguments)])) - scf.yield_(()) - for i in scf.for_(zero, args[-2], one): - tic = func.call((f64_t,), "rtclock", ()) - func.CallOp(op, list(args[: len(payload_arguments)])) - toc = func.call((f64_t,), "rtclock", ()) - time = arith.subf(toc, tic) - memref.store(time, args[-3], [i]) - scf.yield_(()) - - # Mark original function as wrapped - op.attributes[marker] = ir.UnitAttr.get() - return None # Success - - return match_and_rewrite - - def get_bench_wrapper_schedule(workload: Workload): - PopulatePatternOp.name_to_rewrite_pattern["bench_wrapper"] = bench_wrapper_pattern( - workload.payload_function_name, - workload.benchmark_function_name, - ) - with schedule_boilerplate() as (schedule, named_seq): - apply_patterns_op = transform.apply_patterns(named_seq.bodyTarget) - with ir.InsertionPoint(apply_patterns_op.patterns): - populate_pattern( - pattern_name="bench_wrapper", - op_kind="func.func", - priority=1, - ) - transform.yield_([named_seq.bodyTarget]) + named_func = structured.structured_match( + transform.AnyOpType.get(), + target=named_seq.bodyTarget, + ops={"func.func"}, + op_attrs={"sym_name": ir.StringAttr.get(workload.payload_function_name)}, + ) + bench_func = transform_ext.wrap_in_benching_func( + named_func, bench_name=workload.benchmark_function_name + ) + transform.yield_([bench_func]) schedule.body.operations[0].verify() return schedule From 6b36c7006a94964b5482b48ee0ddd3ca978cd184 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 19 Mar 2026 09:26:54 -0700 Subject: [PATCH 5/5] Fixes --- examples/cpu/x86/matmul.py | 6 ++++-- lighthouse/workload/runner.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 0e3c0e47..2f87bce1 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -24,8 +24,8 @@ from mlir.dialects.transform import vector from mlir.dialects.transform import tensor -from lighthouse.workload import benchmark -from lighthouse.workload import get_bench_wrapper_schedule +from lighthouse import dialects as lh_dialects +from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.numpy import numpy_to_mlir_type from lighthouse.pipeline.helper import apply_registered_pass import lighthouse.utils as lh_utils @@ -358,6 +358,8 @@ def parse_cli(): args = parse_cli() with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + match args.dtype: case "f32": in_dtype = np.float32 diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 8798e068..7c83af72 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -11,7 +11,7 @@ from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from lighthouse.dialects import transform_ext -from lighthouse.schedule.utils import schedule_boilerplate +from lighthouse.schedule import schedule_boilerplate from lighthouse.utils.mlir import get_mlir_library_path from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload