From 84e1a8a7709f2469f844681f532f69b51bd2046d Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 20 Mar 2026 13:54:30 +0100 Subject: [PATCH] [dialects][transform] Reorganize transform dialect extensions Refactors transform extension dialects into separate submodules and places their definitions into separate files. The refactor aligns lighthouse dialect structure closer to upstream bindings mirroring its nesting. That also makes the op discovery easier and allows for cleaner addition of new extensions. When possible only the snake case op wrappers are exposed to promote their use which simplifies reuse across code and IR forms. --- lighthouse/dialects/__init__.py | 12 +- lighthouse/dialects/transform/__init__.py | 0 .../dialects/transform/smt_ext/__init__.py | 16 + .../smt_ext/dialect.py} | 17 +- .../smt_ext/ops/constrain_params.py} | 27 +- .../transform/transform_ext/__init__.py | 16 + .../transform/transform_ext/dialect.py | 17 + .../transform_ext/ops/get_named_attribute.py | 63 +++ .../transform_ext/ops/param_cmp_eq.py | 53 +++ .../transform/transform_ext/ops/replace.py | 132 +++++++ .../transform_ext/ops/wrap_in_benching.py | 117 ++++++ .../dialects/transform/tune_ext/__init__.py | 9 + .../tune_ext/dialect.py} | 2 - lighthouse/dialects/transform_ext.py | 366 ------------------ lighthouse/schedule/xegpu/mlp_schedule.py | 10 +- lighthouse/tune/trace.py | 4 +- lighthouse/workload/runner.py | 2 +- test/dialects/transform_ext.py | 6 +- 18 files changed, 455 insertions(+), 414 deletions(-) create mode 100644 lighthouse/dialects/transform/__init__.py create mode 100644 lighthouse/dialects/transform/smt_ext/__init__.py rename lighthouse/dialects/{smt_ext.py => transform/smt_ext/dialect.py} (87%) rename lighthouse/dialects/{transform_smt_ext.py => transform/smt_ext/ops/constrain_params.py} (91%) create mode 100644 lighthouse/dialects/transform/transform_ext/__init__.py create mode 100644 lighthouse/dialects/transform/transform_ext/dialect.py create mode 100644 lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py create mode 100644 lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py create mode 100644 lighthouse/dialects/transform/transform_ext/ops/replace.py create mode 100644 lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py create mode 100644 lighthouse/dialects/transform/tune_ext/__init__.py rename lighthouse/dialects/{transform_tune_ext.py => transform/tune_ext/dialect.py} (99%) delete mode 100644 lighthouse/dialects/transform_ext.py diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py index c693b9f..96da25a 100644 --- a/lighthouse/dialects/__init__.py +++ b/lighthouse/dialects/__init__.py @@ -1,10 +1,8 @@ def register_and_load(**kwargs): - from . import smt_ext - from . import transform_ext - from . import transform_smt_ext - from . import transform_tune_ext + from .transform import transform_ext + from .transform import smt_ext + from .transform import tune_ext - smt_ext.register_and_load(**kwargs) transform_ext.register_and_load(**kwargs) - transform_smt_ext.register_and_load(**kwargs) - transform_tune_ext.register_and_load(**kwargs) + smt_ext.register_and_load(**kwargs) + tune_ext.register_and_load(**kwargs) diff --git a/lighthouse/dialects/transform/__init__.py b/lighthouse/dialects/transform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lighthouse/dialects/transform/smt_ext/__init__.py b/lighthouse/dialects/transform/smt_ext/__init__.py new file mode 100644 index 0000000..b94c6cf --- /dev/null +++ b/lighthouse/dialects/transform/smt_ext/__init__.py @@ -0,0 +1,16 @@ +from .dialect import register_and_load +from .dialect import TransformSMTExtensionDialect +from .dialect import SMTIntValue +from .dialect import assert_ + +from .ops.constrain_params import constrain_params +from .ops.constrain_params import ConstrainParamsOp + +__all__ = [ + "ConstrainParamsOp", + "SMTIntValue", + "TransformSMTExtensionDialect", + "assert_", + "constrain_params", + "register_and_load", +] diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/transform/smt_ext/dialect.py similarity index 87% rename from lighthouse/dialects/smt_ext.py rename to lighthouse/dialects/transform/smt_ext/dialect.py index 2678a79..68cd0da 100644 --- a/lighthouse/dialects/smt_ext.py +++ b/lighthouse/dialects/transform/smt_ext/dialect.py @@ -2,16 +2,29 @@ from mlir import ir from mlir.dialects import smt - -__all__ = ["SMTIntValue", "assert_", "register_and_load"] +from mlir.dialects import ext def register_and_load(**kwargs): """Register and load the SMTIntValue caster.""" + TransformSMTExtensionDialect.load(**kwargs) SMTIntValue.register_value_caster() +class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"): + """A Transform Dialect extension for SMT-related operations.""" + + @classmethod + def load(cls, *args, **kwargs): + # Registers the dialect and its op classes and loads the dialect and ops into the context. + super().load(*args, **kwargs) + + for op in cls.operations: + if hasattr(op, "attach_interfaces"): + op.attach_interfaces() + + def assert_(predicate: ir.Value[smt.BoolType] | bool, error_message: str = ""): """Assert normally if a bool else produce an SMT assertion op.""" diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform/smt_ext/ops/constrain_params.py similarity index 91% rename from lighthouse/dialects/transform_smt_ext.py rename to lighthouse/dialects/transform/smt_ext/ops/constrain_params.py index 534e2d1..f32b033 100644 --- a/lighthouse/dialects/transform_smt_ext.py +++ b/lighthouse/dialects/transform/smt_ext/ops/constrain_params.py @@ -4,32 +4,7 @@ from mlir.dialects import ext, smt, transform from lighthouse.tune import trace - -__all__ = [ - "ConstrainParamsOp", - "TransformSMTExtensionDialect", - "constrain_params", - "register_and_load", -] - - -def register_and_load(**kwargs): - """Register and load the TransformSMTDialectExtension and its operations.""" - - TransformSMTExtensionDialect.load(**kwargs) - - -class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"): - """A Transform Dialect extension for SMT-related operations.""" - - @classmethod - def load(cls, *args, **kwargs): - # Registers the dialect and its op classes and loads the dialect and ops into the context. - super().load(*args, **kwargs) - - for op in cls.operations: - if hasattr(op, "attach_interfaces"): - op.attach_interfaces() +from ..dialect import TransformSMTExtensionDialect class ConstrainParamsOp( diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py new file mode 100644 index 0000000..123b940 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -0,0 +1,16 @@ +from .dialect import register_and_load +from .dialect import TransformExtensionDialect + +from .ops.wrap_in_benching import wrap_in_benching_func +from .ops.get_named_attribute import get_named_attribute +from .ops.param_cmp_eq import param_cmp_eq +from .ops.replace import replace + +__all__ = [ + "TransformExtensionDialect", + "get_named_attribute", + "param_cmp_eq", + "register_and_load", + "replace", + "wrap_in_benching_func", +] diff --git a/lighthouse/dialects/transform/transform_ext/dialect.py b/lighthouse/dialects/transform/transform_ext/dialect.py new file mode 100644 index 0000000..f5d8c23 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/dialect.py @@ -0,0 +1,17 @@ +from mlir.dialects import ext + + +def register_and_load(**kwargs): + TransformExtensionDialect.load(**kwargs) + + +class TransformExtensionDialect(ext.Dialect, name="transform_ext"): + @classmethod + def load(cls, *args, **kwargs): + # Registers the dialect and its op classes and loads the dialect and ops into the context. + super().load(*args, **kwargs) + + # Attach interfaces to just registered/loaded operations. + for op_cls in cls.operations: + if hasattr(op_cls, "attach_interface_impls"): + op_cls.attach_interface_impls() diff --git a/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py b/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py new file mode 100644 index 0000000..12f1740 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py @@ -0,0 +1,63 @@ +from mlir import ir +from mlir.dialects import ext, transform +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from ..dialect import TransformExtensionDialect + + +class GetNamedAttributeOp( + TransformExtensionDialect.Operation, name="get_named_attribute" +): + """ + Obtain a `target` op's associated attribute by `attr_name` as a `param`. + + In case `target` resolves to multiple ops, the operation returns a list of + attributes. If any of the resolved `target` ops does not have an attribute + with the name `attr_name`, the operation fails. + """ + + param: ext.Result[transform.AnyParamType[()]] + target: ext.Operand[transform.AnyOpType] + attr_name: ir.StringAttr + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GetNamedAttributeOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + target_ops = state.get_payload_ops(op.target) + associated_attrs = [] + for target_op in target_ops: + assoc_attr = target_op.attributes.get(op.attr_name.value) + if assoc_attr is None: + return DiagnosedSilenceableFailure.SilenceableFailure + associated_attrs.append(assoc_attr) + results.set_params(op.param, associated_attrs) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.only_reads_payload(effects) + + +def get_named_attribute( + target: ir.Value, attr_name: str | ir.StringAttr +) -> ir.Value[transform.AnyParamType]: + if not isinstance(attr_name, ir.StringAttr): + attr_name = ir.StringAttr.get(attr_name) + return GetNamedAttributeOp(target=target, attr_name=attr_name).param diff --git a/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py b/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py new file mode 100644 index 0000000..9b5d5f5 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py @@ -0,0 +1,53 @@ +from mlir import ir +from mlir.dialects import ext, transform + +from ..dialect import TransformExtensionDialect + + +class ParamCmpEqOp(TransformExtensionDialect.Operation, name="param_cmp_eq"): + """ + Compare the values of the `lhs` and `rhs` parameters for equality. + + The operation succeeds if the values are equal, and fails otherwise. If the + parameters resolve to multiple values, the operation succeeds if all values + are pairwise equal, and fails otherwise. + """ + + lhs: ext.Operand[transform.AnyParamType] + rhs: ext.Operand[transform.AnyParamType] + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "ParamCmpEqOp", + _rewriter: transform.TransformRewriter, + _results: transform.TransformResults, + state: transform.TransformState, + ) -> transform.DiagnosedSilenceableFailure: + lhs_params = state.get_params(op.lhs) + rhs_params = state.get_params(op.rhs) + if len(lhs_params) != len(rhs_params): + return transform.DiagnosedSilenceableFailure.SilenceableFailure + for lhs_param, rhs_param in zip(lhs_params, rhs_params): + if lhs_param != rhs_param: + return transform.DiagnosedSilenceableFailure.SilenceableFailure + return transform.DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ParamCmpEqOp") -> bool: + return True + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "ParamCmpEqOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.only_reads_payload(effects) + + +def param_cmp_eq(lhs: ir.Value, rhs: ir.Value): + return ParamCmpEqOp(lhs=lhs, rhs=rhs) diff --git a/lighthouse/dialects/transform/transform_ext/ops/replace.py b/lighthouse/dialects/transform/transform_ext/ops/replace.py new file mode 100644 index 0000000..633339f --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/replace.py @@ -0,0 +1,132 @@ +from typing import Sequence, Optional + +from mlir import ir +from mlir.dialects import ext, transform +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from ..dialect import TransformExtensionDialect + + +class ReplaceOp(TransformExtensionDialect.Operation, name="replace"): + """Replace the `target` operation(s) with a new `op_kind` operation. + + If `new_operands` are provided, they are used as operands for the new + operation(s); otherwise, the operands of the `target` operation(s) are + reused. The new op's result types are the same as those of the `target` op. + + NB: This op is mostly an escape hatch for testing and prototyping purposes. + No attempt is made to guarantee that the rewrite is semantics perserving. + """ + + new_op: ext.Result[transform.AnyOpType[()]] + target: ext.Operand[transform.AnyOpType] + op_kind: ir.StringAttr + new_operands: Sequence[ext.Operand[transform.AnyValueType]] + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "ReplaceOp", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + target_ops = state.get_payload_ops(op.target) + + # Resolve optional operand handles to payload values. + operand_values_per_handle = [] + for operand_handle in op.new_operands: + operand_values_per_handle.append( + state.get_payload_values(operand_handle) + ) + assert len(operand_values_per_handle[-1]) == len(target_ops), ( + "Expected number of operand values to match number of target ops" + ) + + new_op_name = op.op_kind.value + new_op_attrs = {} + if "new_attrs" in op.attributes: + new_attrs = op.attributes["new_attrs"] + assert isinstance(new_attrs, ir.DictAttr), ( + "Expected new_attrs to be a dictionary attribute" + ) + for named_attr in new_attrs: + new_op_attrs[named_attr.name] = named_attr.attr + + new_ops = [] + for target_idx, target_op in enumerate(target_ops): + if "new_result_types" in op.attributes: + tuple_type = op.attributes["new_result_types"].value + assert isinstance(tuple_type, ir.TupleType), ( + "Expected new_result_types to be a tuple of types" + ) + assert tuple_type.num_types == len(target_op.results), ( + "Expected number of new result types to match number of target op results" + ) + + new_result_types = [ + tuple_type.get_type(i) for i in range(tuple_type.num_types) + ] + else: + new_result_types = [ty.type for ty in target_op.results] + + if operand_values_per_handle: + new_operands = [ + vals[target_idx] for vals in operand_values_per_handle + ] + else: + new_operands = list(target_op.operands) + + with ir.InsertionPoint(target_op): + new_operation = ir.Operation.create( + new_op_name, + results=new_result_types, + operands=new_operands, + attributes=new_op_attrs, + ) + rewriter.replace_op(target_op, new_operation) + new_ops.append(new_operation) + + results.set_ops(op.new_op, new_ops) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ReplaceOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.consumes_handle(op.op_operands[:1], effects) + if new_operands_handles := op.op_operands[1:]: + transform.only_reads_handle(new_operands_handles, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def replace( + target: ir.Value, + op_kind: str | ir.StringAttr, + *new_operands: ir.Value, + new_result_types: Optional[ir.TupleType | Sequence[ir.Type]] = None, + new_attrs=None, +) -> ir.Value: + if not isinstance(op_kind, ir.StringAttr): + op_kind = ir.StringAttr.get(op_kind) + new_op = ReplaceOp(target, op_kind=op_kind, new_operands=new_operands).new_op + if new_result_types: + if not isinstance(new_result_types, ir.TupleType): + new_result_types = ir.TupleType.get_tuple(new_result_types) + new_op.owner.attributes["new_result_types"] = ir.TypeAttr.get(new_result_types) + if new_attrs: + if isinstance(new_attrs, dict): + new_attrs = ir.DictAttr.get(new_attrs) + else: + assert isinstance(new_attrs, ir.DictAttr) + new_op.owner.attributes["new_attrs"] = new_attrs + return new_op diff --git a/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py b/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py new file mode 100644 index 0000000..be6afe1 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py @@ -0,0 +1,117 @@ +from mlir import ir +from mlir.dialects import ext, transform, func, arith, scf, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.utils.mlir import func_cif +from ..dialect import TransformExtensionDialect + + +class WrapInBenchingFuncOp( + TransformExtensionDialect.Operation, name="wrap_in_benching_func" +): + """Create a function that calls `target` function in a benchmarking loop. + + The new function has the same arguments as `target` plus three additional ones: + - A memref to store the timing results (one element per iteration). + - The number of timed iterations. + - The number of warmup iterations. + """ + + target: ext.Operand[transform.AnyOpType] + bench_func: ext.Result[transform.AnyOpType[()]] + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + @staticmethod + def wrap_in_benching_func(target: func.FuncOp, bench_name: str): + """Create a function that calls `target` in a benchmarking loop. + + Each call to `target` is timed separately, and the times (in seconds) + are stored in a memref that is passed as an additional argument to the + benchmark function. It also takes two additional arguments for the + number of runs and warmup iterations. + """ + + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) + index_t = ir.IndexType.get() + args = target.type.inputs + [time_memref_t, index_t, index_t] + + @func_cif(*args, name=bench_name) + def bench(*args): + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + func_args = list(args[: len(target.type.inputs)]) + times_memref, num_times, num_warmup = args[-3:] + for i in scf.for_(zero, num_warmup, one): + # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. + func.CallOp(target, func_args) + scf.yield_(()) + # TODO: get `num_times` from the `times_memref`. + for i in scf.for_(zero, num_times, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(target, func_args) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, times_memref, [i]) + scf.yield_(()) + + return bench.func_op + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "WrapInBenchingFuncOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + targets = state.get_payload_ops(op.target) + if bench_name_attr := op.attributes.get("bench_name"): + bench_name = bench_name_attr.value + if len(targets) != 1: + return DiagnosedSilenceableFailure.SilenceableFailure + else: + bench_name = None + + bench_funcs = [] + for target in targets: + if not isinstance(target, func.FuncOp): + return DiagnosedSilenceableFailure.SilenceableFailure + + with ir.InsertionPoint(target), target.location: + bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( + target, bench_name or f"bench_{target.name.value}" + ) + bench_funcs.append(bench_func) + + results.set_ops(op.bench_func, bench_funcs) + + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "WrapInBenchingFuncOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def wrap_in_benching_func( + target: ir.Value[transform.AnyOpType], bench_name: str | None = None +) -> ir.Value[transform.AnyOpType]: + """snake_case wrapper to create a WrapInBenchingFuncOp.""" + op = WrapInBenchingFuncOp(target=target) + if bench_name is not None: + op.attributes["bench_name"] = ir.StringAttr.get(bench_name) + return op.bench_func diff --git a/lighthouse/dialects/transform/tune_ext/__init__.py b/lighthouse/dialects/transform/tune_ext/__init__.py new file mode 100644 index 0000000..0d0279b --- /dev/null +++ b/lighthouse/dialects/transform/tune_ext/__init__.py @@ -0,0 +1,9 @@ +from .dialect import register_and_load +from .dialect import KnobValue +from .dialect import knob + +__all__ = [ + "KnobValue", + "knob", + "register_and_load", +] diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform/tune_ext/dialect.py similarity index 99% rename from lighthouse/dialects/transform_tune_ext.py rename to lighthouse/dialects/transform/tune_ext/dialect.py index ff14f2a..29a687b 100644 --- a/lighthouse/dialects/transform_tune_ext.py +++ b/lighthouse/dialects/transform/tune_ext/dialect.py @@ -11,8 +11,6 @@ from mlir.dialects import transform from mlir.dialects.transform import tune as transform_tune -__all__ = ["KnobValue", "knob"] - def register_and_load(**kwargs): pass # NB: currently nothing to register or load. diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py deleted file mode 100644 index 24d834f..0000000 --- a/lighthouse/dialects/transform_ext.py +++ /dev/null @@ -1,366 +0,0 @@ -from typing import Sequence, Optional - -from mlir import ir -from mlir.dialects import ext, transform, func, arith, scf, memref -from mlir.dialects.transform import DiagnosedSilenceableFailure - -from lighthouse.utils.mlir import func_cif - - -def register_and_load(**kwargs): - TransformExtensionDialect.load(**kwargs) - - -class TransformExtensionDialect(ext.Dialect, name="transform_ext"): - @classmethod - def load(cls, *args, **kwargs): - # Registers the dialect and its op classes and loads the dialect and ops into the context. - super().load(*args, **kwargs) - - # Attach interfaces to just registered/loaded operations. - for op_cls in cls.operations: - if hasattr(op_cls, "attach_interface_impls"): - op_cls.attach_interface_impls() - - -class GetNamedAttributeOp( - TransformExtensionDialect.Operation, name="get_named_attribute" -): - """ - Obtain a `target` op's associated attribute by `attr_name` as a `param`. - - In case `target` resolves to multiple ops, the operation returns a list of - attributes. If any of the resolved `target` ops does not have an attribute - with the name `attr_name`, the operation fails. - """ - - param: ext.Result[transform.AnyParamType[()]] - target: ext.Operand[transform.AnyOpType] - attr_name: ir.StringAttr - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "GetNamedAttributeOp", - _rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - target_ops = state.get_payload_ops(op.target) - associated_attrs = [] - for target_op in target_ops: - assoc_attr = target_op.attributes.get(op.attr_name.value) - if assoc_attr is None: - return DiagnosedSilenceableFailure.SilenceableFailure - associated_attrs.append(assoc_attr) - results.set_params(op.param, associated_attrs) - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: ir.Operation, effects): - transform.only_reads_handle(op.op_operands, effects) - transform.produces_handle(op.results, effects) - transform.only_reads_payload(effects) - - -def get_named_attribute( - target: ir.Value, attr_name: str | ir.StringAttr -) -> ir.Value[transform.AnyParamType]: - if not isinstance(attr_name, ir.StringAttr): - attr_name = ir.StringAttr.get(attr_name) - return GetNamedAttributeOp(target=target, attr_name=attr_name).param - - -class ParamCmpEqOp(TransformExtensionDialect.Operation, name="param_cmp_eq"): - """ - Compare the values of the `lhs` and `rhs` parameters for equality. - - The operation succeeds if the values are equal, and fails otherwise. If the - parameters resolve to multiple values, the operation succeeds if all values - are pairwise equal, and fails otherwise. - """ - - lhs: ext.Operand[transform.AnyParamType] - rhs: ext.Operand[transform.AnyParamType] - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "ParamCmpEqOp", - _rewriter: transform.TransformRewriter, - _results: transform.TransformResults, - state: transform.TransformState, - ) -> transform.DiagnosedSilenceableFailure: - lhs_params = state.get_params(op.lhs) - rhs_params = state.get_params(op.rhs) - if len(lhs_params) != len(rhs_params): - return transform.DiagnosedSilenceableFailure.SilenceableFailure - for lhs_param, rhs_param in zip(lhs_params, rhs_params): - if lhs_param != rhs_param: - return transform.DiagnosedSilenceableFailure.SilenceableFailure - return transform.DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "ParamCmpEqOp") -> bool: - return True - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: "ParamCmpEqOp", effects): - transform.only_reads_handle(op.op_operands, effects) - transform.only_reads_payload(effects) - - -def param_cmp_eq(lhs: ir.Value, rhs: ir.Value): - return ParamCmpEqOp(lhs=lhs, rhs=rhs) - - -class ReplaceOp(TransformExtensionDialect.Operation, name="replace"): - """Replace the `target` operation(s) with a new `op_kind` operation. - - If `new_operands` are provided, they are used as operands for the new - operation(s); otherwise, the operands of the `target` operation(s) are - reused. The new op's result types are the same as those of the `target` op. - - NB: This op is mostly an escape hatch for testing and prototyping purposes. - No attempt is made to guarantee that the rewrite is semantics perserving. - """ - - new_op: ext.Result[transform.AnyOpType[()]] - target: ext.Operand[transform.AnyOpType] - op_kind: ir.StringAttr - new_operands: Sequence[ext.Operand[transform.AnyValueType]] - - @classmethod - def attach_interface_impls(cls, ctx=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "ReplaceOp", - rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - target_ops = state.get_payload_ops(op.target) - - # Resolve optional operand handles to payload values. - operand_values_per_handle = [] - for operand_handle in op.new_operands: - operand_values_per_handle.append( - state.get_payload_values(operand_handle) - ) - assert len(operand_values_per_handle[-1]) == len(target_ops), ( - "Expected number of operand values to match number of target ops" - ) - - new_op_name = op.op_kind.value - new_op_attrs = {} - if "new_attrs" in op.attributes: - new_attrs = op.attributes["new_attrs"] - assert isinstance(new_attrs, ir.DictAttr), ( - "Expected new_attrs to be a dictionary attribute" - ) - for named_attr in new_attrs: - new_op_attrs[named_attr.name] = named_attr.attr - - new_ops = [] - for target_idx, target_op in enumerate(target_ops): - if "new_result_types" in op.attributes: - tuple_type = op.attributes["new_result_types"].value - assert isinstance(tuple_type, ir.TupleType), ( - "Expected new_result_types to be a tuple of types" - ) - assert tuple_type.num_types == len(target_op.results), ( - "Expected number of new result types to match number of target op results" - ) - - new_result_types = [ - tuple_type.get_type(i) for i in range(tuple_type.num_types) - ] - else: - new_result_types = [ty.type for ty in target_op.results] - - if operand_values_per_handle: - new_operands = [ - vals[target_idx] for vals in operand_values_per_handle - ] - else: - new_operands = list(target_op.operands) - - with ir.InsertionPoint(target_op): - new_operation = ir.Operation.create( - new_op_name, - results=new_result_types, - operands=new_operands, - attributes=new_op_attrs, - ) - rewriter.replace_op(target_op, new_operation) - new_ops.append(new_operation) - - results.set_ops(op.new_op, new_ops) - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "ReplaceOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: ir.Operation, effects): - transform.consumes_handle(op.op_operands[:1], effects) - if new_operands_handles := op.op_operands[1:]: - transform.only_reads_handle(new_operands_handles, effects) - transform.produces_handle(op.results, effects) - transform.modifies_payload(effects) - - -def replace( - target: ir.Value, - op_kind: str | ir.StringAttr, - *new_operands: ir.Value, - new_result_types: Optional[ir.TupleType | Sequence[ir.Type]] = None, - new_attrs=None, -) -> ir.Value: - if not isinstance(op_kind, ir.StringAttr): - op_kind = ir.StringAttr.get(op_kind) - new_op = ReplaceOp(target, op_kind=op_kind, new_operands=new_operands).new_op - if new_result_types: - if not isinstance(new_result_types, ir.TupleType): - new_result_types = ir.TupleType.get_tuple(new_result_types) - new_op.owner.attributes["new_result_types"] = ir.TypeAttr.get(new_result_types) - if new_attrs: - if isinstance(new_attrs, dict): - new_attrs = ir.DictAttr.get(new_attrs) - else: - assert isinstance(new_attrs, ir.DictAttr) - new_op.owner.attributes["new_attrs"] = new_attrs - return new_op - - -class WrapInBenchingFuncOp( - TransformExtensionDialect.Operation, name="wrap_in_benching_func" -): - """Create a function that calls `target` function in a benchmarking loop. - - The new function has the same arguments as `target` plus three additional ones: - - A memref to store the timing results (one element per iteration). - - The number of timed iterations. - - The number of warmup iterations. - """ - - target: ext.Operand[transform.AnyOpType] - bench_func: ext.Result[transform.AnyOpType[()]] - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - @staticmethod - def wrap_in_benching_func(target: func.FuncOp, bench_name: str): - """Create a function that calls `target` in a benchmarking loop. - - Each call to `target` is timed separately, and the times (in seconds) - are stored in a memref that is passed as an additional argument to the - benchmark function. It also takes two additional arguments for the - number of runs and warmup iterations. - """ - - # define rtclock function - f64_t = ir.F64Type.get() - func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) - index_t = ir.IndexType.get() - args = target.type.inputs + [time_memref_t, index_t, index_t] - - @func_cif(*args, name=bench_name) - def bench(*args): - zero = arith.constant(index_t, 0) - one = arith.constant(index_t, 1) - func_args = list(args[: len(target.type.inputs)]) - times_memref, num_times, num_warmup = args[-3:] - for i in scf.for_(zero, num_warmup, one): - # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. - func.CallOp(target, func_args) - scf.yield_(()) - # TODO: get `num_times` from the `times_memref`. - for i in scf.for_(zero, num_times, one): - tic = func.call((f64_t,), "rtclock", ()) - func.CallOp(target, func_args) - toc = func.call((f64_t,), "rtclock", ()) - time = arith.subf(toc, tic) - memref.store(time, times_memref, [i]) - scf.yield_(()) - - return bench.func_op - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "WrapInBenchingFuncOp", - _rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - targets = state.get_payload_ops(op.target) - if bench_name_attr := op.attributes.get("bench_name"): - bench_name = bench_name_attr.value - if len(targets) != 1: - return DiagnosedSilenceableFailure.SilenceableFailure - else: - bench_name = None - - bench_funcs = [] - for target in targets: - if not isinstance(target, func.FuncOp): - return DiagnosedSilenceableFailure.SilenceableFailure - - with ir.InsertionPoint(target), target.location: - bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( - target, bench_name or f"bench_{target.name.value}" - ) - bench_funcs.append(bench_func) - - results.set_ops(op.bench_func, bench_funcs) - - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: "WrapInBenchingFuncOp", effects): - transform.only_reads_handle(op.op_operands, effects) - transform.produces_handle(op.results, effects) - transform.modifies_payload(effects) - - -def wrap_in_benching_func( - target: ir.Value[transform.AnyOpType], bench_name: str | None = None -) -> ir.Value[transform.AnyOpType]: - """snake_case wrapper to create a WrapInBenchingFuncOp.""" - op = WrapInBenchingFuncOp(target=target) - if bench_name is not None: - op.attributes["bench_name"] = ir.StringAttr.get(bench_name) - return op.bench_func diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index e9fa790..f9b18ff 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -13,8 +13,8 @@ match, ) -from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext -from lighthouse.dialects.transform_tune_ext import knob, KnobValue +from lighthouse.dialects.transform import smt_ext +from lighthouse.dialects.transform.tune_ext import knob, KnobValue # hardware constraints DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( @@ -319,7 +319,7 @@ def bundle_xegpu_mlp_schedule( wg_m, wg_n = layer_params["wg_m"], layer_params["wg_n"] sg_m, sg_n = layer_params["sg_m"], layer_params["sg_n"] - @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + @smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) def constrain_wg_sg_and_calc_nb_threads( WG_M: int | smt_ext.SMTIntValue, WG_N: int | smt_ext.SMTIntValue, @@ -414,7 +414,7 @@ def xegpu_wg_annotation_for_mlp_layer( anyvalue = transform.AnyValueType.get() # Calculate with SMT ops in case of symbolic values, normal ints in case of concrete values. - @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + @smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): # NB: Constraint on overall num SG threads already dealt with elsewhere. return WG_M // SG_M, WG_N // SG_N @@ -426,7 +426,7 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): prefetch_tile_a = [prefetch_a_m, prefetch_a_k] prefetch_tile_b = [prefetch_b_k, prefetch_b_n] - @td_smt_ext.constrain_params( + @smt_ext.constrain_params( wg_m, wg_n, sg_m, diff --git a/lighthouse/tune/trace.py b/lighthouse/tune/trace.py index 57df76c..1ec4d05 100644 --- a/lighthouse/tune/trace.py +++ b/lighthouse/tune/trace.py @@ -11,7 +11,7 @@ from mlir.dialects import transform, smt from mlir.dialects.transform import tune as transform_tune -from lighthouse.dialects import transform_smt_ext +from lighthouse.dialects.transform import smt_ext class Node(ABC): @@ -275,7 +275,7 @@ def trace_tune_and_smt_ops(op: ir.Operation, env: Optional[dict] = None) -> dict region_idx_to_result=result_idx_region_idx_to_node[res_idx], ) - case transform_smt_ext.ConstrainParamsOp(): + case smt_ext.ConstrainParamsOp(): # Map the block args in the op's region to the nodes already # associated to the corresponding arguments on the op itself. for operand, block_arg in zip(op.operands, op.body.arguments): diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index c3d3017..6739b2a 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -10,7 +10,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from lighthouse.dialects import transform_ext +from lighthouse.dialects.transform import transform_ext from lighthouse.schedule import schedule_boilerplate from lighthouse.utils.memref import to_packed_args from lighthouse.utils.mlir import get_mlir_library_path diff --git a/test/dialects/transform_ext.py b/test/dialects/transform_ext.py index 7d14925..c392050 100644 --- a/test/dialects/transform_ext.py +++ b/test/dialects/transform_ext.py @@ -6,7 +6,7 @@ from mlir.dialects import transform, func, arith, index from mlir.dialects.transform import structured as transform_structured import lighthouse.dialects as lh_dialects -from lighthouse.dialects import transform_ext +from lighthouse.dialects.transform import transform_ext def run(f): @@ -70,13 +70,13 @@ def test_param_cmp_eq_op(payload_handle): c42_attr = ir.IntegerAttr.get(ir.IndexType.get(), 42) c42_as_param = transform.param_constant(transform.AnyParamType.get(), c42_attr) - transform_ext.ParamCmpEqOp(constant_value_param, c42_as_param) + transform_ext.param_cmp_eq(constant_value_param, c42_as_param) c0_attr = ir.IntegerAttr.get(ir.IndexType.get(), 0) # CHECK: got here transform.print_(name="got here") # Comparing 42 against 0 — should fail and abort the sequence. c0_as_param = transform.param_constant(transform.AnyParamType.get(), c0_attr) - transform_ext.ParamCmpEqOp(constant_value_param, c0_as_param) + transform_ext.param_cmp_eq(constant_value_param, c0_as_param) # CHECK-NOT: but not here # CHECK: Caught exception: Failed to apply named transform sequence transform.print_(name="but not here")