From 5a211239ac083971886d8044f9c75ed040b90636 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Sun, 22 Feb 2026 01:32:16 +0000 Subject: [PATCH 01/11] add solver optimizations --- .../engine/solvers/rigid/constraint/solver.py | 25 +- .../rigid/constraint/solver_breakdown.py | 669 +++++++++++++++++- 2 files changed, 670 insertions(+), 24 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index e55196fd6f..f1ca0b23e8 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -90,7 +90,6 @@ def __init__(self, rigid_solver: "RigidSolver"): self.mv = cs.mv self.jv = cs.jv self.quad_gauss = cs.quad_gauss - self.candidates = cs.candidates self.ls_it = cs.ls_it self.ls_result = cs.ls_result @@ -188,6 +187,7 @@ def resolve(self): func_solve_body( self._solver.entities_info, + self._solver.dofs_info, self._solver.dofs_state, self.constraint_state, self._solver._rigid_global_info, @@ -2262,6 +2262,12 @@ def update_bracket_no_eval_local( return flag, p_alpha, p_cost, p_grad, p_hess, p_next_alpha +@qd.func +def _log_scale(min_value: gs.qd_float, max_value: gs.qd_float, num_values: qd.i32, i: qd.i32) -> gs.qd_float: + step = (qd.log(max_value) - qd.log(min_value)) / qd.max(1.0, gs.qd_float(num_values - 1)) + return qd.exp(qd.log(min_value) + gs.qd_float(i) * step) + + @qd.func def func_linesearch_and_apply_alpha( i_b, @@ -2801,7 +2807,9 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.kernel(fastcache=gs.use_fastcache) +@qd.perf_dispatch( + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=1.0 +) def func_solve_init( dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, @@ -2809,6 +2817,18 @@ def func_solve_init( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), +) -> None: ... + + +@func_solve_init.register(is_compatible=lambda *args, **kwargs: True) +@qd.kernel(fastcache=gs.use_fastcache) +def func_solve_init_monolith( + dofs_info: array_class.DofsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), ): _B = dofs_state.acc_smooth.shape[1] n_dofs = dofs_state.acc_smooth.shape[0] @@ -3013,6 +3033,7 @@ def func_solve_iter( ) def func_solve_body( entities_info: array_class.EntitiesInfo, + dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index 1f2529248e..cd4b21fa17 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -1,20 +1,25 @@ -import quadrants as ti +import quadrants as qd import genesis as gs import genesis.utils.array_class as array_class from genesis.engine.solvers.rigid.constraint import solver +LS_PARALLEL_K = 8 +LS_PARALLEL_MIN_STEP = 1e-6 +_P0_BLOCK = 32 +_JV_BLOCK = 32 -@ti.kernel(fastcache=gs.use_fastcache) + +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_linesearch( entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), + static_rigid_sim_config: qd.template(), ): _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_linesearch_and_apply_alpha( @@ -29,27 +34,364 @@ def _kernel_linesearch( constraint_state.improved[i_b] = False -@ti.kernel(fastcache=gs.use_fastcache) +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_mv( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Compute mv = M @ search, parallelized over (dof, env). + + Uses per-dof entity lookup to find the entity block boundaries, giving n_dofs * B + threads (each computing a single ~6-element dot product) instead of n_entities * B + threads (each computing the full block matvec). + """ + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d1, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0: + I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1 + i_e = dofs_info.entity_idx[I_d1] + mv = gs.qd_float(0.0) + for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + mv = mv + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * constraint_state.search[i_d2, i_b] + constraint_state.mv[i_d1, i_b] = mv + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_jv( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute jv = J @ search, parallelized over (constraint, env).""" + n_dofs = constraint_state.search.shape[0] + len_constraints = constraint_state.jac.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + jv = gs.qd_float(0.0) + if qd.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + jv = jv + constraint_state.jac[i_c, i_d, i_b] * constraint_state.search[i_d, i_b] + else: + for i_d in range(n_dofs): + jv = jv + constraint_state.jac[i_c, i_d, i_b] * constraint_state.search[i_d, i_b] + constraint_state.jv[i_c, i_b] = jv + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_p0( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Snorm check, quad_gauss, eq_sum, and p0_cost. T threads per env with shared memory reductions. + + Phase 1: Fused snorm + quad_gauss parallel reduction over n_dofs (Options A+B). + Phase 2: Parallel reduction over n_constraints for eq_sum and p0_cost. + """ + _B = constraint_state.grad.shape[1] + _T = qd.static(_P0_BLOCK) + + qd.loop_config(block_dim=_T) + for i_ in range(_B * _T): + tid = i_ % _T + i_b = i_ // _T + + # 4 shared arrays for parallel reductions (reused across phases) + sh_a = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_b = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_c = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_d = qd.simt.block.SharedArray((_T,), gs.qd_float) + + if constraint_state.n_constraints[i_b] > 0: + n_dofs = constraint_state.search.shape[0] + + # === Phase 1: Fused snorm + quad_gauss, parallel over n_dofs === + local_snorm_sq = gs.qd_float(0.0) + local_qg1 = gs.qd_float(0.0) + local_qg2 = gs.qd_float(0.0) + + i_d = tid + while i_d < n_dofs: + s = constraint_state.search[i_d, i_b] + local_snorm_sq += s * s + local_qg1 += s * constraint_state.Ma[i_d, i_b] - s * dofs_state.force[i_d, i_b] + local_qg2 += 0.5 * s * constraint_state.mv[i_d, i_b] + i_d += _T + + sh_a[tid] = local_snorm_sq + sh_b[tid] = local_qg1 + sh_c[tid] = local_qg2 + + qd.simt.block.sync() + + # Tree reduction for 3 accumulators + stride = _T // 2 + while stride > 0: + if tid < stride: + sh_a[tid] += sh_a[tid + stride] + sh_b[tid] += sh_b[tid + stride] + sh_c[tid] += sh_c[tid + stride] + qd.simt.block.sync() + stride //= 2 + + # All threads read the reduced snorm + snorm = qd.sqrt(sh_a[0]) + + if snorm < rigid_global_info.EPS[None]: + # Converged — only thread 0 writes + if tid == 0: + constraint_state.candidates[0, i_b] = 0.0 + constraint_state.candidates[1, i_b] = 0.0 + constraint_state.improved[i_b] = False + else: + # Thread 0 writes quad_gauss to global memory + if tid == 0: + constraint_state.improved[i_b] = True + constraint_state.quad_gauss[0, i_b] = constraint_state.gauss[i_b] + constraint_state.quad_gauss[1, i_b] = sh_b[0] + constraint_state.quad_gauss[2, i_b] = sh_c[0] + + # === Phase 2: Constraint cost, parallel over n_constraints === + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + n_con = constraint_state.n_constraints[i_b] + + local_eq0 = gs.qd_float(0.0) + local_eq1 = gs.qd_float(0.0) + local_eq2 = gs.qd_float(0.0) + local_tmp0 = gs.qd_float(0.0) + + i_c = tid + while i_c < n_con: + Jaref_c = constraint_state.Jaref[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + qf_0 = D * (0.5 * Jaref_c * Jaref_c) + + if i_c < ne: + # Equality: always active, need jv for eq_sum + jv_c = constraint_state.jv[i_c, i_b] + qf_1 = D * (jv_c * Jaref_c) + qf_2 = D * (0.5 * jv_c * jv_c) + local_eq0 += qf_0 + local_eq1 += qf_1 + local_eq2 += qf_2 + local_tmp0 += qf_0 + elif i_c < nef: + # Friction: only qf_0 needed (qf_1/qf_2 not stored) + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + rf = r * f + linear_neg = Jaref_c <= -rf + linear_pos = Jaref_c >= rf + if linear_neg or linear_pos: + qf_0 = linear_neg * f * (-0.5 * rf - Jaref_c) + linear_pos * f * (-0.5 * rf + Jaref_c) + local_tmp0 += qf_0 + else: + # Contact: active if Jaref < 0 + active = Jaref_c < 0 + local_tmp0 += qf_0 * active + + i_c += _T + + # Reuse shared arrays for Phase 2 reduction + sh_a[tid] = local_eq0 + sh_b[tid] = local_eq1 + sh_c[tid] = local_eq2 + sh_d[tid] = local_tmp0 + + qd.simt.block.sync() + + # Tree reduction for 4 accumulators + stride = _T // 2 + while stride > 0: + if tid < stride: + sh_a[tid] += sh_a[tid + stride] + sh_b[tid] += sh_b[tid + stride] + sh_c[tid] += sh_c[tid + stride] + sh_d[tid] += sh_d[tid + stride] + qd.simt.block.sync() + stride //= 2 + + if tid == 0: + constraint_state.eq_sum[0, i_b] = sh_a[0] + constraint_state.eq_sum[1, i_b] = sh_b[0] + constraint_state.eq_sum[2, i_b] = sh_c[0] + constraint_state.ls_it[i_b] = 1 + constraint_state.candidates[1, i_b] = constraint_state.gauss[i_b] + sh_d[0] + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_eval( + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Evaluate K candidate alphas in parallel per env, pick the best via reduction.""" + _B = constraint_state.grad.shape[1] + _K = qd.static(LS_PARALLEL_K) + _MIN_STEP = qd.static(LS_PARALLEL_MIN_STEP) + + qd.loop_config(block_dim=_K) + for i_ in range(_B * _K): + tid = i_ % _K + i_b = i_ // _K + + # Shared memory for argmin reduction + sh_cost = qd.simt.block.SharedArray((_K,), gs.qd_float) + sh_idx = qd.simt.block.SharedArray((_K,), qd.i32) + + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + n_con = constraint_state.n_constraints[i_b] + + # Generate log-spaced alpha: alpha[0]=MIN_STEP ... alpha[K-1]=1.0 + alpha = solver._log_scale(_MIN_STEP, 1.0, _K, tid) + + # Evaluate cost at this alpha + cost = ( + alpha * alpha * constraint_state.quad_gauss[2, i_b] + + alpha * constraint_state.quad_gauss[1, i_b] + + constraint_state.quad_gauss[0, i_b] + ) + + # Equality constraints (always active) - use eq_sum precomputed during init + cost = ( + cost + + alpha * alpha * constraint_state.eq_sum[2, i_b] + + alpha * constraint_state.eq_sum[1, i_b] + + constraint_state.eq_sum[0, i_b] + ) + + # Friction constraints + for i_c in range(ne, nef): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + x = Jaref_c + alpha * jv_c + rf = r * f + linear_neg = x <= -rf + linear_pos = x >= rf + if linear_neg or linear_pos: + cost = cost + linear_neg * f * (-0.5 * rf - Jaref_c - alpha * jv_c) + cost = cost + linear_pos * f * (-0.5 * rf + Jaref_c + alpha * jv_c) + else: + cost = cost + D * 0.5 * x * x + + # Contact constraints (active if x < 0) + for i_c in range(nef, n_con): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + x = Jaref_c + alpha * jv_c + if x < 0: + cost += D * 0.5 * x * x + + sh_cost[tid] = cost + sh_idx[tid] = tid + else: + sh_cost[tid] = gs.qd_float(1e30) + sh_idx[tid] = tid + + qd.simt.block.sync() + + # Tree reduction for argmin + stride = _K // 2 + while stride > 0: + if tid < stride: + if sh_cost[tid + stride] < sh_cost[tid]: + sh_cost[tid] = sh_cost[tid + stride] + sh_idx[tid] = sh_idx[tid + stride] + qd.simt.block.sync() + stride = stride // 2 + + # Thread 0: acceptance check and write result + if tid == 0: + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + p0_cost = constraint_state.candidates[1, i_b] + best_tid = sh_idx[0] + best_cost = sh_cost[0] + best_alpha = solver._log_scale(_MIN_STEP, 1.0, _K, best_tid) + if best_cost < p0_cost: + constraint_state.candidates[0, i_b] = best_alpha + else: + constraint_state.candidates[0, i_b] = 0.0 + else: + constraint_state.candidates[0, i_b] = 0.0 + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_apply_alpha_dofs( + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Apply best alpha to qacc and Ma, parallelized over (dof, env).""" + n_dofs = constraint_state.qacc.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + alpha = constraint_state.candidates[0, i_b] + if qd.abs(alpha) < rigid_global_info.EPS[None]: + if i_d == 0: + constraint_state.improved[i_b] = False + else: + constraint_state.qacc[i_d, i_b] += constraint_state.search[i_d, i_b] * alpha + constraint_state.Ma[i_d, i_b] += constraint_state.mv[i_d, i_b] * alpha + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_parallel_linesearch_apply_alpha_constraints( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Apply best alpha to Jaref, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]: + alpha = constraint_state.candidates[0, i_b] + constraint_state.Jaref[i_c, i_b] += constraint_state.jv[i_c, i_b] * alpha + + +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_cg_only_save_prev_grad( constraint_state: array_class.ConstraintState, - static_rigid_sim_config: ti.template(), + static_rigid_sim_config: qd.template(), ): """Save prev_grad and prev_Mgrad (CG only)""" _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_save_prev_grad(i_b, constraint_state=constraint_state) -@ti.kernel(fastcache=gs.use_fastcache) +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_update_constraint( dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, - static_rigid_sim_config: ti.template(), + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), ): _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_update_constraint_batch( @@ -63,15 +405,122 @@ def _kernel_update_constraint( ) -@ti.kernel(fastcache=gs.use_fastcache) +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_update_constraint_forces( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute active flags and efc_force, parallelized over (constraint, env).""" + len_constraints = constraint_state.active.shape[0] + _B = constraint_state.grad.shape[1] + + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]: + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + + if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): + constraint_state.prev_active[i_c, i_b] = constraint_state.active[i_c, i_b] + + constraint_state.active[i_c, i_b] = True + floss_force = gs.qd_float(0.0) + + if ne <= i_c and i_c < nef: + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + rf = r * f + linear_neg = constraint_state.Jaref[i_c, i_b] <= -rf + linear_pos = constraint_state.Jaref[i_c, i_b] >= rf + constraint_state.active[i_c, i_b] = not (linear_neg or linear_pos) + floss_force = linear_neg * f + linear_pos * -f + elif nef <= i_c: + constraint_state.active[i_c, i_b] = constraint_state.Jaref[i_c, i_b] < 0 + + constraint_state.efc_force[i_c, i_b] = floss_force + ( + -constraint_state.Jaref[i_c, i_b] * constraint_state.efc_D[i_c, i_b] * constraint_state.active[i_c, i_b] + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_update_constraint_qfrc( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute qfrc_constraint = J^T @ efc_force, parallelized over (dof, env).""" + n_dofs = constraint_state.qfrc_constraint.shape[0] + _B = constraint_state.grad.shape[1] + + for i_d, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + n_con = constraint_state.n_constraints[i_b] + qfrc = gs.qd_float(0.0) + for i_c in range(n_con): + qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b] + constraint_state.qfrc_constraint[i_d, i_b] = qfrc + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_update_constraint_cost( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute gauss and cost (reductions over dofs and constraints). One thread per env.""" + _B = constraint_state.grad.shape[1] + + qd.loop_config(block_dim=32) + for i_b in range(_B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + n_dofs = constraint_state.qfrc_constraint.shape[0] + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + n_con = constraint_state.n_constraints[i_b] + + constraint_state.prev_cost[i_b] = constraint_state.cost[i_b] + + cost_i = gs.qd_float(0.0) + gauss_i = gs.qd_float(0.0) + + # Gauss cost from dofs + for i_d in range(n_dofs): + v = ( + 0.5 + * (constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b]) + * (constraint_state.qacc[i_d, i_b] - dofs_state.acc_smooth[i_d, i_b]) + ) + gauss_i += v + cost_i += v + + # Constraint cost: quadratic + friction linear + for i_c in range(n_con): + cost_i += 0.5 * ( + constraint_state.Jaref[i_c, i_b] ** 2 + * constraint_state.efc_D[i_c, i_b] + * constraint_state.active[i_c, i_b] + ) + if ne <= i_c and i_c < nef: + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + rf = r * f + linear_neg = constraint_state.Jaref[i_c, i_b] <= -rf + linear_pos = constraint_state.Jaref[i_c, i_b] >= rf + cost_i += linear_neg * f * (-0.5 * rf - constraint_state.Jaref[i_c, i_b]) + linear_pos * f * ( + -0.5 * rf + constraint_state.Jaref[i_c, i_b] + ) + + constraint_state.gauss[i_b] = gauss_i + constraint_state.cost[i_b] = cost_i + + +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_newton_only_nt_hessian( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), + static_rigid_sim_config: qd.template(), ): """Step 4: Newton Hessian update (Newton only)""" solver.func_hessian_direct_tiled(constraint_state=constraint_state, rigid_global_info=rigid_global_info) - if ti.static(static_rigid_sim_config.enable_tiled_cholesky_hessian): + if qd.static(static_rigid_sim_config.enable_tiled_cholesky_hessian): solver.func_cholesky_factor_direct_tiled( constraint_state=constraint_state, rigid_global_info=rigid_global_info, @@ -79,7 +528,7 @@ def _kernel_newton_only_nt_hessian( ) else: _B = constraint_state.jac.shape[2] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_cholesky_factor_direct_batch( @@ -87,17 +536,17 @@ def _kernel_newton_only_nt_hessian( ) -@ti.kernel(fastcache=gs.use_fastcache) +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_update_gradient( entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), + static_rigid_sim_config: qd.template(), ): """Step 5: Update gradient""" _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_update_gradient_batch( @@ -110,15 +559,15 @@ def _kernel_update_gradient( ) -@ti.kernel(fastcache=gs.use_fastcache) +@qd.kernel(fastcache=gs.use_fastcache) def _kernel_update_search_direction( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), + static_rigid_sim_config: qd.template(), ): """Step 6: Check convergence and update search direction""" _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: solver.func_terminate_or_update_descent_batch( @@ -129,9 +578,151 @@ def _kernel_update_search_direction( ) +# ================================================ Init kernels ================================================ + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_warmstart( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Select qacc from warmstart or acc_smooth, parallelized over (dof, env).""" + n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Ma( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Compute Ma = M @ qacc, parallelized over (dof, env).""" + solver.initialize_Ma( + Ma=constraint_state.Ma, + qacc=constraint_state.qacc, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Jaref( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + n_dofs = constraint_state.jac.shape[1] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + Jaref = -constraint_state.aref[i_c, i_b] + if qd.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + else: + for i_d in range(n_dofs): + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + constraint_state.Jaref[i_c, i_b] = Jaref + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_improved( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set improved = (n_constraints > 0) for each env.""" + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_search( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set search = -Mgrad, parallelized over (dof, env).""" + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +def func_solve_init_decomposed( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + """ + Decomposed version of func_solve_init for CUDA backend (non-mujoco path). + + Breaks the monolithic init kernel into separate kernel launches: + 1. Warmstart selection (ndrange over dofs) + 2. Ma = M @ qacc (ndrange over dofs with entity lookup) + 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) + 4. Set improved flags + 5. Update constraint (forces / qfrc / cost — reuse decomposed kernels) + 6. Newton hessian (Newton only — reuse existing kernel) + 7. Update gradient (reuse existing kernel) + 8. search = -Mgrad (ndrange over dofs) + """ + # 1. Warmstart selection + _kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + + # 2. Ma = M @ qacc + _kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 3. Jaref = -aref + J @ qacc (parallelized over constraints) + _kernel_init_Jaref(constraint_state, static_rigid_sim_config) + + # 4. Set improved flags (needed by decomposed update_constraint kernels) + _kernel_init_improved(constraint_state, static_rigid_sim_config) + + # 5. Update constraint (reuse decomposed kernels) + _kernel_update_constraint_forces(constraint_state, static_rigid_sim_config) + _kernel_update_constraint_qfrc(constraint_state, static_rigid_sim_config) + _kernel_update_constraint_cost(dofs_state, constraint_state, static_rigid_sim_config) + + # 6. Newton hessian (Newton only) + if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: + _kernel_newton_only_nt_hessian(entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 7. Update gradient + _kernel_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 8. search = -Mgrad + _kernel_init_search(constraint_state, static_rigid_sim_config) + + @solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) def func_solve_decomposed( entities_info, + dofs_info, dofs_state, constraint_state, rigid_global_info, @@ -145,23 +736,57 @@ def func_solve_decomposed( """ iterations = rigid_global_info.iterations[None] for _it in range(iterations): - _kernel_linesearch( + _kernel_parallel_linesearch_mv( + dofs_info, entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, + ) + _kernel_parallel_linesearch_jv( + constraint_state, + static_rigid_sim_config, + ) + _kernel_parallel_linesearch_p0( dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config, ) + _kernel_parallel_linesearch_eval( + constraint_state, + rigid_global_info, + static_rigid_sim_config, + ) + _kernel_parallel_linesearch_apply_alpha_dofs( + constraint_state, + rigid_global_info, + static_rigid_sim_config, + ) + _kernel_parallel_linesearch_apply_alpha_constraints( + constraint_state, + static_rigid_sim_config, + ) if static_rigid_sim_config.solver_type == gs.constraint_solver.CG: _kernel_cg_only_save_prev_grad( constraint_state, static_rigid_sim_config, ) - _kernel_update_constraint( + + _kernel_update_constraint_forces( + constraint_state, + static_rigid_sim_config, + ) + _kernel_update_constraint_qfrc( + constraint_state, + static_rigid_sim_config, + ) + _kernel_update_constraint_cost( dofs_state, constraint_state, static_rigid_sim_config, ) + if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: _kernel_newton_only_nt_hessian( constraint_state, From ec625a58b68c6e9b014b7b18c017bf7d1015a47b Mon Sep 17 00:00:00 2001 From: Mingrui Date: Sun, 22 Feb 2026 01:49:07 +0000 Subject: [PATCH 02/11] update --- genesis/engine/solvers/rigid/constraint/solver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index f1ca0b23e8..4239eff9d3 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -3045,6 +3045,7 @@ def func_solve_body( @qd.kernel(fastcache=gs.use_fastcache) def func_solve_body_monolith( entities_info: array_class.EntitiesInfo, + dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, From 0d43c2b2f72fc96b19dd9603a1b68fd353f3ddb5 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Sun, 22 Feb 2026 02:22:38 +0000 Subject: [PATCH 03/11] fix test --- tests/test_grad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_grad.py b/tests/test_grad.py index eff264f5c6..9bdb000b26 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -263,6 +263,7 @@ def constraint_solver_resolve(): ) func_solve_body( entities_info=rigid_solver.entities_info, + dofs_info=rigid_solver.dofs_info, dofs_state=rigid_solver.dofs_state, constraint_state=constraint_solver.constraint_state, rigid_global_info=rigid_solver._rigid_global_info, From 6e0b038727310d7a33179b3522e8385be296fd28 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 3 Mar 2026 19:37:17 +0000 Subject: [PATCH 04/11] recover init --- .../rigid/constraint/solver_breakdown.py | 100 +++++++++++++----- 1 file changed, 76 insertions(+), 24 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index cd4b21fa17..ef27bdeb04 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -4,8 +4,9 @@ import genesis.utils.array_class as array_class from genesis.engine.solvers.rigid.constraint import solver -LS_PARALLEL_K = 8 +LS_PARALLEL_K = 16 LS_PARALLEL_MIN_STEP = 1e-6 +LS_PARALLEL_N_REFINE = 1 # number of successive refinement passes in parallel linesearch _P0_BLOCK = 32 _JV_BLOCK = 32 @@ -106,11 +107,13 @@ def _kernel_parallel_linesearch_p0( tid = i_ % _T i_b = i_ // _T - # 4 shared arrays for parallel reductions (reused across phases) + # 6 shared arrays for parallel reductions (reused across phases) sh_a = qd.simt.block.SharedArray((_T,), gs.qd_float) sh_b = qd.simt.block.SharedArray((_T,), gs.qd_float) sh_c = qd.simt.block.SharedArray((_T,), gs.qd_float) sh_d = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_e = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_f = qd.simt.block.SharedArray((_T,), gs.qd_float) if constraint_state.n_constraints[i_b] > 0: n_dofs = constraint_state.search.shape[0] @@ -170,24 +173,28 @@ def _kernel_parallel_linesearch_p0( local_eq1 = gs.qd_float(0.0) local_eq2 = gs.qd_float(0.0) local_tmp0 = gs.qd_float(0.0) + local_total_1 = gs.qd_float(0.0) # full gradient at alpha=0 + local_total_2 = gs.qd_float(0.0) # full hessian/2 at alpha=0 i_c = tid while i_c < n_con: Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] D = constraint_state.efc_D[i_c, i_b] qf_0 = D * (0.5 * Jaref_c * Jaref_c) + qf_1 = D * (jv_c * Jaref_c) + qf_2 = D * (0.5 * jv_c * jv_c) if i_c < ne: - # Equality: always active, need jv for eq_sum - jv_c = constraint_state.jv[i_c, i_b] - qf_1 = D * (jv_c * Jaref_c) - qf_2 = D * (0.5 * jv_c * jv_c) + # Equality: always active local_eq0 += qf_0 local_eq1 += qf_1 local_eq2 += qf_2 local_tmp0 += qf_0 + local_total_1 += qf_1 + local_total_2 += qf_2 elif i_c < nef: - # Friction: only qf_0 needed (qf_1/qf_2 not stored) + # Friction: check linear regime at alpha=0 f = constraint_state.efc_frictionloss[i_c, i_b] r = constraint_state.diag[i_c, i_b] rf = r * f @@ -195,11 +202,17 @@ def _kernel_parallel_linesearch_p0( linear_pos = Jaref_c >= rf if linear_neg or linear_pos: qf_0 = linear_neg * f * (-0.5 * rf - Jaref_c) + linear_pos * f * (-0.5 * rf + Jaref_c) + qf_1 = linear_neg * (-f * jv_c) + linear_pos * (f * jv_c) + qf_2 = 0.0 local_tmp0 += qf_0 + local_total_1 += qf_1 + local_total_2 += qf_2 else: # Contact: active if Jaref < 0 active = Jaref_c < 0 local_tmp0 += qf_0 * active + local_total_1 += qf_1 * active + local_total_2 += qf_2 * active i_c += _T @@ -208,10 +221,12 @@ def _kernel_parallel_linesearch_p0( sh_b[tid] = local_eq1 sh_c[tid] = local_eq2 sh_d[tid] = local_tmp0 + sh_e[tid] = local_total_1 + sh_f[tid] = local_total_2 qd.simt.block.sync() - # Tree reduction for 4 accumulators + # Tree reduction for 6 accumulators stride = _T // 2 while stride > 0: if tid < stride: @@ -219,6 +234,8 @@ def _kernel_parallel_linesearch_p0( sh_b[tid] += sh_b[tid + stride] sh_c[tid] += sh_c[tid + stride] sh_d[tid] += sh_d[tid + stride] + sh_e[tid] += sh_e[tid + stride] + sh_f[tid] += sh_f[tid + stride] qd.simt.block.sync() stride //= 2 @@ -228,6 +245,21 @@ def _kernel_parallel_linesearch_p0( constraint_state.eq_sum[2, i_b] = sh_c[0] constraint_state.ls_it[i_b] = 1 constraint_state.candidates[1, i_b] = constraint_state.gauss[i_b] + sh_d[0] + # Initialize best alpha, search range, and best-cost tracker for parallel linesearch + constraint_state.candidates[0, i_b] = 0.0 # default: no step + + # Use full Newton step (DOF + all constraints) as the range center. + # sh_e[0] = total constraint gradient, sh_f[0] = total constraint hess/2 + total_hess = 2.0 * (constraint_state.quad_gauss[2, i_b] + sh_f[0]) + if total_hess > 0.0: + total_grad = constraint_state.quad_gauss[1, i_b] + sh_e[0] + alpha_newton = qd.max(qd.abs(total_grad / total_hess), gs.qd_float(LS_PARALLEL_MIN_STEP)) + constraint_state.candidates[2, i_b] = alpha_newton * 1e-2 + constraint_state.candidates[3, i_b] = alpha_newton * 1e2 + else: + constraint_state.candidates[2, i_b] = 1e-6 + constraint_state.candidates[3, i_b] = 1e2 + constraint_state.candidates[4, i_b] = gs.qd_float(1e30) # best cost across passes @qd.kernel(fastcache=gs.use_fastcache) @@ -236,10 +268,13 @@ def _kernel_parallel_linesearch_eval( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), ): - """Evaluate K candidate alphas in parallel per env, pick the best via reduction.""" + """Evaluate K candidate alphas in parallel per env, pick the best via reduction. + + Reads the search range from candidates[2] (lo) and candidates[3] (hi). + Writes narrowed range back to candidates[2,3] for successive refinement. + """ _B = constraint_state.grad.shape[1] _K = qd.static(LS_PARALLEL_K) - _MIN_STEP = qd.static(LS_PARALLEL_MIN_STEP) qd.loop_config(block_dim=_K) for i_ in range(_B * _K): @@ -255,8 +290,11 @@ def _kernel_parallel_linesearch_eval( nef = ne + constraint_state.n_constraints_frictionloss[i_b] n_con = constraint_state.n_constraints[i_b] - # Generate log-spaced alpha: alpha[0]=MIN_STEP ... alpha[K-1]=1.0 - alpha = solver._log_scale(_MIN_STEP, 1.0, _K, tid) + lo = constraint_state.candidates[2, i_b] + hi = constraint_state.candidates[3, i_b] + + # Generate log-spaced alpha within [lo, hi] + alpha = solver._log_scale(lo, hi, _K, tid) # Evaluate cost at this alpha cost = ( @@ -317,17 +355,27 @@ def _kernel_parallel_linesearch_eval( qd.simt.block.sync() stride = stride // 2 - # Thread 0: acceptance check and write result + # Thread 0: acceptance check, write result, and narrow range for next pass if tid == 0: if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: p0_cost = constraint_state.candidates[1, i_b] best_tid = sh_idx[0] best_cost = sh_cost[0] - best_alpha = solver._log_scale(_MIN_STEP, 1.0, _K, best_tid) - if best_cost < p0_cost: + lo = constraint_state.candidates[2, i_b] + hi = constraint_state.candidates[3, i_b] + best_alpha = solver._log_scale(lo, hi, _K, best_tid) + + # Only update best alpha if this pass improved over ALL previous passes + best_cost_prev = constraint_state.candidates[4, i_b] + if best_cost < p0_cost and best_cost < best_cost_prev: constraint_state.candidates[0, i_b] = best_alpha - else: - constraint_state.candidates[0, i_b] = 0.0 + constraint_state.candidates[4, i_b] = best_cost + + # Narrow range around accepted point for next refinement pass + lo_idx = qd.max(0, best_tid - 1) + hi_idx = qd.min(_K - 1, best_tid + 1) + constraint_state.candidates[2, i_b] = solver._log_scale(lo, hi, _K, lo_idx) + constraint_state.candidates[3, i_b] = solver._log_scale(lo, hi, _K, hi_idx) else: constraint_state.candidates[0, i_b] = 0.0 @@ -669,7 +717,10 @@ def _kernel_init_search( constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] -@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +# NOTE: decomposed init disabled — causes non-deterministic results on CUDA due to inter-kernel data races +# when multiple @qd.kernel functions write/read shared state (qacc, Ma, Jaref) without synchronization. +# The monolith init (single kernel) is used instead. See test_box_box_dynamics[gpu-implicitfast-Newton]. +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: False) def func_solve_init_decomposed( dofs_info, dofs_state, @@ -710,7 +761,7 @@ def func_solve_init_decomposed( # 6. Newton hessian (Newton only) if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: - _kernel_newton_only_nt_hessian(entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) # 7. Update gradient _kernel_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config) @@ -753,11 +804,12 @@ def func_solve_decomposed( rigid_global_info, static_rigid_sim_config, ) - _kernel_parallel_linesearch_eval( - constraint_state, - rigid_global_info, - static_rigid_sim_config, - ) + for _refine in range(LS_PARALLEL_N_REFINE): + _kernel_parallel_linesearch_eval( + constraint_state, + rigid_global_info, + static_rigid_sim_config, + ) _kernel_parallel_linesearch_apply_alpha_dofs( constraint_state, rigid_global_info, From e8422e95024be9ac206b04356e05e73d53400b52 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 3 Mar 2026 19:39:11 +0000 Subject: [PATCH 05/11] better initial alpha with no init --- genesis/engine/solvers/rigid/constraint/solver_breakdown.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ef27bdeb04..ab6085172b 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -717,7 +717,7 @@ def _kernel_init_search( constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] -# NOTE: decomposed init disabled — causes non-deterministic results on CUDA due to inter-kernel data races +# FIXME: decomposed init disabled — causes non-deterministic results on CUDA due to inter-kernel data races # when multiple @qd.kernel functions write/read shared state (qacc, Ma, Jaref) without synchronization. # The monolith init (single kernel) is used instead. See test_box_box_dynamics[gpu-implicitfast-Newton]. @solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: False) From 07768c3d5f08d789d2054f3bbb85594593e8a041 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 3 Mar 2026 21:49:24 +0000 Subject: [PATCH 06/11] enforce test grad not using cache --- tests/test_grad.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_grad.py b/tests/test_grad.py index 9bdb000b26..fed367bbaa 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -220,6 +220,8 @@ def compute_dL_error(dL_dx, x_type): # stable way. @pytest.mark.required @pytest.mark.precision("64") +@pytest.mark.disable_cache(True) +@pytest.mark.parametrize("backend", [gs.cpu]) def test_diff_solver(monkeypatch): from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 From d621552328ba0d97789c83f6eb4be4b523fdbf81 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 3 Mar 2026 23:12:08 +0000 Subject: [PATCH 07/11] grad computation not using parallel search --- .../engine/solvers/rigid/constraint/solver_breakdown.py | 7 ++++++- tests/test_grad.py | 2 -- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ab6085172b..ea50a76166 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -770,7 +770,12 @@ def func_solve_init_decomposed( _kernel_init_search(constraint_state, static_rigid_sim_config) -@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +@solver.func_solve_body.register( + is_compatible=lambda *args, **kwargs: ( + gs.backend in {gs.cuda} + and not (args[5] if len(args) > 5 else kwargs["static_rigid_sim_config"]).requires_grad + ) +) def func_solve_decomposed( entities_info, dofs_info, diff --git a/tests/test_grad.py b/tests/test_grad.py index fed367bbaa..9bdb000b26 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -220,8 +220,6 @@ def compute_dL_error(dL_dx, x_type): # stable way. @pytest.mark.required @pytest.mark.precision("64") -@pytest.mark.disable_cache(True) -@pytest.mark.parametrize("backend", [gs.cpu]) def test_diff_solver(monkeypatch): from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 From 24c9f81ab9f79900a5f0f872b98392069c430f11 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 3 Mar 2026 23:14:02 +0000 Subject: [PATCH 08/11] update format --- genesis/engine/solvers/rigid/constraint/solver_breakdown.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ea50a76166..cb1ad27210 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -772,8 +772,7 @@ def func_solve_init_decomposed( @solver.func_solve_body.register( is_compatible=lambda *args, **kwargs: ( - gs.backend in {gs.cuda} - and not (args[5] if len(args) > 5 else kwargs["static_rigid_sim_config"]).requires_grad + gs.backend in {gs.cuda} and not (args[5] if len(args) > 5 else kwargs["static_rigid_sim_config"]).requires_grad ) ) def func_solve_decomposed( From 437d03787e891744b9ad2375d53ce86e4fdfbcf5 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Thu, 5 Mar 2026 00:24:52 +0000 Subject: [PATCH 09/11] disable perf dispatch re-benchmarking for decomp --- genesis/engine/solvers/rigid/constraint/solver_breakdown.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index cb1ad27210..f2bf54b793 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -772,6 +772,8 @@ def func_solve_init_decomposed( @solver.func_solve_body.register( is_compatible=lambda *args, **kwargs: ( + # Note: we do not use parallel linesearch for finite difference gradient validation, as it is highly + # sensitive to numerical precision and GPU float64 rounding errors can accumulate over many trials. gs.backend in {gs.cuda} and not (args[5] if len(args) > 5 else kwargs["static_rigid_sim_config"]).requires_grad ) ) From 289063566124b0a1838ac9ce0837a06097794a39 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Thu, 5 Mar 2026 03:11:45 +0000 Subject: [PATCH 10/11] fix implicit cpu gpu sync --- genesis/engine/solvers/rigid/constraint/solver.py | 4 ++-- genesis/engine/solvers/rigid/constraint/solver_breakdown.py | 2 +- genesis/engine/solvers/rigid/rigid_solver.py | 1 + genesis/utils/array_class.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 4239eff9d3..3c85cca2ba 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2808,7 +2808,7 @@ def initialize_Ma( @qd.perf_dispatch( - get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=1.0 + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 ) def func_solve_init( dofs_info: array_class.DofsInfo, @@ -3029,7 +3029,7 @@ def func_solve_iter( @qd.perf_dispatch( - get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=1.0 + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 ) def func_solve_body( entities_info: array_class.EntitiesInfo, diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index f2bf54b793..d8f182b440 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -791,7 +791,7 @@ def func_solve_decomposed( This maximizes kernel granularity, potentially allowing better GPU scheduling and more flexibility in execution, at the cost of more Python→C++ boundary crossings. """ - iterations = rigid_global_info.iterations[None] + iterations = static_rigid_sim_config.iterations for _it in range(iterations): _kernel_parallel_linesearch_mv( dofs_info, diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 3558bc69c8..46e48307db 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -415,6 +415,7 @@ def build(self): sparse_solve=self._options.sparse_solve, integrator=self._integrator, solver_type=self._options.constraint_solver, + iterations=self._options.iterations, ) if self.is_active: diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 89ae88d283..e87ca200a0 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -1876,6 +1876,7 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta): n_entities: int = -1 n_links: int = -1 n_geoms: int = -1 + iterations: int = -1 # =========================================== DataManager =========================================== From bdb89f2d27b36a44e2b4195822bd1d1f3c0e18b3 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Thu, 5 Mar 2026 05:17:35 +0000 Subject: [PATCH 11/11] clean --- .../rigid/constraint/solver_breakdown.py | 209 ++++++++++-------- 1 file changed, 111 insertions(+), 98 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index d8f182b440..a0091f9904 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -11,30 +11,6 @@ _JV_BLOCK = 32 -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_linesearch( - entities_info: array_class.EntitiesInfo, - dofs_state: array_class.DofsState, - constraint_state: array_class.ConstraintState, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: qd.template(), -): - _B = constraint_state.grad.shape[1] - qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) - for i_b in range(_B): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - solver.func_linesearch_and_apply_alpha( - i_b, - entities_info=entities_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) - else: - constraint_state.improved[i_b] = False - - @qd.kernel(fastcache=gs.use_fastcache) def _kernel_parallel_linesearch_mv( dofs_info: array_class.DofsInfo, @@ -103,37 +79,37 @@ def _kernel_parallel_linesearch_p0( _T = qd.static(_P0_BLOCK) qd.loop_config(block_dim=_T) - for i_ in range(_B * _T): - tid = i_ % _T - i_b = i_ // _T + for i_flat in range(_B * _T): + tid = i_flat % _T + i_b = i_flat // _T # 6 shared arrays for parallel reductions (reused across phases) - sh_a = qd.simt.block.SharedArray((_T,), gs.qd_float) - sh_b = qd.simt.block.SharedArray((_T,), gs.qd_float) - sh_c = qd.simt.block.SharedArray((_T,), gs.qd_float) - sh_d = qd.simt.block.SharedArray((_T,), gs.qd_float) - sh_e = qd.simt.block.SharedArray((_T,), gs.qd_float) - sh_f = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_snorm_sq = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_qg_grad = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_qg_hess = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_p0_cost = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_constraint_grad = qd.simt.block.SharedArray((_T,), gs.qd_float) + sh_constraint_hess = qd.simt.block.SharedArray((_T,), gs.qd_float) if constraint_state.n_constraints[i_b] > 0: n_dofs = constraint_state.search.shape[0] # === Phase 1: Fused snorm + quad_gauss, parallel over n_dofs === local_snorm_sq = gs.qd_float(0.0) - local_qg1 = gs.qd_float(0.0) - local_qg2 = gs.qd_float(0.0) + local_qg_grad = gs.qd_float(0.0) + local_qg_hess = gs.qd_float(0.0) i_d = tid while i_d < n_dofs: s = constraint_state.search[i_d, i_b] local_snorm_sq += s * s - local_qg1 += s * constraint_state.Ma[i_d, i_b] - s * dofs_state.force[i_d, i_b] - local_qg2 += 0.5 * s * constraint_state.mv[i_d, i_b] + local_qg_grad += s * constraint_state.Ma[i_d, i_b] - s * dofs_state.force[i_d, i_b] + local_qg_hess += 0.5 * s * constraint_state.mv[i_d, i_b] i_d += _T - sh_a[tid] = local_snorm_sq - sh_b[tid] = local_qg1 - sh_c[tid] = local_qg2 + sh_snorm_sq[tid] = local_snorm_sq + sh_qg_grad[tid] = local_qg_grad + sh_qg_hess[tid] = local_qg_hess qd.simt.block.sync() @@ -141,14 +117,14 @@ def _kernel_parallel_linesearch_p0( stride = _T // 2 while stride > 0: if tid < stride: - sh_a[tid] += sh_a[tid + stride] - sh_b[tid] += sh_b[tid + stride] - sh_c[tid] += sh_c[tid + stride] + sh_snorm_sq[tid] += sh_snorm_sq[tid + stride] + sh_qg_grad[tid] += sh_qg_grad[tid + stride] + sh_qg_hess[tid] += sh_qg_hess[tid + stride] qd.simt.block.sync() stride //= 2 # All threads read the reduced snorm - snorm = qd.sqrt(sh_a[0]) + snorm = qd.sqrt(sh_snorm_sq[0]) if snorm < rigid_global_info.EPS[None]: # Converged — only thread 0 writes @@ -161,20 +137,20 @@ def _kernel_parallel_linesearch_p0( if tid == 0: constraint_state.improved[i_b] = True constraint_state.quad_gauss[0, i_b] = constraint_state.gauss[i_b] - constraint_state.quad_gauss[1, i_b] = sh_b[0] - constraint_state.quad_gauss[2, i_b] = sh_c[0] + constraint_state.quad_gauss[1, i_b] = sh_qg_grad[0] + constraint_state.quad_gauss[2, i_b] = sh_qg_hess[0] # === Phase 2: Constraint cost, parallel over n_constraints === ne = constraint_state.n_constraints_equality[i_b] nef = ne + constraint_state.n_constraints_frictionloss[i_b] n_con = constraint_state.n_constraints[i_b] - local_eq0 = gs.qd_float(0.0) - local_eq1 = gs.qd_float(0.0) - local_eq2 = gs.qd_float(0.0) - local_tmp0 = gs.qd_float(0.0) - local_total_1 = gs.qd_float(0.0) # full gradient at alpha=0 - local_total_2 = gs.qd_float(0.0) # full hessian/2 at alpha=0 + local_eq_cost = gs.qd_float(0.0) + local_eq_grad = gs.qd_float(0.0) + local_eq_hess = gs.qd_float(0.0) + local_p0_cost = gs.qd_float(0.0) + local_constraint_grad = gs.qd_float(0.0) + local_constraint_hess = gs.qd_float(0.0) i_c = tid while i_c < n_con: @@ -187,12 +163,12 @@ def _kernel_parallel_linesearch_p0( if i_c < ne: # Equality: always active - local_eq0 += qf_0 - local_eq1 += qf_1 - local_eq2 += qf_2 - local_tmp0 += qf_0 - local_total_1 += qf_1 - local_total_2 += qf_2 + local_eq_cost += qf_0 + local_eq_grad += qf_1 + local_eq_hess += qf_2 + local_p0_cost += qf_0 + local_constraint_grad += qf_1 + local_constraint_hess += qf_2 elif i_c < nef: # Friction: check linear regime at alpha=0 f = constraint_state.efc_frictionloss[i_c, i_b] @@ -204,25 +180,25 @@ def _kernel_parallel_linesearch_p0( qf_0 = linear_neg * f * (-0.5 * rf - Jaref_c) + linear_pos * f * (-0.5 * rf + Jaref_c) qf_1 = linear_neg * (-f * jv_c) + linear_pos * (f * jv_c) qf_2 = 0.0 - local_tmp0 += qf_0 - local_total_1 += qf_1 - local_total_2 += qf_2 + local_p0_cost += qf_0 + local_constraint_grad += qf_1 + local_constraint_hess += qf_2 else: # Contact: active if Jaref < 0 active = Jaref_c < 0 - local_tmp0 += qf_0 * active - local_total_1 += qf_1 * active - local_total_2 += qf_2 * active + local_p0_cost += qf_0 * active + local_constraint_grad += qf_1 * active + local_constraint_hess += qf_2 * active i_c += _T # Reuse shared arrays for Phase 2 reduction - sh_a[tid] = local_eq0 - sh_b[tid] = local_eq1 - sh_c[tid] = local_eq2 - sh_d[tid] = local_tmp0 - sh_e[tid] = local_total_1 - sh_f[tid] = local_total_2 + sh_snorm_sq[tid] = local_eq_cost + sh_qg_grad[tid] = local_eq_grad + sh_qg_hess[tid] = local_eq_hess + sh_p0_cost[tid] = local_p0_cost + sh_constraint_grad[tid] = local_constraint_grad + sh_constraint_hess[tid] = local_constraint_hess qd.simt.block.sync() @@ -230,29 +206,28 @@ def _kernel_parallel_linesearch_p0( stride = _T // 2 while stride > 0: if tid < stride: - sh_a[tid] += sh_a[tid + stride] - sh_b[tid] += sh_b[tid + stride] - sh_c[tid] += sh_c[tid + stride] - sh_d[tid] += sh_d[tid + stride] - sh_e[tid] += sh_e[tid + stride] - sh_f[tid] += sh_f[tid + stride] + sh_snorm_sq[tid] += sh_snorm_sq[tid + stride] + sh_qg_grad[tid] += sh_qg_grad[tid + stride] + sh_qg_hess[tid] += sh_qg_hess[tid + stride] + sh_p0_cost[tid] += sh_p0_cost[tid + stride] + sh_constraint_grad[tid] += sh_constraint_grad[tid + stride] + sh_constraint_hess[tid] += sh_constraint_hess[tid + stride] qd.simt.block.sync() stride //= 2 if tid == 0: - constraint_state.eq_sum[0, i_b] = sh_a[0] - constraint_state.eq_sum[1, i_b] = sh_b[0] - constraint_state.eq_sum[2, i_b] = sh_c[0] + constraint_state.eq_sum[0, i_b] = sh_snorm_sq[0] + constraint_state.eq_sum[1, i_b] = sh_qg_grad[0] + constraint_state.eq_sum[2, i_b] = sh_qg_hess[0] constraint_state.ls_it[i_b] = 1 - constraint_state.candidates[1, i_b] = constraint_state.gauss[i_b] + sh_d[0] + constraint_state.candidates[1, i_b] = constraint_state.gauss[i_b] + sh_p0_cost[0] # Initialize best alpha, search range, and best-cost tracker for parallel linesearch constraint_state.candidates[0, i_b] = 0.0 # default: no step # Use full Newton step (DOF + all constraints) as the range center. - # sh_e[0] = total constraint gradient, sh_f[0] = total constraint hess/2 - total_hess = 2.0 * (constraint_state.quad_gauss[2, i_b] + sh_f[0]) + total_hess = 2.0 * (constraint_state.quad_gauss[2, i_b] + sh_constraint_hess[0]) if total_hess > 0.0: - total_grad = constraint_state.quad_gauss[1, i_b] + sh_e[0] + total_grad = constraint_state.quad_gauss[1, i_b] + sh_constraint_grad[0] alpha_newton = qd.max(qd.abs(total_grad / total_hess), gs.qd_float(LS_PARALLEL_MIN_STEP)) constraint_state.candidates[2, i_b] = alpha_newton * 1e-2 constraint_state.candidates[3, i_b] = alpha_newton * 1e2 @@ -277,9 +252,9 @@ def _kernel_parallel_linesearch_eval( _K = qd.static(LS_PARALLEL_K) qd.loop_config(block_dim=_K) - for i_ in range(_B * _K): - tid = i_ % _K - i_b = i_ // _K + for i_flat in range(_B * _K): + tid = i_flat % _K + i_b = i_flat // _K # Shared memory for argmin reduction sh_cost = qd.simt.block.SharedArray((_K,), gs.qd_float) @@ -717,10 +692,42 @@ def _kernel_init_search( constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] -# FIXME: decomposed init disabled — causes non-deterministic results on CUDA due to inter-kernel data races -# when multiple @qd.kernel functions write/read shared state (qacc, Ma, Jaref) without synchronization. -# The monolith init (single kernel) is used instead. See test_box_box_dynamics[gpu-implicitfast-Newton]. -@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: False) +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_constraint( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Init-only constraint update — wraps monolith's func_update_constraint for exact FP match.""" + solver.func_update_constraint( + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_gradient( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU).""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) def func_solve_init_decomposed( dofs_info, dofs_state, @@ -737,9 +744,9 @@ def func_solve_init_decomposed( 2. Ma = M @ qacc (ndrange over dofs with entity lookup) 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) 4. Set improved flags - 5. Update constraint (forces / qfrc / cost — reuse decomposed kernels) + 5. Update constraint (wraps monolith's func_update_constraint for exact FP match) 6. Newton hessian (Newton only — reuse existing kernel) - 7. Update gradient (reuse existing kernel) + 7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU) 8. search = -Mgrad (ndrange over dofs) """ # 1. Warmstart selection @@ -754,17 +761,17 @@ def func_solve_init_decomposed( # 4. Set improved flags (needed by decomposed update_constraint kernels) _kernel_init_improved(constraint_state, static_rigid_sim_config) - # 5. Update constraint (reuse decomposed kernels) - _kernel_update_constraint_forces(constraint_state, static_rigid_sim_config) - _kernel_update_constraint_qfrc(constraint_state, static_rigid_sim_config) - _kernel_update_constraint_cost(dofs_state, constraint_state, static_rigid_sim_config) + # 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match) + _kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) # 6. Newton hessian (Newton only) if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) - # 7. Update gradient - _kernel_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config) + # 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU) + _kernel_init_update_gradient( + entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config + ) # 8. search = -Mgrad _kernel_init_search(constraint_state, static_rigid_sim_config) @@ -791,6 +798,9 @@ def func_solve_decomposed( This maximizes kernel granularity, potentially allowing better GPU scheduling and more flexibility in execution, at the cost of more Python→C++ boundary crossings. """ + # Read iterations from the Python-side static config instead of rigid_global_info.iterations[None]. + # The [None] read triggers an implicit GPU→CPU sync that drains the GPU command queue, destroying + # async kernel pipelining and causing ~2x slowdown on contact-heavy benchmarks. iterations = static_rigid_sim_config.iterations for _it in range(iterations): _kernel_parallel_linesearch_mv( @@ -810,6 +820,9 @@ def func_solve_decomposed( rigid_global_info, static_rigid_sim_config, ) + # Successive refinement: each pass narrows the search range around the best alpha. + # Currently N_REFINE=1 is sufficient — no benchmark has shown improvement with more + # passes. The loop is kept as an interface for potential future cases. for _refine in range(LS_PARALLEL_N_REFINE): _kernel_parallel_linesearch_eval( constraint_state,