From 63e90b74f4a9357ffe6a3657df6d4a6597d037de Mon Sep 17 00:00:00 2001 From: Armstrong Date: Tue, 3 Mar 2026 12:03:04 -0800 Subject: [PATCH 1/4] mask loss on padded actions --- scripts/train.py | 15 ++++++++++----- src/openpi/policies/libero_policy.py | 2 ++ src/openpi/training/config.py | 1 + src/openpi/training/data_loader.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 5d289413ab..0245defcee 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -138,24 +138,29 @@ def train_step( config: _config.TrainConfig, rng: at.KeyArrayLike, state: training_utils.TrainState, - batch: tuple[_model.Observation, _model.Actions], + batch: tuple[_model.Observation, _model.Actions, at.Array | None], ) -> tuple[training_utils.TrainState, dict[str, at.Array]]: model = nnx.merge(state.model_def, state.params) model.train() @at.typecheck def loss_fn( - model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions + model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, action_pad_mask: at.Array | None ): chunked_loss = model.compute_loss(rng, observation, actions, train=True) - return jnp.mean(chunked_loss) + if action_pad_mask is not None and chunked_loss.shape == action_pad_mask.shape: + # applied only to flow matching variants, not FAST + action_pad_mask = ~action_pad_mask + return jnp.sum(chunked_loss * action_pad_mask) / (jnp.sum(action_pad_mask) + 1e-8) + else: + return jnp.mean(chunked_loss) train_rng = jax.random.fold_in(rng, state.step) - observation, actions = batch + observation, actions, action_pad_mask = batch # Filter out frozen params. diff_state = nnx.DiffState(0, config.trainable_filter) - loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions) + loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions, action_pad_mask) params = state.params.filter(config.trainable_filter) updates, new_opt_state = state.tx.update(grads, state.opt_state, params) diff --git a/src/openpi/policies/libero_policy.py b/src/openpi/policies/libero_policy.py index 10611f61be..a60fadddab 100644 --- a/src/openpi/policies/libero_policy.py +++ b/src/openpi/policies/libero_policy.py @@ -73,6 +73,8 @@ def __call__(self, data: dict) -> dict: # Actions are only available during training. if "actions" in data: inputs["actions"] = data["actions"] + if "actions_is_pad" in data: + inputs["actions_is_pad"] = data["actions_is_pad"] # Pass the prompt (aka language instruction) to the model. # Keep this for your own dataset (but modify the key if the instruction is not diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 4ca47e1286..5ffcae35c9 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -307,6 +307,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig "observation/state": "state", "actions": "actions", "prompt": "prompt", + "actions_is_pad" : "actions_is_pad" } ) ] diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index e2ee7dd06b..74e337b23c 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -537,4 +537,4 @@ def data_config(self) -> _config.DataConfig: def __iter__(self): for batch in self._data_loader: - yield _model.Observation.from_dict(batch), batch["actions"] + yield _model.Observation.from_dict(batch), batch["actions"], batch["actions_is_pad"] if "actions_is_pad" in batch else None From 3ae80f6e53177f976c8bf65b01f3f313da45297b Mon Sep 17 00:00:00 2001 From: Armstrong Date: Tue, 3 Mar 2026 13:09:11 -0800 Subject: [PATCH 2/4] formatting --- examples/inference.ipynb | 2 +- scripts/train.py | 17 ++++++++++++----- scripts/train_pytorch.py | 4 ++-- src/openpi/training/config.py | 2 +- src/openpi/training/data_loader.py | 2 +- src/openpi/training/data_loader_test.py | 4 ++-- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/inference.ipynb b/examples/inference.ipynb index 2f125880f9..4d01992100 100644 --- a/examples/inference.ipynb +++ b/examples/inference.ipynb @@ -101,7 +101,7 @@ "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n", "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n", "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n", - "obs, act = next(iter(loader))\n", + "obs, act, _ = next(iter(loader))\n", "\n", "# Sample actions from the model.\n", "loss = model.compute_loss(key, obs, act)\n", diff --git a/scripts/train.py b/scripts/train.py index 0245defcee..42a3abddc2 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -145,22 +145,29 @@ def train_step( @at.typecheck def loss_fn( - model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, action_pad_mask: at.Array | None + model: _model.BaseModel, + rng: at.KeyArrayLike, + observation: _model.Observation, + actions: _model.Actions, + action_pad_mask: at.Array | None, ): chunked_loss = model.compute_loss(rng, observation, actions, train=True) - if action_pad_mask is not None and chunked_loss.shape == action_pad_mask.shape: + if action_pad_mask is not None and chunked_loss.shape == action_pad_mask.shape: # applied only to flow matching variants, not FAST action_pad_mask = ~action_pad_mask - return jnp.sum(chunked_loss * action_pad_mask) / (jnp.sum(action_pad_mask) + 1e-8) + loss = jnp.sum(chunked_loss * action_pad_mask) / (jnp.sum(action_pad_mask) + 1e-8) else: - return jnp.mean(chunked_loss) + loss = jnp.mean(chunked_loss) + return loss train_rng = jax.random.fold_in(rng, state.step) observation, actions, action_pad_mask = batch # Filter out frozen params. diff_state = nnx.DiffState(0, config.trainable_filter) - loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions, action_pad_mask) + loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)( + model, train_rng, observation, actions, action_pad_mask + ) params = state.params.filter(config.trainable_filter) updates, new_opt_state = state.tx.update(grads, state.opt_state, params) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c7ddd2b595..f225275ff7 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -364,7 +364,7 @@ def train_loop(config: _config.TrainConfig): sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) sample_batch = next(iter(sample_data_loader)) # Convert observation and actions to torch tensors - observation, actions = sample_batch + observation, actions, _ = sample_batch sample_batch = observation.to_dict() sample_batch["actions"] = actions @@ -511,7 +511,7 @@ def lr_schedule(step: int): if use_ddp and hasattr(loader, "set_epoch"): loader.set_epoch(global_step // len(loader)) - for observation, actions in loader: + for observation, actions, _ in loader: # Check if we've reached the target number of steps if global_step >= config.num_train_steps: break diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 5ffcae35c9..5ef5629de1 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -307,7 +307,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig "observation/state": "state", "actions": "actions", "prompt": "prompt", - "actions_is_pad" : "actions_is_pad" + "actions_is_pad": "actions_is_pad", } ) ] diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index 74e337b23c..c29fb831ed 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -537,4 +537,4 @@ def data_config(self) -> _config.DataConfig: def __iter__(self): for batch in self._data_loader: - yield _model.Observation.from_dict(batch), batch["actions"], batch["actions_is_pad"] if "actions_is_pad" in batch else None + yield _model.Observation.from_dict(batch), batch["actions"], batch.get("actions_is_pad", None) diff --git a/src/openpi/training/data_loader_test.py b/src/openpi/training/data_loader_test.py index d15a73529e..b4ea5909e1 100644 --- a/src/openpi/training/data_loader_test.py +++ b/src/openpi/training/data_loader_test.py @@ -58,7 +58,7 @@ def test_with_fake_dataset(): for batch in batches: assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) - for _, actions in batches: + for _, actions, _ in batches: assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) @@ -80,5 +80,5 @@ def test_with_real_dataset(): assert len(batches) == 2 - for _, actions in batches: + for _, actions, _ in batches: assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) From e29848d9def4ec44a204b53f6f096b23f8ff8d9f Mon Sep 17 00:00:00 2001 From: Armstrong Date: Tue, 3 Mar 2026 13:57:27 -0800 Subject: [PATCH 3/4] padded action masking for pytorch --- scripts/train_pytorch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index f225275ff7..16c0a100b8 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -511,7 +511,7 @@ def lr_schedule(step: int): if use_ddp and hasattr(loader, "set_epoch"): loader.set_epoch(global_step // len(loader)) - for observation, actions, _ in loader: + for observation, actions, action_pad_mask in loader: # Check if we've reached the target number of steps if global_step >= config.num_train_steps: break @@ -520,6 +520,7 @@ def lr_schedule(step: int): observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 actions = actions.to(torch.float32) # noqa: PLW2901 actions = actions.to(device) # noqa: PLW2901 + action_pad_mask = action_pad_mask.to(device) # noqa: PLW2901 # Update LR for pg in optim.param_groups: @@ -533,7 +534,11 @@ def lr_schedule(step: int): elif not isinstance(losses, torch.Tensor): losses = torch.tensor(losses, device=device, dtype=torch.float32) - loss = losses.mean() + if action_pad_mask is not None and action_pad_mask.shape == losses.shape[:-1]: + action_pad_mask = ~action_pad_mask[..., None] # noqa: PLW2901 + loss = (losses * action_pad_mask).sum() / (action_pad_mask.sum() * losses.shape[-1] + 1e-8) + else: + loss = losses.mean() # Backward pass loss.backward() From 7cc30a8780ebf2dc4fcf07069b03f954754738ad Mon Sep 17 00:00:00 2001 From: Armstrong Date: Tue, 3 Mar 2026 14:06:57 -0800 Subject: [PATCH 4/4] handle none --- scripts/train_pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index 16c0a100b8..eca7000cd7 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -520,7 +520,8 @@ def lr_schedule(step: int): observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 actions = actions.to(torch.float32) # noqa: PLW2901 actions = actions.to(device) # noqa: PLW2901 - action_pad_mask = action_pad_mask.to(device) # noqa: PLW2901 + if action_pad_mask is not None: + action_pad_mask = action_pad_mask.to(device) # noqa: PLW2901 # Update LR for pg in optim.param_groups: