Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
from a checkpoint with Orbax.
* Exposes `SharedMemoryArrayMetadata` in a public API as a metadata descriptor
for `SharedMemoryArray`.
* `ParquetIterDataset` can read from multiple string paths interleaving reads.

* Breaking changes:
* Custom implementations of `RandomAccessDataSource` should accept `int`
Expand Down
36 changes: 30 additions & 6 deletions grain/_src/python/dataset/sources/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
"""Provides an `IterDataset` for Parquet file format."""

from typing import TypeVar
from typing import Sequence, TypeVar

from etils import epy
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave


# lazy import for pyarrow
Expand All @@ -26,6 +27,9 @@

T = TypeVar("T")

ParquetDataSourcePath = str | Sequence[str]
_CYCLE_LENGTH = 16


class _ParquetDatasetIterator(dataset.DatasetIterator[T]):
"""A DatasetIterator for Parquet file format."""
Expand Down Expand Up @@ -84,16 +88,36 @@ def set_state(self, state):
class ParquetIterDataset(dataset.IterDataset[T]):
"""An IterDataset for a parquet format file."""

def __init__(self, path: str, **read_kwargs):
def __init__(
self,
path: ParquetDataSourcePath,
**read_kwargs,
):
"""Initializes ParquetIterDataset.

Args:
path: A path to a parquet format file.
path: A path or sequence of paths to parquet format files. If multiple
paths are provided, they are interleaved with at most 16 files read in
parallel.
**read_kwargs: Keyword arguments to pass to pyarrow.parquet.ParquetFile.
"""
super().__init__()
self._path = path
if isinstance(path, (str, bytes)):
self._paths = [path]
else:
self._paths = list(path)
self._read_kwargs = read_kwargs

def __iter__(self) -> _ParquetDatasetIterator[T]:
return _ParquetDatasetIterator(self._path, **self._read_kwargs)
def __iter__(self) -> dataset.DatasetIterator[T]:
if len(self._paths) == 1:
return _ParquetDatasetIterator(self._paths[0], **self._read_kwargs)

datasets = [ParquetIterDataset(p, **self._read_kwargs) for p in self._paths]
delegate = interleave.InterleaveIterDataset(
datasets, cycle_length=_CYCLE_LENGTH
)
return delegate.__iter__()

def set_slice(self, sl: slice, sequential_slice: bool = False):
del sequential_slice
self._paths = self._paths[sl]
34 changes: 34 additions & 0 deletions grain/_src/python/dataset/sources/parquet_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,44 @@ def test_read_row_group(self):
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]])

def test_read_multiple_files(self):
dataset = parquet_dataset.ParquetIterDataset(self.filenames)
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in INTERLEAVED_TEXT])

def test_sharding_multi_file(self):
# Test sharding across files
dataset = parquet_dataset.ParquetIterDataset(self.filenames)
dataset.set_slice(slice(0, 1)) # Only first file
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]])

dataset = parquet_dataset.ParquetIterDataset(self.filenames)
dataset.set_slice(slice(1, 2)) # Only second file
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[1]])

def test_checkpointing_multi_file(self):
dataset = parquet_dataset.ParquetIterDataset(self.filenames)
grain.experimental.assert_equal_output_after_checkpoint(dataset)

def test_checkpointing(self):
dataset = parquet_dataset.ParquetIterDataset(self.filenames[0])
grain.experimental.assert_equal_output_after_checkpoint(dataset)

def test_set_slice(self):
# Test slice first file
dataset = parquet_dataset.ParquetIterDataset(self.filenames)
dataset.set_slice(slice(0, 1))
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]])

# Test slice second file
dataset = parquet_dataset.ParquetIterDataset(self.filenames)
dataset.set_slice(slice(1, 2))
records = list(dataset)
self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[1]])

def test_sharded_files_and_interleaved_dataset(self):
dataset = grain.MapDataset.source(self.filenames)
dataset = dataset.map(parquet_dataset.ParquetIterDataset)
Expand Down
Loading