Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,12 @@ def test_execution_summary_with_no_logging(self):

class GetElementSpecTest(parameterized.TestCase):

def test_get_element_spec_from_source(self):
ds = dataset.MapDataset.source(range(10))
spec = dataset.get_element_spec(ds)
self.assertEqual(spec.shape, ())
self.assertEqual(spec.dtype, np.int64)

def test_get_element_spec_from_map_dataset(self):
ds = dataset.MapDataset.range(10)
spec = dataset.get_element_spec(ds)
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ py_test(
srcs_version = "PY3",
deps = [
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/testing:absltest",
"@pypi//numpy:pkg",
],
Expand Down
18 changes: 18 additions & 0 deletions grain/_src/python/dataset/transformations/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

from collections.abc import Mapping
import contextlib # pylint: disable=unused-import
import functools
import time
Expand Down Expand Up @@ -132,6 +133,23 @@ def paths(self) -> str | Sequence[str]:
else:
return []

@functools.cached_property
def _element_spec(self) -> Any:
if self._source is None or len(self._source) == 0:
return base.ShapeDtypeStruct(shape=(), dtype=np.float32)
first_element = self._source[0]

def _spec(x) -> base.ShapeDtypeStruct:
try:
arr = np.asarray(x)
return base.ShapeDtypeStruct(arr.shape, arr.dtype)
except Exception as e:
raise TypeError(f"Cannot infer element spec for leaf {x}.") from e

return tree_lib.map_structure(
_spec, first_element, is_leaf=lambda x: not isinstance(x, Mapping)
)


def log_lineage_for_sources(
root: Union[dataset.MapDataset, dataset.IterDataset],
Expand Down
50 changes: 50 additions & 0 deletions grain/_src/python/dataset/transformations/source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest import mock

from absl.testing import absltest
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import source
import numpy as np
Expand Down Expand Up @@ -161,6 +162,55 @@ def test_set_slice(self):
)
self.assertEqual(list(ds), [14, 15, 16, 17, 18, 19, 20])

def test_element_spec_list(self):
ds = source.SourceMapDataset([1, 2, 3])
element_spec = dataset.get_element_spec(ds)
self.assertEqual(element_spec.shape, ())
self.assertEqual(element_spec.dtype, np.int64)

def test_element_spec_2d_list(self):
ds = source.SourceMapDataset(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
)
element_spec = dataset.get_element_spec(ds)
self.assertEqual(element_spec.shape, (2, 3))
self.assertEqual(element_spec.dtype, np.int64)

def test_element_spec_np_array(self):
ds = source.SourceMapDataset(
np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
)
element_spec = dataset.get_element_spec(ds)
self.assertEqual(element_spec.shape, (2, 3))
self.assertEqual(element_spec.dtype, np.int64)

def test_element_spec_batch(self):
ds = source.SourceMapDataset(range(10))
ds = ds.batch(3, drop_remainder=True)
element_spec = dataset.get_element_spec(ds)
self.assertEqual(element_spec.shape, (3,))
self.assertEqual(element_spec.dtype, np.int64)

def test_element_spec_dict(self):
ds = source.SourceMapDataset(
[{"a": 1, "b": "Hello, world!", "c": [[1, 2], [3, 4]]}]
)
element_spec = dataset.get_element_spec(ds)
self.assertEqual(
element_spec,
{
"a": base.ShapeDtypeStruct((), np.int64),
"b": base.ShapeDtypeStruct((), np.dtype("<U13")),
"c": base.ShapeDtypeStruct((2, 2), np.int64),
},
)

def test_element_spec_empty(self):
ds = source.SourceMapDataset([])
element_spec = dataset.get_element_spec(ds)
self.assertEqual(element_spec.shape, ())
self.assertEqual(element_spec.dtype, np.float32)


class RangeMapDatasetTest(absltest.TestCase):

Expand Down
Loading