|
7 | 7 | import torch |
8 | 8 | from loguru import logger as log |
9 | 9 |
|
10 | | -from metasim.cfg.objects import ArticulationObjCfg, BaseObjCfg, RigidObjCfg |
| 10 | +from metasim.cfg.objects import ArticulationObjCfg, BaseObjCfg, PrimitiveFrameCfg, RigidObjCfg |
11 | 11 | from metasim.cfg.scenario import ScenarioCfg |
12 | 12 | from metasim.sim import BaseSimHandler, EnvWrapper, IdentityEnvWrapper |
13 | 13 | from metasim.types import Action, EnvState, Extra, Obs, Reward, Success, TimeOut |
@@ -128,6 +128,23 @@ def step(self, action: list[Action]) -> tuple[Obs, Reward, Success, TimeOut, Ext |
128 | 128 | time_out = time_out.cpu() |
129 | 129 | success = self.checker.check(self) |
130 | 130 | states = self.get_states() |
| 131 | + |
| 132 | + ## TODO: organize this |
| 133 | + for obj in self.objects: |
| 134 | + if isinstance(obj, PrimitiveFrameCfg): |
| 135 | + if obj.base_link is None: |
| 136 | + pos = torch.zeros((self.num_envs, 3), device=self.device) |
| 137 | + rot = torch.zeros((self.num_envs, 4), device=self.device) |
| 138 | + rot[:, 0] = 1.0 |
| 139 | + elif isinstance(obj.base_link, str): |
| 140 | + pos, rot = (states.objects | states.robots)[obj.base_link].root_state[:, :7].split([3, 4], dim=-1) |
| 141 | + else: |
| 142 | + base_obj_name = obj.base_link[0] |
| 143 | + base_body_name = obj.base_link[1] |
| 144 | + merged_states = states.objects | states.robots |
| 145 | + body_idx = merged_states[base_obj_name].body_names.index(base_body_name) |
| 146 | + pos, rot = merged_states[base_obj_name].body_state[:, body_idx, :7].split([3, 4], dim=-1) |
| 147 | + self._set_object_pose(obj, pos, rot) |
131 | 148 | return states, None, success, time_out, extras |
132 | 149 |
|
133 | 150 | def reset(self, env_ids: list[int] | None = None) -> tuple[list[EnvState], Extra]: |
|
0 commit comments