diff --git a/CHANGELOG.md b/CHANGELOG.md index 4550ad39b..e3b3fcccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change from a checkpoint with Orbax. * Exposes `SharedMemoryArrayMetadata` in a public API as a metadata descriptor for `SharedMemoryArray`. + * `ParquetIterDataset` can read from multiple string paths interleaving reads. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/dataset/sources/parquet_dataset.py b/grain/_src/python/dataset/sources/parquet_dataset.py index d211fe29e..4d3b565b5 100644 --- a/grain/_src/python/dataset/sources/parquet_dataset.py +++ b/grain/_src/python/dataset/sources/parquet_dataset.py @@ -13,10 +13,11 @@ # limitations under the License. """Provides an `IterDataset` for Parquet file format.""" -from typing import TypeVar +from typing import Sequence, TypeVar from etils import epy from grain._src.python.dataset import dataset +from grain._src.python.dataset.transformations import interleave # lazy import for pyarrow @@ -26,6 +27,9 @@ T = TypeVar("T") +ParquetDataSourcePath = str | Sequence[str] +_CYCLE_LENGTH = 16 + class _ParquetDatasetIterator(dataset.DatasetIterator[T]): """A DatasetIterator for Parquet file format.""" @@ -84,16 +88,36 @@ def set_state(self, state): class ParquetIterDataset(dataset.IterDataset[T]): """An IterDataset for a parquet format file.""" - def __init__(self, path: str, **read_kwargs): + def __init__( + self, + path: ParquetDataSourcePath, + **read_kwargs, + ): """Initializes ParquetIterDataset. Args: - path: A path to a parquet format file. + path: A path or sequence of paths to parquet format files. If multiple + paths are provided, they are interleaved with at most 16 files read in + parallel. **read_kwargs: Keyword arguments to pass to pyarrow.parquet.ParquetFile. """ super().__init__() - self._path = path + if isinstance(path, (str, bytes)): + self._paths = [path] + else: + self._paths = list(path) self._read_kwargs = read_kwargs - def __iter__(self) -> _ParquetDatasetIterator[T]: - return _ParquetDatasetIterator(self._path, **self._read_kwargs) + def __iter__(self) -> dataset.DatasetIterator[T]: + if len(self._paths) == 1: + return _ParquetDatasetIterator(self._paths[0], **self._read_kwargs) + + datasets = [ParquetIterDataset(p, **self._read_kwargs) for p in self._paths] + delegate = interleave.InterleaveIterDataset( + datasets, cycle_length=_CYCLE_LENGTH + ) + return delegate.__iter__() + + def set_slice(self, sl: slice, sequential_slice: bool = False): + del sequential_slice + self._paths = self._paths[sl] diff --git a/grain/_src/python/dataset/sources/parquet_dataset_test.py b/grain/_src/python/dataset/sources/parquet_dataset_test.py index 7f70cbcb0..8dc40d456 100644 --- a/grain/_src/python/dataset/sources/parquet_dataset_test.py +++ b/grain/_src/python/dataset/sources/parquet_dataset_test.py @@ -82,10 +82,44 @@ def test_read_row_group(self): records = list(dataset) self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]]) + def test_read_multiple_files(self): + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + records = list(dataset) + self.assertSequenceEqual(records, [{"text": x} for x in INTERLEAVED_TEXT]) + + def test_sharding_multi_file(self): + # Test sharding across files + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + dataset.set_slice(slice(0, 1)) # Only first file + records = list(dataset) + self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]]) + + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + dataset.set_slice(slice(1, 2)) # Only second file + records = list(dataset) + self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[1]]) + + def test_checkpointing_multi_file(self): + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + grain.experimental.assert_equal_output_after_checkpoint(dataset) + def test_checkpointing(self): dataset = parquet_dataset.ParquetIterDataset(self.filenames[0]) grain.experimental.assert_equal_output_after_checkpoint(dataset) + def test_set_slice(self): + # Test slice first file + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + dataset.set_slice(slice(0, 1)) + records = list(dataset) + self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]]) + + # Test slice second file + dataset = parquet_dataset.ParquetIterDataset(self.filenames) + dataset.set_slice(slice(1, 2)) + records = list(dataset) + self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[1]]) + def test_sharded_files_and_interleaved_dataset(self): dataset = grain.MapDataset.source(self.filenames) dataset = dataset.map(parquet_dataset.ParquetIterDataset)