Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,14 +2834,28 @@ def initialize_Ma(
# ======================================================= Core ========================================================


@qd.kernel(fastcache=gs.use_fastcache)
@qd.perf_dispatch(
get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0
)
def func_solve_init(
dofs_info: array_class.DofsInfo,
dofs_state: array_class.DofsState,
entities_info: array_class.EntitiesInfo,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
) -> None: ...


@func_solve_init.register(is_compatible=lambda *args, **kwargs: True)
@qd.kernel(fastcache=gs.use_fastcache)
def func_solve_init_monolith(
dofs_info: array_class.DofsInfo,
dofs_state: array_class.DofsState,
entities_info: array_class.EntitiesInfo,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
_B = dofs_state.acc_smooth.shape[1]
n_dofs = dofs_state.acc_smooth.shape[0]
Expand Down
179 changes: 179 additions & 0 deletions genesis/engine/solvers/rigid/constraint/solver_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,185 @@ def _kernel_update_search_direction(
)


# ================================================ Init kernels ================================================


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_warmstart(
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
"""Select qacc from warmstart or acc_smooth, parallelized over (dof, env)."""
n_dofs = dofs_state.acc_smooth.shape[0]
_B = dofs_state.acc_smooth.shape[1]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in ti.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]:
constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
else:
constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_Ma(
dofs_info: array_class.DofsInfo,
entities_info: array_class.EntitiesInfo,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
"""Compute Ma = M @ qacc, parallelized over (dof, env)."""
solver.initialize_Ma(
Ma=constraint_state.Ma,
qacc=constraint_state.qacc,
dofs_info=dofs_info,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_Jaref(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
"""Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env)."""
len_constraints = constraint_state.Jaref.shape[0]
n_dofs = constraint_state.jac.shape[1]
_B = constraint_state.grad.shape[1]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_c, i_b in ti.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b]:
Jaref = -constraint_state.aref[i_c, i_b]
if ti.static(static_rigid_sim_config.sparse_solve):
for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]):
i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b]
Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b]
else:
for i_d in range(n_dofs):
Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b]
constraint_state.Jaref[i_c, i_b] = Jaref


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_improved(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
"""Set improved = (n_constraints > 0) for each env."""
_B = constraint_state.grad.shape[1]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_search(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
"""Set search = -Mgrad, parallelized over (dof, env)."""
n_dofs = constraint_state.search.shape[0]
_B = constraint_state.grad.shape[1]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in ti.ndrange(n_dofs, _B):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_update_constraint(
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
"""Init-only constraint update — wraps monolith's func_update_constraint for exact FP match."""
solver.func_update_constraint(
qacc=constraint_state.qacc,
Ma=constraint_state.Ma,
cost=constraint_state.cost,
dofs_state=dofs_state,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_init_update_gradient(
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
"""Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU)."""
solver.func_update_gradient(
dofs_state=dofs_state,
entities_info=entities_info,
constraint_state=constraint_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)


@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda})
def func_solve_init_decomposed(
dofs_info,
dofs_state,
entities_info,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
):
"""
Decomposed version of func_solve_init for CUDA backend (non-mujoco path).

Breaks the monolithic init kernel into separate kernel launches:
1. Warmstart selection (ndrange over dofs)
2. Ma = M @ qacc (ndrange over dofs with entity lookup)
3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization)
4. Set improved flags
5. Update constraint (wraps monolith's func_update_constraint for exact FP match)
6. Newton hessian (Newton only — reuse existing kernel)
7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU)
8. search = -Mgrad (ndrange over dofs)
"""
# 1. Warmstart selection
_kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config)

# 2. Ma = M @ qacc
_kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config)

# 3. Jaref = -aref + J @ qacc (parallelized over constraints)
_kernel_init_Jaref(constraint_state, static_rigid_sim_config)

# 4. Set improved flags (needed by decomposed update_constraint kernels)
_kernel_init_improved(constraint_state, static_rigid_sim_config)

# 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match)
_kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config)

# 6. Newton hessian (Newton only)
if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton:
_kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config)

# 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU)
_kernel_init_update_gradient(
entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config
)

# 8. search = -Mgrad
_kernel_init_search(constraint_state, static_rigid_sim_config)


# ============================================== Solve body kernels ================================================


@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda})
def func_solve_decomposed(
entities_info,
Expand Down
Loading