Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
.venv
checkpoints
data
assets
weights
third_party
__pycache__
.pytest_cache
.git
.gitmodules
*.mp4
*.safetensors
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ assets/
checkpoints/
data/
wandb/
.worktrees/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
49 changes: 49 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# openpi training container
FROM nvcr.io/nvidia/pytorch:25.01-py3

ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
GIT_LFS_SKIP_SMUDGE=1 \
UV_PYTHON=3.11 \
UV_PROJECT_ENVIRONMENT=/.venv

# System deps + uv
RUN apt-get update && \
apt-get install -y --no-install-recommends \
git git-lfs ffmpeg libgl1 build-essential cmake pkg-config \
libavcodec-dev libavformat-dev libswscale-dev libavfilter-dev libavdevice-dev && \
rm -rf /var/lib/apt/lists/* && \
pip install uv

# Install a Python 3.11 runtime for wheels that are missing on 3.12 (e.g., mujoco).
RUN uv python install 3.11


WORKDIR /workspace/repo

# Install dependencies at build time.
COPY . /workspace/repo
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync && \
uv pip install -e .

# Install decord into the uv environment (source for arm64, pip for amd64).
# Arm64 fix: patch FFmpeg includes + AVCodec constness to match newer FFmpeg headers.
ARG TARGETARCH
RUN if [ "$TARGETARCH" = "arm64" ]; then \
git clone --recursive https://github.com/dmlc/decord.git && \
cd decord && \
sed -i 's/#include <libavcodec\/avcodec.h>/#include <libavcodec\/avcodec.h>\n#include <libavcodec\/bsf.h>/' src/video/ffmpeg/ffmpeg_common.h && \
sed -i 's/AVCodec \*dec/const AVCodec \*dec/g' src/video/video_reader.cc && \
mkdir build && cd build && \
cmake .. -DUSE_CUDA=OFF -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
cd ../python && uv pip install . && \
cd ../.. && rm -rf decord; \
else \
uv pip install --prerelease=allow decord; \
fi

CMD ["/bin/bash"]

19 changes: 19 additions & 0 deletions docker/docker_build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
set -e

ARCH="${1:-amd64}"
IMAGE_NAME="${2:-openpi_${ARCH}}"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_DIR="$(dirname "$SCRIPT_DIR")"

if [ "$ARCH" = "amd64" ]; then
docker build -t "$IMAGE_NAME" -f "${SCRIPT_DIR}/Dockerfile" "$REPO_DIR"
else
docker buildx build \
--platform "linux/${ARCH}" \
-t "$IMAGE_NAME" \
--load \
-f "${SCRIPT_DIR}/Dockerfile" \
"$REPO_DIR"
fi

20 changes: 20 additions & 0 deletions docker/run_local.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash
# Run container locally for training/testing

ARCH="${1:-amd64}"
IMAGE_NAME="${2:-openpi_${ARCH}}"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_DIR="$(dirname "$SCRIPT_DIR")"

docker run --gpus all --rm -it \
-v "${REPO_DIR}:/workspace/repo" \
-v "${REPO_DIR}/.venv:/.venv" \
-v "${HOME}/.cache/huggingface:/root/.cache/huggingface" \
-v "${HOME}/.cache/openpi:/openpi_assets" \
-e "HF_HOME=/root/.cache/huggingface" \
-e "WANDB_MODE=offline" \
-e "WANDB_ENTITY=pravsels" \
-e "OPENPI_DATA_HOME=/openpi_assets" \
-e "PYTHONPATH=/workspace/repo:${PYTHONPATH}" \
"$IMAGE_NAME"

29 changes: 29 additions & 0 deletions docs/norm_stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,35 @@

Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.

## Per-timestep action normalization

For delta-action workflows, we support a parallel per-timestep action normalization file named
`norm_stats_actions_per_timestep.json`. This file contains **actions-only** stats computed separately
for each timestep index in the action horizon (shape `[H, D]`). The regular `norm_stats.json`
remains the source of global stats (including state).

By default, per-timestep action normalization is **auto-enabled** when delta actions are used
(e.g. `use_delta_actions=True`, `use_delta_joint_actions=True`, or `extra_delta_transform=True`).
You can override this behavior by setting `use_per_timestep_action_norm` in the data config.
State normalization always remains global.

To generate these files, use `scripts/compute_norm_stats_per_timestep.py`.

## Experiment-specific assets

By default, normalization stats (and related assets like `valid_indices.txt`) are stored under the
config assets directory: `assets/<config name>/<repo_id>/...`. If you want these files to be
experiment-specific, pass a custom assets base directory when running the scripts and training.

Example:

```bash
ASSETS_DIR="/scratch/.../checkpoints/<CONFIG_NAME>/<EXP_NAME>/assets"
uv run scripts/compute_valid_indices.py <CONFIG_NAME> --assets-base-dir="${ASSETS_DIR}"
uv run scripts/compute_norm_stats_per_timestep.py <CONFIG_NAME> --assets-base-dir="${ASSETS_DIR}"
uv run scripts/train.py <CONFIG_NAME> --exp-name=<EXP_NAME> --assets-base-dir="${ASSETS_DIR}"
```

## Reloading normalization statistics

When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
Expand Down
90 changes: 90 additions & 0 deletions docs/plans/2026-01-29-per-timestep-action-norm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Per-Timestep Action Normalization Implementation Plan

> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

**Goal:** Add a parallel per-timestep action normalization flow (actions only) with a separate stats file, auto-enabled for delta-action configs while preserving existing global normalization.

**Architecture:** Keep `norm_stats.json` unchanged for global stats; compute and store per-timestep action stats in `norm_stats_actions_per_timestep.json`. At runtime, merge per-timestep action stats into the normalization dict only when enabled; otherwise use existing global stats.

**Tech Stack:** Python, NumPy, dataclasses, JAX (existing), JSON I/O

---

### Task 1: Add per-timestep stats I/O + compute script

**Files:**
- Create: `scripts/compute_norm_stats_per_timestep.py`
- Modify: `src/openpi/shared/normalize.py`

**Step 1: Implement new stats file helpers**

- Add helpers in `normalize.py` to save/load `norm_stats_actions_per_timestep.json`.
- Use the existing `NormStats` schema for the actions-only stats file.

**Step 2: Implement compute script**

- Compute global stats for `state` and `actions` using existing `RunningStats` (same as current script).
- Compute per-timestep stats for `actions` by maintaining one `RunningStats` per timestep index and stacking into a single `NormStats` (mean/std/q01/q99 arrays shaped `[H, D]`).
- Write global stats to `norm_stats.json` and per-timestep action stats to `norm_stats_actions_per_timestep.json` in the same assets directory.

**Step 3: Quick verification**

- Run a dry invocation (no dataset) is not possible, so verify script imports and runs to argument parsing without errors:
- `python scripts/compute_norm_stats_per_timestep.py --help`

---

### Task 2: Wire config + data loader + policy for per-timestep actions

**Files:**
- Modify: `src/openpi/training/config.py`
- Modify: `src/openpi/training/data_loader.py`
- Modify: `src/openpi/policies/policy_config.py`
- Modify: `src/openpi/training/checkpoints.py`

**Step 1: Add config fields**

- Extend `DataConfig` with:
- `use_per_timestep_action_norm: bool | None = None`
- `per_timestep_action_norm_stats: NormStats | None = None`
- In `DataConfigFactory.create_base_config`, load the per-timestep actions stats file into `per_timestep_action_norm_stats` (no change to existing `norm_stats`).

**Step 2: Auto-enable for delta action configs**

- In `LeRobotBinPackDataConfig.create`, if `use_delta_actions=True` and `use_per_timestep_action_norm is None`, set it to `True`.
- In `LeRobotAlohaDataConfig.create`, if `use_delta_joint_actions=True` and `use_per_timestep_action_norm is None`, set it to `True`.
- In `LeRobotLiberoDataConfig.create`, if `extra_delta_transform=True` and `use_per_timestep_action_norm is None`, set it to `True`.

**Step 3: Merge stats for normalization**

- Add a small helper (in `normalize.py` or `data_loader.py`) that merges `actions` from `per_timestep_action_norm_stats` into the `norm_stats` dict when enabled.
- Use that merged dict in `data_loader.transform_dataset` and `transform_iterable_dataset`.
- In `policy_config.create_trained_policy`, use merged stats for `Normalize`/`Unnormalize` so inference uses per-timestep actions when enabled.

**Step 4: Save per-timestep stats in checkpoints**

- In `checkpoints.save_state`, if `per_timestep_action_norm_stats` exists and `asset_id` is set, save it to the assets directory using the new helper.

---

### Task 3: Documentation update

**Files:**
- Modify: `docs/norm_stats.md`

**Step 1: Document new file and behavior**

- Describe `norm_stats_actions_per_timestep.json` as actions-only, per-timestep stats.
- Note auto-enabling when delta actions are used, and how to override via config.
- Mention that state normalization remains global.

**Step 2: Quick verification**

- Ensure docs build without errors (no formal command required).

---

## Execution Notes

- **TDD waived** per user permission; do not write failing tests first.
- **Commits** only if user explicitly requests them (system constraint).
51 changes: 51 additions & 0 deletions docs/training_metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Training Metrics & Losses (openpi)

This document explains what the current training scripts log, and what those values mean.

## 1) What is logged today

### JAX trainer (`scripts/train.py`)
- `loss`: mean of the model’s `compute_loss(...)` across batch and horizon.
- `grad_norm`: global norm of raw gradients (before optimizer clipping).
- `param_norm`: global norm of a subset of kernel parameters (bias/scale/pos/input embeddings excluded).
- `camera_views`: logged once at step 0 from the first batch for a sanity check.

### PyTorch trainer (`scripts/train_pytorch.py`)
- `loss`: mean of per-element loss from the model’s forward pass.
- `learning_rate`: current LR (cosine schedule with warmup).
- `grad_norm`: global norm of grads as returned by `clip_grad_norm_` (pre-clip norm).
- `time_per_step`: average wall-clock time per step over the log interval.
- `checkpoint_step`: logged when a checkpoint is saved.
- `camera_views`: logged once at step 0 from the first batch for a sanity check.

Note: There is no eval loop or eval metrics in the current trainers.

## 2) Loss definitions by model

### π0 / π0.5 (JAX)
Flow-matching loss on actions:
1) Sample noise `ε ~ N(0, I)` and time `t ~ Beta(1.5, 1)` in `(0, 1)`.
2) Build a noisy action `x_t = t·ε + (1 − t)·a`.
3) Target velocity `u_t = ε − a`.
4) Model predicts `v_t`.
5) Loss = mean squared error over action dimensions: `MSE(v_t, u_t)`.

The trainer averages the loss across batch and horizon.

### π0 / π0.5 (PyTorch)
Same flow-matching loss as JAX, implemented in `PI0Pytorch.forward`, returning per-element MSE with
`reduction="none"`, then the trainer averages it.

### π0-FAST (JAX only)
Autoregressive token loss:
- The model predicts the next token.
- Cross-entropy is computed only where `token_loss_mask` is true.
- `token_loss_mask` is produced by the FAST tokenizer (loss is applied only to postfix tokens).

PyTorch training does not support π0-FAST.

## 3) What is not logged by default
- No eval loss, no rollout metrics.
- No per-dimension action error.
- No policy success metrics.

34 changes: 34 additions & 0 deletions docs/training_trends.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Training Curves: Good vs Problematic Trends (openpi)

This document describes what to look for in training logs based on current instrumentation.

## 1) Metrics available today
- JAX: `loss`, `grad_norm`, `param_norm`
- PyTorch: `loss`, `learning_rate`, `grad_norm`, `time_per_step`
- Both: one-time `camera_views` sanity check

Note: There are no eval metrics unless you add them.

## 2) Generally good signals (based on current logs)
- `loss` decreases over time and then plateaus.
- `grad_norm` is finite and relatively stable (no NaNs/Infs).
- `param_norm` is finite and stable (JAX only).
- `learning_rate` follows the configured schedule (PyTorch only).
- `time_per_step` is roughly stable (PyTorch only).
- `camera_views` look correctly aligned and normalized.

## 3) Potentially problematic signals
- `loss` is NaN/Inf or diverges upward.
- `grad_norm` becomes NaN/Inf or shows repeated huge spikes.
- `param_norm` explodes or collapses to zero (JAX only).
- `time_per_step` steadily increases (PyTorch only), suggesting data-loader or memory issues.
- `camera_views` look corrupted, badly normalized, or mismatched with actions.

## 4) What’s missing (needs your guidance)
- Expected loss ranges for your specific dataset/task.
- Acceptable grad/param norm ranges.
- Task success or rollout quality metrics.
- Eval loss or per-task validation curves.

If you want concrete “good/bad” thresholds, we need to define them from prior runs or desired targets.

2 changes: 1 addition & 1 deletion packages/openpi-client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires-python = ">=3.7"
dependencies = [
"dm-tree>=0.1.8",
"msgpack>=1.0.5",
"numpy>=1.22.4,<2.0.0",
"numpy>=1.22.4,<3.0.0",
"pillow>=9.0.0",
"tree>=0.2.4",
"websockets>=11.0",
Expand Down
13 changes: 5 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"jaxtyping==0.2.36",
"lerobot",
"ml_collections==1.0.0",
"numpy>=1.22.4,<2.0.0",
"numpy>=1.22.4,<3.0.0",
"numpydantic>=1.6.6",
"opencv-python>=4.10.0.84",
"openpi-client",
Expand All @@ -37,6 +37,8 @@ dependencies = [
"transformers==4.53.2",
"rich>=14.0.0",
"polars>=1.30.0",
"robocandywrapper>=0.2.6",
"rewact_tools",
]


Expand All @@ -53,19 +55,14 @@ dev = [
"matplotlib>=3.10.0",
"pynvml>=12.0.0",
]
rlds = [
"dlimp",
"tensorflow-cpu==2.15.0",
"tensorflow-datasets==4.9.9",
]

[tool.uv]
override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
default-groups = []

[tool.uv.sources]
openpi-client = { workspace = true }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "v0.4.3" }

[tool.uv.workspace]
members = ["packages/*"]
Expand Down
Loading
Loading