diff --git a/.gitmodules b/.gitmodules index 4abd60e977..ebce9d0865 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,7 @@ url = https://github.com/Physical-Intelligence/aloha.git [submodule "third_party/libero"] path = third_party/libero - url = https://github.com/Lifelong-Robot-Learning/LIBERO.git + url = https://github.com/szhaovas/LIBERO.git +[submodule "third_party/SimplerEnv"] + path = third_party/SimplerEnv + url = https://github.com/simpler-env/SimplerEnv.git diff --git a/README.md b/README.md index 7280d04ea4..cf9dbab808 100644 --- a/README.md +++ b/README.md @@ -1,195 +1,30 @@ -# openpi - -openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/). - -Currently, this repo contains two types of models: -- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based diffusion vision-language-action model (VLA) -- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer. - -For both models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets. - -This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see! - -## Updates - -- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID. - - -## Requirements - -To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training. - -| Mode | Memory Required | Example GPU | -| ------------------ | --------------- | ------------------ | -| Inference | > 8 GB | RTX 4090 | -| Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 | -| Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 | - -The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems. - ## Installation - -When cloning this repo, make sure to update submodules: - -```bash -git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git - -# Or if you already cloned the repo: -git submodule update --init --recursive -``` - -We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment: - ```bash +git clone --recurse-submodules https://github.com/szhaovas/openpi.git +cd openpi GIT_LFS_SKIP_SMUDGE=1 uv sync GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . +uv venv --python 3.8 examples/libero/.venv +source examples/libero/.venv/bin/activate +uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match +uv pip install -e packages/openpi-client +uv pip install -e third_party/libero ``` -NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency. - -**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details. - - - - -## Model Checkpoints - -### Base Models -We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning. - -| Model | Use Case | Description | Checkpoint Path | -| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- | -| $\pi_0$ | Fine-Tuning | Base diffusion [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` | -| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` | - -### Fine-Tuned Models -We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice. - -| Model | Use Case | Description | Checkpoint Path | -| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | -| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` | -| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` | -| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` | -| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` | -| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` | - - -By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable. - - - - -## Running Inference for a Pre-Trained Model - -Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model): -```python -from openpi.training import config -from openpi.policies import policy_config -from openpi.shared import download - -config = config.get_config("pi0_fast_droid") -checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid") - -# Create a trained policy. -policy = policy_config.create_trained_policy(config, checkpoint_dir) - -# Run inference on a dummy example. -example = { - "observation/exterior_image_1_left": ..., - "observation/wrist_image_left": ..., - ... - "prompt": "pick up the fork" -} -action_chunk = policy.infer(example)["actions"] -``` -You can also test this out in the [example notebook](examples/inference.ipynb). - -We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots. - -**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate. - -**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details. - - - - - -## Fine-Tuning Base Models on Your Own Data - -We will fine-tune the $\pi_0$-FAST model on the [Libero dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps: -1. Convert your data to a LeRobot dataset (which we use for training) -2. Defining training configs and running training -3. Spinning up a policy server and running inference - -### 1. Convert your data to a LeRobot dataset - -We provide a minimal example script for converting Libero data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw Libero dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with: - -```bash -uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data -``` - -### 2. Defining training configs and running training - -To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for Libero below, which you can modify for your own dataset: - -- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the Libero environment to the model and vice versa. Will be used for both, training and inference. -- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw Libero data from LeRobot dataset for training. -- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader. - -We provide example fine-tuning configs for both, [π₀](src/openpi/training/config.py) and [π₀-FAST](src/openpi/training/config.py) on Libero data. - -Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config: - +## Experiment ```bash -uv run scripts/compute_norm_stats.py --config-name pi0_fast_libero +./run_experiment.sh ``` +Archive heatmaps, checkpoints, and other metrics are logged to `test_logs`. -Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config): - +## Visualization +Needs to be run within the LIBERO venv. Make sure [dash](https://pypi.org/project/dash/) is installed. ```bash -XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=my_experiment --overwrite +source examples/libero/.venv/bin/activate +python -m pip install dash ``` - -The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%). - -**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file. - -### 3. Spinning up a policy server and running inference - -Once training is complete, we can run inference by spinning up a policy server and then querying it from a Libero evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed): - -```bash -uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero --policy.dir=checkpoints/pi0_fast_libero/my_experiment/20000 +Enter the scheduler `.pkl` checkpoint you wish to visualize near the end of `viz_spatial_attack.py`, and run +```python +python viz_spatial_attack.py ``` - -This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md). - -If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md). - - - -### More Examples - -We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs: -- [ALOHA Simulator](examples/aloha_sim) -- [ALOHA Real](examples/aloha_real) -- [UR5](examples/ur5) - - - -## Troubleshooting - -We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines). - -| Issue | Resolution | -| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). | -| Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training to allow JAX to use more GPU memory. You can also try reducing the batch size in your training config. | -| Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. | -| Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. | -| Dataset download fails | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). | -| CUDA/GPU errors | Verify NVIDIA drivers and CUDA toolkit are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. | -| Import errors when running examples | Make sure you've installed all dependencies with `uv sync` and activated the virtual environment. Some examples may have additional requirements listed in their READMEs. | -| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. | - +This will display an interactive archive heatmap at `localhost:8050`. You can view it in the browser and click on a cell to save rollouts of that cell's solution to `interactive_vids`. If you are on ssh, you can also configure port forwarding to view and interact with heatmap on your own computer. \ No newline at end of file diff --git a/convert_google_robot_to_lerobot.py b/convert_google_robot_to_lerobot.py new file mode 100644 index 0000000000..34e1780c43 --- /dev/null +++ b/convert_google_robot_to_lerobot.py @@ -0,0 +1,74 @@ +import numpy as np +import json + +from pathlib import Path + +from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import tensorflow_datasets as tfds +import tyro + +google_demo_conversion_cfg = { + "repo_id": "hchen/google_robot", + "demo_dir": "./simpler_env_demos/demo_collection/collected_data" +} + +def main(data_cfg): + demo_root_dir = Path(data_cfg["demo_dir"]) + task_dirs = [d for d in demo_root_dir.iterdir() if d.is_dir()] + task_metadata_paths = [d / "metadata.json" for d in task_dirs] + + task_metadata = [json.loads(p.read_text()) for p in task_metadata_paths] + + dataset = LeRobotDataset.create( + repo_id=data_cfg["repo_id"], + robot_type="google_robot", + fps=10, + features={ + "image": { + "dtype": "image", + "shape": (224, 224, 3), + "names": ["height", "width", "channel"] + }, + "state": { + "dtype": "float32", + "shape": (8,), + "names": ["state"], + }, + "actions": { + "dtype": "float32", + "shape": (7,), + "names": ["actions"], + } + }, + image_writer_threads=10, + image_writer_processes=5 + ) + + for single_task_ds in task_metadata: + success_episode_files = ["./simpler_env_demos/demo_collection/" + file for file in single_task_ds["success_episode_files"]] + for i, episode_file in enumerate(success_episode_files): + traj_data = np.load(episode_file) + + traj_images = traj_data["images"] + traj_states = traj_data["states"] + traj_actions = traj_data["actions"] + traj_language = str(traj_data["language"]) + + episode_len = len(traj_images) + + for t in range(episode_len): + dataset.add_frame( + { + "image": traj_images[t], + "state": traj_states[t], + "actions": traj_actions[t], + "task": traj_language + } + ) + + dataset.save_episode() + print(f"Saved trajectory {i} out of {single_task_ds['total_saved_episodes']} for {single_task_ds['environment_name']}") + +if __name__ == "__main__": + main(data_cfg=google_demo_conversion_cfg) \ No newline at end of file diff --git a/create_lerobot_ds.py b/create_lerobot_ds.py new file mode 100644 index 0000000000..0e4b106c04 --- /dev/null +++ b/create_lerobot_ds.py @@ -0,0 +1,106 @@ +""" +Minimal example script for converting a dataset to LeRobot format. + +We use the Libero dataset (stored in RLDS) for this example, but it can be easily +modified for any other data you have saved in a custom format. + +Usage: +uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data + +If you want to push your dataset to the Hugging Face Hub, you can use the following command: +uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub + +Note: to run the script, you need to install tensorflow_datasets: +`uv pip install tensorflow tensorflow_datasets` + +You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds +The resulting dataset will get saved to the $HF_LEROBOT_HOME directory. +Running this conversion script will take approximately 30 minutes. +""" + +import shutil + +from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import tensorflow_datasets as tfds +import tyro + +import numpy as np +import pickle + +import os, struct + +def main(scheduler_pkl_path="./test_logs/scheduler_00000010.pkl", + repo_name="hchen/libero"): + # Clean up any existing dataset in the output directory + output_path = HF_LEROBOT_HOME / repo_name + if output_path.exists(): + shutil.rmtree(output_path) + + # Create LeRobot dataset, define features to store + # OpenPi assumes that proprio is stored in `state` and actions in `action` + # LeRobot assumes that dtype of image data is `image` + dataset = LeRobotDataset.create( + repo_id=repo_name, + robot_type="panda", + fps=10, + features={ + "image": { + "dtype": "image", + "shape": (256, 256, 3), + "names": ["height", "width", "channel"], + }, + "wrist_image": { + "dtype": "image", + "shape": (256, 256, 3), + "names": ["height", "width", "channel"], + }, + "state": { + "dtype": "float64", + "shape": (8,), + "names": ["state"], + }, + "actions": { + "dtype": "float64", + "shape": (7,), + "names": ["actions"], + }, + }, + image_writer_threads=10, + image_writer_processes=5, + ) + + with open(scheduler_pkl_path, "rb") as f: + scheduler = pickle.load(f) + + archive = scheduler.archive + all_trajectories = archive.data("trajectories") + + include_failures = True + + for elite_trajectories in all_trajectories: + for traj_id, traj in enumerate(elite_trajectories): + episode_len = np.array(traj["image"]).shape[0] + + if not include_failures and not traj["success"]: + continue + + for step in range(episode_len): + dataset.add_frame( + { + "image": traj["image"][step], + "wrist_image": traj["wrist_image"][step], + "state": traj["state"][step], + "actions": traj["action"][step], + "task": traj["prompt"] + } + ) + + dataset.save_episode() + print(f"Saved trajectory {traj_id}!") + + print(dataset) + +if __name__ == "__main__": + main(scheduler_pkl_path="./test_logs/scheduler_00000010.pkl", + repo_name="hchen/libero") \ No newline at end of file diff --git a/eval_base_libero_logs/summary.csv b/eval_base_libero_logs/summary.csv new file mode 100644 index 0000000000..19d7a3b0df --- /dev/null +++ b/eval_base_libero_logs/summary.csv @@ -0,0 +1,51 @@ +env_num,success_rate +0,1.0 +1,1.0 +2,1.0 +3,1.0 +4,0.8 +5,1.0 +6,1.0 +7,1.0 +8,1.0 +9,1.0 +10,1.0 +11,1.0 +12,1.0 +13,1.0 +14,1.0 +15,1.0 +16,1.0 +17,1.0 +18,1.0 +19,1.0 +20,1.0 +21,0.8 +22,1.0 +23,1.0 +24,1.0 +25,1.0 +26,1.0 +27,1.0 +28,1.0 +29,1.0 +30,0.8 +31,1.0 +32,1.0 +33,1.0 +34,0.8 +35,1.0 +36,1.0 +37,1.0 +38,0.8 +39,1.0 +40,1.0 +41,1.0 +42,0.8 +43,1.0 +44,1.0 +45,1.0 +46,1.0 +47,0.8 +48,1.0 +49,1.0 diff --git a/eval_finetuned_libero_logs/summary.csv b/eval_finetuned_libero_logs/summary.csv new file mode 100644 index 0000000000..3f84a41231 --- /dev/null +++ b/eval_finetuned_libero_logs/summary.csv @@ -0,0 +1,51 @@ +env_num,success_rate +0,1.0 +1,1.0 +2,1.0 +3,1.0 +4,1.0 +5,1.0 +6,1.0 +7,0.8 +8,1.0 +9,0.8 +10,1.0 +11,1.0 +12,1.0 +13,0.8 +14,1.0 +15,1.0 +16,0.8 +17,1.0 +18,1.0 +19,0.8 +20,1.0 +21,0.8 +22,1.0 +23,0.8 +24,1.0 +25,1.0 +26,1.0 +27,1.0 +28,1.0 +29,1.0 +30,1.0 +31,1.0 +32,0.8 +33,1.0 +34,1.0 +35,1.0 +36,1.0 +37,1.0 +38,1.0 +39,1.0 +40,1.0 +41,0.8 +42,0.8 +43,1.0 +44,1.0 +45,1.0 +46,1.0 +47,1.0 +48,0.8 +49,1.0 diff --git a/eval_finetuned_random_logs/summary.csv b/eval_finetuned_random_logs/summary.csv new file mode 100644 index 0000000000..d4790cce13 --- /dev/null +++ b/eval_finetuned_random_logs/summary.csv @@ -0,0 +1,225 @@ +env_num,success_rate +0,0 +1,0 +2,0 +3,0 +4,0 +5,0 +6,0 +7,0 +8,0 +9,0 +10,0 +11,0 +12,0 +13,0 +14,0 +15,0 +16,0 +17,0 +18,0 +19,0 +20,0 +21,0 +22,0 +23,0 +24,0 +25,0 +26,0 +27,0 +28,0 +29,0 +30,0 +31,0 +32,0 +33,0 +34,0 +35,0 +36,0 +37,0 +38,0 +39,0 +40,0 +41,0 +42,0 +43,0 +44,0 +45,0 +46,0 +47,0 +48,0 +49,0 +50,0 +51,0 +52,0 +53,0 +54,0 +55,0 +56,0 +57,0 +58,0 +59,0 +60,0 +61,0 +62,0 +63,0 +64,0 +65,0 +66,0 +67,0 +68,0 +69,0 +70,0 +71,0 +72,0 +73,0 +74,0 +75,0 +76,0 +77,0 +78,0 +79,0 +80,0 +81,0 +82,0.4 +83,0 +84,0 +85,0 +86,0 +87,0 +88,0 +89,0 +90,0 +91,0 +92,0 +93,0 +94,0 +95,0 +96,0 +97,0 +98,0 +99,0 +100,0 +101,0 +102,0 +103,0 +104,0 +105,0 +106,0 +107,0 +108,0 +109,0 +110,0 +111,0 +112,0 +113,0 +114,0 +115,0 +116,0 +117,0 +118,0 +119,0 +120,0 +121,0 +122,0 +123,0 +124,0 +125,0 +126,0 +127,0 +128,0 +129,0 +130,0 +131,0 +132,0 +133,0 +134,0 +135,0 +136,0 +137,0 +138,0 +139,0 +140,0 +141,0 +142,0 +143,0 +144,0 +145,0 +146,0 +147,0 +148,0 +149,0 +150,0 +151,0 +152,0 +153,0 +154,0 +155,0 +156,0 +157,0 +158,0 +159,0 +160,0 +161,0 +162,0 +163,0 +164,0 +165,0 +166,0 +167,0 +168,0 +169,0 +170,0 +171,0 +172,0 +173,0 +174,0 +175,0 +176,0 +177,0 +178,0 +179,0 +180,0 +181,0 +182,0 +183,0 +184,0 +185,0 +186,0 +187,0 +188,0 +189,0 +190,0 +191,0 +192,0 +193,0 +194,0 +195,0 +196,0 +197,0 +198,0 +199,0 +200,0 +201,0 +202,0.4 +203,0 +204,0 +205,0 +206,0 +207,0 +208,0 +209,0 +210,0 +211,0 +212,0 +213,0 +214,0 +215,0 +216,0 +217,0 +218,0 +219,0 +220,0 +221,0 +222,0 +223,0 diff --git a/eval_random_logs/summary.csv b/eval_random_logs/summary.csv new file mode 100644 index 0000000000..62fb5383ad --- /dev/null +++ b/eval_random_logs/summary.csv @@ -0,0 +1,51 @@ +env_num,success_rate +0,0 +1,0 +2,0.2 +3,0 +4,0 +5,0 +6,0 +7,0 +8,0 +9,0 +10,0 +11,0 +12,0 +13,0 +14,0 +15,0 +16,0 +17,0 +18,0 +19,0 +20,0 +21,0 +22,0 +23,0 +24,0 +25,0 +26,0 +27,0 +28,0 +29,0 +30,0 +31,0 +32,0 +33,0 +34,0 +35,0 +36,0 +37,0 +38,0 +39,0 +40,0 +41,0 +42,0 +43,0 +44,0 +45,0 +46,0 +47,0 +48,0 +49,"(1e-06, 0, 0, None)" diff --git a/evaluate_google_robot.py b/evaluate_google_robot.py new file mode 100644 index 0000000000..bed94f9566 --- /dev/null +++ b/evaluate_google_robot.py @@ -0,0 +1,84 @@ +import third_party.SimplerEnv.simpler_env as simpler_env +from third_party.SimplerEnv.simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict +import numpy as np +import imageio +import collections + +import openpi.training.config as config +from openpi.policies import policy_config +from openpi.shared import download +from openpi_client import websocket_client_policy as _websocket_client_policy + +from tqdm import trange, tqdm + +experiment_cfg = { + "env_name": "google_robot_pick_coke_can", + + "max_timesteps": 50, + "ntrials": 5, + "env_seed": 42 +} + +pi_cfg = { + "replan_steps": 5, + "host": "0.0.0.0", + "port": 8000, +} + +def main(experiment_cfg, + pi_cfg): + # load simpler env + env = simpler_env.make(experiment_cfg["env_name"]) + prompt = env.get_language_instruction() + + # load openpi client + client = _websocket_client_policy.WebsocketClientPolicy(pi_cfg["host"], + pi_cfg["port"]) + + for i in trange(experiment_cfg["ntrials"]): + t = 0 + + image_seq = [] + action_plan = collections.deque() + obs, reset_info = env.reset(seed=(experiment_cfg["env_seed"] + i)) + + while t < experiment_cfg["max_timesteps"]: + tcp_pose = np.asarray(obs["extra"]["tcp_pose"], dtype=np.float32) + qpos = np.asarray(obs["agent"]["qpos"], dtype=np.float32) + finger_l, finger_r = qpos[7], qpos[8] + gripper = 0.5 * (finger_l + finger_r) + input_obs = np.concatenate([tcp_pose, [gripper]]) + + image = get_image_from_maniskill2_obs_dict(env, obs) + + if not action_plan: + element = { + "state": input_obs, + "image": image, + "prompt": str(prompt) + } + + action_chunk = client.infer(element)["actions"] + action_plan.extend(action_chunk[: pi_cfg["replan_steps"]]) + + action = action_plan.popleft() + action = np.array(action, dtype=np.float32, copy=True) + + print("ACTION") + print(action) + print("STATE") + print(input_obs) + + obs, reward, done, truncated, info = env.step(action) + + image_seq.append(image) + t += 1 + + imageio.mimwrite( + f"trial_{i}.mp4", + [np.asarray(x) for x in image_seq], + fps=10, + ) + +if __name__ == "__main__": + main(experiment_cfg, pi_cfg) \ No newline at end of file diff --git a/evaluate_libero_env.py b/evaluate_libero_env.py new file mode 100644 index 0000000000..8da3109493 --- /dev/null +++ b/evaluate_libero_env.py @@ -0,0 +1,356 @@ +import collections +import csv +import datetime +import math +import pickle as pkl +import re +from functools import partial +from pathlib import Path + +import fire +import imageio +import numpy as np + +from libero.libero import benchmark +from libero.libero import get_libero_path +from libero.libero.envs import OffScreenRenderEnv + +from openpi_client import websocket_client_policy as _websocket_client_policy + +from tqdm import tqdm, trange + +task_5_bddl = ( + Path(get_libero_path("bddl_files")) + / "custom" + / "pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate.bddl" +) + +TASK_ENV = partial( + OffScreenRenderEnv, + bddl_file_name=task_5_bddl, + camera_heights=256, + camera_widths=256, +) + +def _quat2axisangle(quat): + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + +def extract_env_solutions(scheduler_pkl): + with open(scheduler_pkl, "rb") as f: + scheduler = pkl.load(f) + + archive = scheduler.archive + params = archive.data(fields="solution") + return params + +def evaluate_libero_base(host, + port, + ntrials, + max_steps, + num_steps_wait, + replan_steps, + seed): + # setting up libero + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict["libero_spatial"]() + + task_id = -1 + for i in range(10): + task = task_suite.get_task(i) + if task.language == "pick up the black bowl next to the ramekin and place it on the plate": + task_id = i + break + + initial_states = task_suite.get_task_init_states(task_id) + task = task_suite.get_task(task_id) + task_description = task.language + + print(task_description) + + task_bddl_file = Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file + env_args = {"bddl_file_name": task_bddl_file, "camera_heights": 256, "camera_widths": 256} + + env = OffScreenRenderEnv(**env_args) + env.seed(seed) + + # setting up openpi + client = _websocket_client_policy.WebsocketClientPolicy(host, port) + + success_rate = 0 + for episode_idx in trange(ntrials): + # Reset environment + env.reset() + action_plan = collections.deque() + + # Set initial states + if initial_states is None: + obs = env.env._get_observations() + else: + obs = env.set_init_state(initial_states[episode_idx]) + + success = False + t = 0 + while t < max_steps + num_steps_wait: + try: + if t < num_steps_wait: + obs, reward, done, info = env.step([0.0] * 6 + [-1.0]) + t += 1 + continue + + # Get preprocessed image + # IMPORTANT: rotate 180 degrees to match train preprocessing + img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) + wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) + + if not action_plan: + # Finished executing previous action chunk -- compute new chunk + # Prepare observations dict + element = { + "observation/image": img, + "observation/wrist_image": wrist_img, + "observation/state": np.concatenate( + ( + obs["robot0_eef_pos"], + _quat2axisangle(obs["robot0_eef_quat"]), + obs["robot0_gripper_qpos"], + ) + ), + "prompt": env.language_instruction, + } + + # Query model to get action + action_chunk = client.infer(element)["actions"] + assert ( + len(action_chunk) >= replan_steps + ), f"We want to replan every {replan_steps} steps, but policy only predicts {len(action_chunk)} steps." + action_plan.extend(action_chunk[: replan_steps]) + + action = action_plan.popleft() + + # Execute action in environment + obs, reward, done, info = env.step(action.tolist()) + if done: + success_rate += 1 / ntrials + success = True + break + t += 1 + + except Exception as e: + break + + return success_rate + + +def evaluate(params, + host, + port, + ntrials, + max_steps, + num_steps_wait, + replan_steps, + seed, + video_logdir=None): + np.random.seed(seed) + openpi_client = _websocket_client_policy.WebsocketClientPolicy(host, port) + + env = TASK_ENV( + params=params, + repair_env=True, + repair_config={ + 'time_limit':1500, + 'seed':seed + } + ) + + env.seed(seed) + obs = env.reset() + + if obs is None: + return 1e-6, 0, 0, None + + if video_logdir is not None: + sol_logdir = Path(video_logdir) / f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" + sol_logdir.mkdir(parents=True) + + success_rate = 0 + for trial_id in trange(ntrials): + obs = env.reset() + action_plan = collections.deque() + + success = False + images = [] + for t in trange(max_steps + num_steps_wait): + try: + if t < num_steps_wait: + # Do nothing at the start to wait for env to settle + obs, reward, done, info = env.step([0.0] * 6 + [-1.0]) + continue + + img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) + images.append(img) + wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) + + if not action_plan: + element = { + "observation/image": img, + "observation/wrist_image": wrist_img, + "observation/state": np.concatenate( + ( + obs["robot0_eef_pos"], + _quat2axisangle(obs["robot0_eef_quat"]), + obs["robot0_gripper_qpos"], + ) + ), + "prompt": env.language_instruction, + } + + action_chunk = openpi_client.infer(element)["actions"] + assert ( + len(action_chunk) >= replan_steps + ), f"We want to replan every {replan_steps} steps, but policy only predicts {len(action_chunk)} steps." + action_plan.extend(action_chunk[: replan_steps]) + + action = action_plan.popleft() + obs, reward, done, info = env.step(action.tolist()) + if done: + success_rate += 1 / ntrials + success = True + break + except Exception as e: + print(e) + # TODO: How to handle solutions that fail to evaluate + return 1e-6, 0, 0, None + + if video_logdir is not None: + imageio.mimwrite( + sol_logdir / f"trial{trial_id}_{'success' if success else 'fail'}.mp4", + images, + fps=10, + ) + + return success_rate + +def main( + experiment_cfg_name +): + experiment_configs = [ + { + "name": "pi0_libero_base", + "outdir": "eval_base_libero_logs", + "scheduler_pkl": "./test_logs/scheduler_00000105.pkl", + "finetune_true": False, + "ntrials": 5, + "seed": 42, + "max_steps": 220, + "num_steps_wait": 10, + "host": "0.0.0.0", + "port": 8000, + "replan_steps": 5 + }, + { + "name": "pi0_libero_finetuned", + "outdir": "eval_finetuned_libero_logs", + "scheduler_pkl": "./test_logs/scheduler_00000105.pkl", + "finetune_true": False, + "ntrials": 5, + "seed": 42, + "max_steps": 220, + "num_steps_wait": 10, + "host": "0.0.0.0", + "port": 8001, + "replan_steps": 5 + }, + { + "name": "pi0_qd_random_base", + "outdir": "eval_random_logs", + + "scheduler_pkl": "./test_logs_random/scheduler_00000080.pkl", + "finetune_true": False, + "ntrials": 5, + "seed": 42, + "max_steps": 220, + "num_steps_wait": 10, + "host": "0.0.0.0", + "port": 8002, + "replan_steps": 5 + }, + { + "name": "pi0_qd_random_finetuned", + "outdir": "eval_finetuned_random_logs", + "scheduler_pkl": "./test_logs_random/scheduler_00000080.pkl", + "finetune_true": False, + "ntrials": 5, + "seed": 42, + "max_steps": 220, + "num_steps_wait": 10, + "host": "0.0.0.0", + "port": 8003, + "replan_steps": 5 + } + ] + + for cfg in experiment_configs: + if cfg["name"] == experiment_cfg_name: + experiment_cfg = cfg + break + + logdir = Path(experiment_cfg["outdir"]) + + logdir.mkdir(exist_ok=True) + summary_filename = logdir / "summary.csv" + + with open(summary_filename, "w") as summary_file: + writer = csv.writer(summary_file) + writer.writerow(["env_num", "success_rate"]) + + params = None + if "qd" in experiment_cfg["name"]: + params = extract_env_solutions(experiment_cfg["scheduler_pkl"]) + + if params is not None: + for sol_id, sol in enumerate(params): + success_rate = evaluate(params=sol, + host=experiment_cfg["host"], + port=experiment_cfg["port"], + ntrials=experiment_cfg["ntrials"], + max_steps=experiment_cfg["max_steps"], + num_steps_wait=experiment_cfg["num_steps_wait"], + replan_steps=experiment_cfg["replan_steps"], + seed=experiment_cfg["seed"]+sol_id, + video_logdir="./vids") + + with open(summary_filename, "a", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow([sol_id, success_rate]) + else: + for sol_id in np.arange(50): + success_rate = evaluate_libero_base(host=experiment_cfg["host"], + port=experiment_cfg["port"], + ntrials=experiment_cfg["ntrials"], + max_steps=experiment_cfg["max_steps"], + num_steps_wait=experiment_cfg["num_steps_wait"], + replan_steps=experiment_cfg["replan_steps"], + seed=experiment_cfg["seed"]+sol_id) + + with open(summary_filename, "a", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow([sol_id, success_rate]) + +if __name__ == "__main__": + # python evaluate_libero_env.py --experiment_cfg_name="pi0_libero_finetuned" + + fire.Fire(main) \ No newline at end of file diff --git a/examples/libero/main.py b/examples/libero/main.py index dc015a6174..7d04ec9e71 100644 --- a/examples/libero/main.py +++ b/examples/libero/main.py @@ -32,10 +32,10 @@ class Args: # LIBERO environment-specific parameters ################################################################################################################# task_suite_name: str = ( - "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90 + "custom" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90, custom ) num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim - num_trials_per_task: int = 50 # Number of rollouts per task + num_trials_per_task: int = 5 # Number of rollouts per task ################################################################################################################# # Utils @@ -57,7 +57,7 @@ def eval_libero(args: Args) -> None: pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True) - if args.task_suite_name == "libero_spatial": + if args.task_suite_name in ["libero_spatial", "custom"]: max_steps = 220 # longest training demo has 193 steps elif args.task_suite_name == "libero_object": max_steps = 280 # longest training demo has 254 steps @@ -94,7 +94,10 @@ def eval_libero(args: Args) -> None: action_plan = collections.deque() # Set initial states - obs = env.set_init_state(initial_states[episode_idx]) + if initial_states is None: + obs = env.env._get_observations() + else: + obs = env.set_init_state(initial_states[episode_idx]) # Setup t = 0 @@ -168,7 +171,7 @@ def eval_libero(args: Args) -> None: suffix = "success" if done else "failure" task_segment = task_description.replace(" ", "_") imageio.mimwrite( - pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4", + pathlib.Path(args.video_out_path) / f"rollout_{task_id}_{episode_idx}.mp4", [np.asarray(x) for x in replay_images], fps=10, ) diff --git a/examples/libero/requirements.in b/examples/libero/requirements.in index 149006564d..20a6b439ae 100644 --- a/examples/libero/requirements.in +++ b/examples/libero/requirements.in @@ -4,8 +4,13 @@ tqdm tyro PyYaml opencv-python==4.6.0.66 +--extra-index-url https://download.pytorch.org/whl/cu113 torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 robosuite==1.4.1 matplotlib==3.5.3 +ribs==0.7.1 +fire +shapely +dask[distributed] \ No newline at end of file diff --git a/examples/libero/requirements.txt b/examples/libero/requirements.txt index 1a52b42887..e2020925ab 100644 --- a/examples/libero/requirements.txt +++ b/examples/libero/requirements.txt @@ -6,8 +6,22 @@ certifi==2024.12.14 # via requests charset-normalizer==3.4.0 # via requests +click==8.1.8 + # via + # dask + # distributed +cloudpickle==3.1.1 + # via + # dask + # distributed cycler==0.12.1 # via matplotlib +dask==2023.5.0 + # via + # -r examples/libero/requirements.in + # distributed +distributed==2023.5.0 + # via dask docstring-parser==0.16 # via tyro etils==1.3.0 @@ -16,8 +30,12 @@ eval-type-backport==0.2.0 # via tyro evdev==1.7.1 # via pynput +fire==0.7.0 + # via -r examples/libero/requirements.in fonttools==4.55.3 # via matplotlib +fsspec==2025.3.0 + # via dask glfw==1.12.0 # via mujoco idna==3.10 @@ -27,23 +45,39 @@ imageio==2.35.1 imageio-ffmpeg==0.5.1 # via imageio importlib-metadata==8.5.0 - # via typeguard + # via + # dask + # typeguard importlib-resources==6.4.5 # via etils +jinja2==3.1.6 + # via distributed +joblib==1.4.2 + # via scikit-learn kiwisolver==1.4.7 # via matplotlib llvmlite==0.36.0 # via numba +locket==1.0.0 + # via + # distributed + # partd markdown-it-py==3.0.0 # via rich +markupsafe==2.1.5 + # via jinja2 matplotlib==3.5.3 # via -r examples/libero/requirements.in mdurl==0.1.2 # via markdown-it-py +msgpack==1.1.1 + # via distributed mujoco==3.2.3 # via robosuite numba==0.53.1 - # via robosuite + # via + # ribs + # robosuite numpy==1.22.4 # via # -r examples/libero/requirements.in @@ -51,16 +85,30 @@ numpy==1.22.4 # matplotlib # mujoco # numba + # numpy-groupies # opencv-python + # pandas + # ribs # robosuite + # scikit-learn # scipy + # shapely # torchvision +numpy-groupies==0.9.22 + # via ribs opencv-python==4.6.0.66 # via # -r examples/libero/requirements.in # robosuite packaging==24.2 - # via matplotlib + # via + # dask + # distributed + # matplotlib +pandas==2.0.3 + # via ribs +partd==1.4.1 + # via dask pillow==10.4.0 # via # imageio @@ -68,7 +116,9 @@ pillow==10.4.0 # robosuite # torchvision psutil==6.1.0 - # via imageio + # via + # distributed + # imageio pygments==2.18.0 # via rich pynput==1.7.7 @@ -78,23 +128,39 @@ pyopengl==3.1.7 pyparsing==3.1.4 # via matplotlib python-dateutil==2.9.0.post0 - # via matplotlib + # via + # matplotlib + # pandas python-xlib==0.33 # via pynput +pytz==2025.2 + # via pandas pyyaml==6.0.2 - # via -r examples/libero/requirements.in + # via + # -r examples/libero/requirements.in + # dask + # distributed requests==2.32.3 # via torchvision +ribs==0.7.1 + # via -r examples/libero/requirements.in rich==13.9.4 # via tyro robosuite==1.4.1 # via -r examples/libero/requirements.in +scikit-learn==1.3.2 + # via ribs scipy==1.10.1 - # via robosuite + # via + # ribs + # robosuite + # scikit-learn setuptools==75.3.0 # via # imageio-ffmpeg # numba +shapely==2.0.7 + # via -r examples/libero/requirements.in shtab==1.7.1 # via tyro six==1.17.0 @@ -102,8 +168,25 @@ six==1.17.0 # pynput # python-dateutil # python-xlib +sortedcontainers==2.4.0 + # via + # distributed + # ribs +tblib==3.0.0 + # via distributed termcolor==2.4.0 - # via robosuite + # via + # fire + # robosuite +threadpoolctl==3.5.0 + # via + # ribs + # scikit-learn +toolz==1.0.0 + # via + # dask + # distributed + # partd torch==1.11.0+cu113 # via # -r examples/libero/requirements.in @@ -113,6 +196,8 @@ torchaudio==0.11.0+cu113 # via -r examples/libero/requirements.in torchvision==0.12.0+cu113 # via -r examples/libero/requirements.in +tornado==6.4.2 + # via distributed tqdm==4.67.1 # via -r examples/libero/requirements.in typeguard==4.4.0 @@ -127,8 +212,14 @@ typing-extensions==4.12.2 # tyro tyro==0.9.2 # via -r examples/libero/requirements.in +tzdata==2025.2 + # via pandas urllib3==2.2.3 - # via requests + # via + # distributed + # requests +zict==3.0.0 + # via distributed zipp==3.20.2 # via # etils diff --git a/plot_eval.py b/plot_eval.py new file mode 100644 index 0000000000..53274c7ec2 --- /dev/null +++ b/plot_eval.py @@ -0,0 +1,155 @@ +import pandas as pd +import numpy as np + +from matplotlib import pyplot as plt +import pickle + +from tqdm import trange, tqdm + +def graph_and_value(finetuned_csv, + base_csv): + finetuned_df = pd.read_csv(finetuned_csv) + base_df = pd.read_csv(base_csv) + + finetuned_sr = finetuned_df["success_rate"] + base_sr = base_df["success_rate"] + + # Use the smaller dataset length for comparison + less_n = min(len(base_sr), len(finetuned_sr)) + + finetuned_avg = np.mean(finetuned_sr[:less_n]) + base_avg = np.mean(base_sr[:less_n]) + + print(f"finetuned avg: {finetuned_avg}") + print(f"base avg: {base_avg}") + + # Create a wider figure for a stretched x-axis + plt.figure(figsize=(14, 5)) + + x = np.arange(less_n) # environment numbers 0..299 + width = 1.5 # make bars thicker (0.8 of the bin width) + + # Plot base first, then finetuned on top + plt.bar( + x, + base_sr[:less_n], + width, + color="steelblue", + alpha=0.6, + label="Base", + edgecolor=None + ) + plt.bar( + x, + finetuned_sr[:less_n], + width, + color="darkorange", + alpha=0.6, + label="Finetuned", + edgecolor=None + ) + + plt.xlabel("Environment Number") + plt.ylabel("Success Rate") + plt.title("Finetuned vs Base Success Rates per Environment (Overlaid Bars)") + plt.legend() + plt.grid(True, linestyle="--", alpha=0.6, axis="y") + + # Adjust spacing and save + plt.tight_layout() + plt.savefig("comparison_all.png", dpi=300) + plt.close() + +def train_stats(sched_pkl="./test_logs_legacy/scheduler_00000010.pkl"): + with open(sched_pkl, "rb") as f: + scheduler = pickle.load(f) + + archive = scheduler.archive + + all_traj = archive.data("solution") + print(len(all_traj)) + +def extract_unseen(finetune_sched_pkl="./test_logs_legacy/scheduler_00000010.pkl", + sched_pkl="./test_logs/scheduler_00000105.pkl"): + with open(finetune_sched_pkl, "rb") as f: + training_scheduler = pickle.load(f) + + with open(sched_pkl, "rb") as f: + eval_scheduler = pickle.load(f) + + training_params = training_scheduler.archive.data("solution") + eval_params = eval_scheduler.archive.data("solution") + + sol_id_to_keep = [] + for sol_id, sol in tqdm(enumerate(training_params)): + # print(sol) + # print(eval_params[sol_id]) + if not np.array_equal(sol, eval_params[sol_id]): + sol_id_to_keep.append(sol_id) + + return sol_id_to_keep + +def graph_and_value_unseen(finetuned_csv, + base_csv, + to_keep_sol_ids): + finetuned_df = pd.read_csv(finetuned_csv) + base_df = pd.read_csv(base_csv) + + old_finetuned_sr = np.array(finetuned_df["success_rate"]) + old_base_sr = np.array(base_df["success_rate"]) + + finetuned_sr = old_finetuned_sr[to_keep_sol_ids] + base_sr = old_base_sr[to_keep_sol_ids] + + # Use the smaller dataset length for comparison + less_n = min(len(base_sr), len(finetuned_sr)) + + finetuned_avg = np.mean(finetuned_sr[:less_n]) + base_avg = np.mean(base_sr[:less_n]) + + print(f"finetuned avg: {finetuned_avg}") + print(f"base avg: {base_avg}") + + # Create a wider figure for a stretched x-axis + plt.figure(figsize=(14, 5)) + + x = np.arange(less_n) # environment numbers 0..299 + width = 1.5 # make bars thicker (0.8 of the bin width) + + # Plot base first, then finetuned on top + plt.bar( + x, + base_sr[:less_n], + width, + color="steelblue", + alpha=0.6, + label="Base", + edgecolor=None + ) + plt.bar( + x, + finetuned_sr[:less_n], + width, + color="darkorange", + alpha=0.6, + label="Finetuned", + edgecolor=None + ) + + plt.xlabel("Environment Number") + plt.ylabel("Success Rate") + plt.title("Finetuned vs Base Success Rates per Environment (Overlaid Bars)") + plt.legend() + plt.grid(True, linestyle="--", alpha=0.6, axis="y") + + # Adjust spacing and save + plt.tight_layout() + plt.savefig("comparison_unseen.png", dpi=300) + plt.close() + +# graph_and_value_unseen("./qd/eval_logs_finetuned/summary.csv", +# "./qd/eval_logs_base/summary.csv", +# sol_id_to_keep) + +graph_and_value("./eval_base_libero_logs/summary.csv", + "./eval_finetuned_libero_logs/summary.csv") \ No newline at end of file diff --git a/qd_spatial.py b/qd_spatial.py new file mode 100644 index 0000000000..4d6de57751 --- /dev/null +++ b/qd_spatial.py @@ -0,0 +1,415 @@ +import collections +import csv +import datetime +import math +import pickle as pkl +import re +from functools import partial +from pathlib import Path + +import fire +import imageio +import matplotlib.pyplot as plt +import numpy as np +from dask.distributed import Client, LocalCluster +from libero.libero import get_libero_path +from libero.libero.envs import OffScreenRenderEnv +from openpi_client import websocket_client_policy as _websocket_client_policy +from ribs.archives import GridArchive +from ribs.emitters import EvolutionStrategyEmitter +from ribs.schedulers import Scheduler +from ribs.visualize import grid_archive_heatmap + +from tqdm import tqdm, trange + +task_5_bddl = ( + Path(get_libero_path("bddl_files")) + / "custom" + / "pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate.bddl" +) +TASK_ENV = partial( + OffScreenRenderEnv, + bddl_file_name=task_5_bddl, + camera_heights=256, + camera_widths=256, +) + +max_steps = 220 +num_steps_wait = 10 +host = "0.0.0.0" +port = 8000 +replan_steps = 5 + +def _quat2axisangle(quat): + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + +def _extract_scheduler_itr(filename): + """Tries extracting the iteration number from a scheduler filename following + the ``*/scheduler_[0-9]{8}.pkl`` format, where ``[0-9]{8}`` is the iteration + number. If match fails, returns None. + + Args: + filename (str): A scheduler filename following the + ``*/scheduler_[0-9]{8}.pkl`` format. + + Returns: + itr (int or None): Iteration number if match succeeds, else None. + """ + pattern = r"scheduler_(\d{8})\.pkl" + match = re.search(pattern, filename) + if match: + return int(match.group(1)) + return None + +def evaluate(params, ntrials, seed, video_logdir=None): + """Evaluates param by creating LIBERO environments and computing + objective and measure values from the environments' features and VLA + rollout. + + Args: + params (np.ndarray): Array of shape (solution_dim,) containing a single + solution to be evaluated. + ntrials (int): Number of rollouts for each solution. + seed (int): Seed. + video_logdir (str): Folder for saving rollout videos. If None no video + is saved. + + Return: + objective (float): Entropy of VLA's success rate on the environment + created from ``params``. + spread (float): In the environment created from ``params``, how well do + objects cover the table. + similarity (float): In the environment created from ``params``, how + tightly are objects clustered. + trajectories (np.ndarray): Array of shape (ntrials,) containing all + rollout trajectories. Each rollout trajectory is a dictionary of + the following format: + { + "success": bool, + "prompt": str, + "image": List, + "wrist_image": List, + "state": List, + "action": List + } + """ + np.random.seed(seed) + openpi_client = _websocket_client_policy.WebsocketClientPolicy(host, port) + + env = TASK_ENV( + params=params, + repair_env=True, + repair_config={ + 'time_limit':1500, + 'seed':seed + } + ) + + env.seed(seed) + obs = env.reset() + if obs is None: + # TODO: How to handle solutions that fail to evaluate + return 1e-6, 0, 0, None + + if video_logdir is not None: + # ID each sol with datetime to prevent overwriting + sol_logdir = Path(video_logdir) / f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" + sol_logdir.mkdir(parents=True) + + # compute_spread_similarity must be called at the start before any + # action since actions might change objects' locations + spread, similarity = env.env.compute_spread_similarity() + + trajectories = [] + # Get success rates by running openpi on env + success_rate = 0 + for trial_id in trange(ntrials): + obs = env.reset() + action_plan = collections.deque() + + new_trajectory = { + "success": False, + "prompt": env.language_instruction, + "image": [], + "wrist_image": [], + "state": [], + "action": [] + } + print(f"Evaluating trial {trial_id}") + for t in range(max_steps + num_steps_wait): + try: + if t < num_steps_wait: + # Do nothing at the start to wait for env to settle + obs, reward, done, info = env.step([0.0] * 6 + [-1.0]) + continue + + img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) + wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) + + if not action_plan: + element = { + "observation/image": img, + "observation/wrist_image": wrist_img, + "observation/state": np.concatenate( + ( + obs["robot0_eef_pos"], + _quat2axisangle(obs["robot0_eef_quat"]), + obs["robot0_gripper_qpos"], + ) + ), + "prompt": env.language_instruction, + } + + action_chunk = openpi_client.infer(element)["actions"] + assert ( + len(action_chunk) >= replan_steps + ), f"We want to replan every {replan_steps} steps, but policy only predicts {len(action_chunk)} steps." + action_plan.extend(action_chunk[: replan_steps]) + + action = action_plan.popleft() + + # store in trajectory list + new_trajectory["image"].append(img) + new_trajectory["wrist_image"].append(wrist_img) + new_trajectory["state"].append(element["observation/state"]) + new_trajectory["action"].append(action) + + obs, reward, done, info = env.step(action.tolist()) + if done: + success_rate += 1 / ntrials + new_trajectory['success'] = True + break + + except Exception as e: + print(e) + # TODO: How to handle solutions that fail to evaluate + return 1e-6, 0, 0, None + + trajectories.append(new_trajectory) + print(f"\t trial{trial_id}: {'success' if new_trajectory['success'] else 'fail'}") + + if video_logdir is not None: + imageio.mimwrite( + sol_logdir / f"trial{trial_id}_{'success' if new_trajectory['success'] else 'fail'}.mp4", + [np.asarray(x) for x in new_trajectory["image"]], + fps=10, + ) + + # Maximizes entropy as objective, i.e. we want more uncertain + # success rates + success_rate = np.clip(success_rate, 1e-6, 1 - 1e-6) + entropy = -success_rate*math.log2(success_rate) - (1-success_rate)*math.log2(1-success_rate) + + openpi_client._ws.close() + + return entropy, spread, similarity, np.array(trajectories) + +def evaluate_parallel(client, params, ntrials, seed, video_logdir=None): + """Parallelized version of :func:`evaluate`. + + Args: + params (np.ndarray): Array of shape (batch_size, solution_dim) + containing solutions to be evaluated. + ntrials (int): Number of rollouts for each solution. + seed (int): Seed. + video_logdir (str): Folder for saving rollout videos. If None no video + is saved. + + Return: + objective (np.ndarray): Array of shape (batch_size,). Entropies of + VLA's success rates on the environments created from ``params``. + measures (np.ndarray): Array of shape (batch_size, measure_dim). + Spread and similarity of environments created from ``params``. + trajectories (np.ndarray): Array of shape (batch_size, ntrials). + Rollout trajectories. + """ + batch_size = params.shape[0] + nworkers = len(client.scheduler_info()['workers']) + assert nworkers >= batch_size, ( + f"batch_size={batch_size} exceeds the number of workers " + f"{nworkers}" + ) + + futures = [ + client.submit( + evaluate, + params=sol, + ntrials=ntrials, + seed=seed+sol_id, + video_logdir=video_logdir, + pure=False, + ) + for sol_id, sol in enumerate(params) + ] + results = client.gather(futures) + + objs, meas, trajs = [], [], [] + + # Process the results. + for entropy, spread, similarity, trajectoris in results: + objs.append(entropy) + meas.append([spread, similarity]) + trajs.append(trajectoris) + + print(np.array(objs).shape) + print(np.array(meas).shape) + print(np.array(trajs).shape) + + return np.array(objs), np.array(meas), np.array(trajs, dtype=object) + +def save_heatmap(archive, heatmap_path): + """Saves a heatmap of the archive to the given path. + + Args: + archive (GridArchive): The archive to save. + heatmap_path: Image path for the heatmap. + """ + plt.figure(figsize=(8, 6)) + grid_archive_heatmap(archive, vmin=0, vmax=1, cmap="viridis") + plt.tight_layout() + plt.savefig(heatmap_path) + plt.close(plt.gcf()) + +def main( + iterations=1000, + num_trials_per_sol=5, + batch_size=8, + num_emitters=1, + archive_resolution=[100,100], + seed=42, + outdir="test_logs", + reload_from=None, + log_every=5 +): + logdir = Path(outdir) + logdir.mkdir(exist_ok=True) + summary_filename = logdir / "summary.csv" + + if reload_from is None: + # For now ``params`` should be an array listing object + # coordinates in the following order: + # [ + # akita_black_bowl_1_x, akita_black_bowl_1_y, + # akita_black_bowl_2_x, akita_black_bowl_2_y, + # cookies_1_x, cookies_1_y, + # glazed_rim_porcelain_ramekin_1_x, + # glazed_rim_porcelain_ramekin_1_y, + # plate_1_x, plate_1_y + # ] + main_archive = GridArchive( + solution_dim=10, + dims=archive_resolution, + ranges=[(0, 1)] * 2, + # learning_rate=0.1, + # threshold_min=0, + seed=seed, + extra_fields={ + "trajectories": ((num_trials_per_sol,), object) + } + ) + passive_archive = GridArchive( + solution_dim=10, + dims=archive_resolution, + ranges=[(0, 1)] * 2, + seed=seed, + extra_fields={ + "trajectories": ((num_trials_per_sol,), object) + } + ) + + emitters = [ + EvolutionStrategyEmitter( + archive=main_archive, + # Range centers copied from BDDL file + x0=[-0.18, 0.32, 0.13, -0.07, 0.07, 0.03, -0.20, 0.20, 0.06, 0.20], + sigma0=0.02, + # TODO: Define bounds if we want to stay close to the original BDDL + bounds=None, + batch_size=batch_size, + seed=seed + i, + ) + for i in range(num_emitters) + ] + + scheduler = Scheduler(main_archive, emitters, result_archive=passive_archive) + + with open(summary_filename, "w") as summary_file: + writer = csv.writer(summary_file) + writer.writerow(["Iteration", "QD-Score", "Coverage", "Maximum", "Average"]) + else: + reload_itr = _extract_scheduler_itr(reload_from) + assert reload_itr is not None, ( + f'Received invalid reload_from parameter {reload_from}; ' + 'expected */scheduler_[0-9]{8}.pkl' + ) + with open(file=reload_from, mode="rb") as f: + scheduler = pkl.load(f) + + cluster = LocalCluster( + processes=True, + n_workers=batch_size, + threads_per_worker=1, + ) + client = Client(cluster) + + start = 1 if reload_from is None else reload_itr + 1 + end = start + iterations + for i in trange(start, end): + solutions = scheduler.ask() + objectives, measures, trajectories = evaluate_parallel(client=client, params=solutions, ntrials=num_trials_per_sol, seed=seed, video_logdir=None) + scheduler.tell(objectives, measures, trajectories=trajectories) + + print( + f"\n------------------ Iteration{i} ------------------\n" + f"\t QD-Score: {scheduler.result_archive.stats.qd_score}\n" + f"\t Coverage: {scheduler.result_archive.stats.coverage}\n" + f"\t Maximum : {scheduler.result_archive.stats.obj_max}\n" + f"\t Average : {scheduler.result_archive.stats.obj_mean}\n" + ) + + final_itr = i == end + if i % log_every == 0 or final_itr: + directory = Path(logdir) + + for pkl_file in directory.glob("*.pkl"): + pkl_file.unlink() + print(f"Deleted: {pkl_file}") + + pkl.dump( + scheduler, + open(logdir / f"scheduler_{i:08d}.pkl", "wb"), + ) + + with open(summary_filename, "a") as summary_file: + writer = csv.writer(summary_file) + data = [ + i, + scheduler.result_archive.stats.qd_score, + scheduler.result_archive.stats.coverage, + scheduler.result_archive.stats.obj_max, + scheduler.result_archive.stats.obj_mean, + ] + writer.writerow(data) + + save_heatmap( + scheduler.result_archive, + logdir / f"heatmap_{i:08d}.png", + ) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/run_experiment.sh b/run_experiment.sh new file mode 100755 index 0000000000..6f6aedd4c5 --- /dev/null +++ b/run_experiment.sh @@ -0,0 +1,29 @@ +SESSION=openpi-libero + +# Check if session exists +if tmux has-session -t "$SESSION" 2>/dev/null; then + echo "Session $SESSION already exists. Attaching..." + tmux attach-session -t "$SESSION" + exit 0 +fi + +# Create new session +tmux new-session -d -s "$SESSION" + +# Pane 0: Terminal 1 +tmux send-keys -t $SESSION " +cd \"$(pwd)\" +source examples/libero/.venv/bin/activate +export PYTHONPATH=\$PYTHONPATH:\$PWD/third_party/libero +python qd_spatial.py +" C-m + +# Pane 1: Terminal 2 +tmux split-window -h -t "$SESSION" +tmux send-keys -t "$SESSION:0.1" " +cd \"$(pwd)\" +uv run scripts/serve_policy.py --env LIBERO +" C-m + +# Attach +tmux attach-session -t "$SESSION" diff --git a/scheduler_00001000.pkl b/scheduler_00001000.pkl new file mode 100644 index 0000000000..f853d0ca87 Binary files /dev/null and b/scheduler_00001000.pkl differ diff --git a/scripts/compute_norm_stats.py b/scripts/compute_norm_stats.py index 93a59625a8..6859e3e3cd 100644 --- a/scripts/compute_norm_stats.py +++ b/scripts/compute_norm_stats.py @@ -15,6 +15,22 @@ import openpi.training.data_loader as _data_loader import openpi.transforms as transforms +# Monkey-patch to fix 'List' feature type error in old datasets +try: + import datasets.features.features as features + + _OLD_GENERATE_FROM_DICT = features.generate_from_dict + + def _new_generate_from_dict(obj): + if isinstance(obj, dict) and obj.get("_type") == "List": + obj["_type"] = "Sequence" + return _OLD_GENERATE_FROM_DICT(obj) + + features.generate_from_dict = _new_generate_from_dict +except (ImportError, AttributeError): + # If datasets or the function doesn't exist, do nothing. + pass +# End of monkey-patch class RemoveStrings(transforms.DataTransformFn): def __call__(self, x: dict) -> dict: diff --git a/src/openpi/policies/google_robot_policy.py b/src/openpi/policies/google_robot_policy.py new file mode 100644 index 0000000000..f79d45ec37 --- /dev/null +++ b/src/openpi/policies/google_robot_policy.py @@ -0,0 +1,54 @@ +import dataclasses + +import einops +import numpy as np + +from openpi import transforms +from openpi.models import model as _model + +def _parse_image(image) -> np.ndarray: + image = np.asarray(image) + if np.issubdtype(image.dtype, np.floating): + image = (255 * image).astype(np.uint8) + if image.shape[0] == 3: + image = einops.rearrange(image, "c h w -> h w c") + return image + +@dataclasses.dataclass(frozen=True) +class GoogleRobotInputs(transforms.DataTransformFn): + action_dim: int + model_type: _model.ModelType = _model.ModelType.PI0 + + def __call__(self, data: dict) -> dict: + # state = np.concatenate([data["joints_qpos"], data["gripper_qpos"]]) + state = transforms.pad_to_dim(data["state"], self.action_dim) + + base_image = _parse_image(data["image"]) + + inputs = { + "state": state, + "image": { + "base_0_rgb": base_image, + "left_wrist_0_rgb": np.zeros_like(base_image), + "right_wrist_0_rgb": np.zeros_like(base_image) + }, + "image_mask": { + "base_0_rgb": np.True_, + "left_wrist_0_rgb": np.False_, + "right_wrist_0_rgb": np.False_ + }, + } + + if "actions" in data: + actions = transforms.pad_to_dim(data["actions"], self.action_dim) + inputs["actions"] = actions + + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + return inputs + +@dataclasses.dataclass(frozen=True) +class GoogleRobotOutputs(transforms.DataTransformFn): + def __call__(self, data: dict) -> dict: + return {"actions": np.asarray(data["actions"][:, :7])} \ No newline at end of file diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 19ce34ee3c..0cd01b3dee 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -20,6 +20,7 @@ import openpi.policies.aloha_policy as aloha_policy import openpi.policies.droid_policy as droid_policy import openpi.policies.libero_policy as libero_policy +import openpi.policies.google_robot_policy as google_robot_policy import openpi.shared.download as _download import openpi.shared.normalize as _normalize import openpi.training.droid_rlds_dataset as droid_rlds_dataset @@ -389,6 +390,44 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig ) +@dataclasses.dataclass(frozen=True) +class LeRobotGoogleRobotDataConfig(DataConfigFactory): + @override + def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: + repack_transform = _transforms.Group( + inputs=[ + _transforms.RepackTransform( + { + "state": "state", + "image": "image", + "actions": "actions", + "prompt": "prompt" + } + ) + ] + ) + + data_transforms = _transforms.Group( + inputs=[google_robot_policy.GoogleRobotInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], + outputs=[google_robot_policy.GoogleRobotOutputs()] + ) + + delta_action_map = _transforms.make_bool_mask(6, -1) + data_transforms = data_transforms.push( + inputs=[_transforms.DeltaActions(delta_action_map)], + outputs=[_transforms.AbsoluteActions(delta_action_map)] + ) + + model_transforms = ModelTransformFactory()(model_config) + + return dataclasses.replace( + self.create_base_config(assets_dirs, model_config), + repack_transforms=repack_transform, + data_transforms=data_transforms, + model_transforms=model_transforms + ) + + @dataclasses.dataclass(frozen=True) class TrainConfig: # Name of the config. Must be unique. Will be used to reference this config. @@ -622,10 +661,10 @@ def __post_init__(self) -> None: action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" ), data=LeRobotLiberoDataConfig( - repo_id="physical-intelligence/libero", + repo_id="hchen/libero", base_config=DataConfig(prompt_from_task=True), ), - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_libero/params"), num_train_steps=30_000, # Again, make sure to match the model config above when extracting the freeze filter # that specifies which parameters should be frozen during LoRA finetuning. @@ -714,6 +753,26 @@ def __post_init__(self) -> None: num_train_steps=20_000, ), # + # Fine-tuning Google Robot configs + # + TrainConfig( + name="pi0_google_robot_low_mem_finetune", + model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), + data=LeRobotGoogleRobotDataConfig( + repo_id="hchen/google_robot", + base_config=DataConfig( + prompt_from_task=True + ) + ), + weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), + num_train_steps=30_000, + save_interval=1000, + freeze_filter=pi0.Pi0Config( + paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora" + ).get_freeze_filter(), + ema_decay=None, + ), + # # Debugging configs. # TrainConfig( diff --git a/third_party/SimplerEnv b/third_party/SimplerEnv new file mode 160000 index 0000000000..4ab7178e83 --- /dev/null +++ b/third_party/SimplerEnv @@ -0,0 +1 @@ +Subproject commit 4ab7178e83e84ee06894034ec6dbf9e7aad1e882 diff --git a/third_party/libero b/third_party/libero index f78abd68ee..de521c0a33 160000 --- a/third_party/libero +++ b/third_party/libero @@ -1 +1 @@ -Subproject commit f78abd68ee283de9f9be3c8f7e2a9ad60246e95c +Subproject commit de521c0a33e9af445876cefb5979962e79019189 diff --git a/viz_spatial_attack.py b/viz_spatial_attack.py new file mode 100644 index 0000000000..da9bf0cca8 --- /dev/null +++ b/viz_spatial_attack.py @@ -0,0 +1,109 @@ +import pickle as pkl + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go +from dash import Dash, Input, Output, dcc, html, no_update +from ribs.visualize import grid_archive_heatmap + +from qd_spatial import evaluate + + +def show_interactive_archive(archive): + fig = plt.figure(figsize=(8, 6)) + grid_archive_heatmap(archive, vmin=0, vmax=1, cmap="viridis") + plt.tight_layout() + + def onclick(event): + occupied, data = archive.retrieve_single([event.xdata, event.ydata]) + + if occupied: + print( + f'Recorded objective: {data["objective"]}; Recorded measures: {data["measures"]}' + ) + obj, meas = evaluate(params=[data["solution"]], ntrials=1, seed=42, video_logdir='interactive_vids') + else: + print("Archive cell not occupied") + + fig.canvas.mpl_connect("button_press_event", onclick) + plt.show() + + +def _plotly_grid_archive_heatmap(archive, shape={'width': 600, 'height': 600}): + x_dim, y_dim = archive.dims + colors = np.full((y_dim, x_dim), np.nan) + index_batch = archive.data("index") + objective_batch = archive.data("objective") + grid_index_batch = archive.int_to_grid_index(index_batch) + colors[grid_index_batch[:, 1], grid_index_batch[:, 0]] = objective_batch + + x_bounds = archive.boundaries[0] + y_bounds = archive.boundaries[1] + + fig = go.Figure(data=go.Heatmap( + z=colors, + x=x_bounds, + y=y_bounds, + colorbar= { + "title": 'Ent. success' + }, + colorscale='Viridis' + )) + fig.update_layout(**shape) + fig.update_xaxes(title='Spread') + fig.update_yaxes(title='Similarity') + + return fig + + +def host_interactive_archive(archive, port=8050): + '''Similar to :func:`show_interactive_archive`, except it hosts the + interactive plot at localhost: to allow generating the plot on a + remote machine and then viewing it on your local machine (e.g. if you only + have access to ssh). After configuring port forwarding between your local + machine and ``port`` on the remote machine, you will be able to view and + interact with the plot on your local machine's browser. + + Args: + archive (GridArchive): Archive to be displayed. + port (int): The port on which to display the plot. + ''' + app = Dash(__name__) + + app.layout = html.Div([ + dcc.Graph(id='archive-heatmap', figure=_plotly_grid_archive_heatmap(archive)), + html.Div(id='dummy-output', style={'display': 'none'}) # hidden dummy output + ], style={'display': 'flex', 'justifyContent': 'center'}) + + @app.callback( + Output('dummy-output', 'children'), + Input("archive-heatmap", "clickData"), + ) + def onclick(clickData): + if clickData is None: + return no_update + + occupied, data = archive.retrieve_single([clickData["points"][0]["x"], clickData["points"][0]["y"]]) + + if occupied: + print( + f'Recorded objective: {data["objective"]}; Recorded measures: {data["measures"]}' + ) + evaluate(params=[data["solution"]], ntrials=5, seed=42, video_logdir='interactive_vids') + else: + print("Archive cell not occupied") + + return None + + app.run(host="0.0.0.0", port=port, debug=True) + + +if __name__ == "__main__": + with open( + # Enter the scheduler checkpoint you want to visualize here + file="scheduler_00001000.pkl", + mode="rb", + ) as f: + archive = pkl.load(f).result_archive + # show_interactive_archive(archive) + host_interactive_archive(archive)