Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 136 additions & 3 deletions docs/tutorials/data_loader_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
},
"source": [
"## Checkpointing\n",
"We provide `PyGrainCheckpointHandler` to checkpoint the iterator returned by Grain. It is recommended to use it with [Orbax](https://orbax.readthedocs.io), which can checkpoint both input pipeline and model and handles the edge cases for distributed training."
"We provide `GrainCheckpointHandler` to checkpoint the iterator returned by Grain. It is recommended to use it with [Orbax](https://orbax.readthedocs.io), which can checkpoint both input pipeline and model and handles the edge cases for distributed training."
]
},
{
Expand Down Expand Up @@ -549,6 +549,8 @@
}
],
"source": [
"import grain.checkpoint\n",
"\n",
"data_iter = iter(data_loader)\n",
"\n",
"num_steps = 5\n",
Expand Down Expand Up @@ -604,7 +606,7 @@
"\n",
"# Save the checkpoint\n",
"assert mngr.save(\n",
" step=num_steps, args=grain.PyGrainCheckpointSave(data_iter), force=True)\n",
" step=num_steps, args=grain.checkpoint.CheckpointSave(data_iter), force=True)\n",
"# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until\n",
"# finished before examining checkpoint.\n",
"mngr.wait_until_finished()\n",
Expand Down Expand Up @@ -728,7 +730,7 @@
],
"source": [
"# Restore iterator from previously saved checkpoint\n",
"mngr.restore(num_steps, args=grain.PyGrainCheckpointRestore(data_iter))"
"mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(data_iter))"
]
},
{
Expand Down Expand Up @@ -768,6 +770,137 @@
" print(i, x[\"file_name\"], x[\"label\"])"
]
},
{
"cell_type": "markdown",
"id": "ccf933fe",
"metadata": {
"id": "composite_checkpoint_intro"
},
"source": [
"### Checkpointing alongside model state\n",
"\n",
"In a real training run you typically want to save and restore **both** your model\n",
"state and the data iterator in a single atomic checkpoint, so that resuming\n",
"training is fully reproducible. Orbax's `ocp.args.Composite` lets you combine\n",
"multiple objects in one `save` / `restore` call.\n",
"\n",
"The most important rule: **keep a direct reference to the `DataLoaderIterator`\n",
"returned by `iter(data_loader)`**. Do not call `iter()` again during the training\n",
"loop — that would reset the iterator and lose checkpoint state."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4e7614c",
"metadata": {
"id": "composite_checkpoint_setup"
},
"outputs": [],
"source": [
"# A minimal stand-in for a real model state.\n",
"# In practice this would be a JAX/Flax TrainState or a Pytree of parameters.\n",
"import numpy as np\n",
"model_weights = np.array([0.1, 0.2, 0.3])\n",
"\n",
"# Build a fresh data loader and hold on to the iterator.\n",
"ckpt_source = tfds.data_source(\"imagenet_a\", split=\"test\")\n",
"ckpt_loader = grain.DataLoader(\n",
" data_source=ckpt_source,\n",
" operations=[ResizeAndCrop()],\n",
" sampler=grain.IndexSampler(\n",
" num_records=20,\n",
" num_epochs=1,\n",
" shard_options=grain.ShardOptions(\n",
" shard_index=0, shard_count=1, drop_remainder=True),\n",
" shuffle=True,\n",
" seed=42),\n",
" worker_count=0)\n",
"\n",
"# Keep this iterator alive for the full training run.\n",
"ckpt_iter = iter(ckpt_loader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c47a99cb",
"metadata": {
"id": "composite_checkpoint_save"
},
"outputs": [],
"source": [
"ckpt_dir = \"/tmp/orbax_composite\"\n",
"!rm -rf {ckpt_dir}\n",
"\n",
"mngr = ocp.CheckpointManager(\n",
" ckpt_dir,\n",
" options=ocp.CheckpointManagerOptions(max_to_keep=3),\n",
")\n",
"\n",
"# --- Simulate a short training loop ---\n",
"num_train_steps = 5\n",
"for step in range(num_train_steps):\n",
" batch = next(ckpt_iter)\n",
" # ... model update would happen here ...\n",
"\n",
"# Save model weights and data iterator state together.\n",
"mngr.save(\n",
" step=num_train_steps,\n",
" args=ocp.args.Composite(\n",
" model_state=ocp.args.StandardSave(model_weights),\n",
" data_state=grain.checkpoint.CheckpointSave(ckpt_iter),\n",
" ),\n",
" force=True,\n",
")\n",
"mngr.wait_until_finished()\n",
"print(f\"Saved checkpoint at step {num_train_steps}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7fd20c2",
"metadata": {
"id": "composite_checkpoint_restore"
},
"outputs": [],
"source": [
"# --- Simulate resuming from a checkpoint ---\n",
"# Create a fresh loader/iterator with the same config.\n",
"resumed_loader = grain.DataLoader(\n",
" data_source=ckpt_source,\n",
" operations=[ResizeAndCrop()],\n",
" sampler=grain.IndexSampler(\n",
" num_records=20,\n",
" num_epochs=1,\n",
" shard_options=grain.ShardOptions(\n",
" shard_index=0, shard_count=1, drop_remainder=True),\n",
" shuffle=True,\n",
" seed=42),\n",
" worker_count=0)\n",
"resumed_iter = iter(resumed_loader)\n",
"resumed_weights = np.zeros(3)\n",
"\n",
"latest_step = mngr.latest_step()\n",
"if latest_step is not None:\n",
" restored = mngr.restore(\n",
" latest_step,\n",
" args=ocp.args.Composite(\n",
" model_state=ocp.args.StandardRestore(resumed_weights),\n",
" data_state=grain.checkpoint.CheckpointRestore(resumed_iter),\n",
" ),\n",
" )\n",
" resumed_weights = restored[\"model_state\"]\n",
" # resumed_iter is restored in-place; the next element will be step 5's batch.\n",
" print(f\"Restored from step {latest_step}, model_weights={resumed_weights}\")\n",
"\n",
"# Continue training — the data iterator picks up exactly where it left off.\n",
"for i in range(num_train_steps, num_train_steps + 3):\n",
" batch = next(resumed_iter)\n",
" print(i, batch[\"file_name\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
115 changes: 112 additions & 3 deletions docs/tutorials/data_loader_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.4
jupytext_version: 1.19.1
kernelspec:
display_name: Python 3
name: python3
Expand Down Expand Up @@ -222,6 +222,8 @@ executionInfo:
id: irJix4sJkNcf
outputId: 648ffc5d-088e-4747-da1d-f0bfd87e4360
---
import grain.checkpoint

data_iter = iter(data_loader)

num_steps = 5
Expand Down Expand Up @@ -251,7 +253,7 @@ mngr = ocp.CheckpointManager("/tmp/orbax")

# Save the checkpoint
assert mngr.save(
step=num_steps, args=grain.GrainCheckpointSave(data_iter), force=True)
step=num_steps, args=grain.checkpoint.CheckpointSave(data_iter), force=True)
# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until
# finished before examining checkpoint.
mngr.wait_until_finished()
Expand Down Expand Up @@ -312,7 +314,7 @@ id: Js3hheiGnykN
outputId: 31cc1d8d-9a89-4fa8-af7a-5c74650f118e
---
# Restore iterator from previously saved checkpoint
mngr.restore(num_steps, args=grain.GrainCheckpointRestore(data_iter))
mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(data_iter))
```

```{code-cell}
Expand All @@ -334,6 +336,113 @@ for i in range(5, 10):
print(i, x["file_name"], x["label"])
```

+++ {"id": "composite_checkpoint_intro"}

### Checkpointing alongside model state

In a real training run you typically want to save and restore **both** your model
state and the data iterator in a single atomic checkpoint, so that resuming
training is fully reproducible. Orbax's `ocp.args.Composite` lets you combine
multiple objects in one `save` / `restore` call.

The most important rule: **keep a direct reference to the `DataLoaderIterator`
returned by `iter(data_loader)`**. Do not call `iter()` again during the training
loop — that would reset the iterator and lose checkpoint state.

```{code-cell}
:id: composite_checkpoint_setup

# A minimal stand-in for a real model state.
# In practice this would be a JAX/Flax TrainState or a Pytree of parameters.
import numpy as np
model_weights = np.array([0.1, 0.2, 0.3])

# Build a fresh data loader and hold on to the iterator.
ckpt_source = tfds.data_source("imagenet_a", split="test")
ckpt_loader = grain.DataLoader(
data_source=ckpt_source,
operations=[ResizeAndCrop()],
sampler=grain.IndexSampler(
num_records=20,
num_epochs=1,
shard_options=grain.ShardOptions(
shard_index=0, shard_count=1, drop_remainder=True),
shuffle=True,
seed=42),
worker_count=0)

# Keep this iterator alive for the full training run.
ckpt_iter = iter(ckpt_loader)
```

```{code-cell}
:id: composite_checkpoint_save

ckpt_dir = "/tmp/orbax_composite"
!rm -rf {ckpt_dir}

mngr = ocp.CheckpointManager(
ckpt_dir,
options=ocp.CheckpointManagerOptions(max_to_keep=3),
)

# --- Simulate a short training loop ---
num_train_steps = 5
for step in range(num_train_steps):
batch = next(ckpt_iter)
# ... model update would happen here ...

# Save model weights and data iterator state together.
mngr.save(
step=num_train_steps,
args=ocp.args.Composite(
model_state=ocp.args.StandardSave(model_weights),
data_state=grain.checkpoint.CheckpointSave(ckpt_iter),
),
force=True,
)
mngr.wait_until_finished()
print(f"Saved checkpoint at step {num_train_steps}")
```

```{code-cell}
:id: composite_checkpoint_restore

# --- Simulate resuming from a checkpoint ---
# Create a fresh loader/iterator with the same config.
resumed_loader = grain.DataLoader(
data_source=ckpt_source,
operations=[ResizeAndCrop()],
sampler=grain.IndexSampler(
num_records=20,
num_epochs=1,
shard_options=grain.ShardOptions(
shard_index=0, shard_count=1, drop_remainder=True),
shuffle=True,
seed=42),
worker_count=0)
resumed_iter = iter(resumed_loader)
resumed_weights = np.zeros(3)

latest_step = mngr.latest_step()
if latest_step is not None:
restored = mngr.restore(
latest_step,
args=ocp.args.Composite(
model_state=ocp.args.StandardRestore(resumed_weights),
data_state=grain.checkpoint.CheckpointRestore(resumed_iter),
),
)
resumed_weights = restored["model_state"]
# resumed_iter is restored in-place; the next element will be step 5's batch.
print(f"Restored from step {latest_step}, model_weights={resumed_weights}")

# Continue training — the data iterator picks up exactly where it left off.
for i in range(num_train_steps, num_train_steps + 3):
batch = next(resumed_iter)
print(i, batch["file_name"])
```

+++ {"id": "btSRh4EL_Zbo"}

## Extras
Expand Down