Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/cpu/x86/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from mlir.dialects.transform import vector
from mlir.dialects.transform import tensor

from lighthouse.workload import benchmark
from lighthouse.workload import get_bench_wrapper_schedule
from lighthouse import dialects as lh_dialects
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
from lighthouse.utils.numpy import numpy_to_mlir_type
from lighthouse.pipeline.helper import apply_registered_pass
import lighthouse.utils as lh_utils
Expand Down Expand Up @@ -358,6 +358,8 @@ def parse_cli():
args = parse_cli()

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

match args.dtype:
case "f32":
in_dtype = np.float32
Expand Down
3 changes: 3 additions & 0 deletions examples/feed-forward-mpi/feed-forward-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
as_ctype,
)

from lighthouse import dialects as lh_dialects
from lighthouse.utils.memref import (
to_ctype as memref_to_ctype,
deallocate_memrefs_on_exit,
Expand Down Expand Up @@ -412,6 +413,8 @@ def schedule_modules(
R = MPI.COMM_WORLD.Get_rank()

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = DistFF(args, P, R)

# execute(wload, verbose=args.verbose)
Expand Down
3 changes: 3 additions & 0 deletions examples/workload/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mlir.dialects import transform
from mlir.execution_engine import ExecutionEngine

from lighthouse import dialects as lh_dialects
from lighthouse.pipeline.helper import match
from lighthouse.pipeline.opt import PassBundles, apply_bundle

Expand Down Expand Up @@ -159,6 +160,8 @@ def schedule_modules(

if __name__ == "__main__":
with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = ElementwiseSum(400, 400)

print(" Dump kernel ".center(60, "-"))
Expand Down
3 changes: 3 additions & 0 deletions examples/workload/example_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mlir.dialects import func, linalg, arith, memref
from mlir.execution_engine import ExecutionEngine

from lighthouse import dialects as lh_dialects
from lighthouse.workload import execute, benchmark
import lighthouse.utils as lh_utils

Expand Down Expand Up @@ -195,6 +196,8 @@ def payload_module(self):

if __name__ == "__main__":
with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = ElementwiseSumMLIRAlloc(400, 400)

print(" Dump kernel ".center(60, "-"))
Expand Down
3 changes: 3 additions & 0 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mlir import ir
from mlir.execution_engine import ExecutionEngine

from lighthouse import dialects as lh_dialects
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
from lighthouse.utils.memref import to_ctype as memref_to_ctype
from lighthouse.utils.numpy import numpy_to_ctype
Expand Down Expand Up @@ -360,6 +361,8 @@ def parse_cli():
c_type = "f32"

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = XeGPUMatMul(
M=M,
N=N,
Expand Down
3 changes: 3 additions & 0 deletions examples/xegpu/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mlir import ir
from mlir.execution_engine import ExecutionEngine

from lighthouse import dialects as lh_dialects
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
from lighthouse.utils.memref import to_ctype as memref_to_ctype
from lighthouse.utils.numpy import numpy_to_ctype
Expand Down Expand Up @@ -375,6 +376,8 @@ def parse_cli():
identity_weights = args.check_result

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = XeGPUMLP(
batch_size=args.batch_size,
input_size=args.input_size,
Expand Down
4 changes: 4 additions & 0 deletions lighthouse/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def register_and_load():
from . import transform_ext

transform_ext.register_and_load()
129 changes: 129 additions & 0 deletions lighthouse/dialects/transform_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from mlir import ir
from mlir.dialects import ext, transform, func, arith, scf, memref
from mlir.dialects.transform import DiagnosedSilenceableFailure

from lighthouse.utils.mlir import func_cif


def register_and_load(context=None):
TransformExtensionDialect.load()


class TransformExtensionDialect(ext.Dialect, name="transform_ext"):
@classmethod
def load(cls, *args, **kwargs):
super().load(*args, **kwargs)
for op_cls in cls.operations:
if hasattr(op_cls, "attach_interface_impls"):
op_cls.attach_interface_impls()


class WrapInBenchingFuncOp(
TransformExtensionDialect.Operation, name="wrap_in_benching_func"
):
"""Create a function that calls `target` function in a benchmarking loop.

The new function has the same arguments as `target` plus three additional ones:
- A memref to store the timing results (one element per iteration).
- The number of timed iterations.
- The number of warmup iterations.
"""

target: ext.Operand[transform.AnyOpType]
bench_func: ext.Result[transform.AnyOpType[()]]

@classmethod
def attach_interface_impls(cls, context=None):
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)

@staticmethod
def wrap_in_benching_func(target: func.FuncOp, bench_name: str):
"""Create a function that calls `target` in a benchmarking loop.

Each call to `target` is timed separately, and the times (in seconds)
are stored in a memref that is passed as an additional argument to the
benchmark function. It also takes two additional arguments for the
number of runs and warmup iterations.
"""

# define rtclock function
f64_t = ir.F64Type.get()
func.FuncOp("rtclock", ((), (f64_t,)), visibility="private")
# emit benchmark function
time_memref_t = ir.MemRefType.get((ir.ShapedType.get_dynamic_size(),), f64_t)
index_t = ir.IndexType.get()
args = target.type.inputs + [time_memref_t, index_t, index_t]

@func_cif(*args, name=bench_name)
def bench(*args):
zero = arith.constant(index_t, 0)
one = arith.constant(index_t, 1)
func_args = list(args[: len(target.type.inputs)])
times_memref, num_times, num_warmup = args[-3:]
for i in scf.for_(zero, num_warmup, one):
# FIXME(upstream): func.call needs to wrap _overridden_ CallOp.
func.CallOp(target, func_args)
scf.yield_(())
# TODO: get `num_times` from the `times_memref`.
for i in scf.for_(zero, num_times, one):
tic = func.call((f64_t,), "rtclock", ())
func.CallOp(target, func_args)
toc = func.call((f64_t,), "rtclock", ())
time = arith.subf(toc, tic)
memref.store(time, times_memref, [i])
scf.yield_(())

return bench.func_op

class TransformOpInterfaceModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "WrapInBenchingFuncOp",
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
targets = state.get_payload_ops(op.target)
if bench_name_attr := op.attributes.get("bench_name"):
bench_name = bench_name_attr.value
if len(targets) != 1:
return DiagnosedSilenceableFailure.SilenceableFailure
else:
bench_name = None

bench_funcs = []
for target in targets:
if not isinstance(target, func.FuncOp):
return DiagnosedSilenceableFailure.SilenceableFailure

with ir.InsertionPoint(target), target.location:
bench_func = WrapInBenchingFuncOp.wrap_in_benching_func(
target, bench_name or f"bench_{target.name.value}"
)
bench_funcs.append(bench_func)

results.set_ops(op.bench_func, bench_funcs)

return DiagnosedSilenceableFailure.Success

@staticmethod
def allow_repeated_handle_operands(_op: "WrapInBenchingFuncOp") -> bool:
return False

class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: "WrapInBenchingFuncOp", effects):
transform.only_reads_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.modifies_payload(effects)


def wrap_in_benching_func(
target: ir.Value[transform.AnyOpType], bench_name: str | None = None
) -> ir.Value[transform.AnyOpType]:
"""snake_case wrapper to create a WrapInBenchingFuncOp."""
op = WrapInBenchingFuncOp(target=target)
if bench_name is not None:
op.attributes["bench_name"] = ir.StringAttr.get(bench_name)
return op.bench_func
54 changes: 0 additions & 54 deletions lighthouse/schedule/pattern_schedule.py

This file was deleted.

2 changes: 0 additions & 2 deletions lighthouse/workload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from .runner import (
execute,
benchmark,
bench_wrapper_pattern,
get_bench_wrapper_schedule,
)

__all__ = [
"Workload",
"bench_wrapper_pattern",
"benchmark",
"execute",
"get_bench_wrapper_schedule",
Expand Down
Loading
Loading