diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 49359ccf5e..15a69321b2 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING import gstaichi as ti @@ -20,6 +21,12 @@ IS_OLD_TORCH = tuple(map(int, torch.__version__.split(".")[:2])) < (2, 8) +# Parallel linesearch: evaluate K candidate alphas in parallel per env +USE_LS_PARALLEL = os.environ.get("GS_SOLVER_LS_PARALLEL", "0") == "1" +GS_SOLVER_LS_PARALLEL = 1 +LS_PARALLEL_K = 8 +LS_PARALLEL_MIN_STEP = 1e-6 + class ConstraintSolver: def __init__(self, rigid_solver: "RigidSolver"): @@ -183,13 +190,16 @@ def resolve(self): self._solver._rigid_global_info, self._solver._static_rigid_sim_config, ) - func_solve_body( - self._solver.entities_info, - self._solver.dofs_state, - self.constraint_state, - self._solver._rigid_global_info, - self._solver._static_rigid_sim_config, - ) + if USE_LS_PARALLEL and gs.backend == gs.gpu: + self._resolve_body_parallel_ls() + else: + func_solve_body( + self._solver.entities_info, + self._solver.dofs_state, + self.constraint_state, + self._solver._rigid_global_info, + self._solver._static_rigid_sim_config, + ) func_update_qacc( self._solver.dofs_state, @@ -208,6 +218,15 @@ def resolve(self): self._solver._static_rigid_sim_config, ) + def _resolve_body_parallel_ls(self): + kernel_parallel_ls_fused( + self._solver.entities_info, + self._solver.dofs_state, + self.constraint_state, + self._solver._rigid_global_info, + self._solver._static_rigid_sim_config, + ) + def noslip(self): constraint_noslip.kernel_build_efc_AR_b( self._solver.dofs_state, @@ -2966,6 +2985,431 @@ def func_solve_body( constraint_state.improved[i_b] = False +# ===================================================================================================================== +# ============================================ Parallel Linesearch ===================================================== +# ===================================================================================================================== + + +@ti.func +def _log_scale(min_value: gs.ti_float, max_value: gs.ti_float, num_values: ti.i32, i: ti.i32) -> gs.ti_float: + step = (ti.log(max_value) - ti.log(min_value)) / ti.max(1.0, gs.ti_float(num_values - 1)) + return ti.exp(ti.log(min_value) + gs.ti_float(i) * step) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_parallel_ls_init( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Compute mv, jv, quad_gauss, p0_cost for each env. One thread per env.""" + _B = constraint_state.grad.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + for i_b in range(_B): + if constraint_state.n_constraints[i_b] > 0: + n_dofs = constraint_state.search.shape[0] + + # Use adaptive linesearch tolerance (same as func_linesearch_batch) + snorm = gs.ti_float(0.0) + for jd in range(n_dofs): + snorm = snorm + constraint_state.search[jd, i_b] ** 2 + snorm = ti.sqrt(snorm) + + if snorm < rigid_global_info.EPS[None]: + constraint_state.candidates[0, i_b] = 0.0 + constraint_state.candidates[1, i_b] = 0.0 + constraint_state.improved[i_b] = False + else: + constraint_state.improved[i_b] = True + # Reuse existing fused init+p0 eval + _, p0_cost, _, _ = func_ls_init_and_eval_p0_opt( + i_b, + entities_info=entities_info, + dofs_state=dofs_state, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + constraint_state.candidates[1, i_b] = p0_cost + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_parallel_ls_eval( + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Evaluate K candidate alphas in parallel per env, pick the best via reduction.""" + _B = constraint_state.grad.shape[1] + _K = ti.static(LS_PARALLEL_K) + _MIN_STEP = ti.static(LS_PARALLEL_MIN_STEP) + + ti.loop_config(block_dim=_K) + for i_ in range(_B * _K): + tid = i_ % _K + i_b = i_ // _K + + # Shared memory for argmin reduction + sh_cost = ti.simt.block.SharedArray((_K,), gs.ti_float) + sh_idx = ti.simt.block.SharedArray((_K,), ti.i32) + + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + n_con = constraint_state.n_constraints[i_b] + + # Generate log-spaced alpha: alpha[0]=MIN_STEP ... alpha[K-1]=1.0 + alpha = _log_scale(_MIN_STEP, 1.0, _K, tid) + + # Evaluate cost at this alpha + cost = ( + alpha * alpha * constraint_state.quad_gauss[2, i_b] + + alpha * constraint_state.quad_gauss[1, i_b] + + constraint_state.quad_gauss[0, i_b] + ) + + # Equality constraints (always active) - use eq_sum precomputed during init + cost = ( + cost + + alpha * alpha * constraint_state.eq_sum[2, i_b] + + alpha * constraint_state.eq_sum[1, i_b] + + constraint_state.eq_sum[0, i_b] + ) + + # Friction constraints + for i_c in range(ne, nef): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + x = Jaref_c + alpha * jv_c + rf = r * f + linear_neg = x <= -rf + linear_pos = x >= rf + if linear_neg or linear_pos: + cost = cost + linear_neg * f * (-0.5 * rf - Jaref_c - alpha * jv_c) + cost = cost + linear_pos * f * (-0.5 * rf + Jaref_c + alpha * jv_c) + else: + cost = cost + D * 0.5 * x * x + + # Contact constraints (active if x < 0) + for i_c in range(nef, n_con): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + x = Jaref_c + alpha * jv_c + if x < 0: + cost += D * 0.5 * x * x + + sh_cost[tid] = cost + sh_idx[tid] = tid + else: + sh_cost[tid] = gs.ti_float(1e30) + sh_idx[tid] = tid + + ti.simt.block.sync() + + # Tree reduction for argmin + stride = _K // 2 + while stride > 0: + if tid < stride: + if sh_cost[tid + stride] < sh_cost[tid]: + sh_cost[tid] = sh_cost[tid + stride] + sh_idx[tid] = sh_idx[tid + stride] + ti.simt.block.sync() + stride = stride // 2 + + # Thread 0: acceptance check and write result + if tid == 0: + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + p0_cost = constraint_state.candidates[1, i_b] + best_tid = sh_idx[0] + best_cost = sh_cost[0] + best_alpha = _log_scale(_MIN_STEP, 1.0, _K, best_tid) + if best_cost < p0_cost: + constraint_state.candidates[0, i_b] = best_alpha + else: + constraint_state.candidates[0, i_b] = 0.0 + else: + constraint_state.candidates[0, i_b] = 0.0 + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_parallel_ls_apply( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Apply the best alpha found by kernel_parallel_ls_eval. One thread per env.""" + _B = constraint_state.grad.shape[1] + n_dofs = constraint_state.qacc.shape[0] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + for i_b in range(_B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + alpha = constraint_state.candidates[0, i_b] + + if ti.abs(alpha) < rigid_global_info.EPS[None]: + constraint_state.improved[i_b] = False + else: + for i_d in range(n_dofs): + constraint_state.qacc[i_d, i_b] = ( + constraint_state.qacc[i_d, i_b] + constraint_state.search[i_d, i_b] * alpha + ) + constraint_state.Ma[i_d, i_b] = ( + constraint_state.Ma[i_d, i_b] + constraint_state.mv[i_d, i_b] * alpha + ) + + for i_c in range(constraint_state.n_constraints[i_b]): + constraint_state.Jaref[i_c, i_b] = ( + constraint_state.Jaref[i_c, i_b] + constraint_state.jv[i_c, i_b] * alpha + ) + + if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG): + for i_d in range(n_dofs): + constraint_state.cg_prev_grad[i_d, i_b] = constraint_state.grad[i_d, i_b] + constraint_state.cg_prev_Mgrad[i_d, i_b] = constraint_state.Mgrad[i_d, i_b] + + func_update_constraint_batch( + i_b, + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): + is_degenerated = func_hessian_and_cholesky_factor_incremental_batch( + i_b, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + if is_degenerated: + func_hessian_and_cholesky_factor_direct_batch( + i_b, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + func_update_gradient_batch( + i_b, + dofs_state=dofs_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + func_terminate_or_update_descent_batch( + i_b, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_parallel_ls_fused( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """Fused parallel linesearch: init + eval + apply in a single kernel, no Python loop.""" + _B = constraint_state.grad.shape[1] + _K = ti.static(LS_PARALLEL_K) + _MIN_STEP = ti.static(LS_PARALLEL_MIN_STEP) + + ti.loop_config(block_dim=_K) + for i_ in range(_B * _K): + tid = i_ % _K + i_b = i_ // _K + + sh_cost = ti.simt.block.SharedArray((_K,), gs.ti_float) + sh_idx = ti.simt.block.SharedArray((_K,), ti.i32) + + n_con = constraint_state.n_constraints[i_b] + if n_con > 0: + n_dofs = constraint_state.qacc.shape[0] + + for _iter in range(rigid_global_info.iterations[None]): + # ── Phase 1: Init (thread 0 only) ── + if tid == 0: + snorm = gs.ti_float(0.0) + for jd in range(n_dofs): + snorm = snorm + constraint_state.search[jd, i_b] ** 2 + snorm = ti.sqrt(snorm) + + if snorm < rigid_global_info.EPS[None]: + constraint_state.improved[i_b] = False + else: + constraint_state.improved[i_b] = True + _, p0_cost, _, _ = func_ls_init_and_eval_p0_opt( + i_b, + entities_info=entities_info, + dofs_state=dofs_state, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + constraint_state.candidates[1, i_b] = p0_cost + + ti.simt.block.sync() + + if not constraint_state.improved[i_b]: + break + + # ── Phase 2: Eval (all K threads in parallel) ── + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + + alpha = _log_scale(_MIN_STEP, 1.0, _K, tid) + + cost = ( + alpha * alpha * constraint_state.quad_gauss[2, i_b] + + alpha * constraint_state.quad_gauss[1, i_b] + + constraint_state.quad_gauss[0, i_b] + ) + + # Equality constraints — use precomputed eq_sum + cost = ( + cost + + alpha * alpha * constraint_state.eq_sum[2, i_b] + + alpha * constraint_state.eq_sum[1, i_b] + + constraint_state.eq_sum[0, i_b] + ) + + # Friction constraints + for i_c in range(ne, nef): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + x = Jaref_c + alpha * jv_c + rf = r * f + linear_neg = x <= -rf + linear_pos = x >= rf + if linear_neg or linear_pos: + cost = cost + linear_neg * f * (-0.5 * rf - Jaref_c - alpha * jv_c) + cost = cost + linear_pos * f * (-0.5 * rf + Jaref_c + alpha * jv_c) + else: + cost = cost + D * 0.5 * x * x + + # Contact constraints + for i_c in range(nef, n_con): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + x = Jaref_c + alpha * jv_c + if x < 0: + cost = cost + D * 0.5 * x * x + + sh_cost[tid] = cost + sh_idx[tid] = tid + ti.simt.block.sync() + + # Argmin tree reduction + stride = _K // 2 + while stride > 0: + if tid < stride: + if sh_cost[tid + stride] < sh_cost[tid]: + sh_cost[tid] = sh_cost[tid + stride] + sh_idx[tid] = sh_idx[tid + stride] + ti.simt.block.sync() + stride = stride // 2 + + # ── Phase 3: Apply (thread 0 only) ── + if tid == 0: + p0_cost = constraint_state.candidates[1, i_b] + best_cost = sh_cost[0] + best_tid = sh_idx[0] + best_alpha = _log_scale(_MIN_STEP, 1.0, _K, best_tid) + + if best_cost >= p0_cost: + constraint_state.improved[i_b] = False + else: + for i_d in range(n_dofs): + constraint_state.qacc[i_d, i_b] = ( + constraint_state.qacc[i_d, i_b] + constraint_state.search[i_d, i_b] * best_alpha + ) + constraint_state.Ma[i_d, i_b] = ( + constraint_state.Ma[i_d, i_b] + constraint_state.mv[i_d, i_b] * best_alpha + ) + + for i_c in range(n_con): + constraint_state.Jaref[i_c, i_b] = ( + constraint_state.Jaref[i_c, i_b] + constraint_state.jv[i_c, i_b] * best_alpha + ) + + if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG): + for i_d in range(n_dofs): + constraint_state.cg_prev_grad[i_d, i_b] = constraint_state.grad[i_d, i_b] + constraint_state.cg_prev_Mgrad[i_d, i_b] = constraint_state.Mgrad[i_d, i_b] + + func_update_constraint_batch( + i_b, + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): + is_degenerated = func_hessian_and_cholesky_factor_incremental_batch( + i_b, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + if is_degenerated: + func_hessian_and_cholesky_factor_direct_batch( + i_b, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + func_update_gradient_batch( + i_b, + dofs_state=dofs_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + func_terminate_or_update_descent_batch( + i_b, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + ti.simt.block.sync() + + if not constraint_state.improved[i_b]: + break + else: + if tid == 0: + constraint_state.improved[i_b] = False + + # ===================================================================================================================== # ==================================================== Finalization =================================================== # =====================================================================================================================