Skip to content

Commit 6a341b8

Browse files
committed
[WIP] Fix max bug
1 parent 4591ecd commit 6a341b8

2 files changed

Lines changed: 50 additions & 92 deletions

File tree

accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py

Lines changed: 49 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import Any, Callable, Iterable
88

9+
import numpy as np
910
import sympy
1011

1112
from accelforge.frontend.mapping import Nested, TilePattern
@@ -92,6 +93,18 @@ def reduce_precision(data: pd.DataFrame) -> pd.DataFrame:
9293

9394

9495
class PmappingDataframe:
96+
def _assert_invariants_before_and_after(f):
97+
def wrapped(self, *args, **kwargs):
98+
self._assert_reservation_includes_live_tensors()
99+
self._assert_consistent_left_right_reservations()
100+
self._assert_reservation_inclusivity()
101+
result = f(self, *args, **kwargs)
102+
self._assert_reservation_includes_live_tensors()
103+
self._assert_consistent_left_right_reservations()
104+
self._assert_reservation_inclusivity()
105+
return result
106+
return wrapped
107+
95108
def __init__(
96109
self,
97110
data: pd.DataFrame,
@@ -230,6 +243,7 @@ def clear_fused_loop_symbols(self):
230243
self.make_pareto()
231244

232245
@error_check_wrapper
246+
@_assert_invariants_before_and_after
233247
def free_to_loop_index(self, loop_index: int) -> bool:
234248
"""
235249
A B
@@ -251,8 +265,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
251265
if loop_index < -1:
252266
raise ValueError("loop_index must be >= -1")
253267

254-
self._check_reservation_inclusivity()
255-
256268
# We keep reservations under loop_index, which is index loop_index+1
257269
reservation_max_index = loop_index + 1
258270

@@ -263,7 +275,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
263275

264276
if not self._has_left_reservations():
265277
self._move_reservations_to_index(reservation_max_index)
266-
self._check_reservation_inclusivity()
267278
return True
268279

269280
updated = False
@@ -276,7 +287,6 @@ def free_to_loop_index(self, loop_index: int) -> bool:
276287
assert self._has_bottom_right()
277288
assert self.get_max_loop_index() == reservation_max_index
278289

279-
self._check_reservation_inclusivity()
280290
return updated
281291

282292
def get_max_loop_index(self):
@@ -307,6 +317,7 @@ def get_min_loop_index(self):
307317
)
308318

309319
@error_check_wrapper
320+
@_assert_invariants_before_and_after
310321
def merge_next(
311322
self,
312323
right: "PmappingDataframe",
@@ -343,12 +354,10 @@ def merge_next(
343354

344355
right.free_to_loop_index(compatibility_right.n_loops-1)
345356
right.check_live_reservations(compatibility_right.tensors & compatibility_joined.tensors)
346-
right._check_reservation_inclusivity()
347357

348358
assert not right._has_left_reservations()
349359
assert self.get_max_loop_index()-1 <= compatibility_left.n_loops, f"max reservation index is {self.get_max_loop_index()} and compatibility has {compatibility_left.n_loops} loops"
350360
assert right.get_max_loop_index()-1 <= compatibility_right.n_loops
351-
self.check_consistent_left_right_reservation()
352361

353362
self._remove_dead_reservations(compatibility_joined)
354363

@@ -357,7 +366,6 @@ def merge_next(
357366
self._shift_bottom_reservation_left()
358367

359368
self.check_live_reservations(compatibility_joined.tensors & compatibility_left.tensors)
360-
self._check_reservation_inclusivity()
361369

362370
shared_tensor_names = (
363371
compatibility_left.tensor_names & compatibility_right.tensor_names
@@ -454,7 +462,7 @@ def check_match(la: Loop, lb: Loop, param: str):
454462
if not (live_col in df and right_live_col in df):
455463
continue
456464
adjustment_val = (
457-
max(df[live_col], df[right_live_col]) # what it should be
465+
np.maximum(df[live_col], df[right_live_col]) # what it should be
458466
-
459467
(df[live_col] + df[right_live_col]) # what the following logic otherwise would count
460468
)
@@ -570,16 +578,13 @@ def iter_reservations(reservations_dict):
570578

571579
result._remove_dead_reservations(compatibility_joined)
572580

573-
result._check_reservation_inclusivity()
574581
result.check_live_reservations(compatibility_joined.tensors)
575582

576583
if CHECK_CORRECTNESS:
577584
result.check_above_subset_below(live_tensors)
578-
result.check_reservations(live_tensors)
579585

580586
result.free_to_loop_index(next_shared_loop_index)
581587

582-
result.check_consistent_left_right_reservation()
583588
assert result._has_bottom_right()
584589

585590
if not CHECK_CORRECTNESS:
@@ -784,42 +789,6 @@ def _get_reservation_or_parent(
784789
level -= 1
785790
return None
786791

787-
# def _free_reservations_of_resource(
788-
# self,
789-
# resource: str,
790-
# tensor_indices_to_free: Iterable[tuple[str, int]],
791-
# ):
792-
# """
793-
# For every `(tensor, nloops)` in `tensor_indices_to_free`, reduces all
794-
# right reservations of `resource` at `index >= nloops` and left reservations
795-
# at `index > nloops` by the size of the tensor as recorded in
796-
# `self.data[tensor2col(tensor)]`.
797-
# """
798-
# targets = defaultdict(int)
799-
# for tensor_name, to_free_nloops in tensor_indices_to_free:
800-
# size = self.data[tensor2col(tensor_name)]
801-
# for col in get_reservation_cols_with(self.data, name=resource):
802-
# key = col2reservation(col)
803-
# if (
804-
# (key.is_left and key.nloops > to_free_nloops)
805-
# or
806-
# (key.is_right and key.nloops >= to_free_nloops)
807-
# ):
808-
# targets[key.nloops, col] -= size
809-
810-
# # Now apply the allocations. Sort so we go from top to bottom in case
811-
# # there are maxes that propagate down.
812-
# for (_, target), size in sorted(
813-
# targets.items(), key=lambda x: x[0], reverse=True
814-
# ):
815-
# assert target in self.data
816-
# add_to_col(self.data, target, size)
817-
# # Assert all reservations are >= 0
818-
# try:
819-
# assert (self.data[target] >= 0).all(), f"Negative reservation: {target}"
820-
# except:
821-
# breakpoint()
822-
823792
@staticmethod
824793
def concat(
825794
paretos: list["PmappingDataframe"], skip_pareto: bool = False
@@ -850,6 +819,7 @@ def concat(
850819
)
851820
return p
852821

822+
@_assert_invariants_before_and_after
853823
def _shift_bottom_reservation_left(self):
854824
"""
855825
Shifts the bottom reservation from right to left.
@@ -964,7 +934,6 @@ def _shift_bottom_reservation_left(self):
964934
self._data = pd.concat(all_data, ignore_index=True)
965935

966936
assert not self._has_bottom_right()
967-
self._check_reservation_inclusivity()
968937

969938
def _move_reservations_to_index(self, loop_index):
970939
dropcols = []
@@ -975,8 +944,8 @@ def _move_reservations_to_index(self, loop_index):
975944
target = reservation2col(key.name, loop_index)
976945
max_to_col(self.data, target, col)
977946
self.data.drop(columns=dropcols, inplace=True)
978-
self._check_reservation_inclusivity()
979947

948+
@_assert_invariants_before_and_after
980949
def _consolidate_bottom_split(self):
981950
"""
982951
Consolidate bottom split.
@@ -989,7 +958,6 @@ def _consolidate_bottom_split(self):
989958
F E
990959
"""
991960
assert not self._has_bottom_right()
992-
self._check_reservation_inclusivity()
993961

994962
bottom_index = self.get_max_loop_index()
995963
target_index = bottom_index-1
@@ -1040,7 +1008,6 @@ def _consolidate_bottom_split(self):
10401008
drop_columns.append(f"Total<SEP>latency<SEP>{bottom_index}<SEP>{thread_i}")
10411009
self.data.drop(columns=drop_columns, inplace=True)
10421010

1043-
self._check_reservation_inclusivity()
10441011
assert self._has_bottom_right()
10451012

10461013
def _remove_dead_reservations(self, compatibility: Compatibility):
@@ -1162,6 +1129,18 @@ def filter_rows(
11621129
def __len__(self) -> int:
11631130
return len(self._data)
11641131

1132+
def check_live_reservations(self, tensors: oset[TensorReservation]):
1133+
for tensor in tensors:
1134+
if len(list(get_live_reservation_cols_with(
1135+
self.data,
1136+
name=tensor.resource_name,
1137+
tensor=tensor.name
1138+
))) == 0:
1139+
colnames = ""
1140+
for c in self.data.columns:
1141+
colnames += f" {c}\n"
1142+
raise RuntimeError(f"missing live reservation for {tensor}. columns:\n" + colnames)
1143+
11651144
def _assert_no_duplicate_cols(self):
11661145
if len(list(_get_duplicates(self._data.columns))) > 0:
11671146
raise ValueError(
@@ -1170,39 +1149,28 @@ def _assert_no_duplicate_cols(self):
11701149
" " + list(_get_duplicates(self._data.columns))
11711150
)
11721151

1173-
def check_live_reservations(self, tensors: oset[TensorReservation]):
1174-
for tensor in tensors:
1175-
for col in get_live_reservation_cols_with(
1152+
def _assert_reservation_includes_live_tensors(self):
1153+
for col in get_live_reservation_cols_with(self.data):
1154+
live_res_key = col2live_reservation(col)
1155+
if (self.data[col] > self.data[tensor2col(live_res_key.tensor)]).any():
1156+
raise RuntimeError(f"reservation for live tensor larger than tensor size for {col}")
1157+
1158+
for res_col in get_reservation_cols_with(
11761159
self.data,
1177-
name=tensor.resource_name,
1178-
tensor=tensor.name,
1160+
name=live_res_key.name,
1161+
thread=live_res_key.thread,
11791162
):
1180-
if (self.data[col] > self.data[tensor2col(tensor.name)]).any():
1181-
raise RuntimeError(f"reservation for live tensor larger than tensor size for {col}")
1182-
1183-
live_res_key = col2live_reservation(col)
1184-
for res_col in get_reservation_cols_with(
1185-
self.data,
1186-
name=tensor.resource_name,
1187-
thread=live_res_key.thread,
1163+
res_key = col2reservation(res_col)
1164+
if (
1165+
res_key.is_right and res_key.nloops < live_res_key.nloops
1166+
or
1167+
res_key.is_left and res_key.nloops != live_res_key.nloops
11881168
):
1189-
res_key = col2reservation(res_col)
1190-
if (
1191-
res_key.is_right and res_key.nloops < live_res_key.nloops
1192-
or
1193-
res_key.is_left and res_key.nloops != live_res_key.nloops
1194-
):
1195-
continue
1196-
if (self.data[res_col] < self.data[col]).any():
1197-
raise RuntimeError(f"reservation smaller than reservation for live tensor {col}")
1198-
break
1199-
else:
1200-
colnames = ""
1201-
for c in self.data.columns:
1202-
colnames += f" {c}\n"
1203-
raise RuntimeError(f"missing live reservation for {tensor}. columns:\n" + colnames)
1169+
continue
1170+
if (self.data[res_col] < self.data[col]).any():
1171+
raise RuntimeError(f"reservation smaller than reservation for live tensor {col}")
12041172

1205-
def _check_reservation_inclusivity(self):
1173+
def _assert_reservation_inclusivity(self):
12061174
_, r_reservations = self._make_reservations()
12071175
for resource in r_reservations:
12081176
for higher_col in get_reservation_cols_with(self.data, name=resource, is_left=False):
@@ -1214,7 +1182,7 @@ def _check_reservation_inclusivity(self):
12141182
if (self.data[lower_col] < self.data[higher_col]).any():
12151183
raise ValueError("Lower reservation smaller than higher reservation")
12161184

1217-
def check_consistent_left_right_reservation(self):
1185+
def _assert_consistent_left_right_reservations(self):
12181186
if self._has_right_latency():
12191187
assert self._has_bottom_right_reservations() or not self._has_right_reservations()
12201188
return

tests/test_mapper.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,7 @@ def test_tpuv4i_gpt3(self):
4343
)
4444
spec.mapper.metrics = Metrics.ENERGY | Metrics.LATENCY
4545
spec.mapper.n_concurrent_threads = 2
46-
import pickle
47-
# spec.mapper._only_output_pmapping_with_index = {"Q": 0, "K": 0, "QK": 0}
48-
# mappings = spec.map_workload_to_arch(einsum_names=["I", "Q", "K", "QK"])
49-
50-
# pmappings = af.mapper.FFM.make_pmappings(spec, einsum_names=["I", "Q", "K", "QK"])
51-
# with open("tmp.pkl", "wb") as f:
52-
# pickle.dump(pmappings, f)
53-
54-
with open("tmp.pkl", "rb") as f:
55-
pmappings = pickle.load(f)
56-
af.mapper.FFM.join_pmappings(pmappings, spec.mapper.metrics, require_all_einsums=False)
46+
mappings = spec.map_workload_to_arch()
5747

5848
class ActionChecker(unittest.TestCase):
5949
def _check_memory_actions_exist(self, spec, memory_names, result):

0 commit comments

Comments
 (0)