3232)
3333from lighthouse .pipeline .helper import apply_registered_pass , match
3434from lighthouse .workload import Workload , benchmark , get_bench_wrapper_schedule
35+ from lighthouse .schedule .utils import schedule_boilerplate
3536from lighthouse .schedule .x86 import tile_and_vector_matmul
3637from ff_weight_stationary import generate_ff_payload
3738
@@ -290,24 +291,8 @@ def find_factors(n):
290291
291292 return mod
292293
293- def schedule_modules (
294- self , stop_at_stage : Optional [str ] = None , parameters : Optional [dict ] = None
295- ) -> list [ir .Module ]:
296- """Generate two schedules: one that deals with sharding propagation, partition, and MPI.
297- Another one for all the rest"""
298- pre_schedule = ir .Module .create ()
299- pre_schedule .operation .attributes ["transform.with_named_sequence" ] = (
300- ir .UnitAttr .get ()
301- )
302- with ir .InsertionPoint (pre_schedule .body ):
303- named_sequence = transform .named_sequence (
304- "__transform_pre" ,
305- [transform .AnyOpType .get ()],
306- [],
307- arg_attrs = [{"transform.readonly" : ir .UnitAttr .get ()}],
308- )
309- with ir .InsertionPoint (named_sequence .body ):
310- anytype = transform .AnyOpType .get ()
294+ def get_shard_schedule (self ):
295+ with schedule_boilerplate () as (schedule , named_sequence ):
311296 func = match (named_sequence .bodyTarget , ops = {"func.func" })
312297 func = apply_registered_pass (
313298 func ,
@@ -317,7 +302,6 @@ def schedule_modules(
317302 if self .verbose > 0 :
318303 transform .PrintOp (target = func )
319304 func = apply_registered_pass (func , "shard-partition" )
320- func = apply_registered_pass (func , "canonicalize" )
321305 if self .verbose > 0 :
322306 transform .PrintOp (target = func )
323307 func = apply_registered_pass (func , "shard-simplify" )
@@ -329,36 +313,17 @@ def schedule_modules(
329313 transform .PrintOp (target = func )
330314 func = apply_registered_pass (func , "tosa-to-linalg" )
331315 transform .YieldOp ()
332- func = None
333-
334- bench_schedule = get_bench_wrapper_schedule (self )
335-
336- tile_schedule = tile_and_vector_matmul .create (self .tile_size )
316+ return schedule
337317
338- main_schedule = ir .Module .create ()
339- main_schedule .operation .attributes ["transform.with_named_sequence" ] = (
340- ir .UnitAttr .get ()
341- )
342- with ir .InsertionPoint (main_schedule .body ):
343- named_sequence = transform .named_sequence (
344- "__transform_main" ,
345- [transform .AnyOpType .get ()],
346- [],
347- arg_attrs = [{"transform.readonly" : ir .UnitAttr .get ()}],
348- )
349- with ir .InsertionPoint (named_sequence .body ):
318+ def get_lower_schedule (self ):
319+ with schedule_boilerplate () as (schedule , named_sequence ):
350320 anytype = transform .AnyOpType .get ()
351321 func = match (named_sequence .bodyTarget , ops = {"func.func" })
352322 mod = transform .get_parent_op (
353323 anytype , func , op_name = "builtin.module" , deduplicate = True
354324 )
355325 mod = apply_registered_pass (mod , "linalg-generalize-named-ops" )
356- mod = apply_registered_pass (mod , "canonicalize" )
357326 mod = apply_registered_pass (mod , "linalg-fuse-elementwise-ops" )
358- mod = apply_registered_pass (mod , "arith-expand" )
359- mod = apply_registered_pass (mod , "memref-expand" )
360- mod = apply_registered_pass (mod , "empty-tensor-to-alloc-tensor" )
361- mod = apply_registered_pass (mod , "canonicalize" )
362327 identity_layout = LayoutMapOption .IdentityLayoutMap
363328 mod = OneShotBufferizeOp (
364329 mod ,
@@ -371,27 +336,18 @@ def schedule_modules(
371336 "drop-equivalent-buffer-results" ,
372337 options = {"modify-public-functions" : True },
373338 )
374- mod = apply_registered_pass (mod , "expand-realloc" )
375- mod = apply_registered_pass (mod , "canonicalize" )
376339 mod = apply_registered_pass (mod , "buffer-deallocation-simplification" )
377340 mod = apply_registered_pass (mod , "bufferization-lower-deallocations" )
378341 mod = apply_registered_pass (mod , "cse" )
379342 mod = apply_registered_pass (mod , "canonicalize" )
380- mod = apply_registered_pass (mod , "convert-bufferization-to-memref" )
381343 mod = apply_registered_pass (mod , "convert-linalg-to-parallel-loops" )
382344 mod = apply_registered_pass (mod , "scf-parallel-loop-fusion" )
383345 mod = apply_registered_pass (mod , "canonicalize" )
384- mod = apply_registered_pass (mod , "fold-memref-alias-ops" )
385346 mod = apply_registered_pass (mod , "expand-strided-metadata" )
386- mod = apply_registered_pass (mod , "convert-math-to-funcs" )
387347 mod = apply_registered_pass (mod , "lower-affine" )
388348 mod = apply_registered_pass (mod , "convert-vector-to-scf" )
389349 mod = apply_registered_pass (mod , "convert-scf-to-cf" )
390350 mod = apply_registered_pass (mod , "symbol-dce" )
391- mod = apply_registered_pass (mod , "finalize-memref-to-llvm" )
392- mod = apply_registered_pass (mod , "convert-math-to-llvm" )
393- mod = apply_registered_pass (mod , "convert-math-to-libm" )
394- mod = apply_registered_pass (mod , "convert-func-to-llvm" )
395351 mod = apply_registered_pass (mod , "convert-vector-to-llvm" )
396352 mod = apply_registered_pass (mod , "canonicalize" )
397353 mod = apply_registered_pass (mod , "convert-to-llvm" )
@@ -400,8 +356,22 @@ def schedule_modules(
400356 if self .verbose > 1 :
401357 transform .PrintOp (target = mod )
402358 transform .YieldOp ()
359+ return schedule
403360
404- return [pre_schedule , bench_schedule , tile_schedule , main_schedule ]
361+ def schedule_modules (
362+ self , stop_at_stage : Optional [str ] = None , parameters : Optional [dict ] = None
363+ ) -> list [ir .Module ]:
364+ """Generate schedules:
365+ - sharding propagation, partition, and MPI, tosa-to-linalg
366+ - adding benchmark wrapper
367+ - tile_and_vector
368+ - all the rest"""
369+ return [
370+ self .get_shard_schedule (),
371+ get_bench_wrapper_schedule (self ),
372+ tile_and_vector_matmul .create (self .tile_size ),
373+ self .get_lower_schedule (),
374+ ]
405375
406376
407377if __name__ == "__main__" :
0 commit comments