From 24cd5ff2c8109f676f34f173511f0cd7cf120011 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Wed, 18 Mar 2026 11:20:33 -0700 Subject: [PATCH] Try implementing element_spec for `SourceMapDataset`. PiperOrigin-RevId: 885696249 --- grain/_src/python/dataset/dataset_test.py | 6 +++ .../_src/python/dataset/transformations/BUILD | 1 + .../python/dataset/transformations/source.py | 18 +++++++ .../dataset/transformations/source_test.py | 50 +++++++++++++++++++ 4 files changed, 75 insertions(+) diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index fb76f2078..695ea780e 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -1414,6 +1414,12 @@ def test_execution_summary_with_no_logging(self): class GetElementSpecTest(parameterized.TestCase): + def test_get_element_spec_from_source(self): + ds = dataset.MapDataset.source(range(10)) + spec = dataset.get_element_spec(ds) + self.assertEqual(spec.shape, ()) + self.assertEqual(spec.dtype, np.int64) + def test_get_element_spec_from_map_dataset(self): ds = dataset.MapDataset.range(10) spec = dataset.get_element_spec(ds) diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index aa832cc05..b8944fd58 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -82,6 +82,7 @@ py_test( srcs_version = "PY3", deps = [ "//grain/_src/python/dataset", + "//grain/_src/python/dataset:base", "@abseil-py//absl/testing:absltest", "@pypi//numpy:pkg", ], diff --git a/grain/_src/python/dataset/transformations/source.py b/grain/_src/python/dataset/transformations/source.py index 4ba0b3ede..b0e271226 100644 --- a/grain/_src/python/dataset/transformations/source.py +++ b/grain/_src/python/dataset/transformations/source.py @@ -15,6 +15,7 @@ from __future__ import annotations +from collections.abc import Mapping import contextlib # pylint: disable=unused-import import functools import time @@ -132,6 +133,23 @@ def paths(self) -> str | Sequence[str]: else: return [] + @functools.cached_property + def _element_spec(self) -> Any: + if self._source is None or len(self._source) == 0: + return base.ShapeDtypeStruct(shape=(), dtype=np.float32) + first_element = self._source[0] + + def _spec(x) -> base.ShapeDtypeStruct: + try: + arr = np.asarray(x) + return base.ShapeDtypeStruct(arr.shape, arr.dtype) + except Exception as e: + raise TypeError(f"Cannot infer element spec for leaf {x}.") from e + + return tree_lib.map_structure( + _spec, first_element, is_leaf=lambda x: not isinstance(x, Mapping) + ) + def log_lineage_for_sources( root: Union[dataset.MapDataset, dataset.IterDataset], diff --git a/grain/_src/python/dataset/transformations/source_test.py b/grain/_src/python/dataset/transformations/source_test.py index bedbba3ee..94e45db49 100644 --- a/grain/_src/python/dataset/transformations/source_test.py +++ b/grain/_src/python/dataset/transformations/source_test.py @@ -18,6 +18,7 @@ from unittest import mock from absl.testing import absltest +from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import source import numpy as np @@ -161,6 +162,55 @@ def test_set_slice(self): ) self.assertEqual(list(ds), [14, 15, 16, 17, 18, 19, 20]) + def test_element_spec_list(self): + ds = source.SourceMapDataset([1, 2, 3]) + element_spec = dataset.get_element_spec(ds) + self.assertEqual(element_spec.shape, ()) + self.assertEqual(element_spec.dtype, np.int64) + + def test_element_spec_2d_list(self): + ds = source.SourceMapDataset( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] + ) + element_spec = dataset.get_element_spec(ds) + self.assertEqual(element_spec.shape, (2, 3)) + self.assertEqual(element_spec.dtype, np.int64) + + def test_element_spec_np_array(self): + ds = source.SourceMapDataset( + np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + ) + element_spec = dataset.get_element_spec(ds) + self.assertEqual(element_spec.shape, (2, 3)) + self.assertEqual(element_spec.dtype, np.int64) + + def test_element_spec_batch(self): + ds = source.SourceMapDataset(range(10)) + ds = ds.batch(3, drop_remainder=True) + element_spec = dataset.get_element_spec(ds) + self.assertEqual(element_spec.shape, (3,)) + self.assertEqual(element_spec.dtype, np.int64) + + def test_element_spec_dict(self): + ds = source.SourceMapDataset( + [{"a": 1, "b": "Hello, world!", "c": [[1, 2], [3, 4]]}] + ) + element_spec = dataset.get_element_spec(ds) + self.assertEqual( + element_spec, + { + "a": base.ShapeDtypeStruct((), np.int64), + "b": base.ShapeDtypeStruct((), np.dtype("