Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions grain/_src/python/dataset/transformations/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for map transformation."""

import dataclasses
import operator
from typing import Any
Expand Down Expand Up @@ -316,6 +317,18 @@ def test_random_map_element_spec_inference_raises_error(self):
with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"):
_ = ds._element_spec

def test_cross_version_determinism(self):
# This test validates random map determinism across different versions of
# Grain given a fixed seed. Note that we technically do not guarantee
# cross-version determinism because numpy does not. Any changes to the numpy
# RNGs or how we use them could break this. Multiple users nevertheless rely
# on it because it holds in practice. Only update the values if you know
# what you're doing.
ds = dataset.MapDataset.range(10).seed(41)
ds = map_ds.RandomMapMapDataset(ds, RandomMapWithDeterminismTransform())
ds = ds.map(lambda x: x.item())
self.assertEqual(list(ds), [0, 6, 9, 5, 10, 14, 14, 13, 15, 13])


class MapIterDatasetTest(parameterized.TestCase):

Expand Down
10 changes: 10 additions & 0 deletions grain/_src/python/dataset/transformations/shuffle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def test_element_spec(self):
self.assertEqual(spec.dtype, np.int64)
self.assertEqual(spec.shape, ())

def test_cross_version_determinism(self):
# This test validates shuffle determinism across different versions of
# Grain given a fixed seed. Note that we technically do not guarantee
# cross-version determinism, but multiple users nevertheless rely on it
# because it holds in practice. Any updates to the shuffle code could break
# it. Only update the values if you know what you're doing.
ds = dataset.MapDataset.range(10).seed(42)
ds = shuffle.ShuffleMapDataset(ds)
self.assertEqual(list(ds), [1, 7, 6, 9, 0, 8, 4, 5, 3, 2])


class WindowShuffleMapDatasetTest(absltest.TestCase):

Expand Down
Loading