@@ -409,7 +409,6 @@ def shift_bottom_reservation_left(self):
409409 _l_reservations , r_reservations = self ._make_reservations ()
410410
411411 # Explore placing right branch in any of the left threads
412- # TODO: extend lifetime of live_tensors and deduplicate tensors
413412 left_concurrent_threads = list (range (self .n_concurrent_threads ))
414413 assert len (left_concurrent_threads ) > 0
415414 all_data = []
@@ -421,6 +420,30 @@ def shift_bottom_reservation_left(self):
421420 right_reservation = reservation2col (resource , bottom_loop_index )
422421 left_reservation = reservation2col (resource , bottom_loop_index , True , thread_i )
423422
423+ for live_tensor in get_live_reservation_cols_with (
424+ df ,
425+ resource = resource ,
426+ nloops = bottom_loop_index ,
427+ thread = thread_i
428+ ):
429+ add_to_col (df , right_reservation , live_tensor )
430+
431+ for live_tensor_in_right in get_live_reservation_cols_with (
432+ df ,
433+ resource = resource ,
434+ nloops = bottom_loop_index ,
435+ thread = DEFAULT_THREAD ,
436+ ):
437+ right_key = col2live_reservation (live_tensor_in_right )
438+ new_live_tensor = live_reservation2col (
439+ resource ,
440+ right_key .tensor ,
441+ bottom_loop_index ,
442+ thread_i
443+ )
444+ add_to_col (df , new_live_tensor , live_tensor_in_right )
445+ df .drop (columns = [live_tensor_in_right ])
446+
424447 for thread_j in left_concurrent_threads :
425448 key = reservation2col (resource ,bottom_loop_index ,True ,thread_j )
426449 if key not in df :
@@ -429,11 +452,6 @@ def shift_bottom_reservation_left(self):
429452 if left_reservation in df :
430453 max_to_col (df , left_reservation , right_reservation )
431454 df .drop (columns = [right_reservation ], inplace = True )
432- # TODO: extend lifetime of live tensors
433- # left_live_tensors = live_tensors2col(resource, bottom_loop_index, thread_i)
434- # assert left_live_tensors in df
435- # left_live_tensors = df[left_live_tensors]
436- # left_live_tensors.apply(lambda left_live: left_live - right_tensors)
437455 else :
438456 df .rename (columns = {right_reservation : left_reservation }, inplace = True )
439457
@@ -444,8 +462,6 @@ def shift_bottom_reservation_left(self):
444462 add_to_col (df , f"Total<SEP>latency<SEP>{ bottom_loop_index } <SEP>{ thread_i } " , "Total<SEP>latency" )
445463 df .drop (columns = ["Total<SEP>latency" ], inplace = True )
446464
447- # for col in get_live_reservation_cols()
448-
449465 if self .track_binding_sequence :
450466 df ["binding_order" ] += BindingOrder ([thread_i ])
451467 assert not set (get_reservation_cols_with (
@@ -537,6 +553,7 @@ def merge_next(
537553 next_shared_loop_index = compatibility_joined .n_loops - 1
538554
539555 self .check_live_reservations (compatibility_left )
556+ self ._remove_dead_reservations (compatibility_joined )
540557
541558 assert compatibility_left .n_loops <= compatibility_right .n_loops
542559 if self ._has_bottom_right ():
@@ -721,6 +738,8 @@ def iter_reservations(reservations_dict):
721738 n_concurrent_threads = self .n_concurrent_threads ,
722739 )
723740
741+ result ._remove_dead_reservations (compatibility_joined )
742+
724743 doubly_counted_reservations = _get_doubly_counted_reservations (
725744 compatibility_left .tensors ,
726745 compatibility_right .tensors ,
@@ -737,10 +756,6 @@ def iter_reservations(reservations_dict):
737756 assert result ._has_bottom_right ()
738757 result .free_to_loop_index (next_shared_loop_index )
739758
740- # TODO: remove unneeded live tensors
741- # TODO: handle shifting live tensors to left threads
742- # TODO: allocate live tensors when shifting to left thread
743-
744759 if not CHECK_CORRECTNESS :
745760 result .limit_capacity (
746761 next_shared_loop_index , ignored_resources = ignored_resources
@@ -944,6 +959,15 @@ def has_reservations(self):
944959 # ============================================================================
945960 # Helper functions
946961 # ============================================================================
962+ def _remove_dead_reservations (self , compatibility : Compatibility ):
963+ live_tensors = oset (tensor .name for tensor in compatibility .tensors )
964+ dropcols = []
965+ for col in get_live_reservation_cols_with (self .data ):
966+ key = col2live_reservation (col )
967+ if key .tensor not in live_tensors :
968+ dropcols .append (col )
969+ self .data .drop (columns = dropcols )
970+
947971 def _create_live_reservation_from_compatibility (self , compatibility : Compatibility ):
948972 for tensor in compatibility .tensors :
949973 col = live_reservation2col (
0 commit comments