From 2d4fb34f413242d893dadcbdd30536452ed37c7f Mon Sep 17 00:00:00 2001 From: Vasilybib1 Date: Thu, 26 Feb 2026 20:39:13 -0800 Subject: [PATCH] fixed broken checkpoint API names and add composite model+data example (#938) --- docs/tutorials/data_loader_tutorial.ipynb | 139 +++++++++++++++++++++- docs/tutorials/data_loader_tutorial.md | 115 +++++++++++++++++- 2 files changed, 248 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/data_loader_tutorial.ipynb b/docs/tutorials/data_loader_tutorial.ipynb index 8be9733fc..0fd287382 100644 --- a/docs/tutorials/data_loader_tutorial.ipynb +++ b/docs/tutorials/data_loader_tutorial.ipynb @@ -505,7 +505,7 @@ }, "source": [ "## Checkpointing\n", - "We provide `PyGrainCheckpointHandler` to checkpoint the iterator returned by Grain. It is recommended to use it with [Orbax](https://orbax.readthedocs.io), which can checkpoint both input pipeline and model and handles the edge cases for distributed training." + "We provide `GrainCheckpointHandler` to checkpoint the iterator returned by Grain. It is recommended to use it with [Orbax](https://orbax.readthedocs.io), which can checkpoint both input pipeline and model and handles the edge cases for distributed training." ] }, { @@ -549,6 +549,8 @@ } ], "source": [ + "import grain.checkpoint\n", + "\n", "data_iter = iter(data_loader)\n", "\n", "num_steps = 5\n", @@ -604,7 +606,7 @@ "\n", "# Save the checkpoint\n", "assert mngr.save(\n", - " step=num_steps, args=grain.PyGrainCheckpointSave(data_iter), force=True)\n", + " step=num_steps, args=grain.checkpoint.CheckpointSave(data_iter), force=True)\n", "# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until\n", "# finished before examining checkpoint.\n", "mngr.wait_until_finished()\n", @@ -728,7 +730,7 @@ ], "source": [ "# Restore iterator from previously saved checkpoint\n", - "mngr.restore(num_steps, args=grain.PyGrainCheckpointRestore(data_iter))" + "mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(data_iter))" ] }, { @@ -768,6 +770,137 @@ " print(i, x[\"file_name\"], x[\"label\"])" ] }, + { + "cell_type": "markdown", + "id": "ccf933fe", + "metadata": { + "id": "composite_checkpoint_intro" + }, + "source": [ + "### Checkpointing alongside model state\n", + "\n", + "In a real training run you typically want to save and restore **both** your model\n", + "state and the data iterator in a single atomic checkpoint, so that resuming\n", + "training is fully reproducible. Orbax's `ocp.args.Composite` lets you combine\n", + "multiple objects in one `save` / `restore` call.\n", + "\n", + "The most important rule: **keep a direct reference to the `DataLoaderIterator`\n", + "returned by `iter(data_loader)`**. Do not call `iter()` again during the training\n", + "loop — that would reset the iterator and lose checkpoint state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4e7614c", + "metadata": { + "id": "composite_checkpoint_setup" + }, + "outputs": [], + "source": [ + "# A minimal stand-in for a real model state.\n", + "# In practice this would be a JAX/Flax TrainState or a Pytree of parameters.\n", + "import numpy as np\n", + "model_weights = np.array([0.1, 0.2, 0.3])\n", + "\n", + "# Build a fresh data loader and hold on to the iterator.\n", + "ckpt_source = tfds.data_source(\"imagenet_a\", split=\"test\")\n", + "ckpt_loader = grain.DataLoader(\n", + " data_source=ckpt_source,\n", + " operations=[ResizeAndCrop()],\n", + " sampler=grain.IndexSampler(\n", + " num_records=20,\n", + " num_epochs=1,\n", + " shard_options=grain.ShardOptions(\n", + " shard_index=0, shard_count=1, drop_remainder=True),\n", + " shuffle=True,\n", + " seed=42),\n", + " worker_count=0)\n", + "\n", + "# Keep this iterator alive for the full training run.\n", + "ckpt_iter = iter(ckpt_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c47a99cb", + "metadata": { + "id": "composite_checkpoint_save" + }, + "outputs": [], + "source": [ + "ckpt_dir = \"/tmp/orbax_composite\"\n", + "!rm -rf {ckpt_dir}\n", + "\n", + "mngr = ocp.CheckpointManager(\n", + " ckpt_dir,\n", + " options=ocp.CheckpointManagerOptions(max_to_keep=3),\n", + ")\n", + "\n", + "# --- Simulate a short training loop ---\n", + "num_train_steps = 5\n", + "for step in range(num_train_steps):\n", + " batch = next(ckpt_iter)\n", + " # ... model update would happen here ...\n", + "\n", + "# Save model weights and data iterator state together.\n", + "mngr.save(\n", + " step=num_train_steps,\n", + " args=ocp.args.Composite(\n", + " model_state=ocp.args.StandardSave(model_weights),\n", + " data_state=grain.checkpoint.CheckpointSave(ckpt_iter),\n", + " ),\n", + " force=True,\n", + ")\n", + "mngr.wait_until_finished()\n", + "print(f\"Saved checkpoint at step {num_train_steps}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7fd20c2", + "metadata": { + "id": "composite_checkpoint_restore" + }, + "outputs": [], + "source": [ + "# --- Simulate resuming from a checkpoint ---\n", + "# Create a fresh loader/iterator with the same config.\n", + "resumed_loader = grain.DataLoader(\n", + " data_source=ckpt_source,\n", + " operations=[ResizeAndCrop()],\n", + " sampler=grain.IndexSampler(\n", + " num_records=20,\n", + " num_epochs=1,\n", + " shard_options=grain.ShardOptions(\n", + " shard_index=0, shard_count=1, drop_remainder=True),\n", + " shuffle=True,\n", + " seed=42),\n", + " worker_count=0)\n", + "resumed_iter = iter(resumed_loader)\n", + "resumed_weights = np.zeros(3)\n", + "\n", + "latest_step = mngr.latest_step()\n", + "if latest_step is not None:\n", + " restored = mngr.restore(\n", + " latest_step,\n", + " args=ocp.args.Composite(\n", + " model_state=ocp.args.StandardRestore(resumed_weights),\n", + " data_state=grain.checkpoint.CheckpointRestore(resumed_iter),\n", + " ),\n", + " )\n", + " resumed_weights = restored[\"model_state\"]\n", + " # resumed_iter is restored in-place; the next element will be step 5's batch.\n", + " print(f\"Restored from step {latest_step}, model_weights={resumed_weights}\")\n", + "\n", + "# Continue training — the data iterator picks up exactly where it left off.\n", + "for i in range(num_train_steps, num_train_steps + 3):\n", + " batch = next(resumed_iter)\n", + " print(i, batch[\"file_name\"])" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/docs/tutorials/data_loader_tutorial.md b/docs/tutorials/data_loader_tutorial.md index a2d3a9d67..be95098e1 100644 --- a/docs/tutorials/data_loader_tutorial.md +++ b/docs/tutorials/data_loader_tutorial.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.4 + jupytext_version: 1.19.1 kernelspec: display_name: Python 3 name: python3 @@ -222,6 +222,8 @@ executionInfo: id: irJix4sJkNcf outputId: 648ffc5d-088e-4747-da1d-f0bfd87e4360 --- +import grain.checkpoint + data_iter = iter(data_loader) num_steps = 5 @@ -251,7 +253,7 @@ mngr = ocp.CheckpointManager("/tmp/orbax") # Save the checkpoint assert mngr.save( - step=num_steps, args=grain.GrainCheckpointSave(data_iter), force=True) + step=num_steps, args=grain.checkpoint.CheckpointSave(data_iter), force=True) # Checkpoint saving in Orbax is asynchronous by default, so we'll wait until # finished before examining checkpoint. mngr.wait_until_finished() @@ -312,7 +314,7 @@ id: Js3hheiGnykN outputId: 31cc1d8d-9a89-4fa8-af7a-5c74650f118e --- # Restore iterator from previously saved checkpoint -mngr.restore(num_steps, args=grain.GrainCheckpointRestore(data_iter)) +mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(data_iter)) ``` ```{code-cell} @@ -334,6 +336,113 @@ for i in range(5, 10): print(i, x["file_name"], x["label"]) ``` ++++ {"id": "composite_checkpoint_intro"} + +### Checkpointing alongside model state + +In a real training run you typically want to save and restore **both** your model +state and the data iterator in a single atomic checkpoint, so that resuming +training is fully reproducible. Orbax's `ocp.args.Composite` lets you combine +multiple objects in one `save` / `restore` call. + +The most important rule: **keep a direct reference to the `DataLoaderIterator` +returned by `iter(data_loader)`**. Do not call `iter()` again during the training +loop — that would reset the iterator and lose checkpoint state. + +```{code-cell} +:id: composite_checkpoint_setup + +# A minimal stand-in for a real model state. +# In practice this would be a JAX/Flax TrainState or a Pytree of parameters. +import numpy as np +model_weights = np.array([0.1, 0.2, 0.3]) + +# Build a fresh data loader and hold on to the iterator. +ckpt_source = tfds.data_source("imagenet_a", split="test") +ckpt_loader = grain.DataLoader( + data_source=ckpt_source, + operations=[ResizeAndCrop()], + sampler=grain.IndexSampler( + num_records=20, + num_epochs=1, + shard_options=grain.ShardOptions( + shard_index=0, shard_count=1, drop_remainder=True), + shuffle=True, + seed=42), + worker_count=0) + +# Keep this iterator alive for the full training run. +ckpt_iter = iter(ckpt_loader) +``` + +```{code-cell} +:id: composite_checkpoint_save + +ckpt_dir = "/tmp/orbax_composite" +!rm -rf {ckpt_dir} + +mngr = ocp.CheckpointManager( + ckpt_dir, + options=ocp.CheckpointManagerOptions(max_to_keep=3), +) + +# --- Simulate a short training loop --- +num_train_steps = 5 +for step in range(num_train_steps): + batch = next(ckpt_iter) + # ... model update would happen here ... + +# Save model weights and data iterator state together. +mngr.save( + step=num_train_steps, + args=ocp.args.Composite( + model_state=ocp.args.StandardSave(model_weights), + data_state=grain.checkpoint.CheckpointSave(ckpt_iter), + ), + force=True, +) +mngr.wait_until_finished() +print(f"Saved checkpoint at step {num_train_steps}") +``` + +```{code-cell} +:id: composite_checkpoint_restore + +# --- Simulate resuming from a checkpoint --- +# Create a fresh loader/iterator with the same config. +resumed_loader = grain.DataLoader( + data_source=ckpt_source, + operations=[ResizeAndCrop()], + sampler=grain.IndexSampler( + num_records=20, + num_epochs=1, + shard_options=grain.ShardOptions( + shard_index=0, shard_count=1, drop_remainder=True), + shuffle=True, + seed=42), + worker_count=0) +resumed_iter = iter(resumed_loader) +resumed_weights = np.zeros(3) + +latest_step = mngr.latest_step() +if latest_step is not None: + restored = mngr.restore( + latest_step, + args=ocp.args.Composite( + model_state=ocp.args.StandardRestore(resumed_weights), + data_state=grain.checkpoint.CheckpointRestore(resumed_iter), + ), + ) + resumed_weights = restored["model_state"] + # resumed_iter is restored in-place; the next element will be step 5's batch. + print(f"Restored from step {latest_step}, model_weights={resumed_weights}") + +# Continue training — the data iterator picks up exactly where it left off. +for i in range(num_train_steps, num_train_steps + 3): + batch = next(resumed_iter) + print(i, batch["file_name"]) +``` + +++ {"id": "btSRh4EL_Zbo"} ## Extras