diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 0e3c0e4..2f87bce 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -24,8 +24,8 @@ from mlir.dialects.transform import vector from mlir.dialects.transform import tensor -from lighthouse.workload import benchmark -from lighthouse.workload import get_bench_wrapper_schedule +from lighthouse import dialects as lh_dialects +from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.numpy import numpy_to_mlir_type from lighthouse.pipeline.helper import apply_registered_pass import lighthouse.utils as lh_utils @@ -358,6 +358,8 @@ def parse_cli(): args = parse_cli() with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + match args.dtype: case "f32": in_dtype = np.float32 diff --git a/examples/feed-forward-mpi/feed-forward-mpi.py b/examples/feed-forward-mpi/feed-forward-mpi.py index ec52c48..4a32315 100644 --- a/examples/feed-forward-mpi/feed-forward-mpi.py +++ b/examples/feed-forward-mpi/feed-forward-mpi.py @@ -26,6 +26,7 @@ as_ctype, ) +from lighthouse import dialects as lh_dialects from lighthouse.utils.memref import ( to_ctype as memref_to_ctype, deallocate_memrefs_on_exit, @@ -412,6 +413,8 @@ def schedule_modules( R = MPI.COMM_WORLD.Get_rank() with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = DistFF(args, P, R) # execute(wload, verbose=args.verbose) diff --git a/examples/workload/example.py b/examples/workload/example.py index d9eb7f2..c540cc9 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -18,6 +18,7 @@ from mlir.dialects import transform from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.pipeline.helper import match from lighthouse.pipeline.opt import PassBundles, apply_bundle @@ -159,6 +160,8 @@ def schedule_modules( if __name__ == "__main__": with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = ElementwiseSum(400, 400) print(" Dump kernel ".center(60, "-")) diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index 376993e..2126d34 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -21,6 +21,7 @@ from mlir.dialects import func, linalg, arith, memref from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import execute, benchmark import lighthouse.utils as lh_utils @@ -195,6 +196,8 @@ def payload_module(self): if __name__ == "__main__": with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = ElementwiseSumMLIRAlloc(400, 400) print(" Dump kernel ".center(60, "-")) diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 72bb5c6..a5395b1 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -19,6 +19,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -360,6 +361,8 @@ def parse_cli(): c_type = "f32" with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMatMul( M=M, N=N, diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 59d26cc..d8f4caa 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -24,6 +24,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -375,6 +376,8 @@ def parse_cli(): identity_weights = args.check_result with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMLP( batch_size=args.batch_size, input_size=args.input_size, diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py new file mode 100644 index 0000000..ea42a53 --- /dev/null +++ b/lighthouse/dialects/__init__.py @@ -0,0 +1,4 @@ +def register_and_load(): + from . import transform_ext + + transform_ext.register_and_load() diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py new file mode 100644 index 0000000..f9b358e --- /dev/null +++ b/lighthouse/dialects/transform_ext.py @@ -0,0 +1,129 @@ +from mlir import ir +from mlir.dialects import ext, transform, func, arith, scf, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.utils.mlir import func_cif + + +def register_and_load(context=None): + TransformExtensionDialect.load() + + +class TransformExtensionDialect(ext.Dialect, name="transform_ext"): + @classmethod + def load(cls, *args, **kwargs): + super().load(*args, **kwargs) + for op_cls in cls.operations: + if hasattr(op_cls, "attach_interface_impls"): + op_cls.attach_interface_impls() + + +class WrapInBenchingFuncOp( + TransformExtensionDialect.Operation, name="wrap_in_benching_func" +): + """Create a function that calls `target` function in a benchmarking loop. + + The new function has the same arguments as `target` plus three additional ones: + - A memref to store the timing results (one element per iteration). + - The number of timed iterations. + - The number of warmup iterations. + """ + + target: ext.Operand[transform.AnyOpType] + bench_func: ext.Result[transform.AnyOpType[()]] + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + @staticmethod + def wrap_in_benching_func(target: func.FuncOp, bench_name: str): + """Create a function that calls `target` in a benchmarking loop. + + Each call to `target` is timed separately, and the times (in seconds) + are stored in a memref that is passed as an additional argument to the + benchmark function. It also takes two additional arguments for the + number of runs and warmup iterations. + """ + + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) + index_t = ir.IndexType.get() + args = target.type.inputs + [time_memref_t, index_t, index_t] + + @func_cif(*args, name=bench_name) + def bench(*args): + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + func_args = list(args[: len(target.type.inputs)]) + times_memref, num_times, num_warmup = args[-3:] + for i in scf.for_(zero, num_warmup, one): + # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. + func.CallOp(target, func_args) + scf.yield_(()) + # TODO: get `num_times` from the `times_memref`. + for i in scf.for_(zero, num_times, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(target, func_args) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, times_memref, [i]) + scf.yield_(()) + + return bench.func_op + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "WrapInBenchingFuncOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + targets = state.get_payload_ops(op.target) + if bench_name_attr := op.attributes.get("bench_name"): + bench_name = bench_name_attr.value + if len(targets) != 1: + return DiagnosedSilenceableFailure.SilenceableFailure + else: + bench_name = None + + bench_funcs = [] + for target in targets: + if not isinstance(target, func.FuncOp): + return DiagnosedSilenceableFailure.SilenceableFailure + + with ir.InsertionPoint(target), target.location: + bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( + target, bench_name or f"bench_{target.name.value}" + ) + bench_funcs.append(bench_func) + + results.set_ops(op.bench_func, bench_funcs) + + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "WrapInBenchingFuncOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def wrap_in_benching_func( + target: ir.Value[transform.AnyOpType], bench_name: str | None = None +) -> ir.Value[transform.AnyOpType]: + """snake_case wrapper to create a WrapInBenchingFuncOp.""" + op = WrapInBenchingFuncOp(target=target) + if bench_name is not None: + op.attributes["bench_name"] = ir.StringAttr.get(bench_name) + return op.bench_func diff --git a/lighthouse/schedule/pattern_schedule.py b/lighthouse/schedule/pattern_schedule.py deleted file mode 100644 index 6216211..0000000 --- a/lighthouse/schedule/pattern_schedule.py +++ /dev/null @@ -1,54 +0,0 @@ -from mlir import rewrite, ir -from mlir.dialects import ext, transform -from lighthouse.schedule import schedule_boilerplate - - -@ext.register_dialect -class PatternDialect(ext.Dialect, name="lighthouse"): - pass - - -def rewrite_pattern(patterns: dict, pname: str): - """Return a rewrite pattern class that can be registered with MLIR. - The patterns dict should map op names to their corresponding match and rewrite functions.""" - - @ext.register_operation(PatternDialect, replace=True) - class RewritePattern(PatternDialect.Operation, name=pname): - @classmethod - def attach_interface_impls(cls, ctx=None): - cls.PatternDescriptorOpInterfaceFallbackModel.attach( - cls.OPERATION_NAME, context=ctx - ) - - class PatternDescriptorOpInterfaceFallbackModel( - transform.PatternDescriptorOpInterface - ): - @staticmethod - def populate_patterns( - op: "RewritePattern", - patternset: rewrite.RewritePatternSet, - ) -> None: - for op_name, match_and_rewrite in patterns.items(): - patternset.add(op_name, match_and_rewrite, benefit=1) - - return RewritePattern - - -def pattern_rewrite_schedule(patterns: dict, pname: str = "rewrite_pattern"): - """Return a transform module that applies the given rewrite patterns. - patterns: dict mapping op names to match-and-rewrite functions. - pname: name for the generated rewrite pattern operation.""" - - rw_pattern = rewrite_pattern(patterns, pname) - PatternDialect.load(register=False, reload=False) - rw_pattern.attach_interface_impls() - - with schedule_boilerplate() as (schedule, named_seq): - apply_patterns_op = transform.ApplyPatternsOp(named_seq.bodyTarget) - with ir.InsertionPoint(apply_patterns_op.patterns): - rw_pattern() - transform.yield_([named_seq.bodyTarget]) - named_seq.verify() - - schedule.body.operations[0].verify() - return schedule diff --git a/lighthouse/workload/__init__.py b/lighthouse/workload/__init__.py index 44d8b3e..d543526 100644 --- a/lighthouse/workload/__init__.py +++ b/lighthouse/workload/__init__.py @@ -2,13 +2,11 @@ from .runner import ( execute, benchmark, - bench_wrapper_pattern, get_bench_wrapper_schedule, ) __all__ = [ "Workload", - "bench_wrapper_pattern", "benchmark", "execute", "get_bench_wrapper_schedule", diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 70bcd3f..7c83af7 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -5,11 +5,14 @@ import numpy as np import os from mlir import ir -from mlir.dialects import func, arith, scf, memref +from mlir.dialects import transform +from mlir.dialects.transform import structured from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from lighthouse.schedule.pattern_schedule import pattern_rewrite_schedule -from lighthouse.utils.mlir import func_cif, get_mlir_library_path + +from lighthouse.dialects import transform_ext +from lighthouse.schedule import schedule_boilerplate +from lighthouse.utils.mlir import get_mlir_library_path from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload from typing import Optional @@ -66,70 +69,21 @@ def execute( raise ValueError("Benchmark verification failed.") -def bench_wrapper_pattern(funcname: str, benchname=None): - """Returns a rewrite pattern that matches a function named `funcname` and clones it - as a new function with name given by `benchname` (default: "bench_" + funcname). - The new function is a benchmark wrapper that calls the payload function and times it. - Every function call is timed separately. Returns the times (seconds) in a memref, - which is passed as an additional argument to the benchmark function. - It also takes two additional arguments for the number of runs and warmup iterations. - """ - marker = "__bench_wrapped__" - if benchname is None: - benchname = f"bench_{funcname}" - - def match_and_rewrite(op, rewriter): - if op.name.value != funcname: - return True # Failed match, return truthy value - if marker in op.attributes: - return True # Already wrapped, skip - payload_arguments = op.type.inputs - - with rewriter.ip, op.location: - # define rtclock function - f64_t = ir.F64Type.get() - func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), f64_t - ) - index_t = ir.IndexType.get() - args = payload_arguments + [time_memref_t, index_t, index_t] - - @func_cif(*args, name=benchname) - def bench(*args): - index_t = ir.IndexType.get() - zero = arith.constant(index_t, 0) - one = arith.constant(index_t, 1) - for i in scf.for_(zero, args[-1], one): - # FIXME(upstream): func.call is broken for this use case? - func.CallOp(op, list(args[: len(payload_arguments)])) - scf.yield_(()) - for i in scf.for_(zero, args[-2], one): - tic = func.call((f64_t,), "rtclock", ()) - func.CallOp(op, list(args[: len(payload_arguments)])) - toc = func.call((f64_t,), "rtclock", ()) - time = arith.subf(toc, tic) - memref.store(time, args[-3], [i]) - scf.yield_(()) - - # Mark original function as wrapped - op.attributes[marker] = ir.UnitAttr.get() - return None # Success - - return match_and_rewrite - - def get_bench_wrapper_schedule(workload: Workload): - return pattern_rewrite_schedule( - { - "func.func": bench_wrapper_pattern( - workload.payload_function_name, - workload.benchmark_function_name, - ) - }, - "add_bench_pattern", - ) + with schedule_boilerplate() as (schedule, named_seq): + named_func = structured.structured_match( + transform.AnyOpType.get(), + target=named_seq.bodyTarget, + ops={"func.func"}, + op_attrs={"sym_name": ir.StringAttr.get(workload.payload_function_name)}, + ) + bench_func = transform_ext.wrap_in_benching_func( + named_func, bench_name=workload.benchmark_function_name + ) + transform.yield_([bench_func]) + + schedule.body.operations[0].verify() + return schedule def benchmark( diff --git a/pyproject.toml b/pyproject.toml index 060b5d9..944992b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "lighthouse" dynamic = ["version"] requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging dependencies = [ - "mlir-python-bindings==20260315+69780be1d", + "mlir-python-bindings==20260316+f46a51538", "pyyaml>=6.0", ]