From 9f6bd3140219a1c25101553303b0e248f873d98a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:19:36 +0000 Subject: [PATCH 1/2] Initial plan From 4f04ecf44dfaa7e8ae4917cd6b08e0f4f6791544 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:32:42 +0000 Subject: [PATCH 2/2] Add pilla_rl curriculum learning framework (Phase 1) Agent-Logs-Url: https://github.com/code-name-57/pilla_rl/sessions/c6726929-a02e-4608-b1da-2496b3e18e07 Co-authored-by: Macbull <11361002+Macbull@users.noreply.github.com> --- pilla_rl/__init__.py | 3 + pilla_rl/__main__.py | 7 + pilla_rl/config_loader.py | 71 ++++ .../configs/curricula/recovery_to_walk.yaml | 31 ++ pilla_rl/configs/tasks/recovery.yaml | 118 +++++++ pilla_rl/configs/tasks/standup.yaml | 115 ++++++ .../configs/tasks/upside_down_standup.yaml | 121 +++++++ pilla_rl/configs/tasks/walk.yaml | 111 ++++++ pilla_rl/envs/__init__.py | 8 + pilla_rl/envs/base_env.py | 332 ++++++++++++++++++ pilla_rl/envs/recovery_env.py | 137 ++++++++ pilla_rl/envs/standup_env.py | 89 +++++ pilla_rl/envs/walk_env.py | 13 + pilla_rl/evaluate.py | 91 +++++ pilla_rl/rewards/__init__.py | 5 + pilla_rl/rewards/reward_functions.py | 255 ++++++++++++++ pilla_rl/train.py | 132 +++++++ 17 files changed, 1639 insertions(+) create mode 100644 pilla_rl/__init__.py create mode 100644 pilla_rl/__main__.py create mode 100644 pilla_rl/config_loader.py create mode 100644 pilla_rl/configs/curricula/recovery_to_walk.yaml create mode 100644 pilla_rl/configs/tasks/recovery.yaml create mode 100644 pilla_rl/configs/tasks/standup.yaml create mode 100644 pilla_rl/configs/tasks/upside_down_standup.yaml create mode 100644 pilla_rl/configs/tasks/walk.yaml create mode 100644 pilla_rl/envs/__init__.py create mode 100644 pilla_rl/envs/base_env.py create mode 100644 pilla_rl/envs/recovery_env.py create mode 100644 pilla_rl/envs/standup_env.py create mode 100644 pilla_rl/envs/walk_env.py create mode 100644 pilla_rl/evaluate.py create mode 100644 pilla_rl/rewards/__init__.py create mode 100644 pilla_rl/rewards/reward_functions.py create mode 100644 pilla_rl/train.py diff --git a/pilla_rl/__init__.py b/pilla_rl/__init__.py new file mode 100644 index 0000000..3c03813 --- /dev/null +++ b/pilla_rl/__init__.py @@ -0,0 +1,3 @@ +"""pilla_rl — curriculum learning framework for Go2 quadruped robot tasks.""" + +__version__ = "0.1.0" diff --git a/pilla_rl/__main__.py b/pilla_rl/__main__.py new file mode 100644 index 0000000..d4ae29e --- /dev/null +++ b/pilla_rl/__main__.py @@ -0,0 +1,7 @@ +"""Allow running the package as ``python -m pilla_rl.train`` or +``python -m pilla_rl`` (which defaults to the train entry point). +""" + +from pilla_rl.train import main + +main() diff --git a/pilla_rl/config_loader.py b/pilla_rl/config_loader.py new file mode 100644 index 0000000..46b7144 --- /dev/null +++ b/pilla_rl/config_loader.py @@ -0,0 +1,71 @@ +"""YAML configuration loader and environment factory. + +Usage:: + + from pilla_rl.config_loader import load_task_config, instantiate_env + + cfg = load_task_config("pilla_rl/configs/tasks/walk.yaml") + env = instantiate_env(cfg, num_envs=4096, show_viewer=False) + +Dependencies: pyyaml (``pip install pyyaml``) +""" + +import importlib +from pathlib import Path + +import yaml + + +def load_task_config(config_path: str) -> dict: + """Load a YAML task config file and return it as a plain dict. + + Parameters + ---------- + config_path: + Path to the ``.yaml`` file, either absolute or relative to the + current working directory. + + Returns + ------- + dict + The parsed configuration. + """ + config_path = Path(config_path) + with config_path.open("r") as fh: + cfg = yaml.safe_load(fh) + return cfg + + +def instantiate_env(config: dict, num_envs: int = 4096, show_viewer: bool = False): + """Dynamically import the env class specified in *config* and instantiate it. + + The config dict must contain an ``env_class`` key with a fully-qualified + class name, e.g. ``"pilla_rl.envs.walk_env.WalkEnv"``. + + Parameters + ---------- + config: + Parsed task config dict (as returned by :func:`load_task_config`). + num_envs: + Number of parallel simulation environments. + show_viewer: + Whether to open the Genesis viewer window. + + Returns + ------- + BaseQuadrupedEnv + An instantiated environment object. + """ + env_class_path: str = config["env_class"] + module_path, class_name = env_class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + env_cls = getattr(module, class_name) + + return env_cls( + num_envs=num_envs, + env_cfg=config["env_cfg"], + obs_cfg=config["obs_cfg"], + reward_cfg=config["reward_cfg"], + command_cfg=config["command_cfg"], + show_viewer=show_viewer, + ) diff --git a/pilla_rl/configs/curricula/recovery_to_walk.yaml b/pilla_rl/configs/curricula/recovery_to_walk.yaml new file mode 100644 index 0000000..4b32eee --- /dev/null +++ b/pilla_rl/configs/curricula/recovery_to_walk.yaml @@ -0,0 +1,31 @@ +# Curriculum: recovery → walk +# +# Example multi-stage curriculum that chains two task configs. +# Each stage can override reward scales and command ranges, and can +# optionally load a checkpoint from the previous stage. +# +# Usage (conceptual): +# python -m pilla_rl.train \ +# --config pilla_rl/configs/curricula/recovery_to_walk.yaml \ +# --num_envs 4096 + +stages: + + - name: "recovery" + config: "pilla_rl/configs/tasks/recovery.yaml" + max_iterations: 20000 + reward_overrides: {} + command_overrides: {} + + - name: "walk" + config: "pilla_rl/configs/tasks/walk.yaml" + max_iterations: 10000 + load_from: "previous" # load checkpoint from the preceding stage + reward_overrides: + # gradually re-introduce locomotion rewards + tracking_lin_vel: 1.0 + tracking_ang_vel: 0.5 + command_overrides: + lin_vel_x_range: [-1.0, 2.0] + lin_vel_y_range: [-0.5, 0.5] + ang_vel_range: [-0.5, 0.5] diff --git a/pilla_rl/configs/tasks/recovery.yaml b/pilla_rl/configs/tasks/recovery.yaml new file mode 100644 index 0000000..33cb2f6 --- /dev/null +++ b/pilla_rl/configs/tasks/recovery.yaml @@ -0,0 +1,118 @@ +# Upside-down recovery task configuration +# Matches go2/upside_down_recovery/go2_train.py get_cfgs() / get_train_cfg() +# Requires: pyyaml + +env_class: "pilla_rl.envs.recovery_env.RecoveryEnv" + +env_cfg: + num_actions: 12 + default_joint_angles: + FL_hip_joint: 0.0 + FR_hip_joint: 0.0 + RL_hip_joint: 0.0 + RR_hip_joint: 0.0 + FL_thigh_joint: 0.8 + FR_thigh_joint: 0.8 + RL_thigh_joint: 1.0 + RR_thigh_joint: 1.0 + FL_calf_joint: -1.5 + FR_calf_joint: -1.5 + RL_calf_joint: -1.5 + RR_calf_joint: -1.5 + joint_names: + - FR_hip_joint + - FR_thigh_joint + - FR_calf_joint + - FL_hip_joint + - FL_thigh_joint + - FL_calf_joint + - RR_hip_joint + - RR_thigh_joint + - RR_calf_joint + - RL_hip_joint + - RL_thigh_joint + - RL_calf_joint + kp: 20.0 + kd: 0.5 + termination_if_roll_greater_than: 180 + termination_if_pitch_greater_than: 90 + base_init_pos: [0.0, 0.0, 0.42] + base_init_quat: [1.0, 0.0, 0.0, 0.0] + episode_length_s: 20.0 + resampling_time_s: 10.0 + action_scale: 0.3 + simulate_action_latency: true + clip_actions: 100.0 + +obs_cfg: + num_obs: 48 + obs_scales: + lin_vel: 2.0 + ang_vel: 0.25 + dof_pos: 1.0 + dof_vel: 0.05 + +reward_cfg: + tracking_sigma: 0.25 + base_height_target: 0.42 + feet_height_target: 0.075 + reward_scales: + tracking_lin_vel: 0.0 + tracking_ang_vel: 0.0 + lin_vel_z: -1.0 + base_height: -2.0 + action_rate: -0.02 + similar_to_default: -0.1 + upright_orientation: 20.0 + recovery_progress: 30.0 + minimize_base_roll: 15.0 + stability: 5.0 + legs_not_in_air: 8.0 + energy_efficiency: 3.0 + forward_progress: 2.0 + +command_cfg: + num_commands: 3 + lin_vel_x_range: [0.0, 0.0] + lin_vel_y_range: [0.0, 0.0] + ang_vel_range: [0.0, 0.0] + +train: + exp_name: "go2-upside-down-recovery" + algorithm: + class_name: "PPO" + clip_param: 0.2 + desired_kl: 0.01 + entropy_coef: 0.01 + gamma: 0.998 + lam: 0.95 + learning_rate: 0.0003 + max_grad_norm: 1.0 + num_learning_epochs: 10 + num_mini_batches: 4 + schedule: "adaptive" + use_clipped_value_loss: true + value_loss_coef: 1.0 + init_member_classes: {} + policy: + activation: "elu" + actor_hidden_dims: [512, 256, 128] + critic_hidden_dims: [512, 256, 128] + init_noise_std: 1.0 + class_name: "ActorCritic" + runner: + checkpoint: -1 + experiment_name: "go2-upside-down-recovery" + load_run: -1 + log_interval: 1 + max_iterations: 20000 + record_interval: -1 + resume: false + resume_path: null + run_name: "" + logger: "tensorboard" + runner_class_name: "OnPolicyRunner" + num_steps_per_env: 24 + save_interval: 100 + empirical_normalization: null + seed: 1 diff --git a/pilla_rl/configs/tasks/standup.yaml b/pilla_rl/configs/tasks/standup.yaml new file mode 100644 index 0000000..5467d06 --- /dev/null +++ b/pilla_rl/configs/tasks/standup.yaml @@ -0,0 +1,115 @@ +# Standup task configuration +# Matches go2/standup_copilot/go2_train.py get_cfgs() / get_train_cfg() +# Requires: pyyaml + +env_class: "pilla_rl.envs.standup_env.StandupEnv" + +env_cfg: + num_actions: 12 + default_joint_angles: + FL_hip_joint: 0.0 + FR_hip_joint: 0.0 + RL_hip_joint: 0.0 + RR_hip_joint: 0.0 + FL_thigh_joint: 0.8 + FR_thigh_joint: 0.8 + RL_thigh_joint: 1.0 + RR_thigh_joint: 1.0 + FL_calf_joint: -1.5 + FR_calf_joint: -1.5 + RL_calf_joint: -1.5 + RR_calf_joint: -1.5 + joint_names: + - FR_hip_joint + - FR_thigh_joint + - FR_calf_joint + - FL_hip_joint + - FL_thigh_joint + - FL_calf_joint + - RR_hip_joint + - RR_thigh_joint + - RR_calf_joint + - RL_hip_joint + - RL_thigh_joint + - RL_calf_joint + kp: 20.0 + kd: 0.5 + termination_if_roll_greater_than: 45 + termination_if_pitch_greater_than: 45 + base_init_pos: [0.0, 0.0, 0.42] + base_init_quat: [1.0, 0.0, 0.0, 0.0] + episode_length_s: 15.0 + resampling_time_s: 8.0 + action_scale: 0.25 + simulate_action_latency: true + clip_actions: 100.0 + +obs_cfg: + num_obs: 48 + obs_scales: + lin_vel: 2.0 + ang_vel: 0.25 + dof_pos: 1.0 + dof_vel: 0.05 + +reward_cfg: + tracking_sigma: 0.25 + base_height_target: 0.42 + feet_height_target: 0.075 + reward_scales: + tracking_lin_vel: 0.0 + tracking_ang_vel: 0.0 + lin_vel_z: -2.0 + base_height: -5.0 + action_rate: -0.01 + similar_to_default: -0.5 + upright_orientation: 15.0 + stability: 10.0 + stand_up_progress: 25.0 + joint_regularization: 2.0 + +command_cfg: + num_commands: 3 + lin_vel_x_range: [0.0, 0.0] + lin_vel_y_range: [0.0, 0.0] + ang_vel_range: [0.0, 0.0] + +train: + exp_name: "go2-standup" + algorithm: + class_name: "PPO" + clip_param: 0.2 + desired_kl: 0.01 + entropy_coef: 0.005 + gamma: 0.99 + lam: 0.95 + learning_rate: 0.0005 + max_grad_norm: 1.0 + num_learning_epochs: 8 + num_mini_batches: 4 + schedule: "adaptive" + use_clipped_value_loss: true + value_loss_coef: 1.0 + init_member_classes: {} + policy: + activation: "elu" + actor_hidden_dims: [512, 256, 128] + critic_hidden_dims: [512, 256, 128] + init_noise_std: 1.0 + class_name: "ActorCritic" + runner: + checkpoint: -1 + experiment_name: "go2-standup" + load_run: -1 + log_interval: 1 + max_iterations: 15000 + record_interval: -1 + resume: false + resume_path: null + run_name: "" + logger: "tensorboard" + runner_class_name: "OnPolicyRunner" + num_steps_per_env: 24 + save_interval: 100 + empirical_normalization: null + seed: 1 diff --git a/pilla_rl/configs/tasks/upside_down_standup.yaml b/pilla_rl/configs/tasks/upside_down_standup.yaml new file mode 100644 index 0000000..4e17e22 --- /dev/null +++ b/pilla_rl/configs/tasks/upside_down_standup.yaml @@ -0,0 +1,121 @@ +# Upside-down standup task configuration +# Matches go2/upside_down_standup/go2_train.py get_cfgs() / get_train_cfg() +# Requires: pyyaml + +env_class: "pilla_rl.envs.recovery_env.RecoveryEnv" + +env_cfg: + num_actions: 12 + default_joint_angles: + FL_hip_joint: 0.0 + FR_hip_joint: 0.0 + RL_hip_joint: 0.0 + RR_hip_joint: 0.0 + FL_thigh_joint: 0.8 + FR_thigh_joint: 0.8 + RL_thigh_joint: 1.0 + RR_thigh_joint: 1.0 + FL_calf_joint: -1.5 + FR_calf_joint: -1.5 + RL_calf_joint: -1.5 + RR_calf_joint: -1.5 + joint_names: + - FR_hip_joint + - FR_thigh_joint + - FR_calf_joint + - FL_hip_joint + - FL_thigh_joint + - FL_calf_joint + - RR_hip_joint + - RR_thigh_joint + - RR_calf_joint + - RL_hip_joint + - RL_thigh_joint + - RL_calf_joint + kp: 20.0 + kd: 0.5 + termination_if_roll_greater_than: 180 + termination_if_pitch_greater_than: 90 + base_init_pos: [0.0, 0.0, 0.42] + base_init_quat: [1.0, 0.0, 0.0, 0.0] + episode_length_s: 20.0 + resampling_time_s: 10.0 + action_scale: 0.3 + simulate_action_latency: true + clip_actions: 100.0 + +obs_cfg: + num_obs: 48 + obs_scales: + lin_vel: 2.0 + ang_vel: 0.25 + dof_pos: 1.0 + dof_vel: 0.05 + +reward_cfg: + tracking_sigma: 0.25 + base_height_target: 0.42 + feet_height_target: 0.075 + reward_scales: + tracking_lin_vel: 0.0 + tracking_ang_vel: 0.0 + lin_vel_z: -1.0 + base_height: -1.0 + action_rate: -0.02 + similar_to_default: -0.1 + upright_orientation: 15.0 + recovery_progress: 25.0 + minimize_base_roll: 12.0 + standup_height: 20.0 + complete_standup: 50.0 + height_when_upright: 30.0 + stability: 5.0 + legs_not_in_air: 8.0 + energy_efficiency: 2.0 + forward_progress: 1.0 + +command_cfg: + num_commands: 3 + lin_vel_x_range: [0.0, 0.0] + lin_vel_y_range: [0.0, 0.0] + ang_vel_range: [0.0, 0.0] + +train: + exp_name: "go2-upside-down-standup" + algorithm: + class_name: "PPO" + clip_param: 0.2 + desired_kl: 0.01 + entropy_coef: 0.01 + gamma: 0.998 + lam: 0.95 + learning_rate: 0.0003 + max_grad_norm: 1.0 + num_learning_epochs: 10 + num_mini_batches: 4 + schedule: "adaptive" + use_clipped_value_loss: true + value_loss_coef: 1.0 + init_member_classes: {} + policy: + activation: "elu" + actor_hidden_dims: [512, 256, 128] + critic_hidden_dims: [512, 256, 128] + init_noise_std: 1.0 + class_name: "ActorCritic" + runner: + checkpoint: -1 + experiment_name: "go2-upside-down-standup" + load_run: -1 + log_interval: 1 + max_iterations: 20000 + record_interval: -1 + resume: false + resume_path: null + run_name: "" + logger: "tensorboard" + runner_class_name: "OnPolicyRunner" + num_steps_per_env: 24 + save_interval: 100 + empirical_normalization: null + seed: 1 diff --git a/pilla_rl/configs/tasks/walk.yaml b/pilla_rl/configs/tasks/walk.yaml new file mode 100644 index 0000000..e997c0f --- /dev/null +++ b/pilla_rl/configs/tasks/walk.yaml @@ -0,0 +1,111 @@ +# Walk task configuration +# Matches go2/walk/go2_train.py get_cfgs() / get_train_cfg() +# Requires: pyyaml + +env_class: "pilla_rl.envs.walk_env.WalkEnv" + +env_cfg: + num_actions: 12 + default_joint_angles: + FL_hip_joint: 0.0 + FR_hip_joint: 0.0 + RL_hip_joint: 0.0 + RR_hip_joint: 0.0 + FL_thigh_joint: 0.8 + FR_thigh_joint: 0.8 + RL_thigh_joint: 1.0 + RR_thigh_joint: 1.0 + FL_calf_joint: -1.5 + FR_calf_joint: -1.5 + RL_calf_joint: -1.5 + RR_calf_joint: -1.5 + joint_names: + - FR_hip_joint + - FR_thigh_joint + - FR_calf_joint + - FL_hip_joint + - FL_thigh_joint + - FL_calf_joint + - RR_hip_joint + - RR_thigh_joint + - RR_calf_joint + - RL_hip_joint + - RL_thigh_joint + - RL_calf_joint + kp: 20.0 + kd: 0.5 + termination_if_roll_greater_than: 10 + termination_if_pitch_greater_than: 10 + base_init_pos: [0.0, 0.0, 0.42] + base_init_quat: [1.0, 0.0, 0.0, 0.0] + episode_length_s: 20.0 + resampling_time_s: 4.0 + action_scale: 0.25 + simulate_action_latency: true + clip_actions: 100.0 + +obs_cfg: + num_obs: 45 + obs_scales: + lin_vel: 2.0 + ang_vel: 0.25 + dof_pos: 1.0 + dof_vel: 0.05 + +reward_cfg: + tracking_sigma: 0.25 + base_height_target: 0.3 + feet_height_target: 0.075 + reward_scales: + tracking_lin_vel: 1.0 + tracking_ang_vel: 0.5 + lin_vel_z: -1.0 + base_height: -50.0 + action_rate: -0.005 + similar_to_default: -0.1 + +command_cfg: + num_commands: 3 + lin_vel_x_range: [-1.0, 2.0] + lin_vel_y_range: [-0.5, 0.5] + ang_vel_range: [-0.5, 0.5] + +train: + exp_name: "go2-walking" + algorithm: + class_name: "PPO" + clip_param: 0.2 + desired_kl: 0.01 + entropy_coef: 0.01 + gamma: 0.99 + lam: 0.95 + learning_rate: 0.001 + max_grad_norm: 1.0 + num_learning_epochs: 5 + num_mini_batches: 4 + schedule: "adaptive" + use_clipped_value_loss: true + value_loss_coef: 1.0 + init_member_classes: {} + policy: + activation: "elu" + actor_hidden_dims: [512, 256, 128] + critic_hidden_dims: [512, 256, 128] + init_noise_std: 1.0 + class_name: "ActorCritic" + runner: + checkpoint: -1 + experiment_name: "go2-walking" + load_run: -1 + log_interval: 1 + max_iterations: 10000 + record_interval: -1 + resume: false + resume_path: null + run_name: "" + logger: "tensorboard" + runner_class_name: "OnPolicyRunner" + num_steps_per_env: 24 + save_interval: 100 + empirical_normalization: null + seed: 1 diff --git a/pilla_rl/envs/__init__.py b/pilla_rl/envs/__init__.py new file mode 100644 index 0000000..5f032ca --- /dev/null +++ b/pilla_rl/envs/__init__.py @@ -0,0 +1,8 @@ +"""pilla_rl.envs — task-specific environment subclasses.""" + +from pilla_rl.envs.base_env import BaseQuadrupedEnv +from pilla_rl.envs.walk_env import WalkEnv +from pilla_rl.envs.standup_env import StandupEnv +from pilla_rl.envs.recovery_env import RecoveryEnv + +__all__ = ["BaseQuadrupedEnv", "WalkEnv", "StandupEnv", "RecoveryEnv"] diff --git a/pilla_rl/envs/base_env.py b/pilla_rl/envs/base_env.py new file mode 100644 index 0000000..1837b07 --- /dev/null +++ b/pilla_rl/envs/base_env.py @@ -0,0 +1,332 @@ +"""Base quadruped environment shared by all task-specific subclasses. + +All the logic that was previously duplicated across every per-task ``go2_env.py`` +lives here. Task-specific differences are handled by overriding the three +protected methods: + +* ``_check_termination()`` — when to reset an episode +* ``_reset_robot_pose(envs_idx)`` — initial body / joint state on reset +* ``_compute_observations()`` — what the agent sees + +Reward functions are **not** methods on this class. Instead they are looked +up from ``pilla_rl.rewards.REWARD_REGISTRY`` and their arguments are +auto-resolved from the environment's attributes via ``inspect.signature()``. +""" + +import inspect +import math + +import torch +import genesis as gs +from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat + +from pilla_rl.rewards.reward_functions import REWARD_REGISTRY + + +def _gs_rand_float(lower, upper, shape, device): + return (upper - lower) * torch.rand(size=shape, device=device) + lower + + +class BaseQuadrupedEnv: + """Shared quadruped environment for all Go2 tasks.""" + + def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False): + self.num_envs = num_envs + self.num_obs = obs_cfg["num_obs"] + self.num_privileged_obs = None + self.num_actions = env_cfg["num_actions"] + self.num_commands = command_cfg["num_commands"] + self.device = gs.device + + self.simulate_action_latency = True + self.dt = 0.02 + self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt) + + self.env_cfg = env_cfg + self.obs_cfg = obs_cfg + self.reward_cfg = reward_cfg + self.command_cfg = command_cfg + + self.obs_scales = obs_cfg["obs_scales"] + self.reward_scales = reward_cfg["reward_scales"] + + # ------------------------------------------------------------------ + # Scene + # ------------------------------------------------------------------ + self.scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=self.dt, substeps=2), + viewer_options=gs.options.ViewerOptions( + max_FPS=int(0.5 / self.dt), + camera_pos=(2.0, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + ), + vis_options=gs.options.VisOptions(rendered_envs_idx=list(range(1))), + rigid_options=gs.options.RigidOptions( + dt=self.dt, + constraint_solver=gs.constraint_solver.Newton, + enable_collision=True, + enable_joint_limit=True, + ), + show_viewer=show_viewer, + ) + + self.scene.add_entity(gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True)) + + self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=gs.device) + self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=gs.device) + self.inv_base_init_quat = inv_quat(self.base_init_quat) + self.robot = self.scene.add_entity( + gs.morphs.URDF( + file="urdf/go2/urdf/go2.urdf", + pos=self.base_init_pos.cpu().numpy(), + quat=self.base_init_quat.cpu().numpy(), + ), + ) + + self.scene.build(n_envs=num_envs) + + # ------------------------------------------------------------------ + # PD control + # ------------------------------------------------------------------ + self.motors_dof_idx = [self.robot.get_joint(name).dof_start for name in self.env_cfg["joint_names"]] + self.robot.set_dofs_kp([self.env_cfg["kp"]] * self.num_actions, self.motors_dof_idx) + self.robot.set_dofs_kv([self.env_cfg["kd"]] * self.num_actions, self.motors_dof_idx) + + # ------------------------------------------------------------------ + # Buffers + # ------------------------------------------------------------------ + self.base_lin_vel = torch.zeros((self.num_envs, 3), device=gs.device, dtype=gs.tc_float) + self.base_ang_vel = torch.zeros((self.num_envs, 3), device=gs.device, dtype=gs.tc_float) + self.projected_gravity = torch.zeros((self.num_envs, 3), device=gs.device, dtype=gs.tc_float) + self.global_gravity = torch.tensor([0.0, 0.0, -1.0], device=gs.device, dtype=gs.tc_float).repeat( + self.num_envs, 1 + ) + self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=gs.device, dtype=gs.tc_float) + self.rew_buf = torch.zeros((self.num_envs,), device=gs.device, dtype=gs.tc_float) + self.reset_buf = torch.ones((self.num_envs,), device=gs.device, dtype=gs.tc_int) + self.episode_length_buf = torch.zeros((self.num_envs,), device=gs.device, dtype=gs.tc_int) + self.commands = torch.zeros((self.num_envs, self.num_commands), device=gs.device, dtype=gs.tc_float) + self.commands_scale = torch.tensor( + [self.obs_scales["lin_vel"], self.obs_scales["lin_vel"], self.obs_scales["ang_vel"]], + device=gs.device, + dtype=gs.tc_float, + ) + self.actions = torch.zeros((self.num_envs, self.num_actions), device=gs.device, dtype=gs.tc_float) + self.last_actions = torch.zeros_like(self.actions) + self.dof_pos = torch.zeros_like(self.actions) + self.dof_vel = torch.zeros_like(self.actions) + self.last_dof_vel = torch.zeros_like(self.actions) + self.base_pos = torch.zeros((self.num_envs, 3), device=gs.device, dtype=gs.tc_float) + self.base_quat = torch.zeros((self.num_envs, 4), device=gs.device, dtype=gs.tc_float) + self.base_euler = torch.zeros((self.num_envs, 3), device=gs.device, dtype=gs.tc_float) + self.default_dof_pos = torch.tensor( + [self.env_cfg["default_joint_angles"][name] for name in self.env_cfg["joint_names"]], + device=gs.device, + dtype=gs.tc_float, + ) + self.extras = dict() + self.extras["observations"] = dict() + + # ------------------------------------------------------------------ + # Rewards (uses REWARD_REGISTRY instead of per-class methods) + # ------------------------------------------------------------------ + self._setup_rewards() + + # ------------------------------------------------------------------ + # Reward wiring + # ------------------------------------------------------------------ + + def _setup_rewards(self): + """Register reward functions from REWARD_REGISTRY and scale by dt.""" + self.reward_functions = {} + self.episode_sums = {} + for name in list(self.reward_scales.keys()): + self.reward_scales[name] *= self.dt + if name not in REWARD_REGISTRY: + raise ValueError( + f"Reward '{name}' is not in REWARD_REGISTRY. " + f"Available rewards: {sorted(REWARD_REGISTRY.keys())}" + ) + self.reward_functions[name] = REWARD_REGISTRY[name] + self.episode_sums[name] = torch.zeros((self.num_envs,), device=gs.device, dtype=gs.tc_float) + + def _resolve_reward_args(self, fn): + """Map function parameter names to current environment state tensors.""" + param_map = { + "base_lin_vel": lambda: self.base_lin_vel, + "base_ang_vel": lambda: self.base_ang_vel, + "base_pos": lambda: self.base_pos, + "base_euler": lambda: self.base_euler, + "commands": lambda: self.commands, + "dof_pos": lambda: self.dof_pos, + "dof_vel": lambda: self.dof_vel, + "actions": lambda: self.actions, + "last_actions": lambda: self.last_actions, + "default_dof_pos": lambda: self.default_dof_pos, + "tracking_sigma": lambda: self.reward_cfg["tracking_sigma"], + "target_height": lambda: self.reward_cfg["base_height_target"], + "base_height_target": lambda: self.reward_cfg["base_height_target"], + } + sig = inspect.signature(fn) + kwargs = {} + for param_name in sig.parameters: + if param_name not in param_map: + raise ValueError( + f"Reward function '{fn.__name__}' has unknown parameter '{param_name}'. " + f"Supported parameters: {sorted(param_map.keys())}" + ) + kwargs[param_name] = param_map[param_name]() + return kwargs + + def _compute_rewards(self): + self.rew_buf[:] = 0.0 + for name, fn in self.reward_functions.items(): + kwargs = self._resolve_reward_args(fn) + rew = fn(**kwargs) * self.reward_scales[name] + self.rew_buf += rew + self.episode_sums[name] += rew + + # ------------------------------------------------------------------ + # Overridable per-task methods + # ------------------------------------------------------------------ + + def _check_termination(self): + """Default termination: episode length or pitch/roll limit.""" + self.reset_buf = self.episode_length_buf > self.max_episode_length + self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"] + self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"] + + def _reset_robot_pose(self, envs_idx): + """Default reset: upright standing pose.""" + self.dof_pos[envs_idx] = self.default_dof_pos + self.dof_vel[envs_idx] = 0.0 + self.robot.set_dofs_position( + position=self.dof_pos[envs_idx], + dofs_idx_local=self.motors_dof_idx, + zero_velocity=True, + envs_idx=envs_idx, + ) + self.base_pos[envs_idx] = self.base_init_pos + self.base_quat[envs_idx] = self.base_init_quat.reshape(1, -1) + self.robot.set_pos(self.base_pos[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.robot.set_quat(self.base_quat[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.base_lin_vel[envs_idx] = 0 + self.base_ang_vel[envs_idx] = 0 + self.robot.zero_all_dofs_velocity(envs_idx) + + def _compute_observations(self): + """Default 45-dim observation vector.""" + self.obs_buf = torch.cat( + [ + self.base_ang_vel * self.obs_scales["ang_vel"], # 3 + self.projected_gravity, # 3 + self.commands * self.commands_scale, # 3 + (self.dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"], # 12 + self.dof_vel * self.obs_scales["dof_vel"], # 12 + self.actions, # 12 + ], + axis=-1, + ) + + # ------------------------------------------------------------------ + # Command resampling + # ------------------------------------------------------------------ + + def _resample_commands(self, envs_idx): + self.commands[envs_idx, 0] = _gs_rand_float( + *self.command_cfg["lin_vel_x_range"], (len(envs_idx),), gs.device + ) + self.commands[envs_idx, 1] = _gs_rand_float( + *self.command_cfg["lin_vel_y_range"], (len(envs_idx),), gs.device + ) + self.commands[envs_idx, 2] = _gs_rand_float( + *self.command_cfg["ang_vel_range"], (len(envs_idx),), gs.device + ) + + # ------------------------------------------------------------------ + # Core loop + # ------------------------------------------------------------------ + + def step(self, actions): + self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"]) + exec_actions = self.last_actions if self.simulate_action_latency else self.actions + target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos + self.robot.control_dofs_position(target_dof_pos, self.motors_dof_idx) + self.scene.step() + + # Update buffers + self.episode_length_buf += 1 + self.base_pos[:] = self.robot.get_pos() + self.base_quat[:] = self.robot.get_quat() + self.base_euler = quat_to_xyz( + transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat), + rpy=True, + degrees=True, + ) + inv_base_quat = inv_quat(self.base_quat) + self.base_lin_vel[:] = transform_by_quat(self.robot.get_vel(), inv_base_quat) + self.base_ang_vel[:] = transform_by_quat(self.robot.get_ang(), inv_base_quat) + self.projected_gravity = transform_by_quat(self.global_gravity, inv_base_quat) + self.dof_pos[:] = self.robot.get_dofs_position(self.motors_dof_idx) + self.dof_vel[:] = self.robot.get_dofs_velocity(self.motors_dof_idx) + + # Resample commands + envs_idx = ( + (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0) + .nonzero(as_tuple=False) + .flatten() + ) + self._resample_commands(envs_idx) + + # Termination check + self._check_termination() + time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten() + self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=gs.device, dtype=gs.tc_float) + self.extras["time_outs"][time_out_idx] = 1.0 + + self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten()) + + # Rewards and observations + self._compute_rewards() + self._compute_observations() + + self.last_actions[:] = self.actions[:] + self.last_dof_vel[:] = self.dof_vel[:] + self.extras["observations"]["critic"] = self.obs_buf + + return self.obs_buf, self.rew_buf, self.reset_buf, self.extras + + def get_observations(self): + self.extras["observations"]["critic"] = self.obs_buf + return self.obs_buf, self.extras + + def get_privileged_observations(self): + return None + + def reset_idx(self, envs_idx): + if len(envs_idx) == 0: + return + + self._reset_robot_pose(envs_idx) + + # Reset buffers + self.last_actions[envs_idx] = 0.0 + self.last_dof_vel[envs_idx] = 0.0 + self.episode_length_buf[envs_idx] = 0 + self.reset_buf[envs_idx] = True + + # Fill extras + self.extras["episode"] = {} + for key in self.episode_sums.keys(): + self.extras["episode"]["rew_" + key] = ( + torch.mean(self.episode_sums[key][envs_idx]).item() / self.env_cfg["episode_length_s"] + ) + self.episode_sums[key][envs_idx] = 0.0 + + self._resample_commands(envs_idx) + + def reset(self): + self.reset_buf[:] = True + self.reset_idx(torch.arange(self.num_envs, device=gs.device)) + return self.obs_buf, None diff --git a/pilla_rl/envs/recovery_env.py b/pilla_rl/envs/recovery_env.py new file mode 100644 index 0000000..b363880 --- /dev/null +++ b/pilla_rl/envs/recovery_env.py @@ -0,0 +1,137 @@ +"""RecoveryEnv — thin subclass of BaseQuadrupedEnv for upside-down recovery / +standup tasks. + +Differences from the base class: +* Very lenient termination (only base_pos z < 0.05 or episode timeout) +* Upside-down biased reset (70% upside-down ~180° roll, 30% on sides) +* 48-dim observation vector (adds base_euler * pi/180) +""" + +import torch +import genesis as gs +from genesis.utils.geom import xyz_to_quat + +from pilla_rl.envs.base_env import BaseQuadrupedEnv, _gs_rand_float + + +class RecoveryEnv(BaseQuadrupedEnv): + """Recovery / upside-down standup task environment.""" + + # ------------------------------------------------------------------ + # Termination: only z < 0.05 or episode timeout (no orientation limit) + # ------------------------------------------------------------------ + + def _check_termination(self): + self.reset_buf = self.episode_length_buf > self.max_episode_length + self.reset_buf |= self.base_pos[:, 2] < 0.05 + + # ------------------------------------------------------------------ + # Reset: upside-down biased orientations (70/30 split) + # ------------------------------------------------------------------ + + def _reset_robot_pose(self, envs_idx): + dofs_lower_limits, dofs_upper_limits = self.robot.get_dofs_limit(self.motors_dof_idx) + + # Random DOF positions within joint limits + self.dof_pos[envs_idx] = _gs_rand_float( + lower=dofs_lower_limits, + upper=dofs_upper_limits, + shape=(len(envs_idx), self.num_actions), + device=gs.device, + ) + self.dof_vel[envs_idx] = 0.0 + self.robot.set_dofs_position( + position=self.dof_pos[envs_idx], + dofs_idx_local=self.motors_dof_idx, + zero_velocity=True, + envs_idx=envs_idx, + ) + + # Base position — keep some height for recovery + self.base_pos[envs_idx] = self.base_init_pos + _gs_rand_float( + lower=torch.tensor([-0.2, -0.2, 0.15], device=gs.device), + upper=torch.tensor([0.2, 0.2, 0.5], device=gs.device), + shape=(len(envs_idx), 3), + device=gs.device, + ) + + # 70% upside-down, 30% on sides + prob_upside_down = torch.rand(len(envs_idx), device=gs.device) + random_euler = torch.zeros((len(envs_idx), 3), device=gs.device) + + upside_down_mask = prob_upside_down < 0.7 + upside_down_idx = upside_down_mask.nonzero(as_tuple=False).flatten() + if len(upside_down_idx) > 0: + # Roll: 150° to 210° (around upside-down) + random_euler[upside_down_idx, 0] = _gs_rand_float( + lower=torch.tensor(150.0, device=gs.device) * torch.pi / 180, + upper=torch.tensor(210.0, device=gs.device) * torch.pi / 180, + shape=(len(upside_down_idx),), + device=gs.device, + ) + # Pitch: -60° to 60° for variety + random_euler[upside_down_idx, 1] = _gs_rand_float( + lower=torch.tensor(-60.0, device=gs.device) * torch.pi / 180, + upper=torch.tensor(60.0, device=gs.device) * torch.pi / 180, + shape=(len(upside_down_idx),), + device=gs.device, + ) + # Yaw: full range + random_euler[upside_down_idx, 2] = _gs_rand_float( + lower=torch.tensor(-180.0, device=gs.device) * torch.pi / 180, + upper=torch.tensor(180.0, device=gs.device) * torch.pi / 180, + shape=(len(upside_down_idx),), + device=gs.device, + ) + + other_mask = ~upside_down_mask + other_idx = other_mask.nonzero(as_tuple=False).flatten() + if len(other_idx) > 0: + # Roll: 60° to 120° or -120° to -60° (on sides) + side_roll = torch.rand(len(other_idx), device=gs.device) + random_euler[other_idx, 0] = torch.where( + side_roll < 0.5, + _gs_rand_float(60.0, 120.0, (len(other_idx),), gs.device) * torch.pi / 180, + _gs_rand_float(-120.0, -60.0, (len(other_idx),), gs.device) * torch.pi / 180, + ) + # Pitch: -90° to 90° + random_euler[other_idx, 1] = _gs_rand_float( + lower=torch.tensor(-90.0, device=gs.device) * torch.pi / 180, + upper=torch.tensor(90.0, device=gs.device) * torch.pi / 180, + shape=(len(other_idx),), + device=gs.device, + ) + # Yaw: full range + random_euler[other_idx, 2] = _gs_rand_float( + lower=torch.tensor(-180.0, device=gs.device) * torch.pi / 180, + upper=torch.tensor(180.0, device=gs.device) * torch.pi / 180, + shape=(len(other_idx),), + device=gs.device, + ) + + random_quat = xyz_to_quat(random_euler, rpy=True, degrees=False) + self.base_quat[envs_idx] = random_quat + + self.robot.set_pos(self.base_pos[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.robot.set_quat(self.base_quat[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.base_lin_vel[envs_idx] = 0 + self.base_ang_vel[envs_idx] = 0 + self.robot.zero_all_dofs_velocity(envs_idx) + + # ------------------------------------------------------------------ + # Observations: 48-dim (base + euler angles in radians) + # ------------------------------------------------------------------ + + def _compute_observations(self): + self.obs_buf = torch.cat( + [ + self.base_ang_vel * self.obs_scales["ang_vel"], # 3 + self.projected_gravity, # 3 + self.commands * self.commands_scale, # 3 + (self.dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"], # 12 + self.dof_vel * self.obs_scales["dof_vel"], # 12 + self.actions, # 12 + self.base_euler * torch.pi / 180, # 3 + ], + axis=-1, + ) diff --git a/pilla_rl/envs/standup_env.py b/pilla_rl/envs/standup_env.py new file mode 100644 index 0000000..e3cdb85 --- /dev/null +++ b/pilla_rl/envs/standup_env.py @@ -0,0 +1,89 @@ +"""StandupEnv — thin subclass of BaseQuadrupedEnv for the stand-up task. + +Differences from the base class: +* Lenient termination (no pitch/roll limit; only episode timeout) +* Random pose reset (random euler angles up to ±60°; random DOF positions) +* 48-dim observation vector (adds base_euler * pi/180) +""" + +import torch +import genesis as gs +from genesis.utils.geom import xyz_to_quat + +from pilla_rl.envs.base_env import BaseQuadrupedEnv, _gs_rand_float + + +class StandupEnv(BaseQuadrupedEnv): + """Stand-up task environment with randomised starting pose.""" + + # ------------------------------------------------------------------ + # Termination: only episode length (no orientation limit) + # ------------------------------------------------------------------ + + def _check_termination(self): + self.reset_buf = self.episode_length_buf > self.max_episode_length + + # ------------------------------------------------------------------ + # Reset: random euler angles and random DOF positions + # ------------------------------------------------------------------ + + def _reset_robot_pose(self, envs_idx): + dofs_lower_limits, dofs_upper_limits = self.robot.get_dofs_limit(self.motors_dof_idx) + + # Random DOF positions within joint limits + self.dof_pos[envs_idx] = _gs_rand_float( + lower=dofs_lower_limits, + upper=dofs_upper_limits, + shape=(len(envs_idx), self.num_actions), + device=gs.device, + ) + self.dof_vel[envs_idx] = 0.0 + self.robot.set_dofs_position( + position=self.dof_pos[envs_idx], + dofs_idx_local=self.motors_dof_idx, + zero_velocity=True, + envs_idx=envs_idx, + ) + + # Random base position with some variation + self.base_pos[envs_idx] = self.base_init_pos + _gs_rand_float( + lower=torch.tensor([-0.3, -0.3, 0.05], device=gs.device), + upper=torch.tensor([0.3, 0.3, 0.6], device=gs.device), + shape=(len(envs_idx), 3), + device=gs.device, + ) + + # Random euler angles (curriculum: max ±60° for roll/pitch) + max_roll_pitch = torch.tensor(60.0, device=gs.device) + random_euler = _gs_rand_float( + lower=torch.tensor([-max_roll_pitch, -max_roll_pitch, -180.0], device=gs.device) * torch.pi / 180, + upper=torch.tensor([max_roll_pitch, max_roll_pitch, 180.0], device=gs.device) * torch.pi / 180, + shape=(len(envs_idx), 3), + device=gs.device, + ) + random_quat = xyz_to_quat(random_euler, rpy=True, degrees=False) + self.base_quat[envs_idx] = random_quat + + self.robot.set_pos(self.base_pos[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.robot.set_quat(self.base_quat[envs_idx], zero_velocity=False, envs_idx=envs_idx) + self.base_lin_vel[envs_idx] = 0 + self.base_ang_vel[envs_idx] = 0 + self.robot.zero_all_dofs_velocity(envs_idx) + + # ------------------------------------------------------------------ + # Observations: 48-dim (base + euler angles in radians) + # ------------------------------------------------------------------ + + def _compute_observations(self): + self.obs_buf = torch.cat( + [ + self.base_ang_vel * self.obs_scales["ang_vel"], # 3 + self.projected_gravity, # 3 + self.commands * self.commands_scale, # 3 + (self.dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"], # 12 + self.dof_vel * self.obs_scales["dof_vel"], # 12 + self.actions, # 12 + self.base_euler * torch.pi / 180, # 3 + ], + axis=-1, + ) diff --git a/pilla_rl/envs/walk_env.py b/pilla_rl/envs/walk_env.py new file mode 100644 index 0000000..391b3b9 --- /dev/null +++ b/pilla_rl/envs/walk_env.py @@ -0,0 +1,13 @@ +"""WalkEnv — thin subclass of BaseQuadrupedEnv for the walking task. + +Uses all defaults from the base class: +* termination on pitch/roll limit or episode timeout +* reset to upright standing pose +* 45-dim observation vector +""" + +from pilla_rl.envs.base_env import BaseQuadrupedEnv + + +class WalkEnv(BaseQuadrupedEnv): + """Walking task environment. All behaviour is inherited from the base class.""" diff --git a/pilla_rl/evaluate.py b/pilla_rl/evaluate.py new file mode 100644 index 0000000..175507b --- /dev/null +++ b/pilla_rl/evaluate.py @@ -0,0 +1,91 @@ +"""Unified evaluation entry point for all pilla_rl tasks. + +Usage:: + + python -m pilla_rl.evaluate \\ + --config pilla_rl/configs/tasks/walk.yaml \\ + --checkpoint logs/go2-walking/model_5000.pt + +Dependencies: rsl-rl-lib==2.3.3, pyyaml +""" + +import argparse +import os +from importlib import metadata + +try: + try: + if metadata.version("rsl-rl"): + raise ImportError + except metadata.PackageNotFoundError: + if metadata.version("rsl-rl-lib") != "2.3.3": + raise ImportError +except (metadata.PackageNotFoundError, ImportError) as e: + raise ImportError("Please uninstall 'rsl_rl' and install 'rsl-rl-lib==2.3.3'.") from e + +import torch +from rsl_rl.runners import OnPolicyRunner + +import genesis as gs + +from pilla_rl.config_loader import load_task_config, instantiate_env + + +def main(): + parser = argparse.ArgumentParser(description="pilla_rl unified evaluation script") + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to YAML task config file (e.g. pilla_rl/configs/tasks/walk.yaml)", + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.pt file)", + ) + parser.add_argument("--num_envs", type=int, default=1, help="Number of parallel environments (default: 1)") + args = parser.parse_args() + + gs.init(logging_level="warning") + + # ------------------------------------------------------------------ + # Load config — disable rewards during evaluation + # ------------------------------------------------------------------ + config = load_task_config(args.config) + config["reward_cfg"]["reward_scales"] = {} + + # ------------------------------------------------------------------ + # Create environment with viewer + # ------------------------------------------------------------------ + env = instantiate_env(config, num_envs=args.num_envs, show_viewer=True) + + # ------------------------------------------------------------------ + # Load policy + # ------------------------------------------------------------------ + exp_name = config.get("train", {}).get("exp_name", "eval_run") + log_dir = os.path.join("logs", exp_name + "_eval") + os.makedirs(log_dir, exist_ok=True) + + train_cfg = config.get("train", {}) + train_cfg.setdefault("runner", {}) + train_cfg["runner"]["resume"] = True + train_cfg["runner"]["resume_path"] = args.checkpoint + + runner = OnPolicyRunner(env, train_cfg, log_dir, device=gs.device) + runner.load(args.checkpoint) + policy = runner.get_inference_policy(device=gs.device) + + # ------------------------------------------------------------------ + # Inference loop + # ------------------------------------------------------------------ + obs, _ = env.reset() + with torch.no_grad(): + while True: + actions = policy(obs) + obs, _, _, _ = env.step(actions) + + +if __name__ == "__main__": + main() diff --git a/pilla_rl/rewards/__init__.py b/pilla_rl/rewards/__init__.py new file mode 100644 index 0000000..dc76dc4 --- /dev/null +++ b/pilla_rl/rewards/__init__.py @@ -0,0 +1,5 @@ +"""pilla_rl.rewards — centralized reward function library.""" + +from pilla_rl.rewards.reward_functions import REWARD_REGISTRY # noqa: F401 + +__all__ = ["REWARD_REGISTRY"] diff --git a/pilla_rl/rewards/reward_functions.py b/pilla_rl/rewards/reward_functions.py new file mode 100644 index 0000000..7884bd2 --- /dev/null +++ b/pilla_rl/rewards/reward_functions.py @@ -0,0 +1,255 @@ +"""Centralized reward function library for pilla_rl. + +All reward functions are standalone pure functions that accept explicit tensor +arguments rather than accessing environment state via ``self``. Each function +returns a per-environment reward tensor of shape ``(num_envs,)``. + +The ``REWARD_REGISTRY`` dict maps string names to functions so that +``BaseQuadrupedEnv`` can look up and call them by name. +""" + +import torch + + +# --------------------------------------------------------------------------- +# Basic locomotion rewards (from go2/walk/go2_env.py) +# --------------------------------------------------------------------------- + + +def tracking_lin_vel(commands, base_lin_vel, tracking_sigma): + """Tracking of linear velocity commands (xy axes).""" + lin_vel_error = torch.sum(torch.square(commands[:, :2] - base_lin_vel[:, :2]), dim=1) + return torch.exp(-lin_vel_error / tracking_sigma) + + +def tracking_ang_vel(commands, base_ang_vel, tracking_sigma): + """Tracking of angular velocity commands (yaw).""" + ang_vel_error = torch.square(commands[:, 2] - base_ang_vel[:, 2]) + return torch.exp(-ang_vel_error / tracking_sigma) + + +def lin_vel_z(base_lin_vel): + """Penalize z axis base linear velocity.""" + return torch.square(base_lin_vel[:, 2]) + + +def action_rate(last_actions, actions): + """Penalize changes in actions.""" + return torch.sum(torch.square(last_actions - actions), dim=1) + + +def similar_to_default(dof_pos, default_dof_pos): + """Penalize joint poses far away from default pose.""" + return torch.sum(torch.abs(dof_pos - default_dof_pos), dim=1) + + +def base_height(base_pos, base_height_target): + """Penalize base height away from target.""" + return torch.square(base_pos[:, 2] - base_height_target) + + +# --------------------------------------------------------------------------- +# Standup rewards (from go2/standup_copilot/go2_env.py) +# --------------------------------------------------------------------------- + + +def upright_orientation(base_euler): + """Reward for maintaining upright orientation (small roll and pitch). + + Combines a sharp peak at upright with a gradual component that provides + gradients even when the robot is far from upright. + + Source: go2/standup_copilot/go2_env.py + """ + roll_pitch_penalty = torch.square(base_euler[:, 0] * torch.pi / 180) + torch.square( + base_euler[:, 1] * torch.pi / 180 + ) + sharp_reward = torch.exp(-roll_pitch_penalty / 0.1) + gradual_reward = torch.exp(-roll_pitch_penalty / 2.0) + return 0.7 * sharp_reward + 0.3 * gradual_reward + + +def stability(base_ang_vel): + """Reward for low angular velocity (stability). + + Source: go2/standup_copilot/go2_env.py + """ + ang_vel_penalty = torch.sum(torch.square(base_ang_vel), dim=1) + return torch.exp(-ang_vel_penalty / 0.5) + + +def stand_up_progress(base_pos, base_euler, base_height_target): + """Progressive reward for standing up — combines height and orientation. + + Source: go2/standup_copilot/go2_env.py + """ + height_reward = torch.exp(-torch.square(base_pos[:, 2] - base_height_target) / 0.1) + roll_pitch_penalty = torch.square(base_euler[:, 0] * torch.pi / 180) + torch.square( + base_euler[:, 1] * torch.pi / 180 + ) + sharp_orientation = torch.exp(-roll_pitch_penalty / 0.1) + gradual_orientation = torch.exp(-roll_pitch_penalty / 2.0) + orientation_reward = 0.7 * sharp_orientation + 0.3 * gradual_orientation + return height_reward * orientation_reward + + +def recovery_effort(base_euler, dof_vel): + """Encourage movement when far from upright to prevent the policy from + getting stuck. + + Source: go2/standup_copilot/go2_env.py + """ + roll_pitch_penalty = torch.square(base_euler[:, 0] * torch.pi / 180) + torch.square( + base_euler[:, 1] * torch.pi / 180 + ) + is_far_from_upright = roll_pitch_penalty > (45 * torch.pi / 180) ** 2 + joint_movement = torch.sum(torch.square(dof_vel), dim=1) + movement_reward = torch.exp(-joint_movement / 10.0) + return torch.where(is_far_from_upright, 1.0 - movement_reward, torch.zeros_like(movement_reward)) + + +def joint_regularization(dof_pos, default_dof_pos): + """Encourage joint positions closer to default for stability. + + Source: go2/standup_copilot/go2_env.py + """ + joint_deviation = torch.sum(torch.square(dof_pos - default_dof_pos), dim=1) + return torch.exp(-joint_deviation / 2.0) + + +# --------------------------------------------------------------------------- +# Recovery rewards (from go2/upside_down_recovery/go2_env.py) +# --------------------------------------------------------------------------- + + +def recovery_progress(base_euler, base_pos, base_height_target): + """Main reward for upside-down recovery progress. + + Source: go2/upside_down_recovery/go2_env.py + """ + roll_rad = base_euler[:, 0] * torch.pi / 180 + pitch_rad = base_euler[:, 1] * torch.pi / 180 + roll_progress = (torch.cos(roll_rad) + 1.0) / 2.0 + pitch_progress = (torch.cos(pitch_rad) + 1.0) / 2.0 + orientation_progress = roll_progress * pitch_progress + height_factor = torch.exp(-torch.square(base_pos[:, 2] - base_height_target) / 0.2) + return orientation_progress * (0.7 + 0.3 * height_factor) + + +def legs_not_in_air(dof_vel): + """Encourage legs to be grounded rather than flailing. + + Source: go2/upside_down_recovery/go2_env.py + """ + joint_vel_penalty = torch.sum(torch.square(dof_vel), dim=1) + return torch.exp(-joint_vel_penalty / 20.0) + + +def energy_efficiency(actions): + """Encourage efficient movements — penalize excessive joint torques. + + Source: go2/upside_down_recovery/go2_env.py + """ + action_magnitude = torch.sum(torch.square(actions), dim=1) + return torch.exp(-action_magnitude / 5.0) + + +def forward_progress(base_euler, base_lin_vel): + """Small reward for forward movement during recovery (only when reasonably + upright). + + Source: go2/upside_down_recovery/go2_env.py + """ + roll_rad = base_euler[:, 0] * torch.pi / 180 + is_reasonably_upright = torch.abs(roll_rad) < (torch.pi / 2) + forward_vel = base_lin_vel[:, 0] + forward_reward = torch.clamp(forward_vel, 0.0, 1.0) + return torch.where(is_reasonably_upright, forward_reward, torch.zeros_like(forward_reward)) + + +def minimize_base_roll(base_euler): + """Reward for minimizing roll angle — key for upside-down recovery. + + Source: go2/upside_down_recovery/go2_env.py + """ + roll_rad = torch.abs(base_euler[:, 0] * torch.pi / 180) + return torch.exp(-roll_rad / (torch.pi / 4)) + + +# --------------------------------------------------------------------------- +# Combined upside-down standup rewards (from go2/upside_down_standup/go2_env.py) +# --------------------------------------------------------------------------- + + +def standup_height(base_euler, base_pos, base_height_target): + """Strong reward for achieving proper standing height once upright. + + Source: go2/upside_down_standup/go2_env.py + """ + roll_rad = torch.abs(base_euler[:, 0] * torch.pi / 180) + pitch_rad = torch.abs(base_euler[:, 1] * torch.pi / 180) + is_upright = (roll_rad < torch.pi / 6) & (pitch_rad < torch.pi / 6) + height_error = torch.abs(base_pos[:, 2] - base_height_target) + height_reward = torch.exp(-height_error / 0.1) + return torch.where(is_upright, height_reward, torch.zeros_like(height_reward)) + + +def complete_standup(base_euler, base_pos, base_ang_vel, base_height_target): + """Bonus reward for achieving both upright orientation AND proper height. + + Source: go2/upside_down_standup/go2_env.py + """ + roll_rad = torch.abs(base_euler[:, 0] * torch.pi / 180) + pitch_rad = torch.abs(base_euler[:, 1] * torch.pi / 180) + height_error = torch.abs(base_pos[:, 2] - base_height_target) + is_orientation_good = (roll_rad < torch.pi / 12) & (pitch_rad < torch.pi / 12) + is_height_good = height_error < 0.05 + is_stable = torch.sum(torch.square(base_ang_vel), dim=1) < 2.0 + is_fully_standup = is_orientation_good & is_height_good & is_stable + return torch.where(is_fully_standup, torch.ones_like(roll_rad), torch.zeros_like(roll_rad)) + + +def height_when_upright(base_euler, base_pos, base_height_target): + """Progressive height reward that increases as orientation improves. + + Source: go2/upside_down_standup/go2_env.py + """ + roll_rad = torch.abs(base_euler[:, 0] * torch.pi / 180) + pitch_rad = torch.abs(base_euler[:, 1] * torch.pi / 180) + roll_quality = torch.exp(-roll_rad / (torch.pi / 6)) + pitch_quality = torch.exp(-pitch_rad / (torch.pi / 6)) + orientation_quality = roll_quality * pitch_quality + height_error = torch.abs(base_pos[:, 2] - base_height_target) + height_quality = torch.exp(-height_error / 0.08) + return orientation_quality * height_quality + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +REWARD_REGISTRY: dict = { + # walk / locomotion + "tracking_lin_vel": tracking_lin_vel, + "tracking_ang_vel": tracking_ang_vel, + "lin_vel_z": lin_vel_z, + "action_rate": action_rate, + "similar_to_default": similar_to_default, + "base_height": base_height, + # standup + "upright_orientation": upright_orientation, + "stability": stability, + "stand_up_progress": stand_up_progress, + "recovery_effort": recovery_effort, + "joint_regularization": joint_regularization, + # recovery + "recovery_progress": recovery_progress, + "legs_not_in_air": legs_not_in_air, + "energy_efficiency": energy_efficiency, + "forward_progress": forward_progress, + "minimize_base_roll": minimize_base_roll, + # upside-down standup + "standup_height": standup_height, + "complete_standup": complete_standup, + "height_when_upright": height_when_upright, +} diff --git a/pilla_rl/train.py b/pilla_rl/train.py new file mode 100644 index 0000000..a8a3fc9 --- /dev/null +++ b/pilla_rl/train.py @@ -0,0 +1,132 @@ +"""Unified training entry point for all pilla_rl tasks. + +Usage:: + + # Walk task with default settings + python -m pilla_rl.train --config pilla_rl/configs/tasks/walk.yaml + + # Standup task with custom num_envs + python -m pilla_rl.train --config pilla_rl/configs/tasks/standup.yaml --num_envs 2048 + + # Transfer learning from a walk checkpoint + python -m pilla_rl.train \\ + --config pilla_rl/configs/tasks/standup.yaml \\ + --resume_from logs/go2-walking/model_5000.pt + +Dependencies: rsl-rl-lib==2.3.3, pyyaml +""" + +import argparse +import os +import pickle +import shutil +import copy +from importlib import metadata + +try: + try: + if metadata.version("rsl-rl"): + raise ImportError + except metadata.PackageNotFoundError: + if metadata.version("rsl-rl-lib") != "2.3.3": + raise ImportError +except (metadata.PackageNotFoundError, ImportError) as e: + raise ImportError("Please uninstall 'rsl_rl' and install 'rsl-rl-lib==2.3.3'.") from e + +from rsl_rl.runners import OnPolicyRunner + +import genesis as gs + +from pilla_rl.config_loader import load_task_config, instantiate_env + + +def _build_train_cfg(config: dict, exp_name: str, max_iterations: int) -> dict: + """Construct the rsl-rl train config dict from the YAML ``train`` section.""" + train = copy.deepcopy(config.get("train", {})) + + # Override experiment name and max_iterations if provided + train.setdefault("runner", {}) + train["runner"]["experiment_name"] = exp_name + train["runner"]["max_iterations"] = max_iterations + + return train + + +def main(): + parser = argparse.ArgumentParser(description="pilla_rl unified training script") + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to YAML task config file (e.g. pilla_rl/configs/tasks/walk.yaml)", + ) + parser.add_argument("--num_envs", type=int, default=4096, help="Number of parallel environments") + parser.add_argument( + "--max_iterations", + type=int, + default=None, + help="Override max training iterations (defaults to value in config)", + ) + parser.add_argument( + "--resume_from", + type=str, + default=None, + help="Path to checkpoint (.pt) for transfer learning / curriculum continuation", + ) + parser.add_argument( + "--exp_name", + type=str, + default=None, + help="Override experiment name (defaults to value in config train section)", + ) + args = parser.parse_args() + + gs.init(logging_level="warning") + + # ------------------------------------------------------------------ + # Load config + # ------------------------------------------------------------------ + config = load_task_config(args.config) + + exp_name = args.exp_name or config.get("train", {}).get("exp_name", "pilla_rl_run") + max_iterations = args.max_iterations or config.get("train", {}).get("runner", {}).get("max_iterations", 10000) + + log_dir = os.path.join("logs", exp_name) + + # ------------------------------------------------------------------ + # Build and save configs for backward compatibility with eval scripts + # ------------------------------------------------------------------ + env_cfg = config["env_cfg"] + obs_cfg = config["obs_cfg"] + reward_cfg = config["reward_cfg"] + command_cfg = config["command_cfg"] + train_cfg = _build_train_cfg(config, exp_name, max_iterations) + + if os.path.exists(log_dir): + shutil.rmtree(log_dir) + os.makedirs(log_dir, exist_ok=True) + + pickle.dump( + [env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg], + open(os.path.join(log_dir, "cfgs.pkl"), "wb"), + ) + + # ------------------------------------------------------------------ + # Create environment + # ------------------------------------------------------------------ + env = instantiate_env(config, num_envs=args.num_envs, show_viewer=False) + + # ------------------------------------------------------------------ + # Runner setup — optional checkpoint resume + # ------------------------------------------------------------------ + if args.resume_from is not None: + train_cfg.setdefault("runner", {}) + train_cfg["runner"]["resume"] = True + train_cfg["runner"]["resume_path"] = args.resume_from + + runner = OnPolicyRunner(env, train_cfg, log_dir, device=gs.device) + runner.learn(num_learning_iterations=max_iterations, init_at_random_ep_len=True) + + +if __name__ == "__main__": + main()