diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 5382f5fb46..298863f2c9 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2834,7 +2834,9 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.kernel(fastcache=gs.use_fastcache) +@qd.perf_dispatch( + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 +) def func_solve_init( dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, @@ -2842,6 +2844,18 @@ def func_solve_init( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), +) -> None: ... + + +@func_solve_init.register(is_compatible=lambda *args, **kwargs: True) +@qd.kernel(fastcache=gs.use_fastcache) +def func_solve_init_monolith( + dofs_info: array_class.DofsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), ): _B = dofs_state.acc_smooth.shape[1] n_dofs = dofs_state.acc_smooth.shape[0] diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index e5cb7f6555..6d9411f12b 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -215,6 +215,185 @@ def _kernel_update_search_direction( ) +# ================================================ Init kernels ================================================ + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_warmstart( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Select qacc from warmstart or acc_smooth, parallelized over (dof, env).""" + n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Ma( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Compute Ma = M @ qacc, parallelized over (dof, env).""" + solver.initialize_Ma( + Ma=constraint_state.Ma, + qacc=constraint_state.qacc, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Jaref( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + n_dofs = constraint_state.jac.shape[1] + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in ti.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + Jaref = -constraint_state.aref[i_c, i_b] + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + else: + for i_d in range(n_dofs): + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + constraint_state.Jaref[i_c, i_b] = Jaref + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_improved( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Set improved = (n_constraints > 0) for each env.""" + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_search( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Set search = -Mgrad, parallelized over (dof, env).""" + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_constraint( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + """Init-only constraint update — wraps monolith's func_update_constraint for exact FP match.""" + solver.func_update_constraint( + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_gradient( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU).""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +def func_solve_init_decomposed( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + """ + Decomposed version of func_solve_init for CUDA backend (non-mujoco path). + + Breaks the monolithic init kernel into separate kernel launches: + 1. Warmstart selection (ndrange over dofs) + 2. Ma = M @ qacc (ndrange over dofs with entity lookup) + 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) + 4. Set improved flags + 5. Update constraint (wraps monolith's func_update_constraint for exact FP match) + 6. Newton hessian (Newton only — reuse existing kernel) + 7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU) + 8. search = -Mgrad (ndrange over dofs) + """ + # 1. Warmstart selection + _kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + + # 2. Ma = M @ qacc + _kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 3. Jaref = -aref + J @ qacc (parallelized over constraints) + _kernel_init_Jaref(constraint_state, static_rigid_sim_config) + + # 4. Set improved flags (needed by decomposed update_constraint kernels) + _kernel_init_improved(constraint_state, static_rigid_sim_config) + + # 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match) + _kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) + + # 6. Newton hessian (Newton only) + if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: + _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) + + # 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU) + _kernel_init_update_gradient( + entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config + ) + + # 8. search = -Mgrad + _kernel_init_search(constraint_state, static_rigid_sim_config) + + +# ============================================== Solve body kernels ================================================ + + @solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) def func_solve_decomposed( entities_info,