From be25b4dfa6deb062682324f5c7ef249d04c74902 Mon Sep 17 00:00:00 2001 From: Phani Aenugula Date: Tue, 3 Feb 2026 09:21:46 -0800 Subject: [PATCH] Add `ElasticIterDatasetIterator` to handle scaling up and down between checkpoints. * Allows users to keep their pipelines elastic and restore from a checkpoint with variable amount of shards * Dataset and Iterator class in one to allow changing sharding configuration * Add dedicated checkpoint handler for saving/restoring from Orbax Limitations * Does not guarantee determinism between scaling * The limit of parallelism is the number of shards PiperOrigin-RevId: 864910611 --- CHANGELOG.md | 1 + grain/_src/python/checkpoint/BUILD | 25 ++ .../python/checkpoint/elastic_checkpoint.py | 139 ++++++++ .../checkpoint/elastic_checkpoint_test.py | 122 +++++++ grain/_src/python/checkpoint/handler.py | 10 +- grain/_src/python/dataset/BUILD | 1 + grain/_src/python/dataset/elastic_iterator.py | 300 ++++++++++++++++-- .../python/dataset/elastic_iterator_test.py | 170 ++++++++-- .../dataset/transformations/interleave.py | 40 ++- grain/experimental.py | 5 +- 10 files changed, 738 insertions(+), 75 deletions(-) create mode 100644 grain/_src/python/checkpoint/elastic_checkpoint.py create mode 100644 grain/_src/python/checkpoint/elastic_checkpoint_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e3b3fcccc..30f438e1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change * Exposes `SharedMemoryArrayMetadata` in a public API as a metadata descriptor for `SharedMemoryArray`. * `ParquetIterDataset` can read from multiple string paths interleaving reads. + * Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/checkpoint/BUILD b/grain/_src/python/checkpoint/BUILD index a160c9c46..efd871ab7 100644 --- a/grain/_src/python/checkpoint/BUILD +++ b/grain/_src/python/checkpoint/BUILD @@ -17,9 +17,34 @@ py_library( srcs = ["handler.py"], srcs_version = "PY3", deps = [ + ":elastic_checkpoint", "//grain/_src/core:sharding", "//grain/_src/python:data_loader", "//grain/_src/python/dataset", + "//grain/_src/python/dataset:elastic_iterator", + "@pypi//etils:pkg", + ], +) + +py_library( + name = "elastic_checkpoint", + srcs = ["elastic_checkpoint.py"], + srcs_version = "PY3", + deps = [ + "//grain/_src/python/dataset:elastic_iterator", + "@pypi//etils:pkg", + ], +) + +py_test( + name = "elastic_checkpoint_test", + srcs = ["elastic_checkpoint_test.py"], + srcs_version = "PY3", + deps = [ + ":elastic_checkpoint", + "//grain/_src/core:sharding", + "//grain/_src/python/dataset:elastic_iterator", + "@abseil-py//absl/testing:absltest", "@pypi//etils:pkg", ], ) diff --git a/grain/_src/python/checkpoint/elastic_checkpoint.py b/grain/_src/python/checkpoint/elastic_checkpoint.py new file mode 100644 index 000000000..3e5ad7a53 --- /dev/null +++ b/grain/_src/python/checkpoint/elastic_checkpoint.py @@ -0,0 +1,139 @@ +"""This module provides checkpointing logic for ElasticIterDatasetIterator.""" + +import dataclasses +import json +from typing import Any, Optional, Sequence + +from etils import epath +from grain._src.python.dataset import elastic_iterator + + +def _find_shard_file( + directory: epath.Path, + shard_index: int, + total_num_shards: int, +) -> epath.Path: + """Finds all files matching 'shard_state_*.json' in the directory.""" + all_files = list(directory.iterdir()) + pattern = f"shard_state_{shard_index}-of-{total_num_shards}.json" + found_files = [f for f in all_files if f.name.endswith(pattern)] + if not found_files: + raise ValueError( + f"No shard state files found in {directory} for shard {shard_index} of" + f" {total_num_shards}" + ) + if len(found_files) > 1: + raise ValueError( + f"Multiple shard state files found in {directory} for shard" + f" {shard_index} of {total_num_shards}" + ) + return found_files[0] + + +def save_elastic_iterator( + directory: epath.Path, + item: elastic_iterator.ElasticIterDatasetIterator, +): + """Saves the given iterator to the checkpoint in `directory`.""" + state = item.get_state() + ds_iterator_states = state["ds_iterator_states"] + num_dataset_shards = item.num_dataset_shards + for idx, host_iterator_state in ds_iterator_states.items(): + shard_state = json.dumps(host_iterator_state, indent=4) + filename = directory / f"shard_state_{idx}-of-{num_dataset_shards}.json" + filename.write_text(shard_state) + + +def restore_elastic_iterator( + directory: epath.Path, + item: elastic_iterator.ElasticIterDatasetIterator, +): + """Restores the given iterator from the checkpoint in `directory`.""" + num_dataset_shards = item.num_dataset_shards + shard_index = item.shard_options.shard_index + shard_count = item.shard_options.shard_count + iterator_states = {} + while shard_index < num_dataset_shards: + filename = _find_shard_file(directory, shard_index, num_dataset_shards) + state = filename.read_text() + state = json.loads(state) + iterator_states[shard_index] = state + shard_index += shard_count + item.set_state({"ds_iterator_states": iterator_states}) + + +class ElasticCheckpointHandler: + """Orbax CheckpointHandler for PyGrain iterators.""" + + def save( + self, + directory: epath.Path, + item: Optional[ + elastic_iterator.ElasticIterDatasetIterator + | Sequence[elastic_iterator.ElasticIterDatasetIterator] + ] = None, + args: Any = None, + ): + """Saves the given iterator to the checkpoint in `directory`.""" + items = item or args.item + if isinstance(items, elastic_iterator.ElasticIterDatasetIterator): + items = [items] + for iterator in items: + save_elastic_iterator(directory, iterator) + + def restore( + self, + directory: epath.Path, + item: Optional[ + elastic_iterator.ElasticIterDatasetIterator + | Sequence[elastic_iterator.ElasticIterDatasetIterator] + ] = None, + args: Any = None, + ) -> Any: + """Restores the given iterator from the checkpoint in `directory`.""" + items = item or args.item + if isinstance(items, elastic_iterator.ElasticIterDatasetIterator): + items = [items] + for iterator in items: + restore_elastic_iterator(directory, iterator) + return items + + # Required by interface but not supported by PyGrain checkpoints. + def structure(self, directory: epath.Path) -> Any: + del directory + return None + + # Required by interface. + + def metadata(self, directory: epath.Path) -> Optional[Any]: + del directory + return None + + def finalize(self, directory: epath.Path): + pass + + def close(self): + pass + + @classmethod + def typestr(cls): + return f"{cls.__module__}.{cls.__qualname__}" + + +try: + # Register the handler to be used with the new checkpointing API if Orbax is + # present. + import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error + + @ocp.args.register_with_handler(ElasticCheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types + @dataclasses.dataclass + class ElasticCheckpointSave(ocp.args.CheckpointArgs): + item: Any + + @ocp.args.register_with_handler(ElasticCheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types + @dataclasses.dataclass + class ElasticCheckpointRestore(ocp.args.CheckpointArgs): + item: Any + +except (ImportError, TypeError, AttributeError): + pass diff --git a/grain/_src/python/checkpoint/elastic_checkpoint_test.py b/grain/_src/python/checkpoint/elastic_checkpoint_test.py new file mode 100644 index 000000000..334c1fdae --- /dev/null +++ b/grain/_src/python/checkpoint/elastic_checkpoint_test.py @@ -0,0 +1,122 @@ +"""Tests for elastic checkpoint.""" + +import json + +from etils import epath +from grain._src.core import sharding +from grain._src.python.checkpoint import elastic_checkpoint +from grain._src.python.dataset import elastic_iterator + +from absl.testing import absltest + + +class MockElasticIterDatasetIterator( + elastic_iterator.ElasticIterDatasetIterator +): + + def __init__(self, shard_options, total_num_shards, states=None): + self._shard_options = shard_options + self._num_dataset_shards = total_num_shards + self._states = states if states is not None else {} + self.updated_states = {} + + def get_state(self): + return { + "ds_iterator_states": self._states, + } + + def set_state(self, state): + for k, v in state["ds_iterator_states"].items(): + self.updated_states[k] = v + + +class ElasticCheckpointTest(absltest.TestCase): + + def test_save_and_restore_elastic_iterator(self): + temp_dir = epath.Path(self.create_tempdir().full_path) + shard_options = sharding.ShardOptions(shard_index=0, shard_count=1) + states = { + 0: {"val": 0}, + 1: {"val": 1}, + } + iterator = MockElasticIterDatasetIterator( + shard_options=shard_options, total_num_shards=2, states=states + ) + elastic_checkpoint.save_elastic_iterator(temp_dir, iterator) + + file0 = temp_dir / "shard_state_0-of-2.json" + self.assertTrue(file0.exists()) + self.assertEqual( + file0.read_text(), + json.dumps({"val": 0}, indent=4), + ) + file1 = temp_dir / "shard_state_1-of-2.json" + self.assertTrue(file1.exists()) + self.assertEqual( + file1.read_text(), + json.dumps({"val": 1}, indent=4), + ) + + iterator_to_restore = MockElasticIterDatasetIterator( + shard_options=shard_options, total_num_shards=2 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore) + self.assertEqual( + iterator_to_restore.updated_states, + { + 0: {"val": 0}, + 1: {"val": 1}, + }, + ) + + def test_restore_elastic_iterator_with_multiple_processes(self): + temp_dir = epath.Path(self.create_tempdir().full_path) + # Process 0 + shard_options_0 = sharding.ShardOptions(shard_index=0, shard_count=2) + states = { + 0: {"val": 0}, + 1: {"val": 1}, + 2: {"val": 2}, + } + iterator_0 = MockElasticIterDatasetIterator( + shard_options=shard_options_0, total_num_shards=3, states=states + ) + # In reality save_elastic_iterator will be called in each process, but + # get_state() should return all states, so we only need to call it once + # to create checkpoint files. + elastic_checkpoint.save_elastic_iterator(temp_dir, iterator_0) + + # Check files are written + self.assertTrue((temp_dir / "shard_state_0-of-3.json").exists()) + self.assertTrue((temp_dir / "shard_state_1-of-3.json").exists()) + self.assertTrue((temp_dir / "shard_state_2-of-3.json").exists()) + + # Restore for process 0, responsible for shards 0 and 2. + iterator_to_restore_0 = MockElasticIterDatasetIterator( + shard_options=shard_options_0, total_num_shards=3 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_0) + self.assertEqual( + iterator_to_restore_0.updated_states, + { + 0: {"val": 0}, + 2: {"val": 2}, + }, + ) + + # Restore for process 1, responsible for shard 1. + shard_options_1 = sharding.ShardOptions(shard_index=1, shard_count=2) + iterator_to_restore_1 = MockElasticIterDatasetIterator( + shard_options=shard_options_1, total_num_shards=3 + ) + elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_1) + self.assertEqual( + iterator_to_restore_1.updated_states, + { + 1: {"val": 1}, + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/grain/_src/python/checkpoint/handler.py b/grain/_src/python/checkpoint/handler.py index 8eeb2daed..997374389 100644 --- a/grain/_src/python/checkpoint/handler.py +++ b/grain/_src/python/checkpoint/handler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module provides a PyGrain CheckpointHandler for integration with Orbax.""" + import dataclasses import json from typing import Any, Optional, TypeVar @@ -19,7 +20,9 @@ from etils import epath from grain._src.core import sharding from grain._src.python import data_loader +from grain._src.python.checkpoint import elastic_checkpoint from grain._src.python.dataset import dataset +from grain._src.python.dataset import elastic_iterator IteratorType = TypeVar( "IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator @@ -41,6 +44,9 @@ def save( """Saves the given iterator to the checkpoint in `directory`.""" item = item or args.item # pytype:disable=attribute-error if isinstance(item, dataset.DatasetIterator): + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + elastic_checkpoint.save_elastic_iterator(directory, item) + return state = json.dumps(item.get_state(), indent=4) else: state = item.get_state().decode() @@ -56,6 +62,9 @@ def restore( ) -> IteratorType: """Restores the given iterator from the checkpoint in `directory`.""" item = item or args.item # pytype:disable=attribute-error + if isinstance(item, elastic_iterator.ElasticIterDatasetIterator): + elastic_checkpoint.restore_elastic_iterator(directory, item) + return item process_index, process_count = sharding.get_process_index_and_count() filename = directory / f"process_{process_index}-of-{process_count}.json" if not filename.exists(): @@ -106,6 +115,5 @@ class CheckpointSave(ocp.args.CheckpointArgs): class CheckpointRestore(ocp.args.CheckpointArgs): item: Any - except (ImportError, TypeError, AttributeError): pass diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8fcd0bbcd..307b6c70d 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -181,6 +181,7 @@ py_test( ":elastic_iterator", "//grain/_src/core:sharding", "//grain/_src/python:options", + "//grain/_src/python/checkpoint:elastic_checkpoint", "//grain/_src/python/testing:experimental", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", diff --git a/grain/_src/python/dataset/elastic_iterator.py b/grain/_src/python/dataset/elastic_iterator.py index a7140c172..4146733b5 100644 --- a/grain/_src/python/dataset/elastic_iterator.py +++ b/grain/_src/python/dataset/elastic_iterator.py @@ -13,8 +13,9 @@ # limitations under the License. """Iterator supporting changes in the number of hosts (dataset shards).""" +import copy import functools -from typing import Any +from typing import Any, TypeVar, cast from grain._src.core import sharding from grain._src.python import options @@ -22,53 +23,205 @@ from grain._src.python.dataset.transformations import ( filter as filter_dataset, ) +from grain._src.python.dataset.transformations import interleave + +T = TypeVar("T") _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index" -class ElasticIterator(dataset.DatasetIterator): - """Iterator supporting recovery from a checkpoint after changes in sharding. +class ElasticIterDatasetIterator(dataset.DatasetIterator): + """Elastic iterator for InterleaveIterDatasets.""" - The input dataset is expected to be unbatched and unsharded. In order to - provide elasticity guarantee this iterator includes both, batching and - sharding. The iterator supports elastic re-configuration by having each - shard produce the same exact checkpoint (while producing different data) as - long as they are advanced the same number of steps. + def __init__( + self, + interleave_dataset: interleave.InterleaveIterDataset, + shard_options: sharding.ShardOptions, + global_batch_size: int, + drop_remainder: bool, + read_options: options.ReadOptions, + multiprocessing_options: options.MultiprocessingOptions | None = None, + ): + # We must set the slice on the original dataset so that the interleave + # iterator is created with the correct (sliced) datasets. + self._ds: interleave.InterleaveIterDataset = copy.deepcopy( + interleave_dataset + ) + self._num_dataset_shards = len(interleave_dataset._datasets) # pylint: disable=protected-access + self._ds.set_slice( + slice(shard_options.shard_index, None, shard_options.shard_count) + ) + self._num_host_shards = len(self._ds._datasets) # pylint: disable=protected-access + self._cycle_length = self._ds._cycle_length # pylint: disable=protected-access - State of any shard can be used to restore the state of all of the shards after - changes in sharding and global batch size. + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._shard_options = shard_options + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options - This iterator explicitly disallows many-to-one transformations without - a fixed ratio, like `filter` and generic `IterDataset` transformations. - """ + # These will be initialized when the iterator is created. + self._iterator_started = False + self._is_batched = False + self._closed = False + + @property + def num_dataset_shards(self) -> int: + return self._num_dataset_shards + + @property + def num_host_shards(self) -> int: + return self._num_host_shards + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + @functools.cached_property + def _iterator(self) -> dataset.DatasetIterator: + ds = self._ds + self._iterator_started = True + if self._global_batch_size > 0: + ds = ds.batch( + self._global_batch_size, drop_remainder=self._drop_remainder + ) + self._is_batched = True + if self._multiprocessing_options: + self._prefetch_wrapped = True + # ds = ds.mp_prefetch(self._multiprocessing_options) + return ds.__iter__() + + def __next__(self) -> Any: + return next(self._iterator) + + def get_state(self): + state = self._iterator.get_state() + ds_iterator_states = {} + + indices = state["iterators_in_use_indices"] + states = state["iterators_in_use_states"] + exhausted = state["exhausted"] + next_index_in_datasets = state["next_index_in_datasets"] + if self._is_batched: + interleave_iter = cast(interleave.InterleaveDatasetIterator, self._iterator._parent) # pylint: disable=protected-access + else: + interleave_iter = cast( + interleave.InterleaveDatasetIterator, self._iterator + ) + for i in range(self._num_host_shards): + shard_index = ( + i * self._shard_options.shard_count + self._shard_options.shard_index + ) + # If the current shard index is greater than or equal to the next + # index in datasets, it means the current shard has not yet started + # to be iterated on. + if i >= next_index_in_datasets: + ds_iterator_states[shard_index] = { + "exhausted": 0, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + elif i not in indices: + # These shards are exhausted but should still create a state to maintain + # static state spec shapes. + ds_iterator_states[shard_index] = { + "exhausted": 1, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + + for index, state, is_exhausted in zip(indices, states, exhausted): + # These shards are currently being iterated on. + shard_index = ( + index * self._shard_options.shard_count + + self._shard_options.shard_index + ) + ds_iterator_states[shard_index] = { + "exhausted": is_exhausted, + "state": state, + } + + return { + "ds_iterator_states": ds_iterator_states, + } + + def set_state(self, state): + """Sets state by reconstructing the state for the underlying interleave.""" + ds_iterator_states = state["ds_iterator_states"] + active_states = [] + + for shard_index, shard_state in sorted(ds_iterator_states.items()): + # Check if this state belongs to the current shard. + if ( + shard_index - self._shard_options.shard_index + ) % self._shard_options.shard_count == 0: + slice_index = shard_index // self._shard_options.shard_count + if not shard_state["exhausted"]: + active_states.append((slice_index, shard_state["state"])) + + iterators_in_use_indices = [] + iterators_in_use_states = [] + exhausted = [] + count = 0 + future_states = {} + for ind, state in active_states: + if count < self._cycle_length: + iterators_in_use_indices.append(ind) + iterators_in_use_states.append(state) + exhausted.append(0) + count += 1 + elif state: + # If a state exists for this iterator add it to future states + future_states[ind] = state + next_index_in_datasets = max(iterators_in_use_indices) + 1 + while count < self._cycle_length: + iterators_in_use_indices.append(next_index_in_datasets) + iterators_in_use_states.append(None) + exhausted.append(1) + count += 1 + + new_state = { + "next_index_in_cycle": 0, + "next_index_in_datasets": next_index_in_datasets, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": exhausted, + "future_states": future_states, + } + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() + self.__dict__.pop("_iterator", None) + self._iterator.set_state(new_state) + + +class _ElasticMapDatasetIterator(dataset.DatasetIterator): + """Iterator for MapDatasets in ElasticIterator.""" def __init__( self, ds: dataset.MapDataset, - global_batch_size: int, shard_options: sharding.ShardOptions, - *, + global_batch_size: int, + drop_remainder: bool, read_options: options.ReadOptions = options.ReadOptions(), multiprocessing_options: options.MultiprocessingOptions | None = None, ): - super().__init__() - to_check = [ds] - while to_check: - next_ds = to_check.pop() - if isinstance(next_ds, filter_dataset.FilterMapDataset): - raise ValueError( - "ElasticIterator does not support `filter` transformation." - ) - to_check.extend(next_ds.parents) self._ds = ds - self._global_batch_size = global_batch_size self._shard_options = shard_options - self._global_next_index = 0 + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder self._read_options = read_options self._multiprocessing_options = multiprocessing_options + self._global_next_index = 0 + self._closed = False @functools.cached_property - def _iterator(self) -> dataset.DatasetIterator: + def _iterator(self): ds = self._ds[ self._global_next_index + self._shard_options.shard_index :: self._shard_options.shard_count @@ -83,13 +236,10 @@ def _iterator(self) -> dataset.DatasetIterator: ) ds = ds.batch(host_batch_size, drop_remainder=True) ds = ds.to_iter_dataset(read_options=self._read_options) - if self._multiprocessing_options is not None: + if self._multiprocessing_options: ds = ds.mp_prefetch(self._multiprocessing_options) return ds.__iter__() - def __iter__(self) -> dataset.DatasetIterator: - return self - def __next__(self) -> Any: result = next(self._iterator) self._global_next_index += self._global_batch_size @@ -100,7 +250,91 @@ def get_state(self) -> dict[str, Any]: _GLOBAL_NEXT_INDEX_STATE_KEY: self._global_next_index, } - def set_state(self, state: dict[str, Any]): + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + def set_state(self, state): self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY] - # Reset the iterator if it was already created. + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() self.__dict__.pop("_iterator", None) + + +class ElasticIterDataset(dataset.IterDataset): + """Iterator supporting recovery from a checkpoint after changes in sharding. + + The input dataset is expected to be unbatched and unsharded. In order to + provide elasticity guarantee this iterator includes both, batching and + sharding. The iterator supports elastic re-configuration by having each + shard produce the same exact checkpoint (while producing different data) as + long as they are advanced the same number of steps. + + State of any shard can be used to restore the state of all of the shards after + changes in sharding and global batch size. + + This iterator explicitly disallows many-to-one transformations without + a fixed ratio, like `filter` and generic `IterDataset` transformations. + """ + + def __init__( + self, + parents: dataset.MapDataset | dataset.IterDataset, + shard_options: sharding.ShardOptions, + *, + read_options: options.ReadOptions = options.ReadOptions(), + multiprocessing_options: options.MultiprocessingOptions | None = None, + drop_remainder: bool = False, + global_batch_size: int = 0, + ): + super().__init__() + self.num_dataset_shards = 0 + self._ds = parents + if isinstance(parents, dataset.IterDataset): + if not isinstance(parents, interleave.InterleaveIterDataset): + raise ValueError( + "ElasticIterator only supports sliceable InterleaveIterDataset" + ) + self.num_dataset_shards = len(parents._datasets) # pylint: disable=protected-access + else: + to_check = [parents] + while to_check: + next_ds = to_check.pop() + if isinstance(next_ds, filter_dataset.FilterMapDataset): + raise ValueError( + "ElasticIterDataset does not support `filter` transformation." + ) + to_check.extend(next_ds.parents) + + self._shard_options = shard_options + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def __iter__(self) -> dataset.DatasetIterator: + if isinstance(self._ds, interleave.InterleaveIterDataset): + return ElasticIterDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) + else: + return _ElasticMapDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) diff --git a/grain/_src/python/dataset/elastic_iterator_test.py b/grain/_src/python/dataset/elastic_iterator_test.py index 1c4261f09..d2af41e96 100644 --- a/grain/_src/python/dataset/elastic_iterator_test.py +++ b/grain/_src/python/dataset/elastic_iterator_test.py @@ -11,21 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import platform from absl.testing import absltest from absl.testing import parameterized from grain._src.core import sharding import multiprocessing as mp from grain._src.python import options +from grain._src.python.checkpoint import elastic_checkpoint from grain._src.python.dataset import dataset from grain._src.python.dataset import elastic_iterator +from grain._src.python.dataset.transformations import interleave import grain._src.python.testing.experimental as test_util import numpy as np @absltest.skipIf(platform.system() == "Windows", "Skipped under bazel.") -class ElasticIteratorTest(parameterized.TestCase): +class ElasticMapDataset(parameterized.TestCase): @parameterized.parameters( dict( @@ -58,12 +59,12 @@ def test_produces_correct_elements( ): ds = dataset.MapDataset.range(10).map(lambda x: x + 1) actual = list( - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - global_batch_size, shard_options, + global_batch_size=global_batch_size, multiprocessing_options=multiprocessing_options, - ) + ).__iter__() ) np.testing.assert_equal( actual, expected, err_msg=f"actual: {actual}, expected: {expected}" @@ -71,17 +72,19 @@ def test_produces_correct_elements( def test_checkpointing(self): ds = dataset.MapDataset.range(100).map(lambda x: x * 2).shuffle(42) - it = elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + it = elastic_iterator.ElasticIterDataset( + ds, sharding.NoSharding(), global_batch_size=5 + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def test_checkpointing_with_multiprocessing(self): ds = dataset.MapDataset.range(5).map(lambda x: x * 2).shuffle(42) - it = elastic_iterator.ElasticIterator( + it = elastic_iterator.ElasticIterDataset( ds, - 2, sharding.NoSharding(), + global_batch_size=2, multiprocessing_options=options.MultiprocessingOptions(2), - ) + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def _elastic_resize_test_base( @@ -116,22 +119,22 @@ def test_elastic_downsize(self): # Create iterators over 32 hosts with per-host batch size 2. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 64, sharding.ShardOptions(shard_index=i, shard_count=32), - ) + global_batch_size=64, + ).__iter__() for i in range(32) ] # Create new iterators over 16 hosts with per-host batch size 2. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 32, sharding.ShardOptions(shard_index=i, shard_count=16), - ) + global_batch_size=32, + ).__iter__() for i in range(16) ] @@ -147,28 +150,28 @@ def test_elastic_downsize_with_multiprocessing(self): # Create iterators over 8 hosts with per-host batch size 32. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 256, sharding.ShardOptions(shard_index=i, shard_count=8), + global_batch_size=256, multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(8) ] # Create new iterators over 4 hosts with per-host batch size 32. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 128, sharding.ShardOptions(shard_index=i, shard_count=4), + global_batch_size=128, multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] @@ -184,22 +187,22 @@ def test_elastic_upsize(self): # Create iterators over 8 hosts with per-host batch size 16. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 128, sharding.ShardOptions(shard_index=i, shard_count=8), - ) + global_batch_size=128, + ).__iter__() for i in range(8) ] # Create new iterators over 64 hosts with per-host batch size 2. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 128, sharding.ShardOptions(shard_index=i, shard_count=64), - ) + global_batch_size=128, + ).__iter__() for i in range(64) ] @@ -215,28 +218,28 @@ def test_elastic_upsize_with_multiprocessing(self): # Create iterators over 4 hosts with per-host batch size 16. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 64, sharding.ShardOptions(shard_index=i, shard_count=4), + global_batch_size=64, multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] # Create new iterators over 6 hosts with per-host batch size 16. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, - 96, sharding.ShardOptions(shard_index=i, shard_count=6), + global_batch_size=96, multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(6) ] @@ -249,9 +252,106 @@ def test_filter_raises_error(self): ds = ds.filter(lambda x: x % 2 == 0) with self.assertRaisesRegex( ValueError, - "ElasticIterator does not support `filter` transformation.", + "ElasticIterDataset does not support `filter` transformation.", ): - elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + elastic_iterator.ElasticIterDataset( + ds, sharding.NoSharding(), global_batch_size=5 + ).__iter__() + + +class ElasticIterDataset(parameterized.TestCase): + + @parameterized.parameters( + dict( + shard_options=sharding.NoSharding(), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=1), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.NoSharding(), + global_batch_size=3, + # Data is interleaved with cycle length 3. + expected=[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]], + ), + ) + def test_no_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 3 shards, each with 5 elements. + dataset.MapDataset.range(i * 5, (i + 1) * 5).to_iter_dataset() + for i in range(3) + ] + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + @parameterized.parameters( + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=1, + expected=[0, 2, 4, 6, 8], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=1, shard_count=2), + global_batch_size=1, + expected=[1, 3, 5, 7, 9], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=2, + expected=[[0, 2], [4, 6], [8]], + ), + ) + def test_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 4 shards, 0: [0, 4, 8], 1: [1, 5, 9], 2: [2, 6], 3: [3, 7] + dataset.MapDataset.range(i, 10, 4).to_iter_dataset() + for i in range(4) + ] + # Use cycle_length=2 as in the original test. + interleave_ds = interleave.InterleaveIterDataset(ds, cycle_length=2) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + def test_checkpointing_no_change(self): + ds = [ + dataset.MapDataset.range(i, 100, 25).to_iter_dataset() + for i in range(25) + ] + global_batch_size = 2 + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=sharding.ShardOptions(shard_index=2, shard_count=4), + global_batch_size=global_batch_size, + ).__iter__() + test_util.assert_equal_output_after_checkpoint(it) if __name__ == "__main__": diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9a387ae07..ac7827d1a 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -33,7 +33,10 @@ class InterleaveDatasetIterator(dataset.DatasetIterator[T]): def __init__( self, - datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], + datasets: ( + Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]] + | Sequence[dataset.DatasetIterator[T]] + ), cycle_length: int, num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, @@ -52,11 +55,11 @@ def __init__( functools.partial( _add_prefetch_and_make_iterator, # We use weakref to avoid a circular reference. The - # _InterleaveDatasetIterator holds a reference to the + # InterleaveDatasetIterator holds a reference to the # prefetch iterator in `self._prefetch_ds_iter`. # The call to `_add_prefetch_and_make_iterator` (and the # partial object) would hold a reference to the - # _InterleaveDatasetIterator. This would prolong its lifetime + # InterleaveDatasetIterator. This would prolong its lifetime # leading to increased resource usage. interleave_iterator=weakref.ref(self), start_prefetch=True, @@ -86,6 +89,8 @@ def __init__( self._exhausted_iterators: list[ tuple[int, dataset.DatasetIterator[T]] | None ] = [None] * self._cycle_length + # Future states used for elastic iterators + self._future_states: dict[int, Any] = {} @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -142,6 +147,16 @@ def __next__(self) -> T: self._iterators_in_use_indices[self._next_index_in_cycle] = ( self._next_index_in_datasets ) + # For elastic iterators, we might have a future state saved for this + # dataset iterator from which to resume from. + if ( + self._next_index_in_datasets in self._future_states + and self._iterators_in_use[self._next_index_in_cycle] + ): + future_state = self._future_states.pop(self._next_index_in_datasets) + self._iterators_in_use[self._next_index_in_cycle].set_state( + future_state + ) self._next_index_in_datasets += 1 elif not any(self._iterators_in_use): raise StopIteration @@ -182,13 +197,15 @@ def get_state(self): int(self._exhausted_iterator_state[i] is not None) for i in range(self._cycle_length) ] - return { + state = { "next_index_in_cycle": self._next_index_in_cycle, "next_index_in_datasets": self._next_index_in_datasets, "iterators_in_use_indices": self._iterators_in_use_indices.copy(), "iterators_in_use_states": iterators_in_use_states, "exhausted": exhausted, + "future_states": self._future_states, } + return state def set_state(self, state): exhausted = state["exhausted"] @@ -220,7 +237,9 @@ def set_state(self, state): interleave_iterator=weakref.ref(self), start_prefetch=False, ) - iterator.set_state(it_state) + # Only update the iterator state if it is given + if it_state: + iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator else: self._exhausted_iterator_state[index_in_cycle] = it_state @@ -232,6 +251,7 @@ def set_state(self, state): self._next_index_in_cycle = state["next_index_in_cycle"] self._next_index_in_datasets = state["next_index_in_datasets"] self._iterators_in_use_indices = state["iterators_in_use_indices"] + self._future_states = state.get("future_states", {}) def _get_next_index(self) -> int: if len(self._datasets) == 1: @@ -307,6 +327,16 @@ def __str__(self) -> str: f" cycle_length={self._cycle_length})" ) + def _get_iterator_start_state(self, index: int) -> dict[str, Any]: + it = _add_prefetch_and_make_iterator( + self._datasets[index], + weakref.ref(self), + start_prefetch=False, + ) + state = it.get_state() + del it + return state + def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], diff --git a/grain/experimental.py b/grain/experimental.py index 297c40b56..7537b4aa8 100644 --- a/grain/experimental.py +++ b/grain/experimental.py @@ -32,7 +32,10 @@ apply_transformations, WithOptionsIterDataset, ) -from grain._src.python.dataset.elastic_iterator import ElasticIterator +from grain._src.python.dataset.elastic_iterator import ( + ElasticIterDatasetIterator, + ElasticIterDataset, +) from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset from grain._src.python.dataset.sources.tfrecord_dataset import TFRecordIterDataset