From 5f59fb3023a963bbb38ca1a9814ea7e63bda1f45 Mon Sep 17 00:00:00 2001 From: MSTLE <2437721575@qq.com> Date: Tue, 14 Oct 2025 13:50:59 +0800 Subject: [PATCH] Update convert_aloha_data_to_lerobot.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 适配 lerobot新版本3.0+ 脚本修改总结: 将导入从 lerobot.common.datasets.lerobot_dataset.LEROBOT_HOME 改为 lerobot.common.constants.HF_LEROBOT_HOME 移除了不存在的 download_raw 导入 删除了 port_aloha() 函数的 raw_repo_id 参数 将所有 LEROBOT_HOME 引用替换为 HF_LEROBOT_HOME(共 4 处) 在 create_empty_dataset() 中添加 cameras 参数,移除硬编码的相机列表 在 load_raw_episode_data() 中添加 cameras 参数,实现动态相机支持 在 populate_dataset() 中添加 cameras 参数并传递给子函数 在 port_aloha() 中添加 cameras = get_cameras(hdf5_files) 自动检测相机 在每个 frame 字典中添加 "task": task 字段 将 dataset.save_episode(task=task) 改为 dataset.save_episode() 移除了 dataset.consolidate() 调用 核心变化:适配 LeRobot 新版本 API,支持自动检测相机配置 --- .../convert_aloha_data_to_lerobot.py | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/examples/aloha_real/convert_aloha_data_to_lerobot.py b/examples/aloha_real/convert_aloha_data_to_lerobot.py index a3a8ddcb24..b0f7544253 100644 --- a/examples/aloha_real/convert_aloha_data_to_lerobot.py +++ b/examples/aloha_real/convert_aloha_data_to_lerobot.py @@ -10,9 +10,8 @@ from typing import Literal import h5py -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw import numpy as np import torch import tqdm @@ -34,6 +33,7 @@ class DatasetConfig: def create_empty_dataset( repo_id: str, robot_type: str, + cameras: list[str], mode: Literal["video", "image"] = "video", *, has_velocity: bool = False, @@ -56,12 +56,6 @@ def create_empty_dataset( "left_wrist_rotate", "left_gripper", ] - cameras = [ - "cam_high", - "cam_low", - "cam_left_wrist", - "cam_right_wrist", - ] features = { "observation.state": { @@ -109,8 +103,8 @@ def create_empty_dataset( ], } - if Path(LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) + if Path(HF_LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(HF_LEROBOT_HOME / repo_id) return LeRobotDataset.create( repo_id=repo_id, @@ -164,6 +158,7 @@ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, n def load_raw_episode_data( ep_path: Path, + cameras: list[str], ) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: with h5py.File(ep_path, "r") as ep: state = torch.from_numpy(ep["/observations/qpos"][:]) @@ -177,15 +172,7 @@ def load_raw_episode_data( if "/observations/effort" in ep: effort = torch.from_numpy(ep["/observations/effort"][:]) - imgs_per_cam = load_raw_images_per_camera( - ep, - [ - "cam_high", - "cam_low", - "cam_left_wrist", - "cam_right_wrist", - ], - ) + imgs_per_cam = load_raw_images_per_camera(ep, cameras) return imgs_per_cam, state, action, velocity, effort @@ -193,6 +180,7 @@ def load_raw_episode_data( def populate_dataset( dataset: LeRobotDataset, hdf5_files: list[Path], + cameras: list[str], task: str, episodes: list[int] | None = None, ) -> LeRobotDataset: @@ -202,13 +190,14 @@ def populate_dataset( for ep_idx in tqdm.tqdm(episodes): ep_path = hdf5_files[ep_idx] - imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path) + imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path, cameras) num_frames = state.shape[0] for i in range(num_frames): frame = { "observation.state": state[i], "action": action[i], + "task": task, } for camera, img_array in imgs_per_cam.items(): @@ -221,7 +210,7 @@ def populate_dataset( dataset.add_frame(frame) - dataset.save_episode(task=task) + dataset.save_episode() return dataset @@ -229,7 +218,6 @@ def populate_dataset( def port_aloha( raw_dir: Path, repo_id: str, - raw_repo_id: str | None = None, task: str = "DEBUG", *, episodes: list[int] | None = None, @@ -238,19 +226,22 @@ def port_aloha( mode: Literal["video", "image"] = "image", dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG, ): - if (LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) + if (HF_LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(HF_LEROBOT_HOME / repo_id) if not raw_dir.exists(): - if raw_repo_id is None: - raise ValueError("raw_repo_id must be provided if raw_dir does not exist") - download_raw(raw_dir, repo_id=raw_repo_id) + raise ValueError(f"Raw directory {raw_dir} does not exist. Please provide a valid path to the raw data.") hdf5_files = sorted(raw_dir.glob("episode_*.hdf5")) + + # Get camera names from the first episode + cameras = get_cameras(hdf5_files) + print(f"Detected cameras: {cameras}") dataset = create_empty_dataset( repo_id, robot_type="mobile_aloha" if is_mobile else "aloha", + cameras=cameras, mode=mode, has_effort=has_effort(hdf5_files), has_velocity=has_velocity(hdf5_files), @@ -259,10 +250,10 @@ def port_aloha( dataset = populate_dataset( dataset, hdf5_files, + cameras=cameras, task=task, episodes=episodes, ) - dataset.consolidate() if push_to_hub: dataset.push_to_hub()