diff --git a/multiworld/core/image_env.py b/multiworld/core/image_env.py index 3d4b87ac..744165fb 100644 --- a/multiworld/core/image_env.py +++ b/multiworld/core/image_env.py @@ -185,7 +185,7 @@ def _get_flat_img(self): if self.normalize: image_obs = image_obs / 255.0 if self.transpose: - image_obs = image_obs.transpose() + image_obs = image_obs.transpose((2, 0, 1)) assert image_obs.shape[0] == self.channels return image_obs.flatten() diff --git a/multiworld/envs/mujoco/mujoco_env.py b/multiworld/envs/mujoco/mujoco_env.py index f9c4e64b..5e5db3fb 100644 --- a/multiworld/envs/mujoco/mujoco_env.py +++ b/multiworld/envs/mujoco/mujoco_env.py @@ -145,7 +145,7 @@ def get_image(self, width=84, height=84, camera_name=None): width=width, height=height, camera_name=camera_name, - ) + )[::-1,:,:] def initialize_camera(self, init_fctn): sim = self.sim