Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
"obs, act = next(iter(loader))\n",
"obs, act, _ = next(iter(loader))\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
Expand Down
22 changes: 17 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,36 @@ def train_step(
config: _config.TrainConfig,
rng: at.KeyArrayLike,
state: training_utils.TrainState,
batch: tuple[_model.Observation, _model.Actions],
batch: tuple[_model.Observation, _model.Actions, at.Array | None],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
model = nnx.merge(state.model_def, state.params)
model.train()

@at.typecheck
def loss_fn(
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
model: _model.BaseModel,
rng: at.KeyArrayLike,
observation: _model.Observation,
actions: _model.Actions,
action_pad_mask: at.Array | None,
):
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
return jnp.mean(chunked_loss)
if action_pad_mask is not None and chunked_loss.shape == action_pad_mask.shape:
# applied only to flow matching variants, not FAST
action_pad_mask = ~action_pad_mask
loss = jnp.sum(chunked_loss * action_pad_mask) / (jnp.sum(action_pad_mask) + 1e-8)
else:
loss = jnp.mean(chunked_loss)
return loss

train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
observation, actions, action_pad_mask = batch

# Filter out frozen params.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(
model, train_rng, observation, actions, action_pad_mask
)

params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
Expand Down
12 changes: 9 additions & 3 deletions scripts/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def train_loop(config: _config.TrainConfig):
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
sample_batch = next(iter(sample_data_loader))
# Convert observation and actions to torch tensors
observation, actions = sample_batch
observation, actions, _ = sample_batch
sample_batch = observation.to_dict()
sample_batch["actions"] = actions

Expand Down Expand Up @@ -511,7 +511,7 @@ def lr_schedule(step: int):
if use_ddp and hasattr(loader, "set_epoch"):
loader.set_epoch(global_step // len(loader))

for observation, actions in loader:
for observation, actions, action_pad_mask in loader:
# Check if we've reached the target number of steps
if global_step >= config.num_train_steps:
break
Expand All @@ -520,6 +520,8 @@ def lr_schedule(step: int):
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
actions = actions.to(torch.float32) # noqa: PLW2901
actions = actions.to(device) # noqa: PLW2901
if action_pad_mask is not None:
action_pad_mask = action_pad_mask.to(device) # noqa: PLW2901

# Update LR
for pg in optim.param_groups:
Expand All @@ -533,7 +535,11 @@ def lr_schedule(step: int):
elif not isinstance(losses, torch.Tensor):
losses = torch.tensor(losses, device=device, dtype=torch.float32)

loss = losses.mean()
if action_pad_mask is not None and action_pad_mask.shape == losses.shape[:-1]:
action_pad_mask = ~action_pad_mask[..., None] # noqa: PLW2901
loss = (losses * action_pad_mask).sum() / (action_pad_mask.sum() * losses.shape[-1] + 1e-8)
else:
loss = losses.mean()

# Backward pass
loss.backward()
Expand Down
2 changes: 2 additions & 0 deletions src/openpi/policies/libero_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __call__(self, data: dict) -> dict:
# Actions are only available during training.
if "actions" in data:
inputs["actions"] = data["actions"]
if "actions_is_pad" in data:
inputs["actions_is_pad"] = data["actions_is_pad"]

# Pass the prompt (aka language instruction) to the model.
# Keep this for your own dataset (but modify the key if the instruction is not
Expand Down
1 change: 1 addition & 0 deletions src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
"observation/state": "state",
"actions": "actions",
"prompt": "prompt",
"actions_is_pad": "actions_is_pad",
}
)
]
Expand Down
2 changes: 1 addition & 1 deletion src/openpi/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,4 @@ def data_config(self) -> _config.DataConfig:

def __iter__(self):
for batch in self._data_loader:
yield _model.Observation.from_dict(batch), batch["actions"]
yield _model.Observation.from_dict(batch), batch["actions"], batch.get("actions_is_pad", None)
4 changes: 2 additions & 2 deletions src/openpi/training/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_with_fake_dataset():
for batch in batches:
assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch))

for _, actions in batches:
for _, actions, _ in batches:
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)


Expand All @@ -80,5 +80,5 @@ def test_with_real_dataset():

assert len(batches) == 2

for _, actions in batches:
for _, actions, _ in batches:
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)