Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions lighthouse/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
def register_and_load(**kwargs):
from . import smt_ext
from . import transform_ext
from . import transform_smt_ext
from . import transform_tune_ext
from .transform import transform_ext
from .transform import smt_ext
from .transform import tune_ext

smt_ext.register_and_load(**kwargs)
transform_ext.register_and_load(**kwargs)
transform_smt_ext.register_and_load(**kwargs)
transform_tune_ext.register_and_load(**kwargs)
smt_ext.register_and_load(**kwargs)
tune_ext.register_and_load(**kwargs)
Empty file.
16 changes: 16 additions & 0 deletions lighthouse/dialects/transform/smt_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .dialect import register_and_load
from .dialect import TransformSMTExtensionDialect
from .dialect import SMTIntValue
from .dialect import assert_

from .ops.constrain_params import constrain_params
from .ops.constrain_params import ConstrainParamsOp

__all__ = [
"ConstrainParamsOp",
"SMTIntValue",
"TransformSMTExtensionDialect",
"assert_",
"constrain_params",
"register_and_load",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in ‎lighthouse/dialects/smt_ext.py‎ is only related to the SMT dialect itself. That is, those helpers are/could be useful anywhere you create SMT ops. At the moment they are used just inside sections of (schedule) code which generate regions which contain solely smt-dialect ops.

Could you please move/keep these in a smt_ext dialect, that's not treated as a namespace of transform_ext?

Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,29 @@

from mlir import ir
from mlir.dialects import smt

__all__ = ["SMTIntValue", "assert_", "register_and_load"]
from mlir.dialects import ext


def register_and_load(**kwargs):
"""Register and load the SMTIntValue caster."""

TransformSMTExtensionDialect.load(**kwargs)
SMTIntValue.register_value_caster()


class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"):
"""A Transform Dialect extension for SMT-related operations."""

@classmethod
def load(cls, *args, **kwargs):
# Registers the dialect and its op classes and loads the dialect and ops into the context.
super().load(*args, **kwargs)

for op in cls.operations:
if hasattr(op, "attach_interfaces"):
op.attach_interfaces()


def assert_(predicate: ir.Value[smt.BoolType] | bool, error_message: str = ""):
"""Assert normally if a bool else produce an SMT assertion op."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,7 @@
from mlir.dialects import ext, smt, transform

from lighthouse.tune import trace

__all__ = [
"ConstrainParamsOp",
"TransformSMTExtensionDialect",
"constrain_params",
"register_and_load",
]


def register_and_load(**kwargs):
"""Register and load the TransformSMTDialectExtension and its operations."""

TransformSMTExtensionDialect.load(**kwargs)


class TransformSMTExtensionDialect(ext.Dialect, name="transform_smt_ext"):
"""A Transform Dialect extension for SMT-related operations."""

@classmethod
def load(cls, *args, **kwargs):
# Registers the dialect and its op classes and loads the dialect and ops into the context.
super().load(*args, **kwargs)

for op in cls.operations:
if hasattr(op, "attach_interfaces"):
op.attach_interfaces()
from ..dialect import TransformSMTExtensionDialect


class ConstrainParamsOp(
Expand Down
16 changes: 16 additions & 0 deletions lighthouse/dialects/transform/transform_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .dialect import register_and_load
from .dialect import TransformExtensionDialect

from .ops.wrap_in_benching import wrap_in_benching_func
from .ops.get_named_attribute import get_named_attribute
from .ops.param_cmp_eq import param_cmp_eq
from .ops.replace import replace

__all__ = [
"TransformExtensionDialect",
"get_named_attribute",
"param_cmp_eq",
"register_and_load",
"replace",
"wrap_in_benching_func",
]
17 changes: 17 additions & 0 deletions lighthouse/dialects/transform/transform_ext/dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from mlir.dialects import ext


def register_and_load(**kwargs):
TransformExtensionDialect.load(**kwargs)


class TransformExtensionDialect(ext.Dialect, name="transform_ext"):
@classmethod
def load(cls, *args, **kwargs):
# Registers the dialect and its op classes and loads the dialect and ops into the context.
super().load(*args, **kwargs)

# Attach interfaces to just registered/loaded operations.
for op_cls in cls.operations:
if hasattr(op_cls, "attach_interface_impls"):
op_cls.attach_interface_impls()
Comment on lines +8 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this already seems to be boilerplate, maybe we can have our own DialectBase class somewhere.

As the point of this PR is to move code around, we could do it in a follow up.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from mlir import ir
from mlir.dialects import ext, transform
from mlir.dialects.transform import DiagnosedSilenceableFailure

from ..dialect import TransformExtensionDialect


class GetNamedAttributeOp(
TransformExtensionDialect.Operation, name="get_named_attribute"
):
"""
Obtain a `target` op's associated attribute by `attr_name` as a `param`.

In case `target` resolves to multiple ops, the operation returns a list of
attributes. If any of the resolved `target` ops does not have an attribute
with the name `attr_name`, the operation fails.
"""

param: ext.Result[transform.AnyParamType[()]]
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr

@classmethod
def attach_interface_impls(cls, context=None):
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)

class TransformOpInterfaceModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "GetNamedAttributeOp",
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
associated_attrs = []
for target_op in target_ops:
assoc_attr = target_op.attributes.get(op.attr_name.value)
if assoc_attr is None:
return DiagnosedSilenceableFailure.SilenceableFailure
associated_attrs.append(assoc_attr)
results.set_params(op.param, associated_attrs)
return DiagnosedSilenceableFailure.Success

@staticmethod
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False

class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.only_reads_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.only_reads_payload(effects)


def get_named_attribute(
target: ir.Value, attr_name: str | ir.StringAttr
) -> ir.Value[transform.AnyParamType]:
if not isinstance(attr_name, ir.StringAttr):
attr_name = ir.StringAttr.get(attr_name)
return GetNamedAttributeOp(target=target, attr_name=attr_name).param
53 changes: 53 additions & 0 deletions lighthouse/dialects/transform/transform_ext/ops/param_cmp_eq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from mlir import ir
from mlir.dialects import ext, transform

from ..dialect import TransformExtensionDialect


class ParamCmpEqOp(TransformExtensionDialect.Operation, name="param_cmp_eq"):
"""
Compare the values of the `lhs` and `rhs` parameters for equality.

The operation succeeds if the values are equal, and fails otherwise. If the
parameters resolve to multiple values, the operation succeeds if all values
are pairwise equal, and fails otherwise.
"""

lhs: ext.Operand[transform.AnyParamType]
rhs: ext.Operand[transform.AnyParamType]

@classmethod
def attach_interface_impls(cls, context=None):
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)

class TransformOpInterfaceModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "ParamCmpEqOp",
_rewriter: transform.TransformRewriter,
_results: transform.TransformResults,
state: transform.TransformState,
) -> transform.DiagnosedSilenceableFailure:
lhs_params = state.get_params(op.lhs)
rhs_params = state.get_params(op.rhs)
if len(lhs_params) != len(rhs_params):
return transform.DiagnosedSilenceableFailure.SilenceableFailure
for lhs_param, rhs_param in zip(lhs_params, rhs_params):
if lhs_param != rhs_param:
return transform.DiagnosedSilenceableFailure.SilenceableFailure
return transform.DiagnosedSilenceableFailure.Success

@staticmethod
def allow_repeated_handle_operands(_op: "ParamCmpEqOp") -> bool:
return True

class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: "ParamCmpEqOp", effects):
transform.only_reads_handle(op.op_operands, effects)
transform.only_reads_payload(effects)


def param_cmp_eq(lhs: ir.Value, rhs: ir.Value):
return ParamCmpEqOp(lhs=lhs, rhs=rhs)
132 changes: 132 additions & 0 deletions lighthouse/dialects/transform/transform_ext/ops/replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Sequence, Optional

from mlir import ir
from mlir.dialects import ext, transform
from mlir.dialects.transform import DiagnosedSilenceableFailure

from ..dialect import TransformExtensionDialect


class ReplaceOp(TransformExtensionDialect.Operation, name="replace"):
"""Replace the `target` operation(s) with a new `op_kind` operation.

If `new_operands` are provided, they are used as operands for the new
operation(s); otherwise, the operands of the `target` operation(s) are
reused. The new op's result types are the same as those of the `target` op.

NB: This op is mostly an escape hatch for testing and prototyping purposes.
No attempt is made to guarantee that the rewrite is semantics perserving.
"""

new_op: ext.Result[transform.AnyOpType[()]]
target: ext.Operand[transform.AnyOpType]
op_kind: ir.StringAttr
new_operands: Sequence[ext.Operand[transform.AnyValueType]]

@classmethod
def attach_interface_impls(cls, ctx=None):
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)

class TransformOpInterfaceModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "ReplaceOp",
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)

# Resolve optional operand handles to payload values.
operand_values_per_handle = []
for operand_handle in op.new_operands:
operand_values_per_handle.append(
state.get_payload_values(operand_handle)
)
assert len(operand_values_per_handle[-1]) == len(target_ops), (
"Expected number of operand values to match number of target ops"
)

new_op_name = op.op_kind.value
new_op_attrs = {}
if "new_attrs" in op.attributes:
new_attrs = op.attributes["new_attrs"]
assert isinstance(new_attrs, ir.DictAttr), (
"Expected new_attrs to be a dictionary attribute"
)
for named_attr in new_attrs:
new_op_attrs[named_attr.name] = named_attr.attr

new_ops = []
for target_idx, target_op in enumerate(target_ops):
if "new_result_types" in op.attributes:
tuple_type = op.attributes["new_result_types"].value
assert isinstance(tuple_type, ir.TupleType), (
"Expected new_result_types to be a tuple of types"
)
assert tuple_type.num_types == len(target_op.results), (
"Expected number of new result types to match number of target op results"
)

new_result_types = [
tuple_type.get_type(i) for i in range(tuple_type.num_types)
]
else:
new_result_types = [ty.type for ty in target_op.results]

if operand_values_per_handle:
new_operands = [
vals[target_idx] for vals in operand_values_per_handle
]
else:
new_operands = list(target_op.operands)

with ir.InsertionPoint(target_op):
new_operation = ir.Operation.create(
new_op_name,
results=new_result_types,
operands=new_operands,
attributes=new_op_attrs,
)
rewriter.replace_op(target_op, new_operation)
new_ops.append(new_operation)

results.set_ops(op.new_op, new_ops)
return DiagnosedSilenceableFailure.Success

@staticmethod
def allow_repeated_handle_operands(_op: "ReplaceOp") -> bool:
return False

class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.consumes_handle(op.op_operands[:1], effects)
if new_operands_handles := op.op_operands[1:]:
transform.only_reads_handle(new_operands_handles, effects)
transform.produces_handle(op.results, effects)
transform.modifies_payload(effects)


def replace(
target: ir.Value,
op_kind: str | ir.StringAttr,
*new_operands: ir.Value,
new_result_types: Optional[ir.TupleType | Sequence[ir.Type]] = None,
new_attrs=None,
) -> ir.Value:
if not isinstance(op_kind, ir.StringAttr):
op_kind = ir.StringAttr.get(op_kind)
new_op = ReplaceOp(target, op_kind=op_kind, new_operands=new_operands).new_op
if new_result_types:
if not isinstance(new_result_types, ir.TupleType):
new_result_types = ir.TupleType.get_tuple(new_result_types)
new_op.owner.attributes["new_result_types"] = ir.TypeAttr.get(new_result_types)
if new_attrs:
if isinstance(new_attrs, dict):
new_attrs = ir.DictAttr.get(new_attrs)
else:
assert isinstance(new_attrs, ir.DictAttr)
new_op.owner.attributes["new_attrs"] = new_attrs
return new_op
Loading
Loading