From 1b8c30c38a8c951f50ded6eba9ce814f8be51cc1 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 23 Mar 2026 22:15:52 +0100 Subject: [PATCH] Fix GPU synchronization issue on Apple Metal. --- genesis/engine/solvers/rigid/collider/collider.py | 4 ++++ genesis/engine/solvers/rigid/constraint/solver.py | 4 ++++ genesis/engine/solvers/rigid/rigid_solver.py | 12 ++++++++++++ 3 files changed, 20 insertions(+) diff --git a/genesis/engine/solvers/rigid/collider/collider.py b/genesis/engine/solvers/rigid/collider/collider.py index 36b3cf12ed..270307adbc 100644 --- a/genesis/engine/solvers/rigid/collider/collider.py +++ b/genesis/engine/solvers/rigid/collider/collider.py @@ -550,6 +550,8 @@ def reset(self, envs_idx=None, *, cache_only: bool = True) -> None: normal.zero_() else: normal[:, envs_idx] = 0.0 + if gs.backend == gs.metal: + torch.mps.synchronize() return envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx) @@ -603,6 +605,8 @@ def clear(self, envs_idx=None): pos[:, envs_idx] = 0.0 normal[:, envs_idx] = 0.0 force[:, envs_idx] = 0.0 + if gs.backend == gs.metal: + torch.mps.synchronize() return if not isinstance(envs_idx, torch.Tensor): diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 357f084c90..c911b7da64 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -127,6 +127,8 @@ def reset(self, envs_idx=None): else: is_warmstart[envs_idx] = False qacc_ws[:, envs_idx] = 0.0 + if gs.backend == gs.metal: + torch.mps.synchronize() return envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx) @@ -159,6 +161,8 @@ def clear(self, envs_idx=None): assign_indexed_tensor(n_constraints_equality, env_mask, 0) assign_indexed_tensor(n_constraints_frictionloss, env_mask, 0) assign_indexed_tensor(qd_n_equalities, env_mask, n_eq) + if gs.backend == gs.metal: + torch.mps.synchronize() return if not isinstance(envs_idx, torch.Tensor): diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 8b3ede2ba0..2719034023 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -1946,6 +1946,8 @@ def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None): num_values = len(tensor_list) for j, mask_j in enumerate(((*mask, ..., j) for j in range(num_values)) if num_values > 1 else (mask,)): assign_indexed_tensor(data, mask_j, tensor_list[j]) + if gs.backend == gs.metal: + torch.mps.synchronize() return tensor_list = list(tensor_list) @@ -2037,6 +2039,8 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None): if gs.use_zerocopy: errno = qd_to_torch(self._errno, copy=False) errno[envs_idx] = 0 + if gs.backend == gs.metal: + torch.mps.synchronize() else: kernel_set_zero(envs_idx, self._errno) @@ -2064,6 +2068,8 @@ def control_dofs_force(self, force, dofs_idx=None, envs_idx=None): ctrl_mode[mask] = gs.CTRL_MODE.FORCE ctrl_force = qd_to_torch(self.dofs_state.ctrl_force, transpose=True, copy=False) assign_indexed_tensor(ctrl_force, mask, force) + if gs.backend == gs.metal: + torch.mps.synchronize() return force, dofs_idx, envs_idx = self._sanitize_io_variables( @@ -2083,6 +2089,8 @@ def control_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None): ctrl_pos[mask] = 0.0 ctrl_vel = qd_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False) assign_indexed_tensor(ctrl_vel, mask, velocity) + if gs.backend == gs.metal: + torch.mps.synchronize() return velocity, dofs_idx, envs_idx = self._sanitize_io_variables( @@ -2102,6 +2110,8 @@ def control_dofs_position(self, position, dofs_idx=None, envs_idx=None): assign_indexed_tensor(ctrl_pos, mask, position) ctrl_vel = qd_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False) ctrl_vel[mask] = 0.0 + if gs.backend == gs.metal: + torch.mps.synchronize() return position, dofs_idx, envs_idx = self._sanitize_io_variables( @@ -2121,6 +2131,8 @@ def control_dofs_position_velocity(self, position, velocity, dofs_idx=None, envs assign_indexed_tensor(ctrl_pos, mask, position) ctrl_vel = qd_to_torch(self.dofs_state.ctrl_vel, transpose=True, copy=False) assign_indexed_tensor(ctrl_vel, mask, velocity) + if gs.backend == gs.metal: + torch.mps.synchronize() return position, dofs_idx, _ = self._sanitize_io_variables(