Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion scripts/compute_norm_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
to the config assets directory.
"""

import dataclasses

import numpy as np
import tqdm
import tyro
Expand Down Expand Up @@ -86,8 +88,10 @@ def create_rlds_dataloader(
return data_loader, num_batches


def main(config_name: str, max_frames: int | None = None):
def main(config_name: str, repo_id: str | None = None, max_frames: int | None = None):
config = _config.get_config(config_name)
if repo_id is not None:
config = dataclasses.replace(config, data=dataclasses.replace(config.data, repo_id=repo_id))
data_config = config.data.create(config.assets_dirs, config.model)

if data_config.rlds_data_dir is not None:
Expand Down
40 changes: 37 additions & 3 deletions src/openpi/policies/policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
import openpi.transforms as transforms


def _infer_checkpoint_asset_id(assets_dir: pathlib.Path) -> str:
"""Infer the single asset_id stored under a checkpoint's `assets/` directory."""
if not assets_dir.exists():
raise FileNotFoundError(f"Checkpoint assets directory does not exist: {assets_dir}")

candidates = sorted([p.name for p in assets_dir.iterdir() if p.is_dir()])
if len(candidates) == 1:
return candidates[0]
if len(candidates) == 0:
raise FileNotFoundError(f"No asset directories found under checkpoint assets dir: {assets_dir}")
raise ValueError(
"Cannot infer which asset_id to use from checkpoint assets directory "
f"(multiple candidates found): {candidates}. "
"Please specify the correct asset_id via the training config's data.assets.asset_id."
)


def create_trained_policy(
train_config: _config.TrainConfig,
checkpoint_dir: pathlib.Path | str,
Expand Down Expand Up @@ -59,9 +76,26 @@ def create_trained_policy(
if norm_stats is None:
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
# that the policy is using the same normalization stats as the original training process.
if data_config.asset_id is None:
raise ValueError("Asset id is required to load norm stats.")
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
assets_dir = pathlib.Path(checkpoint_dir) / "assets"
asset_id = data_config.asset_id

# Fine-tunes can override `--data.repo-id` during training, so the checkpoint may contain a different
# `assets/<asset_id>` than the static train config. In that case, infer the correct asset id from the
# checkpoint itself.
if asset_id is None or not (assets_dir / asset_id).is_dir():
inferred = _infer_checkpoint_asset_id(assets_dir)
if asset_id is None:
logging.info("No asset_id set in config; inferred %r from checkpoint assets.", inferred)
else:
logging.warning(
"asset_id=%r not found under checkpoint assets (%s); using inferred %r instead.",
asset_id,
assets_dir,
inferred,
)
asset_id = inferred

norm_stats = _checkpoints.load_norm_stats(assets_dir, asset_id)

# Determine the device to use for PyTorch models
if is_pytorch and pytorch_device is None:
Expand Down
30 changes: 30 additions & 0 deletions src/openpi/policies/policy_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from openpi.policies import policy_config as _policy_config


def test_infer_checkpoint_asset_id_single(tmp_path):
assets_dir = tmp_path / "assets"
assets_dir.mkdir()
(assets_dir / "sim1").mkdir()

assert _policy_config._infer_checkpoint_asset_id(assets_dir) == "sim1"


def test_infer_checkpoint_asset_id_none(tmp_path):
assets_dir = tmp_path / "assets"
assets_dir.mkdir()

with pytest.raises(FileNotFoundError):
_policy_config._infer_checkpoint_asset_id(assets_dir)


def test_infer_checkpoint_asset_id_multiple(tmp_path):
assets_dir = tmp_path / "assets"
assets_dir.mkdir()
(assets_dir / "a").mkdir()
(assets_dir / "b").mkdir()

with pytest.raises(ValueError):
_policy_config._infer_checkpoint_asset_id(assets_dir)

35 changes: 35 additions & 0 deletions src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,41 @@ def __post_init__(self) -> None:
#
# ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
#
TrainConfig(
name="pi0_sim1_aloha_low_mem_finetune",
# LoRA fine-tuning config intended for 24GB-class GPUs (e.g. RTX 4090).
model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
data=LeRobotAlohaDataConfig(
# Load prompt from LeRobot task metadata (multi-task friendly).
base_config=DataConfig(prompt_from_task=True),
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_low": "observation.images.cam_low",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
# Keep prompt injected by PromptFromLeRobotTask.
"prompt": "prompt",
}
)
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
ema_decay=None,
batch_size=16,
num_train_steps=20_000,
wandb_enabled=False,
),
TrainConfig(
name="pi0_aloha_sim",
model=pi0_config.Pi0Config(),
Expand Down
29 changes: 22 additions & 7 deletions src/openpi/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import jax
import jax.numpy as jnp
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
try:
import lerobot.datasets.lerobot_dataset as lerobot_dataset # type: ignore
except ModuleNotFoundError:
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset # type: ignore
import numpy as np
import torch

Expand Down Expand Up @@ -138,12 +141,24 @@ def create_torch_dataset(
return FakeDataset(model_config, num_samples=1024)

dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
delta_timestamps = {
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
}

dataset_kwargs: dict[str, typing.Any] = {"delta_timestamps": delta_timestamps}
try:
import torchcodec # noqa: F401
except Exception:
logging.getLogger(__name__).warning(
"'torchcodec' is not available; falling back to 'pyav' video backend for LeRobotDataset"
)
dataset_kwargs["video_backend"] = "pyav"

try:
dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id, **dataset_kwargs)
except TypeError:
dataset_kwargs.pop("video_backend", None)
dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id, **dataset_kwargs)

if data_config.prompt_from_task:
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
Expand Down