Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions examples/aloha_real/convert_aloha_data_to_lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from typing import Literal

import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
Expand All @@ -34,6 +33,7 @@ class DatasetConfig:
def create_empty_dataset(
repo_id: str,
robot_type: str,
cameras: list[str],
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
Expand All @@ -56,12 +56,6 @@ def create_empty_dataset(
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]

features = {
"observation.state": {
Expand Down Expand Up @@ -109,8 +103,8 @@ def create_empty_dataset(
],
}

if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if Path(HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)

return LeRobotDataset.create(
repo_id=repo_id,
Expand Down Expand Up @@ -164,6 +158,7 @@ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, n

def load_raw_episode_data(
ep_path: Path,
cameras: list[str],
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
Expand All @@ -177,22 +172,15 @@ def load_raw_episode_data(
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])

imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
],
)
imgs_per_cam = load_raw_images_per_camera(ep, cameras)

return imgs_per_cam, state, action, velocity, effort


def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
cameras: list[str],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
Expand All @@ -202,13 +190,14 @@ def populate_dataset(
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]

imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path, cameras)
num_frames = state.shape[0]

for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
"task": task,
}

for camera, img_array in imgs_per_cam.items():
Expand All @@ -221,15 +210,14 @@ def populate_dataset(

dataset.add_frame(frame)

dataset.save_episode(task=task)
dataset.save_episode()

return dataset


def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
Expand All @@ -238,19 +226,22 @@ def port_aloha(
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)

if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
raise ValueError(f"Raw directory {raw_dir} does not exist. Please provide a valid path to the raw data.")

hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))

# Get camera names from the first episode
cameras = get_cameras(hdf5_files)
print(f"Detected cameras: {cameras}")

dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
cameras=cameras,
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
Expand All @@ -259,10 +250,10 @@ def port_aloha(
dataset = populate_dataset(
dataset,
hdf5_files,
cameras=cameras,
task=task,
episodes=episodes,
)
dataset.consolidate()

if push_to_hub:
dataset.push_to_hub()
Expand Down