diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9db05ba9..a3a299d1 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -52,11 +52,11 @@ def __init__( functools.partial( _add_prefetch_and_make_iterator, # We use weakref to avoid a circular reference. The - # _InterleaveDatasetIterator holds a reference to the + # InterleaveDatasetIterator holds a reference to the # prefetch iterator in `self._prefetch_ds_iter`. # The call to `_add_prefetch_and_make_iterator` (and the # partial object) would hold a reference to the - # _InterleaveDatasetIterator. This would prolong its lifetime + # InterleaveDatasetIterator. This would prolong its lifetime # leading to increased resource usage. interleave_iterator=weakref.ref(self), start_prefetch=True, @@ -86,6 +86,8 @@ def __init__( self._exhausted_iterators: list[ tuple[int, dataset.DatasetIterator[T]] | None ] = [None] * self._cycle_length + # Future states used for elastic iterators + self._future_states: dict[int, Any] = {} @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -142,6 +144,16 @@ def __next__(self) -> T: self._iterators_in_use_indices[self._next_index_in_cycle] = ( self._next_index_in_datasets ) + # For elastic iterators, we might have a future state saved for this + # dataset iterator from which to resume from. + if ( + self._next_index_in_datasets in self._future_states + and self._iterators_in_use[self._next_index_in_cycle] + ): + future_state = self._future_states.pop(self._next_index_in_datasets) + self._iterators_in_use[self._next_index_in_cycle].set_state( + future_state + ) self._next_index_in_datasets += 1 elif not any(self._iterators_in_use): raise StopIteration @@ -182,13 +194,15 @@ def get_state(self): int(self._exhausted_iterator_state[i] is not None) for i in range(self._cycle_length) ] - return { + state = { "next_index_in_cycle": self._next_index_in_cycle, "next_index_in_datasets": self._next_index_in_datasets, "iterators_in_use_indices": self._iterators_in_use_indices.copy(), "iterators_in_use_states": iterators_in_use_states, "exhausted": exhausted, + "future_states": self._future_states, } + return state def set_state(self, state): exhausted = state["exhausted"] @@ -220,7 +234,9 @@ def set_state(self, state): interleave_iterator=weakref.ref(self), start_prefetch=False, ) - iterator.set_state(it_state) + # Only update the iterator state if it is given + if it_state: + iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator else: self._exhausted_iterator_state[index_in_cycle] = it_state @@ -232,6 +248,7 @@ def set_state(self, state): self._next_index_in_cycle = state["next_index_in_cycle"] self._next_index_in_datasets = state["next_index_in_datasets"] self._iterators_in_use_indices = state["iterators_in_use_indices"] + self._future_states = state.get("future_states", {}) def _get_next_index(self) -> int: if len(self._datasets) == 1: @@ -307,6 +324,16 @@ def __str__(self) -> str: f" cycle_length={self._cycle_length})" ) + def _get_iterator_start_state(self, index: int) -> dict[str, Any]: + it = _add_prefetch_and_make_iterator( + self._datasets[index], + weakref.ref(self), + start_prefetch=False, + ) + state = it.get_state() + del it + return state + def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d64730..a5ed001a 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -291,6 +291,37 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_future_states(self): + datasets = [ + dataset.MapDataset.source([1, 2]).to_iter_dataset(), + dataset.MapDataset.source([3, 4]).to_iter_dataset(), + ] + ds = interleave.InterleaveIterDataset(datasets, cycle_length=1) + ds_iter = ds.__iter__() + + # Initialize the first iterator and get state. + state = ds_iter.get_state() + + # Get state for the second dataset iterator after advancing it. + ds1_iter = datasets[1].__iter__() + next(ds1_iter) # Consumes 3 + ds1_state = ds1_iter.get_state() + + # Inject future state for the second dataset (index 1). + state["future_states"] = {1: ds1_state} + + ds_iter.set_state(state) + + # Consume elements. + # It should yield elements from the first dataset (1, 2) and then + # yield elements from the second dataset starting from the future state (4). + self.assertEqual(next(ds_iter), 1) + self.assertEqual(next(ds_iter), 2) + self.assertEqual(next(ds_iter), 4) + + with self.assertRaises(StopIteration): + next(ds_iter) + if __name__ == "__main__": absltest.main()