diff --git a/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py b/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py index 365c440..0ea956e 100644 --- a/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py +++ b/source/uwlab_tasks/uwlab_tasks/manager_based/locomotion/advance_skills/config/spot/augment.py @@ -89,13 +89,14 @@ def aug_action(actions: torch.Tensor) -> torch.Tensor: return new_actions -def aug_func(obs=None, actions=None, env=None, is_critic=False): +def aug_func(obs=None, actions=None, env=None, obs_type="policy"): aug_obs = None aug_act = None if obs is not None: - if is_critic: + if obs_type == "critic": + aug_obs = aug_observation(obs) + else: aug_obs = aug_observation(obs) - aug_obs = aug_observation(obs) if actions is not None: aug_act = aug_action(actions) return aug_obs, aug_act