From 692172012126b134708764b3d3f402af4c03d26a Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 15:44:15 +0000 Subject: [PATCH 1/3] fix reading field in python scope avoiding gpu-cpu sync --- genesis/engine/solvers/rigid/rigid_solver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index cf518cb6ed..74a68fc28b 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -456,6 +456,9 @@ def _create_data_manager(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info + self._rigid_global_info._n_iterations = ( + self._options.iterations + ) # Python-native mirror to avoid CPU-GPU sync in Python-scope functions self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs From 667cd07061b133b64728ea4445fefc54424f445f Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 16:29:05 +0000 Subject: [PATCH 2/3] update --- genesis/engine/solvers/rigid/rigid_solver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 74a68fc28b..cf518cb6ed 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -456,9 +456,6 @@ def _create_data_manager(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info - self._rigid_global_info._n_iterations = ( - self._options.iterations - ) # Python-native mirror to avoid CPU-GPU sync in Python-scope functions self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs From 8ae30ab221cdc651d0ad3561e7020723f2ac5b06 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 18:09:48 +0000 Subject: [PATCH 3/3] Decompose func_solve_init with perf_dispatch and 8 separate kernels Convert func_solve_init from a plain @qd.kernel to a @qd.perf_dispatch, and register func_solve_init_decomposed for CUDA backend. This breaks the monolithic init into 8 separate kernel launches: 1. _kernel_init_warmstart (warmstart selection, ndrange dofs) 2. _kernel_init_Ma (Ma = M @ qacc, ndrange dofs) 3. _kernel_init_Jaref (Jaref = -aref + J @ qacc, ndrange constraints) 4. _kernel_init_improved (set improved flags) 5. _kernel_init_update_constraint (wraps monolith for FP match) 6. Newton hessian (conditional, reuses existing kernel) 7. _kernel_init_update_gradient (wraps monolith tiled gradient) 8. _kernel_init_search (search = -Mgrad, ndrange dofs) Co-Authored-By: Claude Opus 4.6 --- .../engine/solvers/rigid/constraint/solver.py | 16 +- .../rigid/constraint/solver_breakdown.py | 179 ++++++++++++++++++ 2 files changed, 194 insertions(+), 1 deletion(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 618ad4e0a5..f7f6bcf824 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2819,7 +2819,9 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.kernel(fastcache=gs.use_fastcache) +@qd.perf_dispatch( + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 +) def func_solve_init( dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, @@ -2827,6 +2829,18 @@ def func_solve_init( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), +) -> None: ... + + +@func_solve_init.register(is_compatible=lambda *args, **kwargs: True) +@qd.kernel(fastcache=gs.use_fastcache) +def func_solve_init_monolith( + dofs_info: array_class.DofsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), ): _B = dofs_state.acc_smooth.shape[1] n_dofs = dofs_state.acc_smooth.shape[0] diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index c5d59369ad..411d15b0a5 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -129,6 +129,185 @@ def _kernel_update_search_direction( ) +# ================================================ Init kernels ================================================ + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_warmstart( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Select qacc from warmstart or acc_smooth, parallelized over (dof, env).""" + n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Ma( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Compute Ma = M @ qacc, parallelized over (dof, env).""" + solver.initialize_Ma( + Ma=constraint_state.Ma, + qacc=constraint_state.qacc, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Jaref( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + n_dofs = constraint_state.jac.shape[1] + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in ti.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + Jaref = -constraint_state.aref[i_c, i_b] + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + else: + for i_d in range(n_dofs): + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + constraint_state.Jaref[i_c, i_b] = Jaref + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_improved( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Set improved = (n_constraints > 0) for each env.""" + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_search( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Set search = -Mgrad, parallelized over (dof, env).""" + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_constraint( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Init-only constraint update — wraps monolith's func_update_constraint for exact FP match.""" + solver.func_update_constraint( + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_gradient( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU).""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +def func_solve_init_decomposed( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + """ + Decomposed version of func_solve_init for CUDA backend (non-mujoco path). + + Breaks the monolithic init kernel into separate kernel launches: + 1. Warmstart selection (ndrange over dofs) + 2. Ma = M @ qacc (ndrange over dofs with entity lookup) + 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) + 4. Set improved flags + 5. Update constraint (wraps monolith's func_update_constraint for exact FP match) + 6. Newton hessian (Newton only — reuse existing kernel) + 7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU) + 8. search = -Mgrad (ndrange over dofs) + """ + # 1. Warmstart selection + _kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + + # 2. Ma = M @ qacc + _kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 3. Jaref = -aref + J @ qacc (parallelized over constraints) + _kernel_init_Jaref(constraint_state, static_rigid_sim_config) + + # 4. Set improved flags (needed by decomposed update_constraint kernels) + _kernel_init_improved(constraint_state, static_rigid_sim_config) + + # 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match) + _kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) + + # 6. Newton hessian (Newton only) + if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: + _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) + + # 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU) + _kernel_init_update_gradient( + entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config + ) + + # 8. search = -Mgrad + _kernel_init_search(constraint_state, static_rigid_sim_config) + + +# ============================================== Solve body kernels ================================================ + + @solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) def func_solve_decomposed( entities_info,