diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 2f87bce..840589c 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -27,7 +27,6 @@ from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.numpy import numpy_to_mlir_type -from lighthouse.pipeline.helper import apply_registered_pass import lighthouse.utils as lh_utils from lighthouse import schedule as lh_schedule import lighthouse.schedule.x86 as lh_schedule_x86 @@ -135,7 +134,7 @@ def payload(A, B, C): return mod - def schedule_modules( + def pipeline( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None, @@ -267,24 +266,19 @@ def schedule_modules( return scheds # Lower to LLVM. - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - target = named_seq.bodyTarget - target = apply_registered_pass(target, "convert-linalg-to-loops") - target = apply_registered_pass(target, "fold-memref-alias-ops") - target = apply_registered_pass(target, "expand-strided-metadata") - target = apply_registered_pass(target, "canonicalize") - target = apply_registered_pass(target, "convert-vector-to-scf") - target = apply_registered_pass(target, "lower-affine") - target = apply_registered_pass(target, "convert-scf-to-cf") - target = apply_registered_pass(target, "convert-vector-to-llvm") - target = apply_registered_pass(target, "convert-to-llvm") - target = apply_registered_pass(target, "reconcile-unrealized-casts") - lh_transform.cleanup(target) - - transform.yield_() - scheds.append(sched) - - return scheds + return scheds + [ + "convert-linalg-to-loops", + "fold-memref-alias-ops", + "expand-strided-metadata", + "canonicalize", + "convert-vector-to-scf", + "lower-affine", + "convert-scf-to-cf", + "convert-vector-to-llvm", + "convert-to-llvm", + "reconcile-unrealized-casts", + "CleanupBundle", + ] def parse_cli(): diff --git a/examples/feed-forward-mpi/feed-forward-mpi.py b/examples/feed-forward-mpi/feed-forward-mpi.py index 4a32315..3500e06 100644 --- a/examples/feed-forward-mpi/feed-forward-mpi.py +++ b/examples/feed-forward-mpi/feed-forward-mpi.py @@ -17,8 +17,6 @@ from mlir import ir from mlir.dialects import transform -from mlir.dialects.transform.bufferization import OneShotBufferizeOp -from mlir.dialects.bufferization import LayoutMapOption from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import ( ranked_memref_to_numpy, @@ -292,54 +290,18 @@ def find_factors(n): return mod - def get_shard_schedule(self): - with schedule_boilerplate() as (schedule, named_sequence): - func = match(named_sequence.bodyTarget, ops={"func.func"}) - func = apply_registered_pass( - func, - "sharding-propagation", - options={"traversal": "forward-backward"}, - ) - if self.verbose > 0: - transform.PrintOp(target=func) - func = apply_registered_pass(func, "shard-partition") - if self.verbose > 0: - transform.PrintOp(target=func) - func = apply_registered_pass(func, "shard-simplify") - if self.verbose > 0: - transform.PrintOp(target=func) - func = apply_registered_pass(func, "convert-shard-to-mpi") - func = apply_registered_pass(func, "canonicalize") - if self.verbose > 0: - transform.PrintOp(target=func) - func = apply_registered_pass(func, "tosa-to-linalg") - transform.YieldOp() - return schedule + def pipeline( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module | str]: + """Generate schedules: + - sharding propagation, partition, and MPI + - tosa-to-linalg + - adding benchmark wrapper + - tile_and_vector + - all the rest""" - def get_bufferize_schedule(self): + # Create passes to inject deallocations. Don't do this for dealloc_2d, though. with schedule_boilerplate() as (schedule, named_sequence): - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - mod = transform.get_parent_op( - anytype, func, op_name="builtin.module", deduplicate=True - ) - mod = apply_registered_pass(mod, "linalg-generalize-named-ops") - mod = apply_registered_pass(mod, "linalg-fuse-elementwise-ops") - identity_layout = LayoutMapOption.IdentityLayoutMap - mod = apply_registered_pass(mod, "eliminate-empty-tensors") - mod = OneShotBufferizeOp( - mod, - allow_return_allocs_from_loops=False, - bufferize_function_boundaries=True, - function_boundary_type_conversion=identity_layout, - ) - mod = apply_registered_pass( - mod, - "drop-equivalent-buffer-results", - options={"modify-public-functions": True}, - ) - - # Run passes to inject deallocations. Don't do this for dealloc_2d, though. for fname in [ self.benchmark_function_name, self.payload_function_name, @@ -351,56 +313,41 @@ def get_bufferize_schedule(self): "alloc_wout", ]: func = match( - mod, + named_sequence.bodyTarget, ops={"func.func"}, op_attrs={"sym_name": ir.StringAttr.get(fname)}, ) func = apply_registered_pass(func, "buffer-deallocation-pipeline") - mod = transform.get_parent_op( - anytype, func, op_name="builtin.module", deduplicate=True - ) - transform.YieldOp() - return schedule - - def get_lower_schedule(self): - with schedule_boilerplate() as (schedule, named_sequence): - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - mod = transform.get_parent_op( - anytype, func, op_name="builtin.module", deduplicate=True - ) - mod = apply_registered_pass(mod, "convert-linalg-to-parallel-loops") - mod = apply_registered_pass(mod, "scf-parallel-loop-fusion") - mod = apply_registered_pass(mod, "canonicalize") - mod = apply_registered_pass(mod, "expand-strided-metadata") - mod = apply_registered_pass(mod, "lower-affine") - mod = apply_registered_pass(mod, "convert-vector-to-scf") - mod = apply_registered_pass(mod, "convert-scf-to-cf") - mod = apply_registered_pass(mod, "symbol-dce") - mod = apply_registered_pass(mod, "convert-vector-to-llvm") - mod = apply_registered_pass(mod, "canonicalize") - mod = apply_registered_pass(mod, "convert-to-llvm") - mod = apply_registered_pass(mod, "reconcile-unrealized-casts") - mod = apply_registered_pass(mod, "cse") - if self.verbose > 1: - transform.PrintOp(target=mod) + transform.PrintOp(target=func) transform.YieldOp() - return schedule - def schedule_modules( - self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None - ) -> list[ir.Module]: - """Generate schedules: - - sharding propagation, partition, and MPI, tosa-to-linalg - - adding benchmark wrapper - - tile_and_vector - - all the rest""" return [ - self.get_shard_schedule(), + "func.func(sharding-propagation{traversal=forward-backward})", + "func.func(shard-partition)", + "func.func(shard-simplify)", + "func.func(convert-shard-to-mpi)", + "canonicalize", + "func.func(tosa-to-linalg)", get_bench_wrapper_schedule(self), tile_and_vector_matmul.create(self.tile_size), - self.get_bufferize_schedule(), - self.get_lower_schedule(), + "linalg-generalize-named-ops", + "eliminate-empty-tensors", + "one-shot-bufferize{bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}", + "drop-equivalent-buffer-results{modify-public-functions=1}", + schedule, + "convert-linalg-to-parallel-loops", + "scf-parallel-loop-fusion", + "canonicalize", + "expand-strided-metadata", + "lower-affine", + "convert-vector-to-scf", + "convert-scf-to-cf", + "symbol-dce", + "convert-vector-to-llvm", + "canonicalize", + "convert-to-llvm", + "reconcile-unrealized-casts", + "cse", ] diff --git a/examples/workload/example.py b/examples/workload/example.py index c540cc9..d7884be 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -15,12 +15,9 @@ from mlir import ir from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from mlir.dialects import func, linalg, bufferization -from mlir.dialects import transform from mlir.execution_engine import ExecutionEngine from lighthouse import dialects as lh_dialects -from lighthouse.pipeline.helper import match -from lighthouse.pipeline.opt import PassBundles, apply_bundle from lighthouse.workload import Workload, execute, benchmark, get_bench_wrapper_schedule @@ -121,41 +118,18 @@ def payload(A, B, C): return mod - def schedule_modules( + def pipeline( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None - ) -> ir.Module: - schedule_module = ir.Module.create() - schedule_module.operation.attributes["transform.with_named_sequence"] = ( - ir.UnitAttr.get() - ) - with ir.InsertionPoint(schedule_module.body): - named_sequence = transform.named_sequence( - "__transform_main", - [transform.AnyOpType.get()], - [], - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], - ) - with ir.InsertionPoint(named_sequence.body): - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - mod = apply_bundle(mod, PassBundles["BufferizationBundle"]) - mod = apply_bundle(mod, PassBundles["MLIRLoweringBundle"]) - mod = apply_bundle(mod, PassBundles["CleanupBundle"]) - - if stop_at_stage == "bufferized": - transform.YieldOp() - return [schedule_module] - - mod = apply_bundle(mod, PassBundles["LLVMLoweringBundle"]) - transform.YieldOp() - - return [get_bench_wrapper_schedule(self), schedule_module] + ) -> list[str]: + pline = [ + get_bench_wrapper_schedule(self), + "BufferizationBundle", + "MLIRLoweringBundle", + "CleanupBundle", + ] + if stop_at_stage == "bufferized": + return pline + return pline + ["LLVMLoweringBundle"] if __name__ == "__main__": diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index a5395b1..1b363c9 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -194,7 +194,7 @@ def payload_module(self) -> ir.Module: ) return mod - def schedule_modules( + def pipeline( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> list[ir.Module]: return [ diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index d8f4caa..a879750 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -256,7 +256,7 @@ def payload_module(self) -> ir.Module: ) return mod - def schedule_modules( + def pipeline( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> list[ir.Module]: return [ diff --git a/lighthouse/pipeline/opt.py b/lighthouse/pipeline/opt.py index 127d754..c75b71f 100644 --- a/lighthouse/pipeline/opt.py +++ b/lighthouse/pipeline/opt.py @@ -1,6 +1,5 @@ from abc import abstractmethod import importlib -from enum import Enum from pathlib import Path import os import re @@ -101,7 +100,7 @@ def __str__(self) -> str: }, ), Pass("drop-equivalent-buffer-results"), - # This last pass only works with the pass manager, not schedules. + # This last pass fails if the payload contains explicit deallocs, which is the case for some of our examples. # Pass("buffer-deallocation-pipeline"), ], # Lowers most dialects to basic control flow. @@ -144,40 +143,62 @@ class Transform: to be easily passed to a TransformStage. Arguments: - * filename: the file that will be imported into a schedule (mlir or python) + * source: either a filename (str) for a file that will be imported into + a schedule (mlir or python), or a ready ir.Module. - In the filename, the arguments ([...]) will define: + When source is a filename, the arguments ([...]) will define: * gen: function name in case of a python file, what name to look for to get the MLIR module - * seq: the named sequence to look for. FIXME: This is not implemented yet. - Empty will pick the first. - In the filename, the options ({...}) will be stored as a dict + When source is a filename, the options ({...}) will be stored as a dict and can be passed to the gen function """ - class Type(Enum): - MLIR = 1 - Python = 2 - - def __init__(self, filename: str): + def __init__(self, source: str | ir.Module): + self._module = None + if isinstance(source, ir.Module): + self._module = source + self._filename = None + self._options = None + return # First, eliminate arguments and options - filename, args, self.options = parse_args_and_opts(filename) - if filename.endswith(".mlir"): - self.type = Transform.Type.MLIR - elif filename.endswith(".py"): - self.type = Transform.Type.Python + self._filename, args, self._options = parse_args_and_opts(source) + if not self._filename.endswith(".mlir") and not self._filename.endswith(".py"): + raise ValueError(f"Unsupported transform file type: {self._filename}") + self._generator = args.get("gen", "create_schedule") + + def module(self, context: ir.Context) -> ir.Module: + """Create and return the MLIR module for this transform.""" + if self._module is not None: + if self._module.context != context: + raise ValueError("Module context does not match the provided context.") + return self._module + if self._filename.endswith(".mlir"): + return import_mlir_module(self._filename, context) + elif self._filename.endswith(".py"): + module_name = Path(os.path.basename(self._filename)).stem + spec = importlib.util.spec_from_file_location(module_name, self._filename) + py_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(py_module) + if not hasattr(py_module, self._generator): + raise ValueError( + f"Transform module '{self._filename}' does not define a '{self._generator}' generator function." + ) + gen = getattr(py_module, self._generator) + with context, ir.Location.unknown(): + return gen(self._options) else: - raise ValueError(f"Unsupported transform file type: {filename}") - self.filename = filename - self.generator = args["gen"] if "gen" in args else "create_schedule" - self.sequence = args["seq"] if "seq" in args else "" + raise ValueError(f"Unsupported transform type: {self._filename}") def __str__(self) -> str: """serialize name + filename for debugging purposes""" - if not self.options: - return self.name - return f"{self.filename}{{{self.options}}}" + if self._filename is None: + return ( + str(self._module.body.operations[0].name) if self._module else "" + ) + if not self._options: + return self._filename + return f"{self._filename}{{{self._options}}}" class Stage: @@ -230,35 +251,12 @@ class TransformStage(Stage): MLIR_ATTRIBUTE = "transform.with_named_sequence" def __init__(self, transform: Transform, context: ir.Context): - if transform.type == Transform.Type.MLIR: - # For MLIR transforms, we expect the file to define an MLIR transform sequence - # that we can import and apply to the module. This will be checked below. - self.module = import_mlir_module(transform.filename, context) - elif transform.type == Transform.Type.Python: - # For Python transforms, we expect the file to define a function - # that takes an MLIR module and returns a transformed MLIR module. - module_name = Path(os.path.basename(transform.filename)).stem - spec = importlib.util.spec_from_file_location( - module_name, transform.filename - ) - transform_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(transform_module) - if not hasattr(transform_module, transform.generator): - raise ValueError( - f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function." - ) - self.generator = getattr(transform_module, transform.generator) - - # Run the function with the dictionary as the options that will create the named sequence. - with context, ir.Location.unknown(): - self.module = self.generator(transform.options) - else: - raise ValueError(f"Unsupported transform type: {transform.type}") + self.module = transform.module(context) # Check if the imported module contains at least one schedule if TransformStage.MLIR_ATTRIBUTE not in self.module.operation.attributes: raise ValueError( - f"Transform module {transform.filename} does not define a {TransformStage.MLIR_ATTRIBUTE} attribute." + f"Transform module {transform} does not define a {TransformStage.MLIR_ATTRIBUTE} attribute." ) # Assume the first (or only) sequence. @@ -284,57 +282,68 @@ class Driver: Calling reset() will clear the pipeline and the module, allowing for a new pipeline to be constructed and run on a new module. """ - def __init__(self, filename: str, stages: list[str] = []): + def __init__(self, payload: str | ir.Module = None, stages: list[str | Stage] = []): # The context is shared across the entire pipeline, and is used to create the PassManager and Transform Schedules. # The module is owned by the Driver to encapsulate its use through the pipeline. # It is returned at the end of run() after being transformed by the stages in the pipeline. - self.context = ir.Context() + self.context = ( + payload.context if isinstance(payload, ir.Module) else ir.Context() + ) self.module = None - if filename: - self.import_payload(filename) + if payload is not None: + self.import_payload(payload) self.pipeline: list[Stage] = [] self.pipeline_fixed = False self.bundles = PassBundles if stages: self.add_stages(stages) - def import_payload(self, path: str) -> None: - """Import the payload module and set it as the current module in the pipeline.""" + def import_payload(self, payload: str | ir.Module) -> None: + """Import the payload and set it as the current module in the pipeline. Accepts a file path or a ready ir.Module.""" if self.module is not None: raise ValueError("Module already imported. Reset to start again.") - self.module = import_mlir_module(path, self.context) + if isinstance(payload, ir.Module): + if payload.context != self.context: + raise ValueError("Payload context does not match the Driver context.") + self.module = payload + else: + self.module = import_mlir_module(payload, self.context) - def add_stage(self, stage_name: str) -> None: + def add_stage(self, stage: str | Stage) -> None: + """Add a stage to the pipeline. Accepts a ready Stage object or a string (pass name, bundle name, or file path).""" if self.pipeline_fixed: raise ValueError("Pipeline is fixed. Reset to start again.") + if isinstance(stage, Stage): + self.pipeline.append(stage) + return + # Stages can contain arguments and options, clean up for os checks - filename = remove_args_and_opts(stage_name) + filename = remove_args_and_opts(stage) - if stage_name in self.bundles: + if stage in self.bundles: # Pass Bundle - self.pipeline.append(PassStage(self.bundles[stage_name], self.context)) + self.pipeline.append(PassStage(self.bundles[stage], self.context)) elif os.path.exists(filename): # Transform or YAML if filename.endswith(".mlir") or filename.endswith(".py"): - self.pipeline.append( - TransformStage(Transform(stage_name), self.context) - ) + self.pipeline.append(TransformStage(Transform(stage), self.context)) elif filename.endswith(".yaml"): - desc = PipelineDescriptor(stage_name) + desc = PipelineDescriptor(stage) for s in desc.get_stages(): self.add_stage(s) else: _, ext = os.path.splitext(filename) - raise ValueError(f"Unknown file type '{ext}' for stage '{stage_name}'.") + raise ValueError(f"Unknown file type '{ext}' for stage '{stage}'.") else: # Assume random strings represent a single pass # Will crash later if the pass name is not registered. - self.pipeline.append(PassStage([Pass(stage_name)], self.context)) + self.pipeline.append(PassStage([Pass(stage)], self.context)) - def add_stages(self, stages: list[str]) -> None: + def add_stages(self, stages: list[str | Stage]) -> None: + """Add multiple stages to the pipeline. Each element can be a ready Stage object or a string.""" for s in stages: self.add_stage(s) diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 7c83af7..8e455eb 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -94,13 +94,8 @@ def benchmark( check_correctness: bool = True, verbose: int = 0, ) -> np.ndarray: - # get original payload module - payload_module = workload.payload_module() - - # Lower payload with one or more schedules - schedule_modules = workload.schedule_modules(parameters=schedule_parameters) - for schedule_module in schedule_modules: - schedule_module.body.operations[0].apply(payload_module) + # Lower payload with pipeline + payload_module = workload.lower_payload(schedule_parameters=schedule_parameters) # get execution engine, rtclock requires mlir_c_runner libs = workload.shared_libs() diff --git a/lighthouse/workload/workload.py b/lighthouse/workload/workload.py index 2d3ac0e..0feaedc 100644 --- a/lighthouse/workload/workload.py +++ b/lighthouse/workload/workload.py @@ -10,6 +10,8 @@ from contextlib import contextmanager from typing import Optional +from lighthouse.pipeline.opt import Driver, Stage, Transform, TransformStage + class Workload(ABC): """ @@ -37,13 +39,16 @@ def payload_module(self) -> ir.Module: pass @abstractmethod - def schedule_modules( + def pipeline( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None, - ) -> list[ir.Module]: + ) -> list[str | Stage | ir.Module]: """ - Generate one or more MLIR modules containing the transform schedules. + Return a list of pipeline stages for lowering the payload. + + Each element can be a ready Stage object, an ir.Module (transform + schedule), a Stage, or a string (pass name, bundle name, or file path). The `stop_at_stage` argument can be used to interrupt lowering at a desired IR level for debugging purposes. @@ -57,37 +62,36 @@ def lower_payload( schedule_parameters: Optional[dict] = None, ) -> ir.Module: """ - Apply transform schedules to the payload module. + Apply the pipeline stages to the payload module using the Driver. Optionally dumps the payload IR at the desired level and/or the - transform schedules to stdout. + pipeline stages to stdout. Returns the lowered payload module. """ payload_module = self.payload_module() - schedule_modules = self.schedule_modules( + stages = self.pipeline( stop_at_stage=dump_payload, parameters=schedule_parameters ) - if not isinstance(schedule_modules, list): + if not isinstance(stages, list): raise TypeError( - f"schedule_modules() must return a list of ir.Module instances, " - f"got {type(schedule_modules).__name__}" - ) - if not schedule_modules: - raise ValueError( - "schedule_modules() must return at least one schedule module." + f"pipeline() must return a list, got {type(stages).__name__}" ) - if not dump_payload or dump_payload != "initial": - for schedule_module in schedule_modules: - # apply schedule on payload module - named_seq = schedule_module.body.operations[0] - named_seq.apply(payload_module) - if dump_payload: + if dump_payload and dump_payload == "initial": print(payload_module) + return payload_module + stages = [ + TransformStage(Transform(s), s.context) if isinstance(s, ir.Module) else s + for s in stages + ] + driver = Driver(payload_module, stages) if dump_schedule: - for schedule_module in schedule_modules: - print(schedule_module) - return payload_module + for stage in driver.pipeline: + print(stage) + module = driver.run() + if dump_payload: + print(module) + return module @abstractmethod @contextmanager