diff --git a/src/openpi/models/pi0_config.py b/src/openpi/models/pi0_config.py index b0f6b662ac..9e51c569bf 100644 --- a/src/openpi/models/pi0_config.py +++ b/src/openpi/models/pi0_config.py @@ -23,6 +23,7 @@ class Pi0Config(_model.BaseModelConfig): # Set the model specific defaults. action_dim: int = 32 + actual_action_dim: int = 32 # the actual action dim in your dataset action_horizon: int = 50 max_token_len: int = None # type: ignore # Pi05 has two differences from Pi0: diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 25f0580ba8..aa808395f2 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -86,6 +86,11 @@ def __init__(self, config): super().__init__() self.config = config self.pi05 = config.pi05 + + # introduce action mask for later loss compute + action_mask = torch.zeros(1, 1, self.config.action_dim, dtype=torch.bool) + action_mask[:, :, :self.config.actual_action_dim] = True + self.register_buffer("action_mask", action_mask) paligemma_config = _gemma.get_config(config.paligemma_variant) action_expert_config = _gemma.get_config(config.action_expert_variant) @@ -370,7 +375,7 @@ def action_out_proj_func(suffix_out): v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) - return F.mse_loss(u_t, v_t, reduction="none") + return F.mse_loss(u_t, v_t, reduction="none") * self.action_mask # use mask to mask out losses from padding actions @torch.no_grad() def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: