66
77from typing import Any , Callable , Iterable
88
9+ import numpy as np
910import sympy
1011
1112from accelforge .frontend .mapping import Nested , TilePattern
@@ -92,6 +93,18 @@ def reduce_precision(data: pd.DataFrame) -> pd.DataFrame:
9293
9394
9495class PmappingDataframe :
96+ def _assert_invariants_before_and_after (f ):
97+ def wrapped (self , * args , ** kwargs ):
98+ self ._assert_reservation_includes_live_tensors ()
99+ self ._assert_consistent_left_right_reservations ()
100+ self ._assert_reservation_inclusivity ()
101+ result = f (self , * args , ** kwargs )
102+ self ._assert_reservation_includes_live_tensors ()
103+ self ._assert_consistent_left_right_reservations ()
104+ self ._assert_reservation_inclusivity ()
105+ return result
106+ return wrapped
107+
95108 def __init__ (
96109 self ,
97110 data : pd .DataFrame ,
@@ -230,6 +243,7 @@ def clear_fused_loop_symbols(self):
230243 self .make_pareto ()
231244
232245 @error_check_wrapper
246+ @_assert_invariants_before_and_after
233247 def free_to_loop_index (self , loop_index : int ) -> bool :
234248 """
235249 A B
@@ -251,8 +265,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
251265 if loop_index < - 1 :
252266 raise ValueError ("loop_index must be >= -1" )
253267
254- self ._check_reservation_inclusivity ()
255-
256268 # We keep reservations under loop_index, which is index loop_index+1
257269 reservation_max_index = loop_index + 1
258270
@@ -263,7 +275,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
263275
264276 if not self ._has_left_reservations ():
265277 self ._move_reservations_to_index (reservation_max_index )
266- self ._check_reservation_inclusivity ()
267278 return True
268279
269280 updated = False
@@ -276,7 +287,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
276287 assert self ._has_bottom_right ()
277288 assert self .get_max_loop_index () == reservation_max_index
278289
279- self ._check_reservation_inclusivity ()
280290 return updated
281291
282292 def get_max_loop_index (self ):
@@ -307,6 +317,7 @@ def get_min_loop_index(self):
307317 )
308318
309319 @error_check_wrapper
320+ @_assert_invariants_before_and_after
310321 def merge_next (
311322 self ,
312323 right : "PmappingDataframe" ,
@@ -343,12 +354,10 @@ def merge_next(
343354
344355 right .free_to_loop_index (compatibility_right .n_loops - 1 )
345356 right .check_live_reservations (compatibility_right .tensors & compatibility_joined .tensors )
346- right ._check_reservation_inclusivity ()
347357
348358 assert not right ._has_left_reservations ()
349359 assert self .get_max_loop_index ()- 1 <= compatibility_left .n_loops , f"max reservation index is { self .get_max_loop_index ()} and compatibility has { compatibility_left .n_loops } loops"
350360 assert right .get_max_loop_index ()- 1 <= compatibility_right .n_loops
351- self .check_consistent_left_right_reservation ()
352361
353362 self ._remove_dead_reservations (compatibility_joined )
354363
@@ -357,7 +366,6 @@ def merge_next(
357366 self ._shift_bottom_reservation_left ()
358367
359368 self .check_live_reservations (compatibility_joined .tensors & compatibility_left .tensors )
360- self ._check_reservation_inclusivity ()
361369
362370 shared_tensor_names = (
363371 compatibility_left .tensor_names & compatibility_right .tensor_names
@@ -454,7 +462,7 @@ def check_match(la: Loop, lb: Loop, param: str):
454462 if not (live_col in df and right_live_col in df ):
455463 continue
456464 adjustment_val = (
457- max (df [live_col ], df [right_live_col ]) # what it should be
465+ np . maximum (df [live_col ], df [right_live_col ]) # what it should be
458466 -
459467 (df [live_col ] + df [right_live_col ]) # what the following logic otherwise would count
460468 )
@@ -570,16 +578,13 @@ def iter_reservations(reservations_dict):
570578
571579 result ._remove_dead_reservations (compatibility_joined )
572580
573- result ._check_reservation_inclusivity ()
574581 result .check_live_reservations (compatibility_joined .tensors )
575582
576583 if CHECK_CORRECTNESS :
577584 result .check_above_subset_below (live_tensors )
578- result .check_reservations (live_tensors )
579585
580586 result .free_to_loop_index (next_shared_loop_index )
581587
582- result .check_consistent_left_right_reservation ()
583588 assert result ._has_bottom_right ()
584589
585590 if not CHECK_CORRECTNESS :
@@ -784,42 +789,6 @@ def _get_reservation_or_parent(
784789 level -= 1
785790 return None
786791
787- # def _free_reservations_of_resource(
788- # self,
789- # resource: str,
790- # tensor_indices_to_free: Iterable[tuple[str, int]],
791- # ):
792- # """
793- # For every `(tensor, nloops)` in `tensor_indices_to_free`, reduces all
794- # right reservations of `resource` at `index >= nloops` and left reservations
795- # at `index > nloops` by the size of the tensor as recorded in
796- # `self.data[tensor2col(tensor)]`.
797- # """
798- # targets = defaultdict(int)
799- # for tensor_name, to_free_nloops in tensor_indices_to_free:
800- # size = self.data[tensor2col(tensor_name)]
801- # for col in get_reservation_cols_with(self.data, name=resource):
802- # key = col2reservation(col)
803- # if (
804- # (key.is_left and key.nloops > to_free_nloops)
805- # or
806- # (key.is_right and key.nloops >= to_free_nloops)
807- # ):
808- # targets[key.nloops, col] -= size
809-
810- # # Now apply the allocations. Sort so we go from top to bottom in case
811- # # there are maxes that propagate down.
812- # for (_, target), size in sorted(
813- # targets.items(), key=lambda x: x[0], reverse=True
814- # ):
815- # assert target in self.data
816- # add_to_col(self.data, target, size)
817- # # Assert all reservations are >= 0
818- # try:
819- # assert (self.data[target] >= 0).all(), f"Negative reservation: {target}"
820- # except:
821- # breakpoint()
822-
823792 @staticmethod
824793 def concat (
825794 paretos : list ["PmappingDataframe" ], skip_pareto : bool = False
@@ -850,6 +819,7 @@ def concat(
850819 )
851820 return p
852821
822+ @_assert_invariants_before_and_after
853823 def _shift_bottom_reservation_left (self ):
854824 """
855825 Shifts the bottom reservation from right to left.
@@ -964,7 +934,6 @@ def _shift_bottom_reservation_left(self):
964934 self ._data = pd .concat (all_data , ignore_index = True )
965935
966936 assert not self ._has_bottom_right ()
967- self ._check_reservation_inclusivity ()
968937
969938 def _move_reservations_to_index (self , loop_index ):
970939 dropcols = []
@@ -975,8 +944,8 @@ def _move_reservations_to_index(self, loop_index):
975944 target = reservation2col (key .name , loop_index )
976945 max_to_col (self .data , target , col )
977946 self .data .drop (columns = dropcols , inplace = True )
978- self ._check_reservation_inclusivity ()
979947
948+ @_assert_invariants_before_and_after
980949 def _consolidate_bottom_split (self ):
981950 """
982951 Consolidate bottom split.
@@ -989,7 +958,6 @@ def _consolidate_bottom_split(self):
989958 F E
990959 """
991960 assert not self ._has_bottom_right ()
992- self ._check_reservation_inclusivity ()
993961
994962 bottom_index = self .get_max_loop_index ()
995963 target_index = bottom_index - 1
@@ -1040,7 +1008,6 @@ def _consolidate_bottom_split(self):
10401008 drop_columns .append (f"Total<SEP>latency<SEP>{ bottom_index } <SEP>{ thread_i } " )
10411009 self .data .drop (columns = drop_columns , inplace = True )
10421010
1043- self ._check_reservation_inclusivity ()
10441011 assert self ._has_bottom_right ()
10451012
10461013 def _remove_dead_reservations (self , compatibility : Compatibility ):
@@ -1162,6 +1129,18 @@ def filter_rows(
11621129 def __len__ (self ) -> int :
11631130 return len (self ._data )
11641131
1132+ def check_live_reservations (self , tensors : oset [TensorReservation ]):
1133+ for tensor in tensors :
1134+ if len (list (get_live_reservation_cols_with (
1135+ self .data ,
1136+ name = tensor .resource_name ,
1137+ tensor = tensor .name
1138+ ))) == 0 :
1139+ colnames = ""
1140+ for c in self .data .columns :
1141+ colnames += f" { c } \n "
1142+ raise RuntimeError (f"missing live reservation for { tensor } . columns:\n " + colnames )
1143+
11651144 def _assert_no_duplicate_cols (self ):
11661145 if len (list (_get_duplicates (self ._data .columns ))) > 0 :
11671146 raise ValueError (
@@ -1170,39 +1149,28 @@ def _assert_no_duplicate_cols(self):
11701149 " " + list (_get_duplicates (self ._data .columns ))
11711150 )
11721151
1173- def check_live_reservations (self , tensors : oset [TensorReservation ]):
1174- for tensor in tensors :
1175- for col in get_live_reservation_cols_with (
1152+ def _assert_reservation_includes_live_tensors (self ):
1153+ for col in get_live_reservation_cols_with (self .data ):
1154+ live_res_key = col2live_reservation (col )
1155+ if (self .data [col ] > self .data [tensor2col (live_res_key .tensor )]).any ():
1156+ raise RuntimeError (f"reservation for live tensor larger than tensor size for { col } " )
1157+
1158+ for res_col in get_reservation_cols_with (
11761159 self .data ,
1177- name = tensor . resource_name ,
1178- tensor = tensor . name ,
1160+ name = live_res_key . name ,
1161+ thread = live_res_key . thread ,
11791162 ):
1180- if (self .data [col ] > self .data [tensor2col (tensor .name )]).any ():
1181- raise RuntimeError (f"reservation for live tensor larger than tensor size for { col } " )
1182-
1183- live_res_key = col2live_reservation (col )
1184- for res_col in get_reservation_cols_with (
1185- self .data ,
1186- name = tensor .resource_name ,
1187- thread = live_res_key .thread ,
1163+ res_key = col2reservation (res_col )
1164+ if (
1165+ res_key .is_right and res_key .nloops < live_res_key .nloops
1166+ or
1167+ res_key .is_left and res_key .nloops != live_res_key .nloops
11881168 ):
1189- res_key = col2reservation (res_col )
1190- if (
1191- res_key .is_right and res_key .nloops < live_res_key .nloops
1192- or
1193- res_key .is_left and res_key .nloops != live_res_key .nloops
1194- ):
1195- continue
1196- if (self .data [res_col ] < self .data [col ]).any ():
1197- raise RuntimeError (f"reservation smaller than reservation for live tensor { col } " )
1198- break
1199- else :
1200- colnames = ""
1201- for c in self .data .columns :
1202- colnames += f" { c } \n "
1203- raise RuntimeError (f"missing live reservation for { tensor } . columns:\n " + colnames )
1169+ continue
1170+ if (self .data [res_col ] < self .data [col ]).any ():
1171+ raise RuntimeError (f"reservation smaller than reservation for live tensor { col } " )
12041172
1205- def _check_reservation_inclusivity (self ):
1173+ def _assert_reservation_inclusivity (self ):
12061174 _ , r_reservations = self ._make_reservations ()
12071175 for resource in r_reservations :
12081176 for higher_col in get_reservation_cols_with (self .data , name = resource , is_left = False ):
@@ -1214,7 +1182,7 @@ def _check_reservation_inclusivity(self):
12141182 if (self .data [lower_col ] < self .data [higher_col ]).any ():
12151183 raise ValueError ("Lower reservation smaller than higher reservation" )
12161184
1217- def check_consistent_left_right_reservation (self ):
1185+ def _assert_consistent_left_right_reservations (self ):
12181186 if self ._has_right_latency ():
12191187 assert self ._has_bottom_right_reservations () or not self ._has_right_reservations ()
12201188 return
0 commit comments