diff --git a/scripts/compute_norm_stats.py b/scripts/compute_norm_stats.py index c8aef87222..82cc83df4b 100644 --- a/scripts/compute_norm_stats.py +++ b/scripts/compute_norm_stats.py @@ -5,6 +5,8 @@ to the config assets directory. """ +import dataclasses + import numpy as np import tqdm import tyro @@ -86,8 +88,10 @@ def create_rlds_dataloader( return data_loader, num_batches -def main(config_name: str, max_frames: int | None = None): +def main(config_name: str, repo_id: str | None = None, max_frames: int | None = None): config = _config.get_config(config_name) + if repo_id is not None: + config = dataclasses.replace(config, data=dataclasses.replace(config.data, repo_id=repo_id)) data_config = config.data.create(config.assets_dirs, config.model) if data_config.rlds_data_dir is not None: diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index 6570df05ed..eb78528c63 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -13,6 +13,23 @@ import openpi.transforms as transforms +def _infer_checkpoint_asset_id(assets_dir: pathlib.Path) -> str: + """Infer the single asset_id stored under a checkpoint's `assets/` directory.""" + if not assets_dir.exists(): + raise FileNotFoundError(f"Checkpoint assets directory does not exist: {assets_dir}") + + candidates = sorted([p.name for p in assets_dir.iterdir() if p.is_dir()]) + if len(candidates) == 1: + return candidates[0] + if len(candidates) == 0: + raise FileNotFoundError(f"No asset directories found under checkpoint assets dir: {assets_dir}") + raise ValueError( + "Cannot infer which asset_id to use from checkpoint assets directory " + f"(multiple candidates found): {candidates}. " + "Please specify the correct asset_id via the training config's data.assets.asset_id." + ) + + def create_trained_policy( train_config: _config.TrainConfig, checkpoint_dir: pathlib.Path | str, @@ -59,9 +76,26 @@ def create_trained_policy( if norm_stats is None: # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure # that the policy is using the same normalization stats as the original training process. - if data_config.asset_id is None: - raise ValueError("Asset id is required to load norm stats.") - norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id) + assets_dir = pathlib.Path(checkpoint_dir) / "assets" + asset_id = data_config.asset_id + + # Fine-tunes can override `--data.repo-id` during training, so the checkpoint may contain a different + # `assets/` than the static train config. In that case, infer the correct asset id from the + # checkpoint itself. + if asset_id is None or not (assets_dir / asset_id).is_dir(): + inferred = _infer_checkpoint_asset_id(assets_dir) + if asset_id is None: + logging.info("No asset_id set in config; inferred %r from checkpoint assets.", inferred) + else: + logging.warning( + "asset_id=%r not found under checkpoint assets (%s); using inferred %r instead.", + asset_id, + assets_dir, + inferred, + ) + asset_id = inferred + + norm_stats = _checkpoints.load_norm_stats(assets_dir, asset_id) # Determine the device to use for PyTorch models if is_pytorch and pytorch_device is None: diff --git a/src/openpi/policies/policy_config_test.py b/src/openpi/policies/policy_config_test.py new file mode 100644 index 0000000000..cf3e9a5d22 --- /dev/null +++ b/src/openpi/policies/policy_config_test.py @@ -0,0 +1,30 @@ +import pytest + +from openpi.policies import policy_config as _policy_config + + +def test_infer_checkpoint_asset_id_single(tmp_path): + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + (assets_dir / "sim1").mkdir() + + assert _policy_config._infer_checkpoint_asset_id(assets_dir) == "sim1" + + +def test_infer_checkpoint_asset_id_none(tmp_path): + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + + with pytest.raises(FileNotFoundError): + _policy_config._infer_checkpoint_asset_id(assets_dir) + + +def test_infer_checkpoint_asset_id_multiple(tmp_path): + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + (assets_dir / "a").mkdir() + (assets_dir / "b").mkdir() + + with pytest.raises(ValueError): + _policy_config._infer_checkpoint_asset_id(assets_dir) + diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 4ca47e1286..d3a914d1f6 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -919,6 +919,41 @@ def __post_init__(self) -> None: # # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. # + TrainConfig( + name="pi0_sim1_aloha_low_mem_finetune", + # LoRA fine-tuning config intended for 24GB-class GPUs (e.g. RTX 4090). + model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), + data=LeRobotAlohaDataConfig( + # Load prompt from LeRobot task metadata (multi-task friendly). + base_config=DataConfig(prompt_from_task=True), + repack_transforms=_transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "images": { + "cam_high": "observation.images.cam_high", + "cam_low": "observation.images.cam_low", + "cam_left_wrist": "observation.images.cam_left_wrist", + "cam_right_wrist": "observation.images.cam_right_wrist", + }, + "state": "observation.state", + "actions": "action", + # Keep prompt injected by PromptFromLeRobotTask. + "prompt": "prompt", + } + ) + ] + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + freeze_filter=pi0_config.Pi0Config( + paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora" + ).get_freeze_filter(), + ema_decay=None, + batch_size=16, + num_train_steps=20_000, + wandb_enabled=False, + ), TrainConfig( name="pi0_aloha_sim", model=pi0_config.Pi0Config(), diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index e2ee7dd06b..2a45f7c22c 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -7,7 +7,10 @@ import jax import jax.numpy as jnp -import lerobot.common.datasets.lerobot_dataset as lerobot_dataset +try: + import lerobot.datasets.lerobot_dataset as lerobot_dataset # type: ignore +except ModuleNotFoundError: + import lerobot.common.datasets.lerobot_dataset as lerobot_dataset # type: ignore import numpy as np import torch @@ -138,12 +141,24 @@ def create_torch_dataset( return FakeDataset(model_config, num_samples=1024) dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) - dataset = lerobot_dataset.LeRobotDataset( - data_config.repo_id, - delta_timestamps={ - key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys - }, - ) + delta_timestamps = { + key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys + } + + dataset_kwargs: dict[str, typing.Any] = {"delta_timestamps": delta_timestamps} + try: + import torchcodec # noqa: F401 + except Exception: + logging.getLogger(__name__).warning( + "'torchcodec' is not available; falling back to 'pyav' video backend for LeRobotDataset" + ) + dataset_kwargs["video_backend"] = "pyav" + + try: + dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id, **dataset_kwargs) + except TypeError: + dataset_kwargs.pop("video_backend", None) + dataset = lerobot_dataset.LeRobotDataset(data_config.repo_id, **dataset_kwargs) if data_config.prompt_from_task: dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])