From 43aeca25f5a92e3d954fa6fc9b35bd1769f9edd5 Mon Sep 17 00:00:00 2001 From: Sidharth Rajagopal Date: Thu, 27 Mar 2025 15:24:22 -0700 Subject: [PATCH] added obs_type to aug_func --- .../locomotion/advance_skills/config/spot/augment.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py b/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py index 365c440..0ea956e 100644 --- a/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py +++ b/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py @@ -89,13 +89,14 @@ def aug_action(actions: torch.Tensor) -> torch.Tensor: return new_actions -def aug_func(obs=None, actions=None, env=None, is_critic=False): +def aug_func(obs=None, actions=None, env=None, obs_type="policy"): aug_obs = None aug_act = None if obs is not None: - if is_critic: + if obs_type == "critic": + aug_obs = aug_observation(obs) + else: aug_obs = aug_observation(obs) - aug_obs = aug_observation(obs) if actions is not None: aug_act = aug_action(actions) return aug_obs, aug_act