Skip to content

Commit be34155

Browse files
committed
[WIP] Untested live tensor forwarding
1 parent 2c81dcb commit be34155

2 files changed

Lines changed: 39 additions & 12 deletions

File tree

accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,6 @@ def shift_bottom_reservation_left(self):
409409
_l_reservations, r_reservations = self._make_reservations()
410410

411411
# Explore placing right branch in any of the left threads
412-
# TODO: extend lifetime of live_tensors and deduplicate tensors
413412
left_concurrent_threads = list(range(self.n_concurrent_threads))
414413
assert len(left_concurrent_threads) > 0
415414
all_data = []
@@ -421,6 +420,30 @@ def shift_bottom_reservation_left(self):
421420
right_reservation = reservation2col(resource, bottom_loop_index)
422421
left_reservation = reservation2col(resource, bottom_loop_index, True, thread_i)
423422

423+
for live_tensor in get_live_reservation_cols_with(
424+
df,
425+
resource=resource,
426+
nloops=bottom_loop_index,
427+
thread=thread_i
428+
):
429+
add_to_col(df, right_reservation, live_tensor)
430+
431+
for live_tensor_in_right in get_live_reservation_cols_with(
432+
df,
433+
resource=resource,
434+
nloops=bottom_loop_index,
435+
thread=DEFAULT_THREAD,
436+
):
437+
right_key = col2live_reservation(live_tensor_in_right)
438+
new_live_tensor = live_reservation2col(
439+
resource,
440+
right_key.tensor,
441+
bottom_loop_index,
442+
thread_i
443+
)
444+
add_to_col(df, new_live_tensor, live_tensor_in_right)
445+
df.drop(columns=[live_tensor_in_right])
446+
424447
for thread_j in left_concurrent_threads:
425448
key = reservation2col(resource,bottom_loop_index,True,thread_j)
426449
if key not in df:
@@ -429,11 +452,6 @@ def shift_bottom_reservation_left(self):
429452
if left_reservation in df:
430453
max_to_col(df, left_reservation, right_reservation)
431454
df.drop(columns=[right_reservation], inplace=True)
432-
# TODO: extend lifetime of live tensors
433-
# left_live_tensors = live_tensors2col(resource, bottom_loop_index, thread_i)
434-
# assert left_live_tensors in df
435-
# left_live_tensors = df[left_live_tensors]
436-
# left_live_tensors.apply(lambda left_live: left_live - right_tensors)
437455
else:
438456
df.rename(columns={right_reservation: left_reservation}, inplace=True)
439457

@@ -444,8 +462,6 @@ def shift_bottom_reservation_left(self):
444462
add_to_col(df, f"Total<SEP>latency<SEP>{bottom_loop_index}<SEP>{thread_i}", "Total<SEP>latency")
445463
df.drop(columns=["Total<SEP>latency"], inplace=True)
446464

447-
# for col in get_live_reservation_cols()
448-
449465
if self.track_binding_sequence:
450466
df["binding_order"] += BindingOrder([thread_i])
451467
assert not set(get_reservation_cols_with(
@@ -537,6 +553,7 @@ def merge_next(
537553
next_shared_loop_index = compatibility_joined.n_loops - 1
538554

539555
self.check_live_reservations(compatibility_left)
556+
self._remove_dead_reservations(compatibility_joined)
540557

541558
assert compatibility_left.n_loops <= compatibility_right.n_loops
542559
if self._has_bottom_right():
@@ -721,6 +738,8 @@ def iter_reservations(reservations_dict):
721738
n_concurrent_threads=self.n_concurrent_threads,
722739
)
723740

741+
result._remove_dead_reservations(compatibility_joined)
742+
724743
doubly_counted_reservations = _get_doubly_counted_reservations(
725744
compatibility_left.tensors,
726745
compatibility_right.tensors,
@@ -737,10 +756,6 @@ def iter_reservations(reservations_dict):
737756
assert result._has_bottom_right()
738757
result.free_to_loop_index(next_shared_loop_index)
739758

740-
# TODO: remove unneeded live tensors
741-
# TODO: handle shifting live tensors to left threads
742-
# TODO: allocate live tensors when shifting to left thread
743-
744759
if not CHECK_CORRECTNESS:
745760
result.limit_capacity(
746761
next_shared_loop_index, ignored_resources=ignored_resources
@@ -944,6 +959,15 @@ def has_reservations(self):
944959
# ============================================================================
945960
# Helper functions
946961
# ============================================================================
962+
def _remove_dead_reservations(self, compatibility: Compatibility):
963+
live_tensors = oset(tensor.name for tensor in compatibility.tensors)
964+
dropcols = []
965+
for col in get_live_reservation_cols_with(self.data):
966+
key = col2live_reservation(col)
967+
if key.tensor not in live_tensors:
968+
dropcols.append(col)
969+
self.data.drop(columns=dropcols)
970+
947971
def _create_live_reservation_from_compatibility(self, compatibility: Compatibility):
948972
for tensor in compatibility.tensors:
949973
col = live_reservation2col(

accelforge/mapper/FFM/_pareto_df/df_convention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ class LiveReservationKey(NamedTuple):
152152
nloops: int
153153
thread: int
154154

155+
def get_live_reservation_cols_with(df, **kwargs):
156+
yield from _filter(df, col2live_reservation, kwargs)
157+
155158
def is_live_reservation_col(col):
156159
return LIVE in col
157160

0 commit comments

Comments
 (0)