Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,10 @@ def set_state_fn(state: str):

### END Orbax checkpointing API.

def is_ready(self) -> bool:
"""Returns True if the iterator has data ready to be consumed."""
return True

def start_prefetch(self) -> None:
"""Starts processing elements in the first asynchronous parent iterator.

Expand Down
19 changes: 19 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
num_make_iter_threads: int = 1,
make_iter_buffer_size: int = 1,
iter_buffer_size: int = 1,
allow_reordering: bool = False,
):
# `datasets` is allowed to be a lazily evaluated `MapDataset`. We avoid
# passing it as `parents` to not trigger evaluation early.
Expand All @@ -46,6 +47,7 @@ def __init__(
self._num_make_iter_threads = num_make_iter_threads
self._make_iter_buffer_size = make_iter_buffer_size
self._iter_buffer_size = iter_buffer_size
self._allow_reordering = allow_reordering
self._prefetch_ds_iter = (
dataset.MapDataset.source(datasets)
.map(
Expand Down Expand Up @@ -95,6 +97,17 @@ def __next__(self) -> T:
timer = stats.Timer()
_ = self._stats # eagerly initialize stats
while True:
if (
self._allow_reordering
and self._iterators_in_use[self._next_index_in_cycle] is not None
):
for i in range(self._cycle_length):
idx = (self._next_index_in_cycle + i) % self._cycle_length
if iterator := self._iterators_in_use[idx]:
if iterator.is_ready():
self._next_index_in_cycle = idx
break

if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]:
try:
result = iterator_to_use.__next__()
Expand Down Expand Up @@ -384,6 +397,7 @@ def __init__(
num_make_iter_threads: int = 1,
make_iter_buffer_size: int = 1,
iter_buffer_size: int = 1,
allow_reordering: bool = False,
):
"""Initializes the InterleaveIterDataset.

Expand All @@ -405,13 +419,17 @@ def __init__(
is 1, with this we'll always keep the next iterator ready in advance.
iter_buffer_size: Optional. The number of elements to prefetch from each
iterator. Default value is 1.
allow_reordering: Optional. If True, the next element will be taken from
the first iterator that has an element ready. If False, the iterators
will be cycled through in a round-robin fashion. Default value is False.
"""
super().__init__()
self._datasets = datasets
self._cycle_length = cycle_length
self._num_make_iter_threads = num_make_iter_threads
self._make_iter_buffer_size = make_iter_buffer_size
self._iter_buffer_size = iter_buffer_size
self._allow_reordering = allow_reordering

def __iter__(self) -> dataset.DatasetIterator[T]:
return InterleaveDatasetIterator(
Expand All @@ -420,6 +438,7 @@ def __iter__(self) -> dataset.DatasetIterator[T]:
num_make_iter_threads=self._num_make_iter_threads,
make_iter_buffer_size=self._make_iter_buffer_size,
iter_buffer_size=self._iter_buffer_size,
allow_reordering=self._allow_reordering,
)

def set_slice(self, sl: slice, sequential_slice: bool = False):
Expand Down
57 changes: 55 additions & 2 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib

from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
Expand All @@ -20,7 +22,7 @@
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
from grain._src.python.testing import experimental
import numpy as np


Expand Down Expand Up @@ -187,7 +189,7 @@ def test_checkpointing_comprehensive(self):
for i in range(1, 6)
]
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
assert_equal_output_after_checkpoint(ds)
experimental.assert_equal_output_after_checkpoint(ds)

def test_set_state_does_not_recreate_iterators_if_not_needed(self):
cycle_length = 5
Expand Down Expand Up @@ -291,6 +293,57 @@ def test_set_next_index_with_multiple_datasets(self):
):
dataset.set_next_index(ds_iter, 0)

def test_skips_unready_iterator(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds = interleave.InterleaveIterDataset(
[ds1, ds2],
cycle_length=2,
num_make_iter_threads=2,
make_iter_buffer_size=2,
iter_buffer_size=5,
allow_reordering=True,
)
it = ds.__iter__()

# The first cycle is still deterministic because the iterators have to be
# prepared.
self.assertEqual(next(it), 0)
self.assertEqual(next(it), 10)

# pytype: disable=attribute-error
@contextlib.contextmanager
def _force_disable_iterator(idx):
"""Helper method to set an iterator to "not ready"."""
orig_it_is_ready = it._iterators_in_use[idx].is_ready
it._iterators_in_use[idx].is_ready = lambda: False
try:
yield
finally:
it._iterators_in_use[idx].is_ready = orig_it_is_ready

# pytype: enable=attribute-error

# Force the first iterator to be unready. It should only read from the 2nd
# iterator.
with _force_disable_iterator(0):
self.assertEqual(next(it), 11)
self.assertEqual(next(it), 12)
self.assertEqual(next(it), 13)

# Force the second iterator the be unready. It should only read from the 1st
# iterator.
with _force_disable_iterator(1):
self.assertEqual(next(it), 1)
self.assertEqual(next(it), 2)
self.assertEqual(next(it), 3)

# Verify we can get the remaining values.
remaining_values = sorted(list(it))
self.assertEqual(
remaining_values, [4, 5, 6, 7, 8, 9, 14, 15, 16, 17, 18, 19]
)


if __name__ == "__main__":
absltest.main()
8 changes: 8 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def _threshold_checker(self):
raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio,
)

def is_ready(self) -> bool:
return (
bool(self._buffer) or self._next_returned_index == self._dataset_length
)

@dataset_stats.record_next_duration_if_output
@dataset_stats.trace_input_pipeline_next(
stage_category=dataset_stats.IPL_CAT_PREFETCH
Expand Down Expand Up @@ -547,6 +552,9 @@ def start_prefetch(self):
)
self._prefetch_thread.start()

def is_ready(self) -> bool:
return not self._buffer.empty()

@dataset_stats.record_next_duration_if_output
@dataset_stats.trace_input_pipeline_next(
stage_category=dataset_stats.IPL_CAT_PREFETCH
Expand Down
Loading