From f7e5a24a2d9d6524be8a465a9140e08ec642a6ef Mon Sep 17 00:00:00 2001 From: Abhay D Date: Wed, 21 Jan 2026 15:04:49 -0800 Subject: [PATCH] Fix hardcoded action dim in pi0 pytorch model --- src/openpi/models_pytorch/pi0_pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 25f0580ba8..92ff95028a 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -97,14 +97,14 @@ def __init__(self, config): precision=config.dtype, ) - self.action_in_proj = nn.Linear(32, action_expert_config.width) - self.action_out_proj = nn.Linear(action_expert_config.width, 32) + self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) if self.pi05: self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) else: - self.state_proj = nn.Linear(32, action_expert_config.width) + self.state_proj = nn.Linear(config.action_dim, action_expert_config.width) self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)