From 8d3ef9e5c89884b02ca97c5657f07e54199f65b1 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Fri, 20 Mar 2026 07:45:18 -0700 Subject: [PATCH] Add safeguard tests for cross-version determinism. While we do not technically guarantee it, many users rely on it in practice, so we should try hard not to break it. PiperOrigin-RevId: 886785152 --- .../_src/python/dataset/transformations/map_test.py | 13 +++++++++++++ .../python/dataset/transformations/shuffle_test.py | 10 ++++++++++ 2 files changed, 23 insertions(+) diff --git a/grain/_src/python/dataset/transformations/map_test.py b/grain/_src/python/dataset/transformations/map_test.py index fa0c3998e..fa2cdff77 100644 --- a/grain/_src/python/dataset/transformations/map_test.py +++ b/grain/_src/python/dataset/transformations/map_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for map transformation.""" + import dataclasses import operator from typing import Any @@ -316,6 +317,18 @@ def test_random_map_element_spec_inference_raises_error(self): with self.assertRaisesRegex(ValueError, "does not implement `output_spec`"): _ = ds._element_spec + def test_cross_version_determinism(self): + # This test validates random map determinism across different versions of + # Grain given a fixed seed. Note that we technically do not guarantee + # cross-version determinism because numpy does not. Any changes to the numpy + # RNGs or how we use them could break this. Multiple users nevertheless rely + # on it because it holds in practice. Only update the values if you know + # what you're doing. + ds = dataset.MapDataset.range(10).seed(41) + ds = map_ds.RandomMapMapDataset(ds, RandomMapWithDeterminismTransform()) + ds = ds.map(lambda x: x.item()) + self.assertEqual(list(ds), [0, 6, 9, 5, 10, 14, 14, 13, 15, 13]) + class MapIterDatasetTest(parameterized.TestCase): diff --git a/grain/_src/python/dataset/transformations/shuffle_test.py b/grain/_src/python/dataset/transformations/shuffle_test.py index 8a459ea55..13030a578 100644 --- a/grain/_src/python/dataset/transformations/shuffle_test.py +++ b/grain/_src/python/dataset/transformations/shuffle_test.py @@ -75,6 +75,16 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) self.assertEqual(spec.shape, ()) + def test_cross_version_determinism(self): + # This test validates shuffle determinism across different versions of + # Grain given a fixed seed. Note that we technically do not guarantee + # cross-version determinism, but multiple users nevertheless rely on it + # because it holds in practice. Any updates to the shuffle code could break + # it. Only update the values if you know what you're doing. + ds = dataset.MapDataset.range(10).seed(42) + ds = shuffle.ShuffleMapDataset(ds) + self.assertEqual(list(ds), [1, 7, 6, 9, 0, 8, 4, 5, 3, 2]) + class WindowShuffleMapDatasetTest(absltest.TestCase):