diff --git a/src/openpi/models/model.py b/src/openpi/models/model.py index 29618b4945..24d419eb82 100644 --- a/src/openpi/models/model.py +++ b/src/openpi/models/model.py @@ -243,7 +243,7 @@ def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseM def load_pytorch(self, train_config, weight_path: str): logger.info(f"train_config: {train_config}") model = pi0_pytorch.PI0Pytorch(config=train_config.model) - safetensors.torch.load_model(model, weight_path) + safetensors.torch.load_model(model, weight_path, strict=False) return model @abc.abstractmethod