From f8e836a0bb098c6ceee2140e24c8267f98d50c51 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 3 Dec 2025 20:22:23 -0800 Subject: [PATCH 1/9] migrated add inequality functions in constraint solver --- .../solvers/rigid/constraint_solver_decomp.py | 346 ++++++++++-------- .../solvers/rigid/rigid_solver_decomp.py | 13 + genesis/utils/array_class.py | 2 + 3 files changed, 218 insertions(+), 143 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index ba93534bc3..02abceb498 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -437,94 +437,113 @@ def add_collision_constraints( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_col in range(collider_state.n_contacts[i_b]): - contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b] - contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b] - - contact_data_pos = collider_state.contact_data.pos[i_col, i_b] - contact_data_normal = collider_state.contact_data.normal[i_col, i_b] - contact_data_friction = collider_state.contact_data.friction[i_col, i_b] - contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b] - contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b] - - link_a = contact_data_link_a - link_b = contact_data_link_b - link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a - link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b + for i_b in range(dofs_state.ctrl_mode.shape[1]): + EPS = rigid_global_info.EPS[None] + n_dofs = dofs_state.ctrl_mode.shape[0] + + for i_col in ( + range(collider_state.n_contacts[i_b]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_contact_pairs)) + ): + if i_col < collider_state.n_contacts[i_b]: + contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b] + contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b] - d1, d2 = gu.ti_orthogonals(contact_data_normal) + contact_data_pos = collider_state.contact_data.pos[i_col, i_b] + contact_data_normal = collider_state.contact_data.normal[i_col, i_b] + contact_data_friction = collider_state.contact_data.friction[i_col, i_b] + contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b] + contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b] - invweight = links_info.invweight[link_a_maybe_batch][0] - if link_b > -1: - invweight = invweight + links_info.invweight[link_b_maybe_batch][0] + link_a = contact_data_link_a + link_b = contact_data_link_b + link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b - for i in range(4): - d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) - n = d * contact_data_friction - contact_data_normal + d1, d2 = gu.ti_orthogonals(contact_data_normal) - n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) - if ti.static(static_rigid_sim_config.sparse_solve): - for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): - i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] - constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) - else: - for i_d in range(n_dofs): - constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) + invweight = links_info.invweight[link_a_maybe_batch][0] + if link_b > -1: + invweight = invweight + links_info.invweight[link_b_maybe_batch][0] - con_n_relevant_dofs = 0 - jac_qvel = gs.ti_float(0.0) - for i_ab in range(2): - sign = gs.ti_float(-1.0) - link = link_a - if i_ab == 1: - sign = gs.ti_float(1.0) - link = link_b + for i in range(4) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(4)): + d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) + n = d * contact_data_friction - contact_data_normal - while link > -1: - link_maybe_batch = [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link + n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): + i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] + constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) + else: + for i_d in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) + + con_n_relevant_dofs = 0 + jac_qvel = gs.ti_float(0.0) + for i_ab in range(2) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(2)): + sign = gs.ti_float(-1.0) + link = link_a + if i_ab == 1: + sign = gs.ti_float(1.0) + link = link_b + + # FIXME: Set number of iterations to look for parent to certain value for autodiff + # while link > -1: + for i_parent in ( + range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(20)) + ): + if link > -1: + link_maybe_batch = ( + [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link + ) - # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending - for i_d_ in range(links_info.n_dofs[link_maybe_batch]): - i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ + # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending + for i_d_ in ( + range(links_info.n_dofs[link_maybe_batch]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + if i_d_ < links_info.n_dofs[link_maybe_batch]: + i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + cdof_ang = dofs_state.cdof_ang[i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_d, i_b] - t_quat = gu.ti_identity_quat() - t_pos = contact_data_pos - links_state.root_COM[link, i_b] - _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) + t_quat = gu.ti_identity_quat() + t_pos = contact_data_pos - links_state.root_COM[link, i_b] + _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) - diff = sign * vel - jac = diff @ n - jac_qvel = jac_qvel + jac * dofs_state.vel[i_d, i_b] - constraint_state.jac[n_con, i_d, i_b] = constraint_state.jac[n_con, i_d, i_b] + jac + diff = sign * vel + jac = diff @ n + jac_qvel += jac * dofs_state.vel[i_d, i_b] + constraint_state.jac[n_con, i_d, i_b] += jac - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d - con_n_relevant_dofs += 1 + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d + con_n_relevant_dofs += 1 - link = links_info.parent_idx[link_maybe_batch] + link = links_info.parent_idx[link_maybe_batch] - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs - imp, aref = gu.imp_aref( - contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration - ) + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs + imp, aref = gu.imp_aref( + contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration + ) - diag = invweight + contact_data_friction * contact_data_friction * invweight - diag *= 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp - diag = ti.max(diag, EPS) + diag_0 = invweight + contact_data_friction * contact_data_friction * invweight + diag_1 = diag_0 * 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp + diag = ti.max(diag_1, EPS) - constraint_state.diag[n_con, i_b] = diag - constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag + constraint_state.diag[n_con, i_b] = diag + constraint_state.aref[n_con, i_b] = aref + constraint_state.efc_D[n_con, i_b] = 1 / diag @ti.func @@ -1026,52 +1045,69 @@ def add_joint_limit_constraints( constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = constraint_state.jac.shape[2] - n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] - # TODO: sparse mode ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_l in range(n_links): + for i_b in range(constraint_state.jac.shape[2]): + EPS = rigid_global_info.EPS[None] + n_links = links_info.root_idx.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[0] + + for i_l in ( + range(n_links) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_links)) + ): I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - if joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC: - i_q = joints_info.q_start[I_j] - i_d = joints_info.dof_start[I_j] - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] - pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] - pos_delta = min(pos_delta_min, pos_delta_max) - - if pos_delta < 0: - jac = (pos_delta_min < pos_delta_max) * 2 - 1 - jac_qvel = jac * dofs_state.vel[i_d, i_b] - imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta) - diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS) - - n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) - constraint_state.diag[n_con, i_b] = diag - constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag - - if ti.static(static_rigid_sim_config.sparse_solve): - for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): - i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b] - constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) - else: - for i_d2 in range(n_dofs): - constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) - constraint_state.jac[n_con, i_d, i_b] = jac + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = ( + i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l]) + ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + if ( + joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE + or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC + ): + i_q = joints_info.q_start[I_j] + i_d = joints_info.dof_start[I_j] + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] + pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] + pos_delta = min(pos_delta_min, pos_delta_max) + + if pos_delta < 0: + jac = (pos_delta_min < pos_delta_max) * 2 - 1 + jac_qvel = jac * dofs_state.vel[i_d, i_b] + imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta) + diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS) + + n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + constraint_state.diag[n_con, i_b] = diag + constraint_state.aref[n_con, i_b] = aref + constraint_state.efc_D[n_con, i_b] = 1 / diag - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1 - constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): + i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b] + constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) + else: + for i_d2 in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) + constraint_state.jac[n_con, i_d, i_b] = jac + + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1 + constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d @ti.func @@ -1084,46 +1120,70 @@ def add_frictionloss_constraints( constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = constraint_state.jac.shape[2] - n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] - # TODO: sparse mode # FIXME: The condition `if dofs_info.frictionloss[I_d] > EPS:` is not correctly evaluated on Apple Metal # if `serialize=True`... ti.loop_config( serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL and gs.backend != gs.metal) ) - for i_b in range(_B): + for i_b in range(constraint_state.jac.shape[2]): constraint_state.n_constraints_frictionloss[i_b] = 0 - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - for i_d in range(joints_info.dof_start[I_j], joints_info.dof_end[I_j]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + EPS = rigid_global_info.EPS[None] + n_links = links_info.root_idx.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[0] - if dofs_info.frictionloss[I_d] > EPS: - jac = 1.0 - jac_qvel = jac * dofs_state.vel[i_d, i_b] - imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0) - diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS) - - i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) - ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1) + for i_l in ( + range(n_links) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_links)) + ): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - constraint_state.diag[i_con, i_b] = diag - constraint_state.aref[i_con, i_b] = aref - constraint_state.efc_D[i_con, i_b] = 1.0 / diag - constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d] - for i_d2 in range(n_dofs): - constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0) - constraint_state.jac[i_con, i_d, i_b] = jac + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = ( + i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l]) + ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + for i_d_ in ( + range(joints_info.dof_start[I_j], joints_info.dof_end[I_j]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + joints_info.dof_start[I_j]) + ) + if i_d < joints_info.dof_end[I_j]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + + if dofs_info.frictionloss[I_d] > EPS: + jac = 1.0 + jac_qvel = jac * dofs_state.vel[i_d, i_b] + imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0) + diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS) + + i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1) + + constraint_state.diag[i_con, i_b] = diag + constraint_state.aref[i_con, i_b] = aref + constraint_state.efc_D[i_con, i_b] = 1.0 / diag + constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d] + for i_d2 in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0) + constraint_state.jac[i_con, i_d, i_b] = jac @ti.func diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index e666d31559..44b6917efc 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -271,6 +271,8 @@ def build(self): max_n_qs_per_link=max(link.n_qs for link in self.links) if self.links else 0, n_links=self._n_links, n_geoms=self._n_geoms, + n_dofs=self._n_dofs, + # max_contact_pairs=self.collider._collider_state.contact_data.geom_a.shape[0], ) self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig(**static_rigid_sim_config) else: @@ -299,6 +301,9 @@ def build(self): if getattr(self._options, "noslip_iterations", 0) > 0: gs.raise_exception("Noslip is not supported yet when requires_grad is True.") + if getattr(self._options, "sparse_solve", False): + gs.raise_exception("Sparse solve is not supported yet when requires_grad is True.") + # when the migration is finished, we will remove the about two lines self._func_vel_at_point = func_vel_at_point self._func_apply_coupling_force = func_apply_coupling_force @@ -816,6 +821,9 @@ def _init_sdf(self): def _init_collider(self): self.collider = Collider(self) + if self.sim.options.requires_grad: + self._static_rigid_sim_config.max_contact_pairs = self.collider._collider_state.contact_data.geom_a.shape[0] + if self.collider._collider_static_config.has_terrain: link_idx_ = next( i for i, _type in enumerate(ti_to_numpy(self.geoms_info.type)) if _type == gs.GEOM_TYPE.TERRAIN @@ -1241,6 +1249,11 @@ def substep_pre_coupling_grad(self, f): contact_island_state=self.constraint_solver.contact_island.contact_island_state, ) + if not self._disable_constraint: + dL_dqacc = self.dofs_state.acc.grad.to_numpy() + self.constraint_solver.backward(dL_dqacc) + pass + # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index a20a254636..40f0bc9e93 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -1813,6 +1813,8 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta): max_n_geoms_per_entity: int = -1 n_links: int = -1 n_geoms: int = -1 + n_dofs: int = -1 + max_contact_pairs: int = -1 # =========================================== DataManager =========================================== From 123f228443c4cbe14d3460189fc489d38cea4545 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Thu, 4 Dec 2025 02:02:35 -0800 Subject: [PATCH 2/9] stash --- examples/diffrigid/pingpong.py | 103 +++++++++++ .../entities/rigid_entity/rigid_entity.py | 2 +- .../solvers/rigid/constraint_solver_decomp.py | 55 ++++-- .../solvers/rigid/rigid_solver_decomp.py | 164 +++++++++++++++--- genesis/utils/array_class.py | 11 +- 5 files changed, 292 insertions(+), 43 deletions(-) create mode 100644 examples/diffrigid/pingpong.py diff --git a/examples/diffrigid/pingpong.py b/examples/diffrigid/pingpong.py new file mode 100644 index 0000000000..e1751cf5bc --- /dev/null +++ b/examples/diffrigid/pingpong.py @@ -0,0 +1,103 @@ +import torch +import genesis as gs + +show_viewer = True + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 200 +substeps = 4 +goal_pos = gs.tensor([0.0, 0.1, -0.1]) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.5), + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + material=gs.materials.Rigid( + rho=0.001, + ) +) +# if show_viewer: +# target = scene.add_entity( +# gs.morphs.Sphere( +# pos=goal_pos.cpu().numpy().tolist(), +# radius=0.1, +# ), +# surface=gs.surfaces.Default( +# color=(0.0, 0.9, 0.0, 0.5), +# ), +# ) + +racket = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.01), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +scene.build() + +num_iter = 200 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +prev_loss = float('inf') +for iter in range(num_iter): + scene.reset() + + racket.set_pos(init_pos) + racket.set_quat(init_quat) + #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0])) + + losses = [] + for i in range(horizon): + scene.step() + # ball_state = ball.get_state() + # ball_pos = ball_state.pos + # losses.append(torch.abs(ball_pos - goal_pos).sum()) + # if show_viewer: + # target.set_pos(goal_pos) + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + # loss = sum(losses) / len(losses) + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + prev_loss = loss.item() + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index d57776a15c..7e8db78c42 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -2330,7 +2330,7 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * @gs.assert_built def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad): - dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data) @gs.assert_built diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index 02abceb498..79ad804b52 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -174,6 +174,20 @@ def add_inequality_constraints(self): static_rigid_sim_config=solver._static_rigid_sim_config, ) + def add_inequality_constraints_grad(self): + solver = self._solver + add_inequality_constraints.grad( + links_info=solver.links_info, + links_state=solver.links_state, + dofs_state=solver.dofs_state, + dofs_info=solver.dofs_info, + joints_info=solver.joints_info, + constraint_state=self.constraint_state, + collider_state=self._collider._collider_state, + rigid_global_info=solver._rigid_global_info, + static_rigid_sim_config=solver._static_rigid_sim_config, + ) + def resolve(self): solver = self._solver @@ -438,15 +452,20 @@ def add_collision_constraints( static_rigid_sim_config: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(dofs_state.ctrl_mode.shape[1]): + for i_b, i_0 in ( + ti.ndrange(dofs_state.ctrl_mode.shape[1], 1) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.ndrange(dofs_state.ctrl_mode.shape[1], static_rigid_sim_config.max_contact_pairs) + ): EPS = rigid_global_info.EPS[None] n_dofs = dofs_state.ctrl_mode.shape[0] - for i_col in ( + for i_1 in ( range(collider_state.n_contacts[i_b]) if ti.static(not static_rigid_sim_config.is_backward) - else ti.static(range(static_rigid_sim_config.max_contact_pairs)) + else ti.static(range(1)) ): + i_col = i_1 if ti.static(not static_rigid_sim_config.is_backward) else i_0 if i_col < collider_state.n_contacts[i_b]: contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b] contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b] @@ -472,7 +491,7 @@ def add_collision_constraints( d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) n = d * contact_data_friction - contact_data_normal - n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + n_con = i_col * 4 + i # + constraint_state.n_constraints[i_b] if ti.static(static_rigid_sim_config.sparse_solve): for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] @@ -495,9 +514,8 @@ def add_collision_constraints( link = link_b # FIXME: Set number of iterations to look for parent to certain value for autodiff - # while link > -1: for i_parent in ( - range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(20)) + range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(1)) ): if link > -1: link_maybe_batch = ( @@ -531,6 +549,10 @@ def add_collision_constraints( link = links_info.parent_idx[link_maybe_batch] + if ti.static(static_rigid_sim_config.is_backward): + if i_parent == 4 and link > -1: + print("Warning: Number of parents is too large for backward mode in add_collision_constraints") + if ti.static(static_rigid_sim_config.sparse_solve): constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs imp, aref = gu.imp_aref( @@ -545,6 +567,9 @@ def add_collision_constraints( constraint_state.aref[n_con, i_b] = aref constraint_state.efc_D[n_con, i_b] = 1 / diag + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(dofs_state.ctrl_mode.shape[1]): + constraint_state.n_constraints[i_b] += 4 * collider_state.n_contacts[i_b] @ti.func def func_equality_connect( @@ -811,15 +836,15 @@ def add_inequality_constraints( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - add_frictionloss_constraints( - links_info=links_info, - joints_info=joints_info, - dofs_info=dofs_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) + # add_frictionloss_constraints( + # links_info=links_info, + # joints_info=joints_info, + # dofs_info=dofs_info, + # dofs_state=dofs_state, + # rigid_global_info=rigid_global_info, + # constraint_state=constraint_state, + # static_rigid_sim_config=static_rigid_sim_config, + # ) if ti.static(static_rigid_sim_config.enable_collision): add_collision_constraints( links_info=links_info, diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 44b6917efc..46b441a89e 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -272,6 +272,7 @@ def build(self): n_links=self._n_links, n_geoms=self._n_geoms, n_dofs=self._n_dofs, + n_entities=self._n_entities, # max_contact_pairs=self.collider._collider_state.contact_data.geom_a.shape[0], ) self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig(**static_rigid_sim_config) @@ -1195,7 +1196,31 @@ def substep_pre_coupling_grad(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, ) - kernel_update_cartesian_space.grad( + kernel_update_geoms.grad( + envs_idx, + entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) + kernel_COM_links.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) + kernel_forward_kinematics.grad( links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, @@ -1209,6 +1234,20 @@ def substep_pre_coupling_grad(self, f): static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, ) + # kernel_update_cartesian_space.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) is_grad_valid = kernel_begin_backward_substep( f=f, @@ -1250,25 +1289,47 @@ def substep_pre_coupling_grad(self, f): ) if not self._disable_constraint: + # Solver backward dL_dqacc = self.dofs_state.acc.grad.to_numpy() - self.constraint_solver.backward(dL_dqacc) - pass + self.dofs_state.acc.grad.fill(0.0) - # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, - # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). - # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. - # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through - # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation - # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. - # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a - # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid - # the data access violation. However, both of these require major changes. - kernel_compute_qacc.grad( - dofs_state=self.dofs_state, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + self.constraint_solver.backward(dL_dqacc) + dL_dM = self.constraint_solver.constraint_state.dL_dM.to_numpy() + dL_djac = self.constraint_solver.constraint_state.dL_djac.to_numpy() + dL_daref = self.constraint_solver.constraint_state.dL_daref.to_numpy() + dL_defc_D = self.constraint_solver.constraint_state.dL_defc_D.to_numpy() + dL_dforce = self.constraint_solver.constraint_state.dL_dforce.to_numpy() + + self._rigid_global_info.mass_mat.grad.from_numpy(dL_dM) + self.constraint_solver.constraint_state.jac.grad.from_numpy(dL_djac) + self.constraint_solver.constraint_state.aref.grad.from_numpy(dL_daref) + self.constraint_solver.constraint_state.efc_D.grad.from_numpy(dL_defc_D) + self.dofs_state.force.grad.from_numpy(dL_dforce) + + self.constraint_solver.constraint_state.n_constraints.fill(0) + self.constraint_solver.add_inequality_constraints_grad() + + # Collider backward + dL_dcontact_pos = self.collider._collider_state.contact_data.pos.grad.to_numpy() + dL_dcontact_normal = self.collider._collider_state.contact_data.normal.grad.to_numpy() + dL_dcontact_penetration = self.collider._collider_state.contact_data.penetration.grad.to_numpy() + self.collider.backward(dL_dcontact_pos, dL_dcontact_normal, dL_dcontact_penetration) + else: + # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, + # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). + # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. + # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through + # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation + # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. + # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a + # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid + # the data access violation. However, both of these require major changes. + kernel_compute_qacc.grad( + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) kernel_copy_acc( f=f, dofs_state=self.dofs_state, @@ -3915,7 +3976,7 @@ def func_solve_mass_batched( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range(i_0, 0, rigid_global_info.n_awake_entities[i_b], static_rigid_sim_config.is_backward): @@ -4241,6 +4302,65 @@ def kernel_update_cartesian_space( force_update_fixed_geoms=force_update_fixed_geoms, ) +@ti.kernel +def kernel_forward_kinematics( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_forward_kinematics( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + +@ti.kernel +def kernel_COM_links( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_COM_links( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_update_cartesian_space( @@ -4967,7 +5087,7 @@ def func_forward_kinematics( else ( ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range( @@ -5017,7 +5137,7 @@ def func_forward_velocity( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range( @@ -5446,7 +5566,7 @@ def func_update_geoms( ): i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 if func_check_index_range( - i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], static_rigid_sim_config.is_backward + i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], static_rigid_sim_config.use_hibernation ): if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 40f0bc9e93..4f182824fd 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -286,15 +286,15 @@ def get_constraint_state(constraint_solver, solver): efc_AR=V(dtype=gs.ti_float, shape=efc_AR_shape), active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)), prev_active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)), - diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), + aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), Jaref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), efc_force=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), jv=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), quad=V(dtype=gs.ti_float, shape=(len_constraints_, 3, _B)), - jac=V(dtype=gs.ti_float, shape=jac_shape), + jac=V(dtype=gs.ti_float, shape=jac_shape, needs_grad=solver._requires_grad), jac_relevant_dofs=V(dtype=gs.ti_int, shape=jac_relevant_dofs_shape), jac_n_relevant_dofs=V(dtype=gs.ti_int, shape=jac_n_relevant_dofs_shape), # Backward gradients @@ -1814,6 +1814,7 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta): n_links: int = -1 n_geoms: int = -1 n_dofs: int = -1 + n_entities: int = -1 max_contact_pairs: int = -1 From 8c33184cf8b2658ed922d3ca078efc3e4288e229 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Fri, 5 Dec 2025 16:32:06 -0800 Subject: [PATCH 3/9] make grad zero in reset_grad --- examples/differentiable_rigid.py | 92 +++++++++++++++ examples/differentiable_rigid_demo_1.py | 92 +++++++++++++++ examples/diffrigid/one_step.py | 96 +++++++++++++++ examples/diffrigid/slide_ball.py | 110 ++++++++++++++++++ .../solvers/rigid/rigid_solver_decomp.py | 29 +++++ 5 files changed, 419 insertions(+) create mode 100644 examples/differentiable_rigid.py create mode 100644 examples/differentiable_rigid_demo_1.py create mode 100644 examples/diffrigid/one_step.py create mode 100644 examples/diffrigid/slide_ball.py diff --git a/examples/differentiable_rigid.py b/examples/differentiable_rigid.py new file mode 100644 index 0000000000..9ee90ac18e --- /dev/null +++ b/examples/differentiable_rigid.py @@ -0,0 +1,92 @@ +import torch +import genesis as gs + +show_viewer = False + +gs.init(precision="32", logging_level="info") + +dt = 1e-2 +horizon = 100 +substeps = 1 +goal_pos = gs.tensor([0.7, 1.0, 0.05]) +goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) +goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, +) + +box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + +scene.build() + +num_iter = 200 +lr = 1e-2 + +init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print("loss: ", loss.item()) + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/examples/differentiable_rigid_demo_1.py b/examples/differentiable_rigid_demo_1.py new file mode 100644 index 0000000000..9ee90ac18e --- /dev/null +++ b/examples/differentiable_rigid_demo_1.py @@ -0,0 +1,92 @@ +import torch +import genesis as gs + +show_viewer = False + +gs.init(precision="32", logging_level="info") + +dt = 1e-2 +horizon = 100 +substeps = 1 +goal_pos = gs.tensor([0.7, 1.0, 0.05]) +goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) +goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, +) + +box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + +scene.build() + +num_iter = 200 +lr = 1e-2 + +init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print("loss: ", loss.item()) + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/examples/diffrigid/one_step.py b/examples/diffrigid/one_step.py new file mode 100644 index 0000000000..a7afd6cbf7 --- /dev/null +++ b/examples/diffrigid/one_step.py @@ -0,0 +1,96 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 1 +substeps = 1 +goal_pos = gs.tensor([0.0, 0.0, 1.0]) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +scene.build() + +num_iter = 400 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +#init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + ground.set_pos(init_pos) + # ground.set_quat(init_quat) + + for i in range(horizon): + scene.step() + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + grad_norm = torch.nn.utils.clip_grad_norm_(init_pos.grad, 1.0) + + # with torch.no_grad(): + # init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + # with torch.no_grad(): + # init_pos.data[0] = 0.0 + # init_pos.data[1] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/slide_ball.py b/examples/diffrigid/slide_ball.py new file mode 100644 index 0000000000..87530b86f8 --- /dev/null +++ b/examples/diffrigid/slide_ball.py @@ -0,0 +1,110 @@ +import torch +import genesis as gs +import matplotlib.pyplot as plt + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 100 +substeps = 4 +goal_pos = gs.tensor([0.0, 0.1, -0.1]) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.11), + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + material=gs.materials.Rigid( + rho=0.001, + ) +) +# if show_viewer: +# target = scene.add_entity( +# gs.morphs.Sphere( +# pos=goal_pos.cpu().numpy().tolist(), +# radius=0.1, +# ), +# surface=gs.surfaces.Default( +# color=(0.0, 0.9, 0.0, 0.5), +# ), +# ) + +racket = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +scene.build() + +num_iter = 200 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + racket.set_pos(init_pos) + racket.set_quat(init_quat) + #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0])) + + for i in range(horizon): + scene.step() + # ball_state = ball.get_state() + # ball_pos = ball_state.pos + # losses.append(torch.abs(ball_pos - goal_pos).sum()) + # if show_viewer: + # target.set_pos(goal_pos) + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + # loss = sum(losses) / len(losses) + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 0ae8e901d9..5dd209b406 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1467,6 +1467,24 @@ def reset_grad(self): entity.reset_grad() self._queried_states.clear() + # zero grad + for state in [ + self._rigid_global_info, + self.links_state, + self.dofs_state, + self.geoms_state, + self.joints_state, + self.entities_state, + self.constraint_solver.constraint_state, + self.collider._collider_state.diff_contact_input, + self.collider._collider_state.contact_data, + self._rigid_adjoint_cache, + ]: + for attr in state.__dict__.values(): + if hasattr(attr, 'grad') and attr.grad is not None: + attr.grad.fill(0.0) + + def update_geoms_render_T(self): kernel_update_geoms_render_T( self._geoms_render_T, @@ -7986,3 +8004,14 @@ def func_write_and_read_field_if(field: array_class.V_ANNOTATION, I, value, cond def func_check_index_range(idx: ti.i32, min: ti.i32, max: ti.i32, cond: ti.template()): # Conditionally check if the index is in the range [min, max) to save computational cost return (idx >= min and idx < max) if ti.static(cond) else True + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_zero_grad( + links_state: array_class.LinksState, + dofs_state: array_class.DofsState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + pass \ No newline at end of file From bbdfffb492edc689263f9604c5081e42bf4a2b3e Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Mon, 8 Dec 2025 19:30:09 -0800 Subject: [PATCH 4/9] 1. added one another rigid adjoint cache to check integrity of the backward pass 2. fixed bug in save_ckpt, needed to set copy=True in converting taichi field to numpy array --- .../solvers/rigid/rigid_solver_decomp.py | 133 +++++++++++++++--- genesis/utils/array_class.py | 12 +- 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 5dd209b406..62a57cfdbd 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -316,7 +316,6 @@ def build(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info - self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs self.awake_dofs = self._rigid_global_info.awake_dofs @@ -329,6 +328,8 @@ def build(self): self.links_state_adjoint_cache = self.data_manager.links_state_adjoint_cache self.joints_state_adjoint_cache = self.data_manager.joints_state_adjoint_cache self.geoms_state_adjoint_cache = self.data_manager.geoms_state_adjoint_cache + self._rigid_adjoint_cache_fw = self.data_manager.rigid_adjoint_cache_fw + self._rigid_adjoint_cache_bw = self.data_manager.rigid_adjoint_cache_bw self._init_mass_mat() self._init_dof_fields() @@ -868,8 +869,9 @@ def substep(self, f): kernel_save_adjoint_cache( f=f, dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) @@ -917,8 +919,9 @@ def substep(self, f): kernel_save_adjoint_cache( f=f + 1, dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) @@ -1172,13 +1175,14 @@ def substep_pre_coupling_grad(self, f): dofs_info=self.dofs_info, geoms_state=self.geoms_state, geoms_info=self.geoms_info, + constraint_state=self.constraint_solver.constraint_state, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, dofs_state_adjoint_cache=self.dofs_state_adjoint_cache, links_state_adjoint_cache=self.links_state_adjoint_cache, joints_state_adjoint_cache=self.joints_state_adjoint_cache, geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) self.substep(f) @@ -1249,7 +1253,7 @@ def substep_pre_coupling_grad(self, f): # force_update_fixed_geoms=False, # ) - is_grad_valid = kernel_begin_backward_substep( + errno = kernel_begin_backward_substep( f=f, links_state=self.links_state, links_info=self.links_info, @@ -1265,11 +1269,15 @@ def substep_pre_coupling_grad(self, f): links_state_adjoint_cache=self.links_state_adjoint_cache, joints_state_adjoint_cache=self.joints_state_adjoint_cache, geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache_fw=self._rigid_adjoint_cache_fw, + rigid_adjoint_cache_bw=self._rigid_adjoint_cache_bw, static_rigid_sim_config=self._static_rigid_sim_config, ) - if not is_grad_valid: - gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + match errno: + case 1: + gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + case 2: + gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}") kernel_step_2.grad( dofs_state=self.dofs_state, @@ -1293,6 +1301,9 @@ def substep_pre_coupling_grad(self, f): dL_dqacc = self.dofs_state.acc.grad.to_numpy() self.dofs_state.acc.grad.fill(0.0) + qacc_ws = self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy()[f] + self.constraint_solver.constraint_state.qacc_ws.from_numpy(qacc_ws) + self.constraint_solver.backward(dL_dqacc) dL_dM = self.constraint_solver.constraint_state.dL_dM.to_numpy() dL_djac = self.constraint_solver.constraint_state.dL_djac.to_numpy() @@ -1333,7 +1344,7 @@ def substep_pre_coupling_grad(self, f): kernel_copy_acc( f=f, dofs_state=self.dofs_state, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) @@ -1478,7 +1489,8 @@ def reset_grad(self): self.constraint_solver.constraint_state, self.collider._collider_state.diff_contact_input, self.collider._collider_state.contact_data, - self._rigid_adjoint_cache, + self._rigid_adjoint_cache_fw, + self._rigid_adjoint_cache_bw, ]: for attr in state.__dict__.values(): if hasattr(attr, 'grad') and attr.grad is not None: @@ -1586,20 +1598,32 @@ def save_ckpt(self, ckpt_name): if ckpt_name not in self._ckpt: self._ckpt[ckpt_name] = dict() - self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache.qpos) - self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_vel) - self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_acc) + self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache_fw.qpos, copy=True) + self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_vel, copy=True) + self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc, copy=True) + self._ckpt[ckpt_name]["dofs_acc_smooth"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc_smooth, copy=True) + self._ckpt[ckpt_name]["solver_qacc_ws"] = ti_to_numpy(self._rigid_adjoint_cache_fw.solver_qacc_ws, copy=True) for entity in self._entities: entity.save_ckpt(ckpt_name) def load_ckpt(self, ckpt_name): + # Load adjoint cache for backward pass + self._rigid_adjoint_cache_bw.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"]) + self._rigid_adjoint_cache_bw.dofs_vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"]) + self._rigid_adjoint_cache_bw.dofs_acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"]) + self._rigid_adjoint_cache_bw.dofs_acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"]) + self._rigid_adjoint_cache_bw.solver_qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"]) + # Set first frame self._rigid_global_info.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"][0]) self.dofs_state.vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"][0]) self.dofs_state.acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"][0]) + self.dofs_state.acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"][0]) + self.constraint_solver.constraint_state.qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"][0]) if not self._enable_mujoco_compatibility: + envs_idx = self._scene._sanitize_envs_idx(None) kernel_update_cartesian_space( links_state=self.links_state, links_info=self.links_info, @@ -1614,6 +1638,16 @@ def load_ckpt(self, ckpt_name): static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, ) + kernel_forward_velocity( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) for entity in self._entities: entity.load_ckpt(ckpt_name) @@ -6681,17 +6715,19 @@ def func_copy_next_to_curr_grad( def kernel_save_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), ): - func_save_adjoint_cache(f, dofs_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config) + func_save_adjoint_cache(f, dofs_state, constraint_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config) @ti.func def func_save_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), @@ -6704,6 +6740,8 @@ def func_save_adjoint_cache( for i_d, i_b in ti.ndrange(n_dofs, _B): rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b] rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b] + rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b in ti.ndrange(n_qs, _B): @@ -6714,6 +6752,7 @@ def func_save_adjoint_cache( def func_load_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), @@ -6726,6 +6765,8 @@ def func_load_adjoint_cache( for i_d, i_b in ti.ndrange(n_dofs, _B): dofs_state.vel[i_d, i_b] = rigid_adjoint_cache.dofs_vel[f, i_d, i_b] dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b] + dofs_state.acc_smooth[i_d, i_b] = rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b] + constraint_state.qacc_ws[i_d, i_b] = rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b in ti.ndrange(n_qs, _B): @@ -6743,6 +6784,7 @@ def kernel_prepare_backward_substep( dofs_info: array_class.DofsInfo, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, + constraint_state: array_class.ConstraintState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, dofs_state_adjoint_cache: array_class.DofsState, @@ -6756,6 +6798,7 @@ def kernel_prepare_backward_substep( func_load_adjoint_cache( f=f, dofs_state=dofs_state, + constraint_state=constraint_state, rigid_global_info=rigid_global_info, rigid_adjoint_cache=rigid_adjoint_cache, static_rigid_sim_config=static_rigid_sim_config, @@ -6781,6 +6824,16 @@ def kernel_prepare_backward_substep( static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, ) + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) # FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names. # Save results of [update_cartesian_space] to adjoint cache @@ -6814,20 +6867,29 @@ def kernel_begin_backward_substep( links_state_adjoint_cache: array_class.LinksState, joints_state_adjoint_cache: array_class.JointsState, geoms_state_adjoint_cache: array_class.GeomsState, - rigid_adjoint_cache: array_class.RigidAdjointCache, + rigid_adjoint_cache_fw: array_class.RigidAdjointCache, + rigid_adjoint_cache_bw: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), ) -> ti.i32: + errno = 0 is_grad_valid = func_is_grad_valid( rigid_global_info=rigid_global_info, dofs_state=dofs_state, static_rigid_sim_config=static_rigid_sim_config, ) - if is_grad_valid: + # Check the integrity of the next frame's adjoint cache as it is the computation result of the current frame + is_cache_valid = func_check_cache_integrity( + f=f + 1, + rigid_adjoint_cache_fw=rigid_adjoint_cache_fw, + rigid_adjoint_cache_bw=rigid_adjoint_cache_bw, + static_rigid_sim_config=static_rigid_sim_config, + ) + if is_grad_valid and is_cache_valid: func_copy_next_to_curr_grad( f=f, dofs_state=dofs_state, rigid_global_info=rigid_global_info, - rigid_adjoint_cache=rigid_adjoint_cache, + rigid_adjoint_cache=rigid_adjoint_cache_fw, static_rigid_sim_config=static_rigid_sim_config, ) @@ -6845,8 +6907,12 @@ def kernel_begin_backward_substep( geoms_state_adjoint_cache=geoms_state_adjoint_cache, static_rigid_sim_config=static_rigid_sim_config, ) + elif not is_grad_valid: + errno = 1 + elif not is_cache_valid: + errno = 2 - return is_grad_valid + return errno @ti.func @@ -6869,6 +6935,37 @@ def func_is_grad_valid( return is_valid +@ti.func +def func_check_cache_integrity( + f: ti.int32, + rigid_adjoint_cache_fw: array_class.RigidAdjointCache, + rigid_adjoint_cache_bw: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + is_valid = True + n_qs = rigid_adjoint_cache_fw.qpos.shape[1] + n_dofs = rigid_adjoint_cache_fw.dofs_vel.shape[1] + _B = rigid_adjoint_cache_fw.qpos.shape[2] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + if rigid_adjoint_cache_fw.qpos[f, i_q, i_b] != rigid_adjoint_cache_bw.qpos[f, i_q, i_b]: + is_valid = False + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + if rigid_adjoint_cache_fw.dofs_vel[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_vel[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.dofs_acc[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.dofs_acc_smooth[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc_smooth[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.solver_qacc_ws[f, i_d, i_b] != rigid_adjoint_cache_bw.solver_qacc_ws[f, i_d, i_b]: + is_valid = False + + return is_valid + + @ti.func def func_copy_cartesian_space( dofs_state: array_class.DofsState, diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 5cb75dae2a..3bc2eea24c 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -1776,6 +1776,11 @@ class StructRigidAdjointCache(metaclass=BASE_METACLASS): qpos: V_ANNOTATION dofs_vel: V_ANNOTATION dofs_acc: V_ANNOTATION + # We also store the initial solutions (acc_smooth, qacc_ws) for the constraint solver to use in the backward pass. + # For [acc_smooth], even though it could be reproduced during the backward pass and thus we do not need to store it, + # we do it to compare the reproduced value with the stored one to ensure the integrity of the backward pass. + dofs_acc_smooth: V_ANNOTATION + solver_qacc_ws: V_ANNOTATION def get_rigid_adjoint_cache(solver): @@ -1786,6 +1791,8 @@ def get_rigid_adjoint_cache(solver): qpos=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_qs_, solver._B), needs_grad=requires_grad), dofs_vel=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), dofs_acc=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + dofs_acc_smooth=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)), + solver_qacc_ws=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)), ) @@ -1865,7 +1872,10 @@ def __init__(self, solver): self.joints_state_adjoint_cache = get_joints_state(solver) self.geoms_state_adjoint_cache = get_geoms_state(solver) - self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver) + # We use a pair of adjoint cache, one of which is used for the forward pass and the other is used for the + # backward pass to check the integrity of the backward pass. + self.rigid_adjoint_cache_fw = get_rigid_adjoint_cache(solver) + self.rigid_adjoint_cache_bw = get_rigid_adjoint_cache(solver) self.errno = V_SCALAR_FROM(dtype=gs.ti_int, value=0) From f6f6746f49c89c715d96a04de1ac71d67a0f368d Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Tue, 9 Dec 2025 14:00:26 -0800 Subject: [PATCH 5/9] 1. use dofs force as differentiable input 2. use backward mode for rigid body sim only when computing gradient --- .../entities/rigid_entity/rigid_entity.py | 23 ++++++++++- genesis/engine/simulator.py | 7 +++- .../solvers/rigid/rigid_solver_decomp.py | 41 ++++++++++++++++--- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index c2cef6ec54..922592a842 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -117,7 +117,7 @@ def __init__( self._load_model() # Initialize target variables and checkpoint - self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity") + self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force") self._tgt = dict() self._tgt_buffer = list() self._ckpt = dict() @@ -1661,6 +1661,8 @@ def process_input(self, in_backward=False): self.set_quat(**data_kwargs) case "set_dofs_velocity": self.set_dofs_velocity(**data_kwargs) + case "control_dofs_force": + self.control_dofs_force(**data_kwargs) case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -1693,6 +1695,16 @@ def process_input_grad(self): data_kwargs["dofs_idx_local"], data_kwargs["envs_idx"], ) + + case "control_dofs_force": + force = data_kwargs.pop("force") + if force.requires_grad: + force._backward_from_ti( + self.control_dofs_force_grad, + data_kwargs["dofs_idx_local"], + data_kwargs["envs_idx"], + ) + case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -2355,6 +2367,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer self._solver.set_dofs_position(position, dofs_idx, envs_idx) @gs.assert_built + @tracked def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None): """ Control the entity's dofs' motor force. This is used for force/torque control. @@ -2371,6 +2384,14 @@ def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None): dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.control_dofs_force(force, dofs_idx, envs_idx) + @gs.assert_built + def control_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad): + dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + self._solver.control_dofs_force_grad(dofs_idx, envs_idx, force_grad.data) + + pass + + @gs.assert_built def control_dofs_velocity(self, velocity, dofs_idx_local=None, envs_idx=None): """ diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index a269c998b4..ca6bbc0e80 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -280,7 +280,9 @@ def step(self, in_backward=False): self.save_ckpt() if self.rigid_solver.is_active: - self.rigid_solver.clear_external_force() + # In backward pass, we need to keep the external force for gradient computation + if not in_backward: + self.rigid_solver.clear_external_force() if self._cur_substep_global % RATE_CHECK_ERRNO == 0: self.rigid_solver.check_errno() @@ -294,6 +296,9 @@ def _step_grad(self): self._cur_substep_global -= 1 self.sub_step_grad(self.cur_substep_local) + if self.rigid_solver.is_active: + # Clear external force after gradient computation + self.rigid_solver.clear_external_force() self.process_input_grad() diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 62a57cfdbd..b787891abc 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1160,9 +1160,6 @@ def substep_pre_coupling(self, f): self.substep(f) def substep_pre_coupling_grad(self, f): - # Change to backward mode - self._static_rigid_sim_config.is_backward = True - # Run forward substep again to restore this step's information, this is needed because we do not store info # of every substep. kernel_prepare_backward_substep( @@ -1188,6 +1185,12 @@ def substep_pre_coupling_grad(self, f): self.substep(f) # =================== Backward substep ====================== + # Change to backward mode: Note that we use forward mode in the [substep] function right above. This is because + # we need to reproduce the same data as in the forward pass for the backward pass. The backward pass should be + # logically same as the forward pass, but it is tweaked to use autodiff, and thus numerical differences could + # arise. To prevent this, we change to backward mode here. + self._static_rigid_sim_config.is_backward = True + envs_idx = self._scene._sanitize_envs_idx(None) if not self._enable_mujoco_compatibility: kernel_forward_velocity.grad( @@ -2182,6 +2185,14 @@ def control_dofs_force(self, force, dofs_idx=None, envs_idx=None): kernel_control_dofs_force(force, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + def control_dofs_force_grad(self, dofs_idx, envs_idx, force_grad): + force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables( + force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True + ) + if self.n_envs == 0: + force_grad_ = force_grad_.unsqueeze(0) + kernel_control_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + def control_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None): if gs.use_zerocopy: mask = (0, *indices_to_mask(dofs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, dofs_idx) @@ -2559,11 +2570,11 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True def clear_external_force(self): if gs.use_zerocopy: - for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel): + for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel, self.dofs_state.ctrl_force): out = ti_to_torch(tensor, copy=False) out.zero_() else: - kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config) + kernel_clear_external_force(self.links_state, self.dofs_state, self._rigid_global_info, self._static_rigid_sim_config) def update_vgeoms(self): kernel_update_vgeoms(self.vgeoms_info, self.vgeoms_state, self.links_state, self._static_rigid_sim_config) @@ -4251,11 +4262,13 @@ def kernel_forward_dynamics_without_qacc( @ti.kernel(fastcache=gs.use_fastcache) def kernel_clear_external_force( links_state: array_class.LinksState, + dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): func_clear_external_force( links_state=links_state, + dofs_state=dofs_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) @@ -5907,6 +5920,7 @@ def func_apply_link_external_torque( @ti.func def func_clear_external_force( links_state: array_class.LinksState, + dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): @@ -5925,6 +5939,10 @@ def func_clear_external_force( i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for I in ti.grouped(dofs_state.ctrl_force): + dofs_state.ctrl_force[I] = ti.Vector.zero(gs.ti_float, 3) @ti.func @@ -7839,6 +7857,19 @@ def kernel_control_dofs_force( dofs_state.ctrl_mode[dofs_idx[i_d_], envs_idx[i_b_]] = gs.CTRL_MODE.FORCE dofs_state.ctrl_force[dofs_idx[i_d_], envs_idx[i_b_]] = force[i_b_, i_d_] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_control_dofs_force_grad( + force_grad: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + dofs_state: array_class.DofsState, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] + dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + @ti.kernel(fastcache=gs.use_fastcache) def kernel_control_dofs_velocity( From a8ffcd07e61a392a7bd65b2c368585f8384793e9 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 10 Dec 2025 09:29:02 -0800 Subject: [PATCH 6/9] 1. added grad window 2. stabilized differentiable contact, now can solve simple optimization problem --- genesis/engine/simulator.py | 11 ++++++++--- genesis/engine/solvers/rigid/diff_gjk_decomp.py | 4 +++- genesis/engine/solvers/rigid/rigid_solver_decomp.py | 7 +++++++ genesis/options/solvers.py | 3 +++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index ca6bbc0e80..80859ba86d 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -296,12 +296,17 @@ def _step_grad(self): self._cur_substep_global -= 1 self.sub_step_grad(self.cur_substep_local) - if self.rigid_solver.is_active: - # Clear external force after gradient computation - self.rigid_solver.clear_external_force() + + if self.rigid_solver.is_active: + # Clear external force after gradient computation + self.rigid_solver.clear_external_force() self.process_input_grad() + if self.options.grad_window_steps is not None and self.cur_step_global % self.options.grad_window_steps == 0: + # Truncate upstream gradient flow + self.rigid_solver.zero_grad() + def process_input(self, in_backward=False): """ setting _tgt state using external commands diff --git a/genesis/engine/solvers/rigid/diff_gjk_decomp.py b/genesis/engine/solvers/rigid/diff_gjk_decomp.py index 5d3e407e08..8f36df3e8c 100644 --- a/genesis/engine/solvers/rigid/diff_gjk_decomp.py +++ b/genesis/engine/solvers/rigid/diff_gjk_decomp.py @@ -77,7 +77,7 @@ def func_gjk_contact( found_default_epa = False # 4 (small) + 4 (large) perturbated configurations - num_perturb = 8 + num_perturb = 4 ### Detect multiple possible contact points and gather the non-differentiable contact data. for i in range(1 + num_perturb): @@ -209,6 +209,8 @@ def func_gjk_contact( ) found_default_epa = True + # Do not use extended EPA algorithm for now, practically it seems not needed. + break # Break the loop if we found enough contact points for default configuration. As we can find at most # 8 contact points for perturbed configurations, we can find at most max_contacts_per_pair - 8 diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index b787891abc..cc977a1b60 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1280,6 +1280,11 @@ def substep_pre_coupling_grad(self, f): case 1: gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") case 2: + qpos_diff = self._rigid_adjoint_cache_fw.qpos.to_numpy() - self._rigid_adjoint_cache_bw.qpos.to_numpy() + vel_diff = self._rigid_adjoint_cache_fw.dofs_vel.to_numpy() - self._rigid_adjoint_cache_bw.dofs_vel.to_numpy() + acc_diff = self._rigid_adjoint_cache_fw.dofs_acc.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc.to_numpy() + acc_smooth_diff = self._rigid_adjoint_cache_fw.dofs_acc_smooth.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc_smooth.to_numpy() + solver_qacc_ws_diff = self._rigid_adjoint_cache_fw.solver_qacc_ws.to_numpy() - self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy() gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}") kernel_step_2.grad( @@ -1480,7 +1485,9 @@ def reset_grad(self): for entity in self._entities: entity.reset_grad() self._queried_states.clear() + self.zero_grad() + def zero_grad(self): # zero grad for state in [ self._rigid_global_info, diff --git a/genesis/options/solvers.py b/genesis/options/solvers.py index 67da878c2c..5315b7b61c 100644 --- a/genesis/options/solvers.py +++ b/genesis/options/solvers.py @@ -36,6 +36,8 @@ class SimOptions(Options): Height of the floor in meters. Defaults to 0.0. requires_grad : bool, optional Whether to enable differentiable mode. Defaults to False. + substeps_grad: int, optional + Number of steps that constitutes a window for gradient computation, defaults to None (window not used). use_hydroelastic_contact : bool, optional Whether to use hydroelastic contact. Defaults to False. """ @@ -46,6 +48,7 @@ class SimOptions(Options): gravity: tuple = (0.0, 0.0, -9.81) floor_height: float = 0.0 requires_grad: bool = False + grad_window_steps: Optional[int] = None # not set by user _steps_local: Optional[int] = None From b2910721b7e0140d7cf8cb8c4aa6e48afc1bdc81 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 10 Dec 2025 09:51:04 -0800 Subject: [PATCH 7/9] stash --- .../diffrigid/{one_step.py => 1_one_step.py} | 0 examples/diffrigid/2_lift_ball.py | 106 ++++++++++++++++ examples/diffrigid/ant.py | 52 ++++++++ examples/diffrigid/debug_rot.py | 113 ++++++++++++++++++ examples/diffrigid/slide_ball.py | 37 ++++-- 5 files changed, 298 insertions(+), 10 deletions(-) rename examples/diffrigid/{one_step.py => 1_one_step.py} (100%) create mode 100644 examples/diffrigid/2_lift_ball.py create mode 100644 examples/diffrigid/ant.py create mode 100644 examples/diffrigid/debug_rot.py diff --git a/examples/diffrigid/one_step.py b/examples/diffrigid/1_one_step.py similarity index 100% rename from examples/diffrigid/one_step.py rename to examples/diffrigid/1_one_step.py diff --git a/examples/diffrigid/2_lift_ball.py b/examples/diffrigid/2_lift_ball.py new file mode 100644 index 0000000000..8da89913d3 --- /dev/null +++ b/examples/diffrigid/2_lift_ball.py @@ -0,0 +1,106 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 10 +substeps = 1 +grad_window = 5 #None + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + +scene.build() + +num_iter = 10000 +lr = 1e-2 + +force = gs.zeros((horizon, 6), requires_grad=True) +optimizer = torch.optim.Adam([force], lr=lr) + +render_every = 100 +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + curr_losses = [] + if iter % render_every == 0: + cam.start_recording() + for i in range(horizon): + curr_force = force[i] + ball.control_dofs_force(curr_force) + scene.step() + + ball_state = ball.get_state() + ball_pos = ball_state.pos + curr_loss = -ball_pos[:, 2].sum() # make x, y, z larger + curr_losses.append(curr_loss) + + if iter % render_every == 0: + cam.render() + if iter % render_every == 0: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) + + loss = sum(curr_losses) / len(curr_losses) + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0) + + with torch.no_grad(): + force.data[:, 6:] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/ant.py b/examples/diffrigid/ant.py new file mode 100644 index 0000000000..60d8e1ab9f --- /dev/null +++ b/examples/diffrigid/ant.py @@ -0,0 +1,52 @@ +import argparse + +import numpy as np + +import genesis as gs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-v", "--vis", action="store_true", default=False) + parser.add_argument("-n", "--n_envs", type=int, default=49) + args = parser.parse_args() + + args.vis = True + args.n_envs = 1 + + ########################## init ########################## + gs.init(backend=gs.cpu) + + ########################## create a scene ########################## + viewer_options = gs.options.ViewerOptions( + camera_pos=(3, -1, 1.5), + camera_lookat=(0.0, 0.0, 0.0), + camera_fov=30, + max_FPS=60, + ) + + scene = gs.Scene( + viewer_options=viewer_options, + rigid_options=gs.options.RigidOptions( + dt=0.01, + ), + show_viewer=args.vis, + ) + + ########################## entities ########################## + plane = scene.add_entity( + gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True), + ) + ant = scene.add_entity( + gs.morphs.MJCF(file="xml/humanoid.xml"), + ) + + ########################## build ########################## + scene.build(n_envs=args.n_envs, env_spacing=(1, 1)) + + for i in range(10000): + scene.step() + + +if __name__ == "__main__": + main() diff --git a/examples/diffrigid/debug_rot.py b/examples/diffrigid/debug_rot.py new file mode 100644 index 0000000000..3ffba27e5f --- /dev/null +++ b/examples/diffrigid/debug_rot.py @@ -0,0 +1,113 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt +import matplotlib +import numpy as np +matplotlib.use("Agg") + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 10 +substeps = 1 +grad_window = None +np.random.seed(0) +goal_quat = np.random.randn(4) +goal_quat = goal_quat / np.linalg.norm(goal_quat) +goal_quat = gs.tensor(goal_quat) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + +scene.build() + +num_iter = 10000 +lr = 1e-2 + +force = gs.zeros((horizon, 6), requires_grad=True) +with torch.no_grad(): + torch.manual_seed(0) + force.data[:, 3:] = torch.randn_like(force.data[:, 3:]) +optimizer = torch.optim.Adam([force], lr=lr) + +render_every = 100 +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + curr_losses = [] + if iter % render_every == 0: + cam.start_recording() + for i in range(horizon): + curr_force = force[i] + ground.control_dofs_force(curr_force) + scene.step() + + box_state = ground.get_state() + box_quat = box_state.quat + curr_loss = (box_quat - goal_quat).abs().sum() + curr_losses.append(curr_loss) + + if iter % render_every == 0: + cam.render() + if iter % render_every == 0: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) + + loss = sum(curr_losses) / len(curr_losses) + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0) + + with torch.no_grad(): + force.data[:, :3] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/slide_ball.py b/examples/diffrigid/slide_ball.py index 87530b86f8..a674c479c0 100644 --- a/examples/diffrigid/slide_ball.py +++ b/examples/diffrigid/slide_ball.py @@ -9,7 +9,8 @@ dt = 1e-2 horizon = 100 substeps = 4 -goal_pos = gs.tensor([0.0, 0.1, -0.1]) +goal_pos = gs.tensor([0.2, 0.1, 0.1]) +render_every = 100 scene = gs.Scene( sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), @@ -57,16 +58,23 @@ ) ) +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + scene.build() -num_iter = 200 -lr = 1e-4 +num_iter = 300 +lr = 1e-2 init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-4) prev_loss = float('inf') losses = [] @@ -76,14 +84,18 @@ racket.set_pos(init_pos) racket.set_quat(init_quat) #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0])) + + record = (iter % render_every == 0) or (iter == num_iter - 1) + if record: + cam.start_recording() for i in range(horizon): scene.step() - # ball_state = ball.get_state() - # ball_pos = ball_state.pos - # losses.append(torch.abs(ball_pos - goal_pos).sum()) - # if show_viewer: - # target.set_pos(goal_pos) + if record: + cam.render() + + if record: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) ball_state = ball.get_state() ball_pos = ball_state.pos @@ -96,9 +108,14 @@ scheduler.step() with torch.no_grad(): + # init_quat.data[0] = 1.0 + # init_quat.data[1] = 0.0 + # init_quat.data[2] = 0.0 + # init_quat.data[3] = 0.0 init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + #init_pos.data = init_pos.data.clamp(0.0, 0.0) - print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Ball Pos: {ball_pos.detach().cpu().numpy().tolist()}") prev_loss = loss.item() losses.append(loss.item()) From 4b472b89017f6f601aad92b82871e4b49c6269f6 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 10 Dec 2025 16:44:49 -0800 Subject: [PATCH 8/9] 1. accelerated backward pass for some functions by pulling out static inner loops outside 2. fixed some geometric functions in utils.geom.py to avoid nan grads 3. added rigid entity state queries (e.g. qpos, dofs_vel) --- genesis/assets/xml/ant_no_ground.xml | 93 +++ genesis/assets/xml/walker_no_ground.xml | 80 +++ .../entities/rigid_entity/rigid_entity.py | 8 +- .../solvers/rigid/rigid_solver_decomp.py | 679 +++++++++++++++++- genesis/engine/states/entities.py | 20 +- genesis/utils/geom.py | 17 +- 6 files changed, 862 insertions(+), 35 deletions(-) create mode 100644 genesis/assets/xml/ant_no_ground.xml create mode 100644 genesis/assets/xml/walker_no_ground.xml diff --git a/genesis/assets/xml/ant_no_ground.xml b/genesis/assets/xml/ant_no_ground.xml new file mode 100644 index 0000000000..7eb2522ec4 --- /dev/null +++ b/genesis/assets/xml/ant_no_ground.xml @@ -0,0 +1,93 @@ + + + diff --git a/genesis/assets/xml/walker_no_ground.xml b/genesis/assets/xml/walker_no_ground.xml new file mode 100644 index 0000000000..827e776c7c --- /dev/null +++ b/genesis/assets/xml/walker_no_ground.xml @@ -0,0 +1,80 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 922592a842..47ff9c9a3a 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1727,10 +1727,16 @@ def get_state(self): solver_state = self._solver.get_state() pos = solver_state.links_pos[:, self.base_link_idx] quat = solver_state.links_quat[:, self.base_link_idx] + qpos = solver_state.qpos[:, self._q_start:self._q_start + self.n_qs] + dofs_vel = solver_state.dofs_vel[:, self._dof_start:self._dof_start + self.n_dofs] + dofs_acc = solver_state.dofs_acc[:, self._dof_start:self._dof_start + self.n_dofs] state._pos = pos state._quat = quat - + state._qpos = qpos + state._dofs_vel = dofs_vel + state._dofs_acc = dofs_acc + return state def _get_global_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False): diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index cc977a1b60..a0a352eb6f 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1193,16 +1193,28 @@ def substep_pre_coupling_grad(self, f): envs_idx = self._scene._sanitize_envs_idx(None) if not self._enable_mujoco_compatibility: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + # kernel_forward_velocity.grad( + # envs_idx=envs_idx, + # links_state=self.links_state, + # links_info=self.links_info, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_velocity_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) kernel_update_geoms.grad( envs_idx, entities_info=self.entities_info, @@ -1213,34 +1225,75 @@ def substep_pre_coupling_grad(self, f): static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, ) - kernel_COM_links.grad( - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, - ) - kernel_forward_kinematics.grad( + # kernel_COM_links.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_COM_links_ad_1.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_COM_links_ad_0.grad( links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, joints_info=self.joints_info, dofs_state=self.dofs_state, dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, ) + # kernel_forward_kinematics.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) + + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_kinematics_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) # kernel_update_cartesian_space.grad( # links_state=self.links_state, # links_info=self.links_info, @@ -4345,6 +4398,39 @@ def kernel_forward_kinematics( static_rigid_sim_config=static_rigid_sim_config, ) +@ti.kernel +def kernel_forward_kinematics_ad( + i_l_:ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + func_forward_kinematics_entity_ad( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + @ti.kernel def kernel_COM_links( @@ -4738,6 +4824,34 @@ def kernel_forward_velocity( ) +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_forward_velocity_ad( + i_l_:ti.int32, + envs_idx: ti.types.ndarray(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + for i_e, i_b_ in ti.ndrange(entities_info.n_links.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + func_forward_velocity_entity_ad( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + @ti.func def func_COM_links( i_b, @@ -5071,6 +5185,240 @@ def func_COM_links( dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] ) +@ti.kernel +def kernel_COM_links_ad_0( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + links_state.root_COM_bw[i_l, i_b].fill(0.0) + links_state.mass_sum[i_l, i_b] = 0.0 + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.i_pos_bw[i_l, i_b], + links_state.i_quat[i_l, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], + links_info.inertial_quat[I_l], + links_state.pos[i_l, i_b], + links_state.quat[i_l, i_b], + ) + + i_r = links_info.root_idx[I_l] + links_state.mass_sum[i_r, i_b] += mass + links_state.root_COM_bw[i_r, i_b] += mass * links_state.i_pos_bw[i_l, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + i_r = links_info.root_idx[I_l] + if i_l == i_r: + links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + i_r = links_info.root_idx[I_l] + links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + links_state.i_pos[i_l, i_b] = links_state.i_pos_bw[i_l, i_b] - links_state.root_COM[i_l, i_b] + + i_inertial = links_info.inertial_i[I_l] + i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.cinr_inertial[i_l, i_b], + links_state.cinr_pos[i_l, i_b], + links_state.cinr_quat[i_l, i_b], + links_state.cinr_mass[i_l, i_b], + ) = gu.ti_transform_inertia_by_trans_quat( + i_inertial, + i_mass, + links_state.i_pos[i_l, i_b], + links_state.i_quat[i_l, i_b], + rigid_global_info.EPS[None], + ) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + BW = ti.static(static_rigid_sim_config.is_backward) + + if links_info.n_dofs[I_l] > 0: + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ if ti.static(not BW) else (i_j_ + links_info.joint_start[I_l]) + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + + dof_start = joints_info.dof_start[I_j] + + EPS = rigid_global_info.EPS[None] + if joint_type == gs.JOINT_TYPE.REVOLUTE: + dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.FREE: + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0 + + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + dofs_state.cdofvel_ang[i_d, i_b] = ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + dofs_state.cdofvel_vel[i_d, i_b] = ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + +@ti.kernel +def kernel_COM_links_ad_1( + i_l_:ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + func_COM_links_ad_1( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + +@ti.func +def func_COM_links_ad_1( + i_e: ti.int32, + i_l_:ti.int32, + i_b: ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + i_l = i_l_ + entities_info.link_start[i_e] + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.n_dofs[I_l] > 0: + i_p = links_info.parent_idx[I_l] + + _i_j = links_info.joint_start[I_l] + _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j + joint_type = joints_info.type[_I_j] + + p_pos = ti.Vector.zero(gs.ti_float, 3) + p_quat = gu.ti_identity_quat() + if i_p != -1: + p_pos = links_state.pos[i_p, i_b] + p_quat = links_state.quat[i_p, i_b] + + if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): + links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] + links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] + else: + ( + links_state.j_pos_bw[i_l, 0, i_b], + links_state.j_quat_bw[i_l, 0, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + curr_i_j = 0 if ti.static(not BW) else i_j_ + next_i_j = 0 if ti.static(not BW) else i_j_ + 1 + + if func_check_index_range( + i_j, + links_info.joint_start[I_l], + links_info.joint_end[I_l], + BW, + ): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + ( + links_state.j_pos_bw[i_l, next_i_j, i_b], + links_state.j_quat_bw[i_l, next_i_j, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + joints_info.pos[I_j], + gu.ti_identity_quat(), + links_state.j_pos_bw[i_l, curr_i_j, i_b], + links_state.j_quat_bw[i_l, curr_i_j, i_b], + ) + + i_j_ = 0 if ti.static(not BW) else n_joints + links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b] + links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b] @ti.func def func_forward_kinematics( @@ -5363,6 +5711,160 @@ def func_forward_kinematics_entity( links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW) links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW) +@ti.func +def func_forward_kinematics_entity_ad( + i_e, + i_l_, + i_b, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + W = ti.static(func_write_field_if) + R = ti.static(func_read_field_if) + WR = ti.static(func_write_and_read_field_if) + + EPS = rigid_global_info.EPS[None] + + i_l = i_l_ + entities_info.link_start[i_e] + + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + I_l0 = (i_l, 0, i_b) + + pos = W(links_state.pos_bw, I_l0, links_info.pos[I_l], BW) + quat = W(links_state.quat_bw, I_l0, links_info.quat[I_l], BW) + if links_info.parent_idx[I_l] != -1: + parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] + parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] + pos_ = parent_pos + gu.ti_transform_by_quat(links_info.pos[I_l], parent_quat) + quat_ = gu.ti_transform_quat_by_quat(links_info.quat[I_l], parent_quat) + + pos = W(links_state.pos_bw, I_l0, pos_, BW) + quat = W(links_state.quat_bw, I_l0, quat_, BW) + + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b) + next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b) + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + + # compute axis and anchor + if joint_type == gs.JOINT_TYPE.FREE: + joints_state.xanchor[i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ] + ) + joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + else: + axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) + if joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + axis = dofs_info.motion_vel[I_d] + + pos_ = R(links_state.pos_bw, curr_I, pos, BW) + quat_ = R(links_state.quat_bw, curr_I, quat, BW) + + joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat_) + pos_ + joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat_) + + if joint_type == gs.JOINT_TYPE.FREE: + pos_ = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ], + dt=gs.ti_float, + ) + quat_ = ti.Vector( + [ + rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + 4, i_b], + rigid_global_info.qpos[q_start + 5, i_b], + rigid_global_info.qpos[q_start + 6, i_b], + ], + dt=gs.ti_float, + ) + pos = WR(links_state.pos_bw, next_I, pos_, BW) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + + xyz = gu.ti_quat_to_xyz(quat, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = pos[j] + dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + print("SPHERICAL") + qloc = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[q_start + 3, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(qloc, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = xyz[j] + quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW)) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = W(links_state.pos_bw, next_I, pos_, BW) + elif joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) + quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW)) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = W(links_state.pos_bw, next_I, pos_, BW) + else: # joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + pos_ = ( + R(links_state.pos_bw, curr_I, pos, BW) + + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + ) + pos = W(links_state.pos_bw, next_I, pos_, BW) + + # Skip link pose update for fixed root links to let users manually overwrite them + I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b) + if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): + links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW) + links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW) + @ti.func def func_forward_velocity_entity( @@ -5489,6 +5991,127 @@ def func_forward_velocity_entity( links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW) links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW) +@ti.func +def func_forward_velocity_entity_ad( + i_e, + i_l_, + i_b, + entities_info: array_class.EntitiesInfo, + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + W = ti.static(func_write_field_if) + R = ti.static(func_read_field_if) + A = ti.static(func_atomic_add_if) + + i_l = i_l_ + entities_info.link_start[i_e] + + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + I_j0 = (i_l, 0, i_b) + cvel_vel = W(links_state.cd_vel_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW) + cvel_ang = W(links_state.cd_ang_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW) + + if links_info.parent_idx[I_l] != -1: + cvel_vel = W(links_state.cd_vel_bw, I_j0, links_state.cd_vel[links_info.parent_idx[I_l], i_b], BW) + cvel_ang = W(links_state.cd_ang_bw, I_j0, links_state.cd_ang[links_info.parent_idx[I_l], i_b], BW) + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + + curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b) + next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b) + + if joint_type == gs.JOINT_TYPE.FREE: + for i_3 in ti.static(range(3)): + _vel = dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + _ang = dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, curr_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, curr_I, _ang, BW) + + for i_3 in ti.static(range(3)): + ( + dofs_state.cdofd_ang[dof_start + i_3, i_b], + dofs_state.cdofd_vel[dof_start + i_3, i_b], + ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + + ( + dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], + ) = gu.motion_cross_motion( + R(links_state.cd_ang_bw, curr_I, cvel_ang, BW), + R(links_state.cd_vel_bw, curr_I, cvel_vel, BW), + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], + ) + + if ti.static(BW): + links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I] + links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I] + + for i_3 in ti.static(range(3)): + _vel = ( + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + _ang = ( + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW) + + else: + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( + R(links_state.cd_ang_bw, curr_I, cvel_ang, BW), + R(links_state.cd_vel_bw, curr_I, cvel_vel, BW), + dofs_state.cdof_ang[i_d, i_b], + dofs_state.cdof_vel[i_d, i_b], + ) + + if ti.static(BW): + links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I] + links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I] + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + _vel = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + _ang = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW) + + I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b) + links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW) + links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW) + @ti.kernel(fastcache=gs.use_fastcache) def kernel_update_geoms( diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py index b6ee1f6dd1..06ce22a501 100644 --- a/genesis/engine/states/entities.py +++ b/genesis/engine/states/entities.py @@ -204,13 +204,19 @@ def __init__(self, entity, s_global): scene = self._entity.scene self._pos = gs.zeros((num_batch, 3), dtype=float, requires_grad=requires_grad, scene=scene) self._quat = gs.zeros((num_batch, 4), dtype=float, requires_grad=requires_grad, scene=scene) + self._qpos = gs.zeros((num_batch, entity.n_qs), dtype=float, requires_grad=requires_grad, scene=scene) + self._dofs_vel = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene) + self._dofs_acc = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene) def serializable(self): self._entity = None self._pos = self._pos.detach() self._quat = self._quat.detach() - + self._qpos = self._qpos.detach() + self._dofs_vel = self._dofs_vel.detach() + self._dofs_acc = self._dofs_acc.detach() + @property def entity(self): return self._entity @@ -226,3 +232,15 @@ def pos(self): @property def quat(self): return self._quat + + @property + def qpos(self): + return self._qpos + + @property + def dofs_vel(self): + return self._dofs_vel + + @property + def dofs_acc(self): + return self._dofs_acc \ No newline at end of file diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index e1a052c8e3..20397402b4 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -82,6 +82,7 @@ def ti_rotvec_to_R(rotvec, eps): @ti.func def ti_rotvec_to_quat(rotvec, eps): quat = ti.Vector.zero(gs.ti_float, 4) + res = ti.Vector.zero(gs.ti_float, 4) # We need to use [norm_sqr] instead of [norm] to avoid nan gradients in the backward pass. Even when theta = 0, # the gradient of [norm] operation is computed and used (note that the gradient becomes NaN when theta = 0). This @@ -98,11 +99,12 @@ def ti_rotvec_to_quat(rotvec, eps): quat[i + 1] = xyz[i] # First order quaternion normalization is accurate enough yet necessary - quat *= 0.5 * (3.0 - quat.norm_sqr()) + # quat *= 0.5 * (3.0 - quat.norm_sqr()) + res = quat * 0.5 * (3.0 - quat.norm_sqr()) else: - quat[0] = 1.0 + res[0] = 1.0 - return quat + return res @ti.func @@ -221,7 +223,12 @@ def ti_transform_quat_by_quat(v, u): This is equivalent to quatmul(quat_u, quat_v) or R_u @ R_v """ vec = ti_quat_mul(u, v) - return vec.normalized() + res = ti.Vector([1.0, 0.0, 0.0, 0.0], dt=gs.ti_float) + vec_norm_sqr = vec.norm_sqr() + if vec_norm_sqr > gs.EPS ** 2: + res = vec / (ti.sqrt(vec_norm_sqr) + gs.EPS) + return res + # return vec.normalized() @ti.func @@ -239,7 +246,7 @@ def ti_transform_by_quat(v, quat): v.x * (-2.0 * q_wy + 2.0 * q_xz) + v.y * (2.0 * q_wx + 2.0 * q_yz) + v.z * (q_ww - q_xx - q_yy + q_zz), ], dt=gs.ti_float, - ) / (q_ww + q_xx + q_yy + q_zz) + ) / (q_ww + q_xx + q_yy + q_zz + gs.EPS) @ti.func From e1655804f9ab449ec27eddb71977b2e86cbd1367 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Mon, 15 Dec 2025 10:29:58 -0800 Subject: [PATCH 9/9] stash --- examples/diffrigid/ant.py | 261 +++++++++++++++++- examples/diffrigid/slide_ball.py | 4 +- .../solvers/rigid/constraint_solver_decomp.py | 167 +++++------ .../solvers/rigid/rigid_solver_decomp.py | 256 +++++++++++++++-- 4 files changed, 576 insertions(+), 112 deletions(-) diff --git a/examples/diffrigid/ant.py b/examples/diffrigid/ant.py index 60d8e1ab9f..9935066ea1 100644 --- a/examples/diffrigid/ant.py +++ b/examples/diffrigid/ant.py @@ -4,6 +4,99 @@ import genesis as gs +import torch + +from copy import deepcopy + +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") + +class Controller(torch.nn.Module): + def __init__(self, obs_dim, n_dofs, hidden_dim=64): + super().__init__() + self.obs_dim = obs_dim + self.n_dofs = n_dofs + + # Batch normalization layer + self.bn = torch.nn.BatchNorm1d(obs_dim) + + # MLP layers (2-3 layers) + self.fc1 = torch.nn.Linear(obs_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim) + + # Output layers for mean and log_std + self.mean_layer = torch.nn.Linear(hidden_dim, n_dofs) + self.log_std_layer = torch.nn.Linear(hidden_dim, n_dofs) + + # Initialize log_std to small values + self.log_std_layer.weight.data.fill_(0.0) + self.log_std_layer.bias.data.fill_(-0.5) + + def forward(self, obs): + """ + Args: + obs: observation tensor of shape (batch_size, obs_dim) + Returns: + mean: mean of action distribution, shape (batch_size, n_dofs) + std: standard deviation of action distribution, shape (batch_size, n_dofs) + """ + # Batch normalization + if obs.shape[0] > 1: + x = self.bn(obs) + else: + x = obs + + # MLP layers with ReLU activation + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = torch.relu(self.fc3(x)) + + # Output mean and log_std + mean = self.mean_layer(x) + log_std = self.log_std_layer(x) + + # Clamp log_std to prevent extreme values + log_std = torch.clamp(log_std, min=-10, max=2) + std = torch.exp(log_std) + + return mean, std + + def sample_action(self, obs): + """ + Sample action from the policy distribution. + + Args: + obs: observation tensor of shape (batch_size, obs_dim) + Returns: + action: sampled action, shape (batch_size, n_dofs) + mean: mean of action distribution, shape (batch_size, n_dofs) + std: standard deviation of action distribution, shape (batch_size, n_dofs) + """ + mean, std = self.forward(obs) + noise = torch.randn_like(mean) + action = mean + std * noise + return action, mean, std + +def observe_fn(state): + qpos = state.qpos + dofs_vel = state.dofs_vel + dofs_acc = state.dofs_acc + + return torch.cat([qpos, dofs_vel, dofs_acc], dim=1).detach() + +def reward_fn(state, dt, prev_state=None): + pos = state.pos + + height_clip = torch.clamp(pos[:, 2] - 0.8, -float('inf'), 1.0) + height_reward = torch.where(height_clip <= 0.0, -200 * (height_clip ** 2), height_clip) + forward_reward = (pos[:, 0] - prev_state.pos[:, 0].detach()) / dt if prev_state is not None else 0.0 + + reward = height_reward * 0.01 + forward_reward + + return reward + def main(): parser = argparse.ArgumentParser() @@ -11,11 +104,20 @@ def main(): parser.add_argument("-n", "--n_envs", type=int, default=49) args = parser.parse_args() - args.vis = True - args.n_envs = 1 + args.vis = False + args.n_envs = 64 + + dt = 0.01 + substeps = 1 + horizon_steps = 128 + window_substeps = 32 + window_steps = int(window_substeps / substeps) + iteration = 10000 + lr = 1e-4 + render_every = 100 ########################## init ########################## - gs.init(backend=gs.cpu) + gs.init(backend=gs.gpu, logging_level="warn") ########################## create a scene ########################## viewer_options = gs.options.ViewerOptions( @@ -26,27 +128,170 @@ def main(): ) scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=window_steps), viewer_options=viewer_options, rigid_options=gs.options.RigidOptions( - dt=0.01, + use_gjk_collision=True, + enable_joint_limit=False, + ), + vis_options=gs.options.VisOptions( + rendered_envs_idx=(0,), + show_world_frame=True, ), show_viewer=args.vis, ) ########################## entities ########################## + # plane = scene.add_entity( + # gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True, pos=(0, 0, -0.5)), + # ) plane = scene.add_entity( - gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True), + gs.morphs.Box(size=(10, 10, 0.6), pos=(0, 0, -0.3), fixed=True), ) ant = scene.add_entity( - gs.morphs.MJCF(file="xml/humanoid.xml"), + gs.morphs.MJCF(file="xml/walker_no_ground.xml"), + # vis_mode="collision" ) + cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, + env_idx=0, + ) + cam.follow_entity(ant) ########################## build ########################## scene.build(n_envs=args.n_envs, env_spacing=(1, 1)) - for i in range(10000): - scene.step() + n_dofs = ant.n_dofs + + # rand_force = torch.randn((n_dofs,), dtype=torch.float32) + # for i in range(10000): + # ant.control_dofs_force(rand_force) + # scene.step() + + # if i % 1000 == 0 and i > 0: + # print("----------------------------------------------------------") + # rand_force = torch.randn((n_dofs,), dtype=torch.float32) + + # # Reset env + # rigid_solver = scene.sim.rigid_solver + # qpos = rigid_solver._rigid_global_info.qpos.to_numpy() + # dofs_vel = rigid_solver.dofs_state.vel.to_numpy() + # dofs_acc = rigid_solver.dofs_state.acc.to_numpy() + # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy() + # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy() + + # scene.reset() + # scene.sim.rigid_solver._rigid_global_info.qpos.from_numpy(qpos) + # rigid_solver.dofs_state.vel.from_numpy(dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws) + # rigid_solver.load_test() + + + # Initialize controller + # Get obs_dim by computing observation once + scene.reset() + state = ant.get_state() + obs = observe_fn(state) + obs_dim = obs.shape[1] + + controller = Controller(obs_dim=obs_dim, n_dofs=n_dofs, hidden_dim=64) + optimizer = torch.optim.Adam(controller.parameters(), lr=lr) + + rewards = [] + for iter in range(iteration): + scene.reset() + acc_reward = None + prev_state = None + + record = (iter % render_every == 0) or (iter == iteration - 1) + if record: + cam.start_recording() + print("running forward pass...") + for step in range(horizon_steps): + scene.step() + if record: + cam.render() + + # Determine observation and reward + state = ant.get_state() + obs = observe_fn(state) + reward = reward_fn(state, dt, prev_state) + if acc_reward is None: + acc_reward = reward + else: + acc_reward += reward + + prev_state = state + + # Determine action + action, mean, std = controller.sample_action(obs) + # Apply action (assuming action is force/torque for dofs) + ant.control_dofs_force(action) + + truncate = step == horizon_steps - 1 # step % window_steps == 0 or step == horizon_steps - 1 + if truncate and step > 0: + print("running backward pass...") + acc_reward = acc_reward / (step + 1) + mean_reward = acc_reward.mean() + loss = -mean_reward + + optimizer.zero_grad() + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0) + optimizer.step() + + print(f"[ITER {iter}] Mean Reward: {mean_reward.item():.4g} | Grad Norm: {grad_norm:.4g}") + + rewards.append(mean_reward.detach().item()) + + plt.plot(rewards) + plt.savefig("rewards.png") + plt.close() + + # Reset env + # scene.sim.reset_grad() + # scene._forward_ready = True + # scene._backward_ready = True + # scene._t = t + # scene.sim._cur_substep_global = substep_global + + # rigid_solver._rigid_global_info.qpos.from_numpy(_qpos) + # rigid_solver.dofs_state.vel.from_numpy(_dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(_dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(_dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(_solver_qacc_ws) + # rigid_solver.load_test() + + # rigid_solver = scene.sim.rigid_solver + # qpos = rigid_solver._rigid_global_info.qpos.to_numpy() + # dofs_vel = rigid_solver.dofs_state.vel.to_numpy() + # dofs_acc = rigid_solver.dofs_state.acc.to_numpy() + # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy() + # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy() + # scene_t = scene._t + # scene.reset() + # scene._t = scene_t + # rigid_solver._rigid_global_info.qpos.from_numpy(qpos) + # rigid_solver.dofs_state.vel.from_numpy(dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws) + # rigid_solver.load_test() + + acc_reward = None + # if step // window_steps > 1: + break + + if record: + cam.stop_recording(save_to_filename=f"ant_video_{iter:06d}.mp4", fps=30) + if __name__ == "__main__": main() diff --git a/examples/diffrigid/slide_ball.py b/examples/diffrigid/slide_ball.py index a674c479c0..a12c286ede 100644 --- a/examples/diffrigid/slide_ball.py +++ b/examples/diffrigid/slide_ball.py @@ -68,7 +68,7 @@ scene.build() num_iter = 300 -lr = 1e-2 +lr = 1e-4 init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) @@ -121,6 +121,8 @@ losses.append(loss.item()) plt.plot(losses) + # set y axis to log scale + plt.yscale('log') plt.savefig("loss.png") plt.close() diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index a9288546bd..a6aebb3d42 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -452,10 +452,10 @@ def add_collision_constraints( static_rigid_sim_config: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b, i_0 in ( - ti.ndrange(dofs_state.ctrl_mode.shape[1], 1) + for i_b, i_0, i_4 in ( + ti.ndrange(dofs_state.ctrl_mode.shape[1], 1, 4) if ti.static(not static_rigid_sim_config.is_backward) - else ti.ndrange(dofs_state.ctrl_mode.shape[1], static_rigid_sim_config.max_contact_pairs) + else ti.ndrange(dofs_state.ctrl_mode.shape[1], static_rigid_sim_config.max_contact_pairs, 4) ): EPS = rigid_global_info.EPS[None] n_dofs = dofs_state.ctrl_mode.shape[0] @@ -487,87 +487,88 @@ def add_collision_constraints( if link_b > -1: invweight = invweight + links_info.invweight[link_b_maybe_batch][0] - for i in range(4) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(4)): - d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) - n = d * contact_data_friction - contact_data_normal + #for i in range(4) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(4)): + i = i_4 + d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) + n = d * contact_data_friction - contact_data_normal - n_con = i_col * 4 + i # + constraint_state.n_constraints[i_b] - if ti.static(static_rigid_sim_config.sparse_solve): - for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): - i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] - constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) - else: - for i_d in ( - range(n_dofs) - if ti.static(not static_rigid_sim_config.is_backward) - else ti.static(range(static_rigid_sim_config.n_dofs)) - ): - constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) - - con_n_relevant_dofs = 0 - jac_qvel = gs.ti_float(0.0) - for i_ab in range(2) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(2)): - sign = gs.ti_float(-1.0) - link = link_a - if i_ab == 1: - sign = gs.ti_float(1.0) - link = link_b - - # FIXME: Set number of iterations to look for parent to certain value for autodiff - for i_parent in ( - range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(1)) - ): - if link > -1: - link_maybe_batch = ( - [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link - ) + n_con = i_col * 4 + i # + constraint_state.n_constraints[i_b] + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): + i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] + constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) + else: + for i_d in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) + + con_n_relevant_dofs = 0 + jac_qvel = gs.ti_float(0.0) + for i_ab in range(2) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(2)): + sign = gs.ti_float(-1.0) + link = link_a + if i_ab == 1: + sign = gs.ti_float(1.0) + link = link_b + + # FIXME: Set number of iterations to look for parent to certain value for autodiff + for i_parent in ( + range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(5)) + ): + if link > -1: + link_maybe_batch = ( + [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link + ) - # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending - for i_d_ in ( - range(links_info.n_dofs[link_maybe_batch]) - if ti.static(not static_rigid_sim_config.is_backward) - else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) - ): - if i_d_ < links_info.n_dofs[link_maybe_batch]: - i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ + # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending + for i_d_ in ( + range(links_info.n_dofs[link_maybe_batch]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + if i_d_ < links_info.n_dofs[link_maybe_batch]: + i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + cdof_ang = dofs_state.cdof_ang[i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_d, i_b] - t_quat = gu.ti_identity_quat() - t_pos = contact_data_pos - links_state.root_COM[link, i_b] - _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) + t_quat = gu.ti_identity_quat() + t_pos = contact_data_pos - links_state.root_COM[link, i_b] + _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) - diff = sign * vel - jac = diff @ n - jac_qvel += jac * dofs_state.vel[i_d, i_b] - constraint_state.jac[n_con, i_d, i_b] += jac + diff = sign * vel + jac = diff @ n + jac_qvel += jac * dofs_state.vel[i_d, i_b] + constraint_state.jac[n_con, i_d, i_b] += jac - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d - con_n_relevant_dofs += 1 + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d + con_n_relevant_dofs += 1 - link = links_info.parent_idx[link_maybe_batch] + link = links_info.parent_idx[link_maybe_batch] - if ti.static(static_rigid_sim_config.is_backward): - if i_parent == 4 and link > -1: - print( - "Warning: Number of parents is too large for backward mode in add_collision_constraints" - ) + if ti.static(static_rigid_sim_config.is_backward): + if i_parent == 4 and link > -1: + print( + "Warning: Number of parents is too large for backward mode in add_collision_constraints" + ) - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs - imp, aref = gu.imp_aref( - contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration - ) + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs + imp, aref = gu.imp_aref( + contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration + ) - diag_0 = invweight + contact_data_friction * contact_data_friction * invweight - diag_1 = diag_0 * 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp - diag = ti.max(diag_1, EPS) + diag_0 = invweight + contact_data_friction * contact_data_friction * invweight + diag_1 = diag_0 * 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp + diag = ti.max(diag_1, EPS) - constraint_state.diag[n_con, i_b] = diag - constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag + constraint_state.diag[n_con, i_b] = diag + constraint_state.aref[n_con, i_b] = aref + constraint_state.efc_D[n_con, i_b] = 1 / diag ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(dofs_state.ctrl_mode.shape[1]): @@ -858,16 +859,16 @@ def add_inequality_constraints( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(static_rigid_sim_config.enable_joint_limit): - add_joint_limit_constraints( - links_info=links_info, - joints_info=joints_info, - dofs_info=dofs_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) + # if ti.static(static_rigid_sim_config.enable_joint_limit): + # add_joint_limit_constraints( + # links_info=links_info, + # joints_info=joints_info, + # dofs_info=dofs_info, + # dofs_state=dofs_state, + # rigid_global_info=rigid_global_info, + # constraint_state=constraint_state, + # static_rigid_sim_config=static_rigid_sim_config, + # ) @ti.func diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index a0a352eb6f..2f28b7637e 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1332,13 +1332,13 @@ def substep_pre_coupling_grad(self, f): match errno: case 1: gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") - case 2: - qpos_diff = self._rigid_adjoint_cache_fw.qpos.to_numpy() - self._rigid_adjoint_cache_bw.qpos.to_numpy() - vel_diff = self._rigid_adjoint_cache_fw.dofs_vel.to_numpy() - self._rigid_adjoint_cache_bw.dofs_vel.to_numpy() - acc_diff = self._rigid_adjoint_cache_fw.dofs_acc.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc.to_numpy() - acc_smooth_diff = self._rigid_adjoint_cache_fw.dofs_acc_smooth.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc_smooth.to_numpy() - solver_qacc_ws_diff = self._rigid_adjoint_cache_fw.solver_qacc_ws.to_numpy() - self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy() - gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}") + # case 2: + # qpos_diff = self._rigid_adjoint_cache_fw.qpos.to_numpy() - self._rigid_adjoint_cache_bw.qpos.to_numpy() + # vel_diff = self._rigid_adjoint_cache_fw.dofs_vel.to_numpy() - self._rigid_adjoint_cache_bw.dofs_vel.to_numpy() + # acc_diff = self._rigid_adjoint_cache_fw.dofs_acc.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc.to_numpy() + # acc_smooth_diff = self._rigid_adjoint_cache_fw.dofs_acc_smooth.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc_smooth.to_numpy() + # solver_qacc_ws_diff = self._rigid_adjoint_cache_fw.solver_qacc_ws.to_numpy() - self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy() + # gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}") kernel_step_2.grad( dofs_state=self.dofs_state, @@ -1409,46 +1409,138 @@ def substep_pre_coupling_grad(self, f): static_rigid_sim_config=self._static_rigid_sim_config, ) - kernel_forward_dynamics_without_qacc.grad( + kernel_bias_force.grad( + dofs_state=self.dofs_state, + links_state=self.links_state, + links_info=self.links_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_force.grad( links_state=self.links_state, links_info=self.links_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_acc.grad( dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - joints_info=self.joints_info, + links_info=self.links_info, + links_state=self.links_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_torque_and_passive_force.grad( entities_state=self.entities_state, entities_info=self.entities_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, geoms_state=self.geoms_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, ) + kernel_factor_mass.grad( + implicit_damping=False, + entities_info=self.entities_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_compute_mass_matrix_ad.grad( + implicit_damping=self._static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast, + links_state=self.links_state, + links_info=self.links_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + + # kernel_forward_dynamics_without_qacc.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # joints_info=self.joints_info, + # entities_state=self.entities_state, + # entities_info=self.entities_info, + # geoms_state=self.geoms_state, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # contact_island_state=self.constraint_solver.contact_island.contact_island_state, + # ) # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, - dofs_state=self.dofs_state, + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_velocity_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_geoms.grad( + envs_idx, entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + links_state=self.links_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, ) - kernel_update_cartesian_space.grad( + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_COM_links_ad_1.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_COM_links_ad_0.grad( links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, joints_info=self.joints_info, dofs_state=self.dofs_state, dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_kinematics_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) # Change back to forward mode self._static_rigid_sim_config.is_backward = False @@ -1714,6 +1806,34 @@ def load_ckpt(self, ckpt_name): for entity in self._entities: entity.load_ckpt(ckpt_name) + + def load_test(self): + if not self._enable_mujoco_compatibility: + envs_idx = self._scene._sanitize_envs_idx(None) + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) + kernel_forward_velocity( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) @property def is_active(self): @@ -3537,6 +3657,27 @@ def func_vel_at_point(pos_world, link_idx, i_b, links_state: array_class.LinksSt vel_lin = links_state.cd_vel[link_idx, i_b] return vel_rot + vel_lin +@ti.kernel +def kernel_compute_mass_matrix_ad( + implicit_damping: ti.template(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_compute_mass_matrix( + implicit_damping=implicit_damping, + links_state=links_state, + links_info=links_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_compute_mass_matrix( @@ -3790,6 +3931,23 @@ def func_compute_mass_matrix( # qM += d qfrc_actuator / d qvel rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] +@ti.kernel +def kernel_factor_mass( + implicit_damping: ti.template(), + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_factor_mass( + implicit_damping=implicit_damping, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_factor_mass( @@ -6574,6 +6732,33 @@ def func_clear_external_force( for I in ti.grouped(dofs_state.ctrl_force): dofs_state.ctrl_force[I] = ti.Vector.zero(gs.ti_float, 3) +@ti.kernel +def kernel_torque_and_passive_force( + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + contact_island_state: array_class.ContactIslandState, +): + func_torque_and_passive_force( + entities_state=entities_state, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + links_state=links_state, + links_info=links_info, + joints_info=joints_info, + geoms_state=geoms_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + contact_island_state=contact_island_state, + ) @ti.func def func_torque_and_passive_force( @@ -6893,6 +7078,21 @@ def func_update_acc( BW, ) +@ti.kernel +def kernel_update_force( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_update_force( + links_state=links_state, + links_info=links_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_update_force( @@ -7033,6 +7233,22 @@ def func_actuation(self): self.dofs_state.act_length[i_d, i_b] = 0.0 self.dofs_state.qf_actuator[i_d, i_b] = self.dofs_state.act_length[i_d, i_b] +@ti.kernel +def kernel_bias_force( + dofs_state: array_class.DofsState, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_bias_force( + dofs_state=dofs_state, + links_state=links_state, + links_info=links_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + @ti.func def func_bias_force(