Skip to content

Commit ea13d00

Browse files
committed
adressing review comments
1 parent 1d4f137 commit ea13d00

7 files changed

Lines changed: 58 additions & 82 deletions

File tree

examples/feed-forward-mpi/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This example shows how MLIR's sharding infrastructure can be used to distribute
44

55
Currently, only the lower part of the sharding pipeline is used: `shard-partition`, `convert-shard-to-mpi`, and lowering to LLVM. Therefore, the ingress MLIR is fully annotated.
66

7-
The example implements a single feed-fowrad layer, following a 1D/2D weight-stationary partition strategy as described in figures 2a and 2b of https://arxiv.org/pdf/2211.05102.
7+
The example implements a single feed-forwad layer, following a 1D/2D weight-stationary partition strategy as described in figures 2a and 2b of https://arxiv.org/pdf/2211.05102.
88

99
## Prerequisites
1010

examples/feed-forward-mpi/feed-forward-mpi.py

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from lighthouse.pipeline.helper import apply_registered_pass, match
3434
from lighthouse.workload import Workload, benchmark, get_bench_wrapper_schedule
35+
from lighthouse.schedule.utils import schedule_boilerplate
3536
from lighthouse.schedule.x86 import tile_and_vector_matmul
3637
from ff_weight_stationary import generate_ff_payload
3738

@@ -290,24 +291,8 @@ def find_factors(n):
290291

291292
return mod
292293

293-
def schedule_modules(
294-
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
295-
) -> list[ir.Module]:
296-
"""Generate two schedules: one that deals with sharding propagation, partition, and MPI.
297-
Another one for all the rest"""
298-
pre_schedule = ir.Module.create()
299-
pre_schedule.operation.attributes["transform.with_named_sequence"] = (
300-
ir.UnitAttr.get()
301-
)
302-
with ir.InsertionPoint(pre_schedule.body):
303-
named_sequence = transform.named_sequence(
304-
"__transform_pre",
305-
[transform.AnyOpType.get()],
306-
[],
307-
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
308-
)
309-
with ir.InsertionPoint(named_sequence.body):
310-
anytype = transform.AnyOpType.get()
294+
def get_shard_schedule(self):
295+
with schedule_boilerplate() as (schedule, named_sequence):
311296
func = match(named_sequence.bodyTarget, ops={"func.func"})
312297
func = apply_registered_pass(
313298
func,
@@ -317,7 +302,6 @@ def schedule_modules(
317302
if self.verbose > 0:
318303
transform.PrintOp(target=func)
319304
func = apply_registered_pass(func, "shard-partition")
320-
func = apply_registered_pass(func, "canonicalize")
321305
if self.verbose > 0:
322306
transform.PrintOp(target=func)
323307
func = apply_registered_pass(func, "shard-simplify")
@@ -329,36 +313,17 @@ def schedule_modules(
329313
transform.PrintOp(target=func)
330314
func = apply_registered_pass(func, "tosa-to-linalg")
331315
transform.YieldOp()
332-
func = None
333-
334-
bench_schedule = get_bench_wrapper_schedule(self)
335-
336-
tile_schedule = tile_and_vector_matmul.create(self.tile_size)
316+
return schedule
337317

338-
main_schedule = ir.Module.create()
339-
main_schedule.operation.attributes["transform.with_named_sequence"] = (
340-
ir.UnitAttr.get()
341-
)
342-
with ir.InsertionPoint(main_schedule.body):
343-
named_sequence = transform.named_sequence(
344-
"__transform_main",
345-
[transform.AnyOpType.get()],
346-
[],
347-
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
348-
)
349-
with ir.InsertionPoint(named_sequence.body):
318+
def get_lower_schedule(self):
319+
with schedule_boilerplate() as (schedule, named_sequence):
350320
anytype = transform.AnyOpType.get()
351321
func = match(named_sequence.bodyTarget, ops={"func.func"})
352322
mod = transform.get_parent_op(
353323
anytype, func, op_name="builtin.module", deduplicate=True
354324
)
355325
mod = apply_registered_pass(mod, "linalg-generalize-named-ops")
356-
mod = apply_registered_pass(mod, "canonicalize")
357326
mod = apply_registered_pass(mod, "linalg-fuse-elementwise-ops")
358-
mod = apply_registered_pass(mod, "arith-expand")
359-
mod = apply_registered_pass(mod, "memref-expand")
360-
mod = apply_registered_pass(mod, "empty-tensor-to-alloc-tensor")
361-
mod = apply_registered_pass(mod, "canonicalize")
362327
identity_layout = LayoutMapOption.IdentityLayoutMap
363328
mod = OneShotBufferizeOp(
364329
mod,
@@ -371,27 +336,18 @@ def schedule_modules(
371336
"drop-equivalent-buffer-results",
372337
options={"modify-public-functions": True},
373338
)
374-
mod = apply_registered_pass(mod, "expand-realloc")
375-
mod = apply_registered_pass(mod, "canonicalize")
376339
mod = apply_registered_pass(mod, "buffer-deallocation-simplification")
377340
mod = apply_registered_pass(mod, "bufferization-lower-deallocations")
378341
mod = apply_registered_pass(mod, "cse")
379342
mod = apply_registered_pass(mod, "canonicalize")
380-
mod = apply_registered_pass(mod, "convert-bufferization-to-memref")
381343
mod = apply_registered_pass(mod, "convert-linalg-to-parallel-loops")
382344
mod = apply_registered_pass(mod, "scf-parallel-loop-fusion")
383345
mod = apply_registered_pass(mod, "canonicalize")
384-
mod = apply_registered_pass(mod, "fold-memref-alias-ops")
385346
mod = apply_registered_pass(mod, "expand-strided-metadata")
386-
mod = apply_registered_pass(mod, "convert-math-to-funcs")
387347
mod = apply_registered_pass(mod, "lower-affine")
388348
mod = apply_registered_pass(mod, "convert-vector-to-scf")
389349
mod = apply_registered_pass(mod, "convert-scf-to-cf")
390350
mod = apply_registered_pass(mod, "symbol-dce")
391-
mod = apply_registered_pass(mod, "finalize-memref-to-llvm")
392-
mod = apply_registered_pass(mod, "convert-math-to-llvm")
393-
mod = apply_registered_pass(mod, "convert-math-to-libm")
394-
mod = apply_registered_pass(mod, "convert-func-to-llvm")
395351
mod = apply_registered_pass(mod, "convert-vector-to-llvm")
396352
mod = apply_registered_pass(mod, "canonicalize")
397353
mod = apply_registered_pass(mod, "convert-to-llvm")
@@ -400,8 +356,22 @@ def schedule_modules(
400356
if self.verbose > 1:
401357
transform.PrintOp(target=mod)
402358
transform.YieldOp()
359+
return schedule
403360

404-
return [pre_schedule, bench_schedule, tile_schedule, main_schedule]
361+
def schedule_modules(
362+
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
363+
) -> list[ir.Module]:
364+
"""Generate schedules:
365+
- sharding propagation, partition, and MPI, tosa-to-linalg
366+
- adding benchmark wrapper
367+
- tile_and_vector
368+
- all the rest"""
369+
return [
370+
self.get_shard_schedule(),
371+
get_bench_wrapper_schedule(self),
372+
tile_and_vector_matmul.create(self.tile_size),
373+
self.get_lower_schedule(),
374+
]
405375

406376

407377
if __name__ == "__main__":

examples/feed-forward-mpi/ff_weight_stationary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _(a, b, c, r):
124124
res = bufferization.materialize_in_destination(
125125
t_mk,
126126
sd_res,
127-
sd_r, # , restrict=True, writable=True
127+
sd_r,
128128
)
129129
return shard.shard(res, sh_act, annotate_for_users=True)
130130

lighthouse/schedule/pattern_schedule.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from contextlib import contextmanager
21
from mlir import rewrite, ir
32
from mlir.dialects import ext, transform
4-
from mlir.dialects.transform import AnyOpType
3+
from lighthouse.schedule.utils import schedule_boilerplate
54

65

76
@ext.register_dialect
@@ -35,21 +34,6 @@ def populate_patterns(
3534
return RewritePattern
3635

3736

38-
@contextmanager
39-
def schedule_boilerplate():
40-
schedule = ir.Module.create()
41-
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
42-
with ir.InsertionPoint(schedule.body):
43-
named_sequence = transform.NamedSequenceOp(
44-
"__transform_main",
45-
[AnyOpType.get()],
46-
[AnyOpType.get()],
47-
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
48-
)
49-
with ir.InsertionPoint(named_sequence.body):
50-
yield schedule, named_sequence
51-
52-
5337
def pattern_rewrite_schedule(patterns: dict, pname: str = "rewrite_pattern"):
5438
"""Return a transform module that applies the given rewrite patterns.
5539
patterns: dict mapping op names to match-and-rewrite functions.

lighthouse/schedule/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from contextlib import contextmanager
2+
from mlir import ir
3+
from mlir.dialects import transform
4+
5+
6+
@contextmanager
7+
def schedule_boilerplate():
8+
schedule = ir.Module.create()
9+
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
10+
with ir.InsertionPoint(schedule.body):
11+
named_sequence = transform.NamedSequenceOp(
12+
"__transform_main",
13+
[transform.AnyOpType.get()],
14+
[transform.AnyOpType.get()],
15+
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
16+
)
17+
with ir.InsertionPoint(named_sequence.body):
18+
yield schedule, named_sequence

lighthouse/workload/runner.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,17 @@ def execute(
6666
raise ValueError("Benchmark verification failed.")
6767

6868

69-
def bench_wrapper_pattern(funcname: str, get_bench_name=None):
69+
def bench_wrapper_pattern(funcname: str, benchname=None):
7070
"""Returns a rewrite pattern that matches a function named `funcname` and clones it
71-
as a new function with name given by `get_bench_name(funcname)` (default: "bench_" + funcname).
71+
as a new function with name given by `benchname` (default: "bench_" + funcname).
7272
The new function is a benchmark wrapper that calls the payload function and times it.
7373
Every function call is timed separately. Returns the times (seconds) in a memref,
7474
which is passed as an additional argument to the benchmark function.
7575
It also takes two additional arguments for the number of runs and warmup iterations.
7676
"""
77-
marker = "__wrapped__"
78-
if get_bench_name is None:
79-
80-
def default_bench_name(name):
81-
return f"bench_{name}"
82-
83-
get_bench_name = default_bench_name
77+
marker = "__bench_wrapped__"
78+
if benchname is None:
79+
benchname = f"bench_{funcname}"
8480

8581
def match_and_rewrite(op, rewriter):
8682
if op.name.value != funcname:
@@ -100,7 +96,7 @@ def match_and_rewrite(op, rewriter):
10096
index_t = ir.IndexType.get()
10197
args = payload_arguments + [time_memref_t, index_t, index_t]
10298

103-
@func_cif(*args, name=get_bench_name(funcname))
99+
@func_cif(*args, name=benchname)
104100
def bench(*args):
105101
index_t = ir.IndexType.get()
106102
zero = arith.constant(index_t, 0)
@@ -129,7 +125,7 @@ def get_bench_wrapper_schedule(workload: Workload):
129125
{
130126
"func.func": bench_wrapper_pattern(
131127
workload.payload_function_name,
132-
lambda name: workload.benchmark_function_name,
128+
workload.benchmark_function_name,
133129
)
134130
},
135131
"add_bench_pattern",

lighthouse/workload/workload.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@ def lower_payload(
6868
schedule_modules = self.schedule_modules(
6969
stop_at_stage=dump_payload, parameters=schedule_parameters
7070
)
71-
assert isinstance(schedule_modules, list)
71+
if not isinstance(schedule_modules, list):
72+
raise TypeError(
73+
f"schedule_modules() must return a list of ir.Module instances, "
74+
f"got {type(schedule_modules).__name__}"
75+
)
76+
if not schedule_modules:
77+
raise ValueError(
78+
"schedule_modules() must return at least one schedule module."
79+
)
7280
if not dump_payload or dump_payload != "initial":
7381
for schedule_module in schedule_modules:
7482
# apply schedule on payload module

0 commit comments

Comments
 (0)