Skip to content
Open

V0 #823

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ assets/
checkpoints/
data/
wandb/
lerobot/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 3 additions & 3 deletions examples/aloha_real/convert_aloha_data_to_lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from typing import Literal

import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
Expand Down
3 changes: 2 additions & 1 deletion examples/aloha_sim/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def reset(self) -> None:
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
# breakpoint()

@override
def is_episode_complete(self) -> bool:
Expand All @@ -46,7 +47,7 @@ def apply_action(self, action: dict) -> None:

def _convert_observation(self, gym_obs: dict) -> dict:
img = gym_obs["pixels"]["top"]
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 640, 640))
# Convert axis order from [H, W, C] --> [C, H, W]
img = np.transpose(img, (2, 0, 1))

Expand Down
2 changes: 1 addition & 1 deletion examples/aloha_sim/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Args:
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")

task: str = "gym_aloha/AlohaTransferCube-v0"
task: str = "gym_aloha/AlohaInsertion-v0"
seed: int = 0

action_horizon: int = 10
Expand Down
4 changes: 2 additions & 2 deletions examples/droid/convert_droid_data_to_lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
from PIL import Image
from tqdm import tqdm
Expand Down
4 changes: 2 additions & 2 deletions examples/libero/convert_libero_data_to_lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import shutil

from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import tensorflow_datasets as tfds
import tyro

Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"imageio>=2.36.1",
"jax[cuda12]==0.5.3",
"jaxtyping==0.2.36",
"lerobot",
# "lerobot",
"ml_collections==1.0.0",
"numpy>=1.22.4,<2.0.0",
"numpydantic>=1.6.6",
Expand All @@ -26,7 +26,7 @@ dependencies = [
"orbax-checkpoint==0.11.13",
"pillow>=11.0.0",
"sentencepiece>=0.2.0",
"torch==2.7.1",
"torch==2.8.0+cu128",
"tqdm-loggable>=0.2",
"typing-extensions>=4.12.2",
"tyro>=0.9.5",
Expand Down Expand Up @@ -61,10 +61,11 @@ rlds = [

[tool.uv]
override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
extra-index-url = ["https://download.pytorch.org/whl/cu128"]

[tool.uv.sources]
openpi-client = { workspace = true }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
# lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }

[tool.uv.workspace]
Expand Down
8 changes: 8 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

# Generate timestamp in YYYYMMDDHHMMSS format
TIMESTAMP=$(date +"%Y%m%d%H%M%S")

# Training script for pi05 aloha simulation with automatic timestamp
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
python scripts/train.py pi05_aloha_sim_transfer_cube_scripted --exp-name=${TIMESTAMP}
3 changes: 2 additions & 1 deletion src/openpi/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def initialize_checkpoint_dir(
max_to_keep=1,
keep_period=keep_period,
create=False,
async_options=ocp.AsyncOptions(timeout_secs=7200),
enable_async_checkpointing=False,
# async_options=ocp.AsyncOptions(timeout_secs=7200),
),
)

Expand Down
58 changes: 57 additions & 1 deletion src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class TrainConfig:
# How often (in steps) to save checkpoints.
save_interval: int = 1000
# If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.
keep_period: int | None = 5000
keep_period: int | None = 10000

# If true, will overwrite the checkpoint directory if it already exists.
overwrite: bool = False
Expand Down Expand Up @@ -922,6 +922,62 @@ def __post_init__(self) -> None:
num_train_steps=20_000,
),
#
# ALOHA Sim pi05 configs. This config is used to demonstrate how to train on a simple simulated environment.
#
TrainConfig(
name="pi05_aloha_sim_insertion_human",
model=pi0_config.Pi0Config(pi05=True),
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_sim_insertion_human",
default_prompt="Insert the peg into the socket.",
use_delta_joint_actions=False,
),
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000,
peak_lr=1e-5,
decay_steps=50_000,
decay_lr=1e-6,
),

weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
num_train_steps=50_000,

batch_size=32,
num_workers=4,
),
TrainConfig(
name="pi05_aloha_sim_transfer_cube_human",
model=pi0_config.Pi0Config(pi05=True),
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_sim_transfer_cube_human",
default_prompt="Pick up the cube with the right arm and transfer it to the left arm.",
use_delta_joint_actions=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
num_train_steps=50_000,

batch_size=12,
num_workers=4,

# video_backend="pyav",
),
TrainConfig(
name="pi05_aloha_sim_transfer_cube_scripted",
model=pi0_config.Pi0Config(pi05=True),
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_sim_transfer_cube_scripted",
default_prompt="Pick up the cube with the right arm and transfer it to the left arm.",
use_delta_joint_actions=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
num_train_steps=50_000,

batch_size=12,
num_workers=4,

# video_backend="pyav",
),
#
# Debugging configs.
#
TrainConfig(
Expand Down
3 changes: 2 additions & 1 deletion src/openpi/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import jax
import jax.numpy as jnp
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
import lerobot.datasets.lerobot_dataset as lerobot_dataset
import numpy as np
import torch

Expand Down Expand Up @@ -143,6 +143,7 @@ def create_torch_dataset(
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
video_backend='pyav',
)

if data_config.prompt_from_task:
Expand Down
Loading