Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def __init__(
functools.partial(
_add_prefetch_and_make_iterator,
# We use weakref to avoid a circular reference. The
# _InterleaveDatasetIterator holds a reference to the
# InterleaveDatasetIterator holds a reference to the
# prefetch iterator in `self._prefetch_ds_iter`.
# The call to `_add_prefetch_and_make_iterator` (and the
# partial object) would hold a reference to the
# _InterleaveDatasetIterator. This would prolong its lifetime
# InterleaveDatasetIterator. This would prolong its lifetime
# leading to increased resource usage.
interleave_iterator=weakref.ref(self),
start_prefetch=True,
Expand Down Expand Up @@ -86,6 +86,8 @@ def __init__(
self._exhausted_iterators: list[
tuple[int, dataset.DatasetIterator[T]] | None
] = [None] * self._cycle_length
# Future states used for elastic iterators
self._future_states: dict[int, Any] = {}

@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
Expand Down Expand Up @@ -142,6 +144,16 @@ def __next__(self) -> T:
self._iterators_in_use_indices[self._next_index_in_cycle] = (
self._next_index_in_datasets
)
# For elastic iterators, we might have a future state saved for this
# dataset iterator from which to resume from.
if (
self._next_index_in_datasets in self._future_states
and self._iterators_in_use[self._next_index_in_cycle]
):
future_state = self._future_states.pop(self._next_index_in_datasets)
self._iterators_in_use[self._next_index_in_cycle].set_state(
future_state
)
self._next_index_in_datasets += 1
elif not any(self._iterators_in_use):
raise StopIteration
Expand Down Expand Up @@ -182,13 +194,15 @@ def get_state(self):
int(self._exhausted_iterator_state[i] is not None)
for i in range(self._cycle_length)
]
return {
state = {
"next_index_in_cycle": self._next_index_in_cycle,
"next_index_in_datasets": self._next_index_in_datasets,
"iterators_in_use_indices": self._iterators_in_use_indices.copy(),
"iterators_in_use_states": iterators_in_use_states,
"exhausted": exhausted,
"future_states": self._future_states,
}
return state

def set_state(self, state):
exhausted = state["exhausted"]
Expand Down Expand Up @@ -220,7 +234,9 @@ def set_state(self, state):
interleave_iterator=weakref.ref(self),
start_prefetch=False,
)
iterator.set_state(it_state)
# Only update the iterator state if it is given
if it_state:
iterator.set_state(it_state)
self._iterators_in_use[index_in_cycle] = iterator
else:
self._exhausted_iterator_state[index_in_cycle] = it_state
Expand All @@ -232,6 +248,7 @@ def set_state(self, state):
self._next_index_in_cycle = state["next_index_in_cycle"]
self._next_index_in_datasets = state["next_index_in_datasets"]
self._iterators_in_use_indices = state["iterators_in_use_indices"]
self._future_states = state.get("future_states", {})

def _get_next_index(self) -> int:
if len(self._datasets) == 1:
Expand Down Expand Up @@ -307,6 +324,16 @@ def __str__(self) -> str:
f" cycle_length={self._cycle_length})"
)

def _get_iterator_start_state(self, index: int) -> dict[str, Any]:
it = _add_prefetch_and_make_iterator(
self._datasets[index],
weakref.ref(self),
start_prefetch=False,
)
state = it.get_state()
del it
return state


def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
Expand Down
31 changes: 31 additions & 0 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ def test_set_next_index_with_multiple_datasets(self):
):
dataset.set_next_index(ds_iter, 0)

def test_future_states(self):
datasets = [
dataset.MapDataset.source([1, 2]).to_iter_dataset(),
dataset.MapDataset.source([3, 4]).to_iter_dataset(),
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=1)
ds_iter = ds.__iter__()

# Initialize the first iterator and get state.
state = ds_iter.get_state()

# Get state for the second dataset iterator after advancing it.
ds1_iter = datasets[1].__iter__()
next(ds1_iter) # Consumes 3
ds1_state = ds1_iter.get_state()

# Inject future state for the second dataset (index 1).
state["future_states"] = {1: ds1_state}

ds_iter.set_state(state)

# Consume elements.
# It should yield elements from the first dataset (1, 2) and then
# yield elements from the second dataset starting from the future state (4).
self.assertEqual(next(ds_iter), 1)
self.assertEqual(next(ds_iter), 2)
self.assertEqual(next(ds_iter), 4)

with self.assertRaises(StopIteration):
next(ds_iter)


if __name__ == "__main__":
absltest.main()
Loading