diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py index c693b9f..96da25a 100644 --- a/lighthouse/dialects/__init__.py +++ b/lighthouse/dialects/__init__.py @@ -1,10 +1,8 @@ def register_and_load(**kwargs): - from . import smt_ext - from . import transform_ext - from . import transform_smt_ext - from . import transform_tune_ext + from .transform import transform_ext + from .transform import smt_ext + from .transform import tune_ext - smt_ext.register_and_load(**kwargs) transform_ext.register_and_load(**kwargs) - transform_smt_ext.register_and_load(**kwargs) - transform_tune_ext.register_and_load(**kwargs) + smt_ext.register_and_load(**kwargs) + tune_ext.register_and_load(**kwargs) diff --git a/lighthouse/dialects/transform/__init__.py b/lighthouse/dialects/transform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lighthouse/dialects/transform/smt_ext/__init__.py b/lighthouse/dialects/transform/smt_ext/__init__.py new file mode 100644 index 0000000..b94c6cf --- /dev/null +++ b/lighthouse/dialects/transform/smt_ext/__init__.py @@ -0,0 +1,16 @@ +from .dialect import register_and_load +from .dialect import TransformSMTExtensionDialect +from .dialect import SMTIntValue +from .dialect import assert_ + +from .ops.constrain_params import constrain_params +from .ops.constrain_params import ConstrainParamsOp + +__all__ = [ + "ConstrainParamsOp", + "SMTIntValue", + "TransformSMTExtensionDialect", + "assert_", + "constrain_params", + "register_and_load", +] diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/transform/smt_ext/dialect.py similarity index 87% rename from lighthouse/dialects/smt_ext.py rename to lighthouse/dialects/transform/smt_ext/dialect.py index 2678a79..68cd0da 100644 --- a/lighthouse/dialects/smt_ext.py +++ b/lighthouse/dialects/transform/smt_ext/dialect.py @@ -2,16 +2,29 @@ from mlir import ir from mlir.dialects import smt - -__all__ = ["SMTIntValue", "assert_", "register_and_load"] +from mlir.dialects import ext def register_and_load(**kwargs): """Register and load the SMTIntValue caster.""" + TransformSMTExtensionDialect.load(**kwargs) SMTIntValue.register_value_caster() +class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"): + """A Transform Dialect extension for SMT-related operations.""" + + @classmethod + def load(cls, *args, **kwargs): + # Registers the dialect and its op classes and loads the dialect and ops into the context. + super().load(*args, **kwargs) + + for op in cls.operations: + if hasattr(op, "attach_interfaces"): + op.attach_interfaces() + + def assert_(predicate: ir.Value[smt.BoolType] | bool, error_message: str = ""): """Assert normally if a bool else produce an SMT assertion op.""" diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform/smt_ext/ops/constrain_params.py similarity index 91% rename from lighthouse/dialects/transform_smt_ext.py rename to lighthouse/dialects/transform/smt_ext/ops/constrain_params.py index 534e2d1..f32b033 100644 --- a/lighthouse/dialects/transform_smt_ext.py +++ b/lighthouse/dialects/transform/smt_ext/ops/constrain_params.py @@ -4,32 +4,7 @@ from mlir.dialects import ext, smt, transform from lighthouse.tune import trace - -__all__ = [ - "ConstrainParamsOp", - "TransformSMTExtensionDialect", - "constrain_params", - "register_and_load", -] - - -def register_and_load(**kwargs): - """Register and load the TransformSMTDialectExtension and its operations.""" - - TransformSMTExtensionDialect.load(**kwargs) - - -class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"): - """A Transform Dialect extension for SMT-related operations.""" - - @classmethod - def load(cls, *args, **kwargs): - # Registers the dialect and its op classes and loads the dialect and ops into the context. - super().load(*args, **kwargs) - - for op in cls.operations: - if hasattr(op, "attach_interfaces"): - op.attach_interfaces() +from ..dialect import TransformSMTExtensionDialect class ConstrainParamsOp( diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py new file mode 100644 index 0000000..123b940 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -0,0 +1,16 @@ +from .dialect import register_and_load +from .dialect import TransformExtensionDialect + +from .ops.wrap_in_benching import wrap_in_benching_func +from .ops.get_named_attribute import get_named_attribute +from .ops.param_cmp_eq import param_cmp_eq +from .ops.replace import replace + +__all__ = [ + "TransformExtensionDialect", + "get_named_attribute", + "param_cmp_eq", + "register_and_load", + "replace", + "wrap_in_benching_func", +] diff --git a/lighthouse/dialects/transform/transform_ext/dialect.py b/lighthouse/dialects/transform/transform_ext/dialect.py new file mode 100644 index 0000000..f5d8c23 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/dialect.py @@ -0,0 +1,17 @@ +from mlir.dialects import ext + + +def register_and_load(**kwargs): + TransformExtensionDialect.load(**kwargs) + + +class TransformExtensionDialect(ext.Dialect, name="transform_ext"): + @classmethod + def load(cls, *args, **kwargs): + # Registers the dialect and its op classes and loads the dialect and ops into the context. + super().load(*args, **kwargs) + + # Attach interfaces to just registered/loaded operations. + for op_cls in cls.operations: + if hasattr(op_cls, "attach_interface_impls"): + op_cls.attach_interface_impls() diff --git a/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py b/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py new file mode 100644 index 0000000..12f1740 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py @@ -0,0 +1,63 @@ +from mlir import ir +from mlir.dialects import ext, transform +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from ..dialect import TransformExtensionDialect + + +class GetNamedAttributeOp( + TransformExtensionDialect.Operation, name="get_named_attribute" +): + """ + Obtain a `target` op's associated attribute by `attr_name` as a `param`. + + In case `target` resolves to multiple ops, the operation returns a list of + attributes. If any of the resolved `target` ops does not have an attribute + with the name `attr_name`, the operation fails. + """ + + param: ext.Result[transform.AnyParamType[()]] + target: ext.Operand[transform.AnyOpType] + attr_name: ir.StringAttr + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GetNamedAttributeOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + target_ops = state.get_payload_ops(op.target) + associated_attrs = [] + for target_op in target_ops: + assoc_attr = target_op.attributes.get(op.attr_name.value) + if assoc_attr is None: + return DiagnosedSilenceableFailure.SilenceableFailure + associated_attrs.append(assoc_attr) + results.set_params(op.param, associated_attrs) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.only_reads_payload(effects) + + +def get_named_attribute( + target: ir.Value, attr_name: str | ir.StringAttr +) -> ir.Value[transform.AnyParamType]: + if not isinstance(attr_name, ir.StringAttr): + attr_name = ir.StringAttr.get(attr_name) + return GetNamedAttributeOp(target=target, attr_name=attr_name).param diff --git a/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py b/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py new file mode 100644 index 0000000..9b5d5f5 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py @@ -0,0 +1,53 @@ +from mlir import ir +from mlir.dialects import ext, transform + +from ..dialect import TransformExtensionDialect + + +class ParamCmpEqOp(TransformExtensionDialect.Operation, name="param_cmp_eq"): + """ + Compare the values of the `lhs` and `rhs` parameters for equality. + + The operation succeeds if the values are equal, and fails otherwise. If the + parameters resolve to multiple values, the operation succeeds if all values + are pairwise equal, and fails otherwise. + """ + + lhs: ext.Operand[transform.AnyParamType] + rhs: ext.Operand[transform.AnyParamType] + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "ParamCmpEqOp", + _rewriter: transform.TransformRewriter, + _results: transform.TransformResults, + state: transform.TransformState, + ) -> transform.DiagnosedSilenceableFailure: + lhs_params = state.get_params(op.lhs) + rhs_params = state.get_params(op.rhs) + if len(lhs_params) != len(rhs_params): + return transform.DiagnosedSilenceableFailure.SilenceableFailure + for lhs_param, rhs_param in zip(lhs_params, rhs_params): + if lhs_param != rhs_param: + return transform.DiagnosedSilenceableFailure.SilenceableFailure + return transform.DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ParamCmpEqOp") -> bool: + return True + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "ParamCmpEqOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.only_reads_payload(effects) + + +def param_cmp_eq(lhs: ir.Value, rhs: ir.Value): + return ParamCmpEqOp(lhs=lhs, rhs=rhs) diff --git a/lighthouse/dialects/transform/transform_ext/ops/replace.py b/lighthouse/dialects/transform/transform_ext/ops/replace.py new file mode 100644 index 0000000..633339f --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/replace.py @@ -0,0 +1,132 @@ +from typing import Sequence, Optional + +from mlir import ir +from mlir.dialects import ext, transform +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from ..dialect import TransformExtensionDialect + + +class ReplaceOp(TransformExtensionDialect.Operation, name="replace"): + """Replace the `target` operation(s) with a new `op_kind` operation. + + If `new_operands` are provided, they are used as operands for the new + operation(s); otherwise, the operands of the `target` operation(s) are + reused. The new op's result types are the same as those of the `target` op. + + NB: This op is mostly an escape hatch for testing and prototyping purposes. + No attempt is made to guarantee that the rewrite is semantics perserving. + """ + + new_op: ext.Result[transform.AnyOpType[()]] + target: ext.Operand[transform.AnyOpType] + op_kind: ir.StringAttr + new_operands: Sequence[ext.Operand[transform.AnyValueType]] + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "ReplaceOp", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + target_ops = state.get_payload_ops(op.target) + + # Resolve optional operand handles to payload values. + operand_values_per_handle = [] + for operand_handle in op.new_operands: + operand_values_per_handle.append( + state.get_payload_values(operand_handle) + ) + assert len(operand_values_per_handle[-1]) == len(target_ops), ( + "Expected number of operand values to match number of target ops" + ) + + new_op_name = op.op_kind.value + new_op_attrs = {} + if "new_attrs" in op.attributes: + new_attrs = op.attributes["new_attrs"] + assert isinstance(new_attrs, ir.DictAttr), ( + "Expected new_attrs to be a dictionary attribute" + ) + for named_attr in new_attrs: + new_op_attrs[named_attr.name] = named_attr.attr + + new_ops = [] + for target_idx, target_op in enumerate(target_ops): + if "new_result_types" in op.attributes: + tuple_type = op.attributes["new_result_types"].value + assert isinstance(tuple_type, ir.TupleType), ( + "Expected new_result_types to be a tuple of types" + ) + assert tuple_type.num_types == len(target_op.results), ( + "Expected number of new result types to match number of target op results" + ) + + new_result_types = [ + tuple_type.get_type(i) for i in range(tuple_type.num_types) + ] + else: + new_result_types = [ty.type for ty in target_op.results] + + if operand_values_per_handle: + new_operands = [ + vals[target_idx] for vals in operand_values_per_handle + ] + else: + new_operands = list(target_op.operands) + + with ir.InsertionPoint(target_op): + new_operation = ir.Operation.create( + new_op_name, + results=new_result_types, + operands=new_operands, + attributes=new_op_attrs, + ) + rewriter.replace_op(target_op, new_operation) + new_ops.append(new_operation) + + results.set_ops(op.new_op, new_ops) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ReplaceOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.consumes_handle(op.op_operands[:1], effects) + if new_operands_handles := op.op_operands[1:]: + transform.only_reads_handle(new_operands_handles, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def replace( + target: ir.Value, + op_kind: str | ir.StringAttr, + *new_operands: ir.Value, + new_result_types: Optional[ir.TupleType | Sequence[ir.Type]] = None, + new_attrs=None, +) -> ir.Value: + if not isinstance(op_kind, ir.StringAttr): + op_kind = ir.StringAttr.get(op_kind) + new_op = ReplaceOp(target, op_kind=op_kind, new_operands=new_operands).new_op + if new_result_types: + if not isinstance(new_result_types, ir.TupleType): + new_result_types = ir.TupleType.get_tuple(new_result_types) + new_op.owner.attributes["new_result_types"] = ir.TypeAttr.get(new_result_types) + if new_attrs: + if isinstance(new_attrs, dict): + new_attrs = ir.DictAttr.get(new_attrs) + else: + assert isinstance(new_attrs, ir.DictAttr) + new_op.owner.attributes["new_attrs"] = new_attrs + return new_op diff --git a/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py b/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py new file mode 100644 index 0000000..be6afe1 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/wrap_in_benching.py @@ -0,0 +1,117 @@ +from mlir import ir +from mlir.dialects import ext, transform, func, arith, scf, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.utils.mlir import func_cif +from ..dialect import TransformExtensionDialect + + +class WrapInBenchingFuncOp( + TransformExtensionDialect.Operation, name="wrap_in_benching_func" +): + """Create a function that calls `target` function in a benchmarking loop. + + The new function has the same arguments as `target` plus three additional ones: + - A memref to store the timing results (one element per iteration). + - The number of timed iterations. + - The number of warmup iterations. + """ + + target: ext.Operand[transform.AnyOpType] + bench_func: ext.Result[transform.AnyOpType[()]] + + @classmethod + def attach_interface_impls(cls, context=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) + + @staticmethod + def wrap_in_benching_func(target: func.FuncOp, bench_name: str): + """Create a function that calls `target` in a benchmarking loop. + + Each call to `target` is timed separately, and the times (in seconds) + are stored in a memref that is passed as an additional argument to the + benchmark function. It also takes two additional arguments for the + number of runs and warmup iterations. + """ + + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) + index_t = ir.IndexType.get() + args = target.type.inputs + [time_memref_t, index_t, index_t] + + @func_cif(*args, name=bench_name) + def bench(*args): + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + func_args = list(args[: len(target.type.inputs)]) + times_memref, num_times, num_warmup = args[-3:] + for i in scf.for_(zero, num_warmup, one): + # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. + func.CallOp(target, func_args) + scf.yield_(()) + # TODO: get `num_times` from the `times_memref`. + for i in scf.for_(zero, num_times, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(target, func_args) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, times_memref, [i]) + scf.yield_(()) + + return bench.func_op + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "WrapInBenchingFuncOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + targets = state.get_payload_ops(op.target) + if bench_name_attr := op.attributes.get("bench_name"): + bench_name = bench_name_attr.value + if len(targets) != 1: + return DiagnosedSilenceableFailure.SilenceableFailure + else: + bench_name = None + + bench_funcs = [] + for target in targets: + if not isinstance(target, func.FuncOp): + return DiagnosedSilenceableFailure.SilenceableFailure + + with ir.InsertionPoint(target), target.location: + bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( + target, bench_name or f"bench_{target.name.value}" + ) + bench_funcs.append(bench_func) + + results.set_ops(op.bench_func, bench_funcs) + + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "WrapInBenchingFuncOp", effects): + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def wrap_in_benching_func( + target: ir.Value[transform.AnyOpType], bench_name: str | None = None +) -> ir.Value[transform.AnyOpType]: + """snake_case wrapper to create a WrapInBenchingFuncOp.""" + op = WrapInBenchingFuncOp(target=target) + if bench_name is not None: + op.attributes["bench_name"] = ir.StringAttr.get(bench_name) + return op.bench_func diff --git a/lighthouse/dialects/transform/tune_ext/__init__.py b/lighthouse/dialects/transform/tune_ext/__init__.py new file mode 100644 index 0000000..0d0279b --- /dev/null +++ b/lighthouse/dialects/transform/tune_ext/__init__.py @@ -0,0 +1,9 @@ +from .dialect import register_and_load +from .dialect import KnobValue +from .dialect import knob + +__all__ = [ + "KnobValue", + "knob", + "register_and_load", +] diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform/tune_ext/dialect.py similarity index 99% rename from lighthouse/dialects/transform_tune_ext.py rename to lighthouse/dialects/transform/tune_ext/dialect.py index ff14f2a..29a687b 100644 --- a/lighthouse/dialects/transform_tune_ext.py +++ b/lighthouse/dialects/transform/tune_ext/dialect.py @@ -11,8 +11,6 @@ from mlir.dialects import transform from mlir.dialects.transform import tune as transform_tune -__all__ = ["KnobValue", "knob"] - def register_and_load(**kwargs): pass # NB: currently nothing to register or load. diff --git a/lighthouse/dialects/transform_ext.py b/lighthouse/dialects/transform_ext.py deleted file mode 100644 index 24d834f..0000000 --- a/lighthouse/dialects/transform_ext.py +++ /dev/null @@ -1,366 +0,0 @@ -from typing import Sequence, Optional - -from mlir import ir -from mlir.dialects import ext, transform, func, arith, scf, memref -from mlir.dialects.transform import DiagnosedSilenceableFailure - -from lighthouse.utils.mlir import func_cif - - -def register_and_load(**kwargs): - TransformExtensionDialect.load(**kwargs) - - -class TransformExtensionDialect(ext.Dialect, name="transform_ext"): - @classmethod - def load(cls, *args, **kwargs): - # Registers the dialect and its op classes and loads the dialect and ops into the context. - super().load(*args, **kwargs) - - # Attach interfaces to just registered/loaded operations. - for op_cls in cls.operations: - if hasattr(op_cls, "attach_interface_impls"): - op_cls.attach_interface_impls() - - -class GetNamedAttributeOp( - TransformExtensionDialect.Operation, name="get_named_attribute" -): - """ - Obtain a `target` op's associated attribute by `attr_name` as a `param`. - - In case `target` resolves to multiple ops, the operation returns a list of - attributes. If any of the resolved `target` ops does not have an attribute - with the name `attr_name`, the operation fails. - """ - - param: ext.Result[transform.AnyParamType[()]] - target: ext.Operand[transform.AnyOpType] - attr_name: ir.StringAttr - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "GetNamedAttributeOp", - _rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - target_ops = state.get_payload_ops(op.target) - associated_attrs = [] - for target_op in target_ops: - assoc_attr = target_op.attributes.get(op.attr_name.value) - if assoc_attr is None: - return DiagnosedSilenceableFailure.SilenceableFailure - associated_attrs.append(assoc_attr) - results.set_params(op.param, associated_attrs) - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: ir.Operation, effects): - transform.only_reads_handle(op.op_operands, effects) - transform.produces_handle(op.results, effects) - transform.only_reads_payload(effects) - - -def get_named_attribute( - target: ir.Value, attr_name: str | ir.StringAttr -) -> ir.Value[transform.AnyParamType]: - if not isinstance(attr_name, ir.StringAttr): - attr_name = ir.StringAttr.get(attr_name) - return GetNamedAttributeOp(target=target, attr_name=attr_name).param - - -class ParamCmpEqOp(TransformExtensionDialect.Operation, name="param_cmp_eq"): - """ - Compare the values of the `lhs` and `rhs` parameters for equality. - - The operation succeeds if the values are equal, and fails otherwise. If the - parameters resolve to multiple values, the operation succeeds if all values - are pairwise equal, and fails otherwise. - """ - - lhs: ext.Operand[transform.AnyParamType] - rhs: ext.Operand[transform.AnyParamType] - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "ParamCmpEqOp", - _rewriter: transform.TransformRewriter, - _results: transform.TransformResults, - state: transform.TransformState, - ) -> transform.DiagnosedSilenceableFailure: - lhs_params = state.get_params(op.lhs) - rhs_params = state.get_params(op.rhs) - if len(lhs_params) != len(rhs_params): - return transform.DiagnosedSilenceableFailure.SilenceableFailure - for lhs_param, rhs_param in zip(lhs_params, rhs_params): - if lhs_param != rhs_param: - return transform.DiagnosedSilenceableFailure.SilenceableFailure - return transform.DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "ParamCmpEqOp") -> bool: - return True - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: "ParamCmpEqOp", effects): - transform.only_reads_handle(op.op_operands, effects) - transform.only_reads_payload(effects) - - -def param_cmp_eq(lhs: ir.Value, rhs: ir.Value): - return ParamCmpEqOp(lhs=lhs, rhs=rhs) - - -class ReplaceOp(TransformExtensionDialect.Operation, name="replace"): - """Replace the `target` operation(s) with a new `op_kind` operation. - - If `new_operands` are provided, they are used as operands for the new - operation(s); otherwise, the operands of the `target` operation(s) are - reused. The new op's result types are the same as those of the `target` op. - - NB: This op is mostly an escape hatch for testing and prototyping purposes. - No attempt is made to guarantee that the rewrite is semantics perserving. - """ - - new_op: ext.Result[transform.AnyOpType[()]] - target: ext.Operand[transform.AnyOpType] - op_kind: ir.StringAttr - new_operands: Sequence[ext.Operand[transform.AnyValueType]] - - @classmethod - def attach_interface_impls(cls, ctx=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "ReplaceOp", - rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - target_ops = state.get_payload_ops(op.target) - - # Resolve optional operand handles to payload values. - operand_values_per_handle = [] - for operand_handle in op.new_operands: - operand_values_per_handle.append( - state.get_payload_values(operand_handle) - ) - assert len(operand_values_per_handle[-1]) == len(target_ops), ( - "Expected number of operand values to match number of target ops" - ) - - new_op_name = op.op_kind.value - new_op_attrs = {} - if "new_attrs" in op.attributes: - new_attrs = op.attributes["new_attrs"] - assert isinstance(new_attrs, ir.DictAttr), ( - "Expected new_attrs to be a dictionary attribute" - ) - for named_attr in new_attrs: - new_op_attrs[named_attr.name] = named_attr.attr - - new_ops = [] - for target_idx, target_op in enumerate(target_ops): - if "new_result_types" in op.attributes: - tuple_type = op.attributes["new_result_types"].value - assert isinstance(tuple_type, ir.TupleType), ( - "Expected new_result_types to be a tuple of types" - ) - assert tuple_type.num_types == len(target_op.results), ( - "Expected number of new result types to match number of target op results" - ) - - new_result_types = [ - tuple_type.get_type(i) for i in range(tuple_type.num_types) - ] - else: - new_result_types = [ty.type for ty in target_op.results] - - if operand_values_per_handle: - new_operands = [ - vals[target_idx] for vals in operand_values_per_handle - ] - else: - new_operands = list(target_op.operands) - - with ir.InsertionPoint(target_op): - new_operation = ir.Operation.create( - new_op_name, - results=new_result_types, - operands=new_operands, - attributes=new_op_attrs, - ) - rewriter.replace_op(target_op, new_operation) - new_ops.append(new_operation) - - results.set_ops(op.new_op, new_ops) - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "ReplaceOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: ir.Operation, effects): - transform.consumes_handle(op.op_operands[:1], effects) - if new_operands_handles := op.op_operands[1:]: - transform.only_reads_handle(new_operands_handles, effects) - transform.produces_handle(op.results, effects) - transform.modifies_payload(effects) - - -def replace( - target: ir.Value, - op_kind: str | ir.StringAttr, - *new_operands: ir.Value, - new_result_types: Optional[ir.TupleType | Sequence[ir.Type]] = None, - new_attrs=None, -) -> ir.Value: - if not isinstance(op_kind, ir.StringAttr): - op_kind = ir.StringAttr.get(op_kind) - new_op = ReplaceOp(target, op_kind=op_kind, new_operands=new_operands).new_op - if new_result_types: - if not isinstance(new_result_types, ir.TupleType): - new_result_types = ir.TupleType.get_tuple(new_result_types) - new_op.owner.attributes["new_result_types"] = ir.TypeAttr.get(new_result_types) - if new_attrs: - if isinstance(new_attrs, dict): - new_attrs = ir.DictAttr.get(new_attrs) - else: - assert isinstance(new_attrs, ir.DictAttr) - new_op.owner.attributes["new_attrs"] = new_attrs - return new_op - - -class WrapInBenchingFuncOp( - TransformExtensionDialect.Operation, name="wrap_in_benching_func" -): - """Create a function that calls `target` function in a benchmarking loop. - - The new function has the same arguments as `target` plus three additional ones: - - A memref to store the timing results (one element per iteration). - - The number of timed iterations. - - The number of warmup iterations. - """ - - target: ext.Operand[transform.AnyOpType] - bench_func: ext.Result[transform.AnyOpType[()]] - - @classmethod - def attach_interface_impls(cls, context=None): - cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context) - - @staticmethod - def wrap_in_benching_func(target: func.FuncOp, bench_name: str): - """Create a function that calls `target` in a benchmarking loop. - - Each call to `target` is timed separately, and the times (in seconds) - are stored in a memref that is passed as an additional argument to the - benchmark function. It also takes two additional arguments for the - number of runs and warmup iterations. - """ - - # define rtclock function - f64_t = ir.F64Type.get() - func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t) - index_t = ir.IndexType.get() - args = target.type.inputs + [time_memref_t, index_t, index_t] - - @func_cif(*args, name=bench_name) - def bench(*args): - zero = arith.constant(index_t, 0) - one = arith.constant(index_t, 1) - func_args = list(args[: len(target.type.inputs)]) - times_memref, num_times, num_warmup = args[-3:] - for i in scf.for_(zero, num_warmup, one): - # FIXME(upstream): func.call needs to wrap _overridden_ CallOp. - func.CallOp(target, func_args) - scf.yield_(()) - # TODO: get `num_times` from the `times_memref`. - for i in scf.for_(zero, num_times, one): - tic = func.call((f64_t,), "rtclock", ()) - func.CallOp(target, func_args) - toc = func.call((f64_t,), "rtclock", ()) - time = arith.subf(toc, tic) - memref.store(time, times_memref, [i]) - scf.yield_(()) - - return bench.func_op - - class TransformOpInterfaceModel(transform.TransformOpInterface): - @staticmethod - def apply( - op: "WrapInBenchingFuncOp", - _rewriter: transform.TransformRewriter, - results: transform.TransformResults, - state: transform.TransformState, - ) -> DiagnosedSilenceableFailure: - targets = state.get_payload_ops(op.target) - if bench_name_attr := op.attributes.get("bench_name"): - bench_name = bench_name_attr.value - if len(targets) != 1: - return DiagnosedSilenceableFailure.SilenceableFailure - else: - bench_name = None - - bench_funcs = [] - for target in targets: - if not isinstance(target, func.FuncOp): - return DiagnosedSilenceableFailure.SilenceableFailure - - with ir.InsertionPoint(target), target.location: - bench_func = WrapInBenchingFuncOp.wrap_in_benching_func( - target, bench_name or f"bench_{target.name.value}" - ) - bench_funcs.append(bench_func) - - results.set_ops(op.bench_func, bench_funcs) - - return DiagnosedSilenceableFailure.Success - - @staticmethod - def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool: - return False - - class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): - @staticmethod - def get_effects(op: "WrapInBenchingFuncOp", effects): - transform.only_reads_handle(op.op_operands, effects) - transform.produces_handle(op.results, effects) - transform.modifies_payload(effects) - - -def wrap_in_benching_func( - target: ir.Value[transform.AnyOpType], bench_name: str | None = None -) -> ir.Value[transform.AnyOpType]: - """snake_case wrapper to create a WrapInBenchingFuncOp.""" - op = WrapInBenchingFuncOp(target=target) - if bench_name is not None: - op.attributes["bench_name"] = ir.StringAttr.get(bench_name) - return op.bench_func diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index e9fa790..f9b18ff 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -13,8 +13,8 @@ match, ) -from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext -from lighthouse.dialects.transform_tune_ext import knob, KnobValue +from lighthouse.dialects.transform import smt_ext +from lighthouse.dialects.transform.tune_ext import knob, KnobValue # hardware constraints DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( @@ -319,7 +319,7 @@ def bundle_xegpu_mlp_schedule( wg_m, wg_n = layer_params["wg_m"], layer_params["wg_n"] sg_m, sg_n = layer_params["sg_m"], layer_params["sg_n"] - @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + @smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) def constrain_wg_sg_and_calc_nb_threads( WG_M: int | smt_ext.SMTIntValue, WG_N: int | smt_ext.SMTIntValue, @@ -414,7 +414,7 @@ def xegpu_wg_annotation_for_mlp_layer( anyvalue = transform.AnyValueType.get() # Calculate with SMT ops in case of symbolic values, normal ints in case of concrete values. - @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + @smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): # NB: Constraint on overall num SG threads already dealt with elsewhere. return WG_M // SG_M, WG_N // SG_N @@ -426,7 +426,7 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): prefetch_tile_a = [prefetch_a_m, prefetch_a_k] prefetch_tile_b = [prefetch_b_k, prefetch_b_n] - @td_smt_ext.constrain_params( + @smt_ext.constrain_params( wg_m, wg_n, sg_m, diff --git a/lighthouse/tune/trace.py b/lighthouse/tune/trace.py index 57df76c..1ec4d05 100644 --- a/lighthouse/tune/trace.py +++ b/lighthouse/tune/trace.py @@ -11,7 +11,7 @@ from mlir.dialects import transform, smt from mlir.dialects.transform import tune as transform_tune -from lighthouse.dialects import transform_smt_ext +from lighthouse.dialects.transform import smt_ext class Node(ABC): @@ -275,7 +275,7 @@ def trace_tune_and_smt_ops(op: ir.Operation, env: Optional[dict] = None) -> dict region_idx_to_result=result_idx_region_idx_to_node[res_idx], ) - case transform_smt_ext.ConstrainParamsOp(): + case smt_ext.ConstrainParamsOp(): # Map the block args in the op's region to the nodes already # associated to the corresponding arguments on the op itself. for operand, block_arg in zip(op.operands, op.body.arguments): diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index c3d3017..6739b2a 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -10,7 +10,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from lighthouse.dialects import transform_ext +from lighthouse.dialects.transform import transform_ext from lighthouse.schedule import schedule_boilerplate from lighthouse.utils.memref import to_packed_args from lighthouse.utils.mlir import get_mlir_library_path diff --git a/test/dialects/transform_ext.py b/test/dialects/transform_ext.py index 7d14925..c392050 100644 --- a/test/dialects/transform_ext.py +++ b/test/dialects/transform_ext.py @@ -6,7 +6,7 @@ from mlir.dialects import transform, func, arith, index from mlir.dialects.transform import structured as transform_structured import lighthouse.dialects as lh_dialects -from lighthouse.dialects import transform_ext +from lighthouse.dialects.transform import transform_ext def run(f): @@ -70,13 +70,13 @@ def test_param_cmp_eq_op(payload_handle): c42_attr = ir.IntegerAttr.get(ir.IndexType.get(), 42) c42_as_param = transform.param_constant(transform.AnyParamType.get(), c42_attr) - transform_ext.ParamCmpEqOp(constant_value_param, c42_as_param) + transform_ext.param_cmp_eq(constant_value_param, c42_as_param) c0_attr = ir.IntegerAttr.get(ir.IndexType.get(), 0) # CHECK: got here transform.print_(name="got here") # Comparing 42 against 0 — should fail and abort the sequence. c0_as_param = transform.param_constant(transform.AnyParamType.get(), c0_attr) - transform_ext.ParamCmpEqOp(constant_value_param, c0_as_param) + transform_ext.param_cmp_eq(constant_value_param, c0_as_param) # CHECK-NOT: but not here # CHECK: Caught exception: Failed to apply named transform sequence transform.print_(name="but not here")