Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions genesis/engine/solvers/kinematic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,23 @@ def substep_pre_coupling(self, f):
def substep_pre_coupling_grad(self, f):
pass

def _ensure_forward_vel_updated(self):
if not self._is_forward_vel_updated:
kernel_forward_velocity(
self.scene._envs_idx,
links_state=self.links_state,
links_info=self.links_info,
joints_info=self.joints_info,
dofs_state=self.dofs_state,
entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
is_backward=False,
)
self._is_forward_vel_updated = True

def substep_post_coupling(self, f):
if not self._is_forward_pos_updated or not self._is_forward_vel_updated:
if not self._is_forward_pos_updated:
kernel_forward_kinematics(
self.scene._envs_idx,
links_state=self.links_state,
Expand All @@ -461,7 +476,6 @@ def substep_post_coupling(self, f):
static_rigid_sim_config=self._static_rigid_sim_config,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True

def substep_post_coupling_grad(self, f):
pass
Expand Down Expand Up @@ -578,6 +592,17 @@ def set_state(self, f, state, envs_idx=None):
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
)
kernel_forward_velocity(
envs_idx,
links_state=self.links_state,
links_info=self.links_info,
joints_info=self.joints_info,
dofs_state=self.dofs_state,
entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
is_backward=False,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True

Expand Down Expand Up @@ -680,7 +705,7 @@ def set_base_links_pos(self, pos, links_idx=None, envs_idx=None, *, relative=Fal
static_rigid_sim_config=self._static_rigid_sim_config,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True
self._is_forward_vel_updated = False

def set_base_links_pos_grad(self, links_idx, envs_idx, relative, pos_grad):
if links_idx is None:
Expand Down Expand Up @@ -734,7 +759,7 @@ def set_base_links_quat(self, quat, links_idx=None, envs_idx=None, *, relative=F
static_rigid_sim_config=self._static_rigid_sim_config,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True
self._is_forward_vel_updated = False

def set_base_links_quat_grad(self, links_idx, envs_idx, relative, quat_grad):
if links_idx is None:
Expand Down Expand Up @@ -805,7 +830,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False):
static_rigid_sim_config=self._static_rigid_sim_config,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True
self._is_forward_vel_updated = False
else:
self._is_forward_pos_updated = False
self._is_forward_vel_updated = False
Expand Down Expand Up @@ -917,7 +942,7 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None):
static_rigid_sim_config=self._static_rigid_sim_config,
)
self._is_forward_pos_updated = True
self._is_forward_vel_updated = True
self._is_forward_vel_updated = False

def get_links_pos(self, links_idx=None, envs_idx=None):
if not gs.use_zerocopy:
Expand All @@ -932,6 +957,7 @@ def get_links_quat(self, links_idx=None, envs_idx=None):
return tensor[0] if self.n_envs == 0 else tensor

def get_links_vel(self, links_idx=None, envs_idx=None):
self._ensure_forward_vel_updated()
if gs.use_zerocopy:
mask = (0, *indices_to_mask(links_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, links_idx)
cd_vel = qd_to_torch(self.links_state.cd_vel, transpose=True)
Expand All @@ -949,6 +975,7 @@ def get_links_vel(self, links_idx=None, envs_idx=None):
return _tensor

def get_links_ang(self, links_idx=None, envs_idx=None):
self._ensure_forward_vel_updated()
tensor = qd_to_torch(self.links_state.cd_ang, envs_idx, links_idx, transpose=True, copy=True)
return tensor[0] if self.n_envs == 0 else tensor

Expand Down
22 changes: 0 additions & 22 deletions genesis/engine/solvers/rigid/abd/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,6 @@ def kernel_forward_kinematics(
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)
func_forward_velocity_batch(
i_b=i_b,
entities_info=entities_info,
links_info=links_info,
links_state=links_state,
joints_info=joints_info,
dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)


@qd.kernel(fastcache=gs.use_fastcache)
Expand Down Expand Up @@ -211,17 +200,6 @@ def kernel_masked_forward_kinematics(
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)
func_forward_velocity_batch(
i_b=i_b,
entities_info=entities_info,
links_info=links_info,
links_state=links_state,
joints_info=joints_info,
dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)


@qd.kernel(fastcache=gs.use_fastcache)
Expand Down
Loading