diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ed49bfe1..f3a0695e 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1502,6 +1502,10 @@ def set_state_fn(state: str): ### END Orbax checkpointing API. + def is_ready(self) -> bool: + """Returns True if the iterator has data ready to be consumed.""" + return True + def start_prefetch(self) -> None: """Starts processing elements in the first asynchronous parent iterator. diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9db05ba9..2ca1008a 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -38,6 +38,7 @@ def __init__( num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, iter_buffer_size: int = 1, + allow_reordering: bool = False, ): # `datasets` is allowed to be a lazily evaluated `MapDataset`. We avoid # passing it as `parents` to not trigger evaluation early. @@ -46,6 +47,7 @@ def __init__( self._num_make_iter_threads = num_make_iter_threads self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size + self._allow_reordering = allow_reordering self._prefetch_ds_iter = ( dataset.MapDataset.source(datasets) .map( @@ -95,6 +97,17 @@ def __next__(self) -> T: timer = stats.Timer() _ = self._stats # eagerly initialize stats while True: + if ( + self._allow_reordering + and self._iterators_in_use[self._next_index_in_cycle] is not None + ): + for i in range(self._cycle_length): + idx = (self._next_index_in_cycle + i) % self._cycle_length + if iterator := self._iterators_in_use[idx]: + if iterator.is_ready(): + self._next_index_in_cycle = idx + break + if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]: try: result = iterator_to_use.__next__() @@ -384,6 +397,7 @@ def __init__( num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, iter_buffer_size: int = 1, + allow_reordering: bool = False, ): """Initializes the InterleaveIterDataset. @@ -405,6 +419,9 @@ def __init__( is 1, with this we'll always keep the next iterator ready in advance. iter_buffer_size: Optional. The number of elements to prefetch from each iterator. Default value is 1. + allow_reordering: Optional. If True, the next element will be taken from + the first iterator that has an element ready. If False, the iterators + will be cycled through in a round-robin fashion. Default value is False. """ super().__init__() self._datasets = datasets @@ -412,6 +429,7 @@ def __init__( self._num_make_iter_threads = num_make_iter_threads self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size + self._allow_reordering = allow_reordering def __iter__(self) -> dataset.DatasetIterator[T]: return InterleaveDatasetIterator( @@ -420,6 +438,7 @@ def __iter__(self) -> dataset.DatasetIterator[T]: num_make_iter_threads=self._num_make_iter_threads, make_iter_buffer_size=self._make_iter_buffer_size, iter_buffer_size=self._iter_buffer_size, + allow_reordering=self._allow_reordering, ) def set_slice(self, sl: slice, sequential_slice: bool = False): diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d64730..24ed0dd7 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib + from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized @@ -20,7 +22,7 @@ from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import interleave -from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint +from grain._src.python.testing import experimental import numpy as np @@ -187,7 +189,7 @@ def test_checkpointing_comprehensive(self): for i in range(1, 6) ] ds = interleave.InterleaveIterDataset(ds, cycle_length=5) - assert_equal_output_after_checkpoint(ds) + experimental.assert_equal_output_after_checkpoint(ds) def test_set_state_does_not_recreate_iterators_if_not_needed(self): cycle_length = 5 @@ -291,6 +293,57 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_skips_unready_iterator(self): + ds1 = dataset.MapDataset.range(10).to_iter_dataset() + ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset() + ds = interleave.InterleaveIterDataset( + [ds1, ds2], + cycle_length=2, + num_make_iter_threads=2, + make_iter_buffer_size=2, + iter_buffer_size=5, + allow_reordering=True, + ) + it = ds.__iter__() + + # The first cycle is still deterministic because the iterators have to be + # prepared. + self.assertEqual(next(it), 0) + self.assertEqual(next(it), 10) + + # pytype: disable=attribute-error + @contextlib.contextmanager + def _force_disable_iterator(idx): + """Helper method to set an iterator to "not ready".""" + orig_it_is_ready = it._iterators_in_use[idx].is_ready + it._iterators_in_use[idx].is_ready = lambda: False + try: + yield + finally: + it._iterators_in_use[idx].is_ready = orig_it_is_ready + + # pytype: enable=attribute-error + + # Force the first iterator to be unready. It should only read from the 2nd + # iterator. + with _force_disable_iterator(0): + self.assertEqual(next(it), 11) + self.assertEqual(next(it), 12) + self.assertEqual(next(it), 13) + + # Force the second iterator the be unready. It should only read from the 1st + # iterator. + with _force_disable_iterator(1): + self.assertEqual(next(it), 1) + self.assertEqual(next(it), 2) + self.assertEqual(next(it), 3) + + # Verify we can get the remaining values. + remaining_values = sorted(list(it)) + self.assertEqual( + remaining_values, [4, 5, 6, 7, 8, 9, 14, 15, 16, 17, 18, 19] + ) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index a9cde484..23144913 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -183,6 +183,11 @@ def _threshold_checker(self): raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio, ) + def is_ready(self) -> bool: + return ( + bool(self._buffer) or self._next_returned_index == self._dataset_length + ) + @dataset_stats.record_next_duration_if_output @dataset_stats.trace_input_pipeline_next( stage_category=dataset_stats.IPL_CAT_PREFETCH @@ -547,6 +552,9 @@ def start_prefetch(self): ) self._prefetch_thread.start() + def is_ready(self) -> bool: + return not self._buffer.empty() + @dataset_stats.record_next_duration_if_output @dataset_stats.trace_input_pipeline_next( stage_category=dataset_stats.IPL_CAT_PREFETCH