diff --git a/grain/_src/python/checkpoint/handler_test.py b/grain/_src/python/checkpoint/handler_test.py deleted file mode 100644 index 0f322ba5d..000000000 --- a/grain/_src/python/checkpoint/handler_test.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for checkpoint handlers.""" - -from grain._src.core import sharding -from grain._src.python import data_loader -from grain._src.python import data_sources -from grain._src.python import samplers -from grain._src.python.checkpoint import handler -from grain._src.python.dataset import dataset -import orbax.checkpoint as ocp - -from absl.testing import absltest - - -class CheckpointHandlersTest(absltest.TestCase): - - def _create_data_loader(self) -> data_loader.DataLoader: - # Generates elements [0, 1, 2, 3, 4, 5, 6, 7]. - range_data_source = data_sources.RangeDataSource(0, 8, 1) - sampler = samplers.SequentialSampler( - num_records=len(range_data_source), - shard_options=sharding.NoSharding(), - ) - return data_loader.DataLoader( - data_source=range_data_source, - sampler=sampler, - ) - - def _create_data_loader_iter_to_checkpoint(self): - ds = self._create_data_loader() - break_at = 4 - ds_iter = iter(ds) - for _ in range(break_at): - _ = next(ds_iter) - return ds_iter - - def _assert_restored_data_loader_iter(self, ds_iter): - expected_data = list(self._create_data_loader_iter_to_checkpoint()) - self.assertEqual(list(ds_iter), expected_data) - - def _create_dataset(self): - return ( - dataset.MapDataset.range(35) - .seed(23) - .map(lambda x: x + 100) - .shuffle() - .to_iter_dataset() - .batch(3) - .map(lambda x: x.tolist()) - ) - - def _create_dataset_iter_to_checkpoint(self): - ds = self._create_dataset() - break_at = 5 - ds_iter = iter(ds) - for _ in range(break_at): - _ = next(ds_iter) - return ds_iter - - def _assert_restored_dataset_iter(self, ds_iter): - expected_data = list(self._create_dataset_iter_to_checkpoint()) - self.assertEqual(list(ds_iter), expected_data) - - def test_data_loader_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - mngr = ocp.CheckpointManager(tmpdir) - mngr.save( - 0, - args=handler.CheckpointSave( - self._create_data_loader_iter_to_checkpoint() - ), - ) - mngr.wait_until_finished() - ds = self._create_data_loader() - ds_iter = iter(ds) - mngr.restore(0, args=handler.CheckpointRestore(ds_iter)) - self._assert_restored_data_loader_iter(ds_iter) - - def test_dataset_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - mngr = ocp.CheckpointManager(tmpdir) - mngr.save( - 0, - args=handler.CheckpointSave(self._create_dataset_iter_to_checkpoint()), - ) - mngr.wait_until_finished() - ds = self._create_dataset() - ds_iter = iter(ds) - mngr.restore(0, args=handler.CheckpointRestore(ds_iter)) - self._assert_restored_dataset_iter(ds_iter) - - def test_composite_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - mngr = ocp.CheckpointManager(tmpdir, item_names=("state", "ds_iter")) - mngr.save( - 0, - args=ocp.args.Composite( - state=ocp.args.StandardSave({"values": [0]}), - ds_iter=handler.CheckpointSave( - self._create_data_loader_iter_to_checkpoint() - ), - ), - ) - mngr.wait_until_finished() - ds = self._create_data_loader() - ds_iter = iter(ds) - mngr.restore( - 0, - args=ocp.args.Composite( - state=ocp.args.StandardRestore({"values": [0]}), - ds_iter=handler.CheckpointRestore(ds_iter), - ), - ) - self._assert_restored_data_loader_iter(ds_iter) - - -if __name__ == "__main__": - absltest.main() diff --git a/grain/_src/python/checkpoint/orbax_integration_test.py b/grain/_src/python/checkpoint/orbax_integration_test.py deleted file mode 100644 index b85b4fb26..000000000 --- a/grain/_src/python/checkpoint/orbax_integration_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for checkpoint handlers.""" -from grain._src.core import sharding -from grain._src.python import data_loader -from grain._src.python import data_sources -from grain._src.python import samplers -from grain._src.python.dataset import dataset -from orbax.checkpoint import v1 as ocp - -from absl.testing import absltest - - -class CheckpointHandlersTest(absltest.TestCase): - - def _create_data_loader(self) -> data_loader.DataLoader: - # Generates elements [0, 1, 2, 3, 4, 5, 6, 7]. - range_data_source = data_sources.RangeDataSource(0, 8, 1) - sampler = samplers.SequentialSampler( - num_records=len(range_data_source), - shard_options=sharding.NoSharding(), - ) - return data_loader.DataLoader( - data_source=range_data_source, - sampler=sampler, - ) - - def _create_data_loader_iter_to_checkpoint(self): - ds = self._create_data_loader() - break_at = 4 - ds_iter = iter(ds) - for _ in range(break_at): - _ = next(ds_iter) - return ds_iter - - def _assert_restored_data_loader_iter(self, ds_iter): - expected_data = list(self._create_data_loader_iter_to_checkpoint()) - self.assertEqual(list(ds_iter), expected_data) - - def _create_dataset(self): - return ( - dataset.MapDataset.range(35) - .seed(23) - .map(lambda x: x + 100) - .shuffle() - .to_iter_dataset() - .batch(3) - .map(lambda x: x.tolist()) - ) - - def _create_dataset_iter_to_checkpoint(self): - ds = self._create_dataset() - break_at = 5 - ds_iter = iter(ds) - for _ in range(break_at): - _ = next(ds_iter) - return ds_iter - - def _assert_restored_dataset_iter(self, ds_iter): - expected_data = list(self._create_dataset_iter_to_checkpoint()) - self.assertEqual(list(ds_iter), expected_data) - - def test_data_loader_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - ocp.save_checkpointables( - tmpdir, dict(dataset=self._create_data_loader_iter_to_checkpoint()) - ) - ds = self._create_data_loader() - ds_iter = iter(ds) - ocp.load_checkpointables(tmpdir, dict(dataset=ds_iter)) - self._assert_restored_data_loader_iter(ds_iter) - - def test_dataset_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - ocp.save_checkpointables( - tmpdir, dict(dataset=self._create_dataset_iter_to_checkpoint()) - ) - ds = self._create_dataset() - ds_iter = iter(ds) - ocp.load_checkpointables(tmpdir, dict(dataset=ds_iter)) - self._assert_restored_dataset_iter(ds_iter) - - def test_composite_checkpoint_save_and_restore(self): - tmpdir = f"{self.create_tempdir().full_path}/checkpoint" - ocp.save_checkpointables( - tmpdir, - dict( - state={"values": [0]}, - dataset=self._create_data_loader_iter_to_checkpoint(), - ), - ) - ds = self._create_data_loader() - ds_iter = iter(ds) - ocp.load_checkpointables(tmpdir, dict(state=None, dataset=ds_iter)) - self._assert_restored_data_loader_iter(ds_iter) - - -if __name__ == "__main__": - absltest.main() diff --git a/grain/_src/python/dataset/sources/BUILD b/grain/_src/python/dataset/sources/BUILD index af8cafd3a..85e42f15c 100644 --- a/grain/_src/python/dataset/sources/BUILD +++ b/grain/_src/python/dataset/sources/BUILD @@ -29,10 +29,6 @@ py_library( name = "tfrecord_dataset", srcs = ["tfrecord_dataset.py"], srcs_version = "PY3", - visibility = [ - "//third_party/py/grain:internal", - "//third_party/py/maxtext:__pkg__", - ], deps = ["//grain/_src/python/dataset"], ) diff --git a/grain/oss/build_whl.sh b/grain/oss/build_whl.sh index abe9f18ca..4d3f78e52 100644 --- a/grain/oss/build_whl.sh +++ b/grain/oss/build_whl.sh @@ -21,7 +21,6 @@ main() { # Enable host OS specific configs. For instance, "build:linux" will be used # automatically when building on Linux. write_to_bazelrc "build --enable_platform_specific_config" - write_to_bazelrc "build --verbose_failures" # Bazel 7.0.0 no longer supports dynamic symbol lookup on macOS. To resolve # undefined symbol errors in macOS arm64 builds, explicitly add the necessary # linker flags until dependencies are well defined. See