diff --git a/examples/differentiable_rigid.py b/examples/differentiable_rigid.py
new file mode 100644
index 0000000000..9ee90ac18e
--- /dev/null
+++ b/examples/differentiable_rigid.py
@@ -0,0 +1,92 @@
+import torch
+import genesis as gs
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="info")
+
+dt = 1e-2
+horizon = 100
+substeps = 1
+goal_pos = gs.tensor([0.7, 1.0, 0.05])
+goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9])
+goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True)
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_contact_island=False,
+ use_hibernation=False,
+ ),
+ viewer_options=gs.options.ViewerOptions(
+ camera_pos=(2.5, -0.15, 2.42),
+ camera_lookat=(0.5, 0.5, 0.1),
+ ),
+ show_viewer=show_viewer,
+)
+
+box = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+)
+if show_viewer:
+ target = scene.add_entity(
+ gs.morphs.Box(
+ pos=goal_pos,
+ quat=goal_quat,
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.9, 0.0, 0.5),
+ ),
+ )
+
+scene.build()
+
+num_iter = 200
+lr = 1e-2
+
+init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True)
+init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
+
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
+
+for iter in range(num_iter):
+ scene.reset()
+
+ box.set_pos(init_pos)
+ box.set_quat(init_quat)
+
+ loss = 0
+ for i in range(horizon):
+ scene.step()
+ if show_viewer:
+ target.set_pos(goal_pos)
+ target.set_quat(goal_quat)
+
+ box_state = box.get_state()
+ box_pos = box_state.pos
+ box_quat = box_state.quat
+ loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum()
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ with torch.no_grad():
+ init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+
+ print("loss: ", loss.item())
+
+# assert_allclose(loss, 0.0, atol=1e-2)
diff --git a/examples/differentiable_rigid_demo_1.py b/examples/differentiable_rigid_demo_1.py
new file mode 100644
index 0000000000..9ee90ac18e
--- /dev/null
+++ b/examples/differentiable_rigid_demo_1.py
@@ -0,0 +1,92 @@
+import torch
+import genesis as gs
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="info")
+
+dt = 1e-2
+horizon = 100
+substeps = 1
+goal_pos = gs.tensor([0.7, 1.0, 0.05])
+goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9])
+goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True)
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_contact_island=False,
+ use_hibernation=False,
+ ),
+ viewer_options=gs.options.ViewerOptions(
+ camera_pos=(2.5, -0.15, 2.42),
+ camera_lookat=(0.5, 0.5, 0.1),
+ ),
+ show_viewer=show_viewer,
+)
+
+box = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+)
+if show_viewer:
+ target = scene.add_entity(
+ gs.morphs.Box(
+ pos=goal_pos,
+ quat=goal_quat,
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.9, 0.0, 0.5),
+ ),
+ )
+
+scene.build()
+
+num_iter = 200
+lr = 1e-2
+
+init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True)
+init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
+
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
+
+for iter in range(num_iter):
+ scene.reset()
+
+ box.set_pos(init_pos)
+ box.set_quat(init_quat)
+
+ loss = 0
+ for i in range(horizon):
+ scene.step()
+ if show_viewer:
+ target.set_pos(goal_pos)
+ target.set_quat(goal_quat)
+
+ box_state = box.get_state()
+ box_pos = box_state.pos
+ box_quat = box_state.quat
+ loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum()
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ with torch.no_grad():
+ init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+
+ print("loss: ", loss.item())
+
+# assert_allclose(loss, 0.0, atol=1e-2)
diff --git a/examples/diffrigid/1_one_step.py b/examples/diffrigid/1_one_step.py
new file mode 100644
index 0000000000..a7afd6cbf7
--- /dev/null
+++ b/examples/diffrigid/1_one_step.py
@@ -0,0 +1,96 @@
+"""
+One step optimization for the basic debugging of differentiable rigid simulation.
+"""
+import torch
+import genesis as gs
+import matplotlib.pyplot as plt
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="warn", backend=gs.cpu)
+
+dt = 1e-2
+horizon = 1
+substeps = 1
+goal_pos = gs.tensor([0.0, 0.0, 1.0])
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True),
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ show_viewer=show_viewer,
+)
+
+ball = scene.add_entity(
+ gs.morphs.Sphere(
+ pos=(0, 0, 0.109), # small penetration with ground
+ radius=0.1,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+)
+
+ground = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(5.0, 5.0, 0.02),
+ #fixed=True,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.0, 0.9, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ gravity_compensation=1,
+ )
+)
+
+scene.build()
+
+num_iter = 400
+lr = 1e-4
+
+init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True)
+#init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+optimizer = torch.optim.Adam([init_pos], lr=lr)
+
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
+
+prev_loss = float('inf')
+losses = []
+for iter in range(num_iter):
+ scene.reset()
+
+ ground.set_pos(init_pos)
+ # ground.set_quat(init_quat)
+
+ for i in range(horizon):
+ scene.step()
+
+ ball_state = ball.get_state()
+ ball_pos = ball_state.pos
+ loss = torch.abs(ball_pos - goal_pos).sum()
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(init_pos.grad, 1.0)
+
+ # with torch.no_grad():
+ # init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+ # with torch.no_grad():
+ # init_pos.data[0] = 0.0
+ # init_pos.data[1] = 0.0
+
+ print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}")
+ prev_loss = loss.item()
+
+ losses.append(loss.item())
+
+ plt.plot(losses)
+ plt.savefig("loss.png")
+ plt.close()
\ No newline at end of file
diff --git a/examples/diffrigid/2_lift_ball.py b/examples/diffrigid/2_lift_ball.py
new file mode 100644
index 0000000000..8da89913d3
--- /dev/null
+++ b/examples/diffrigid/2_lift_ball.py
@@ -0,0 +1,106 @@
+"""
+One step optimization for the basic debugging of differentiable rigid simulation.
+"""
+import torch
+import genesis as gs
+import matplotlib.pyplot as plt
+import matplotlib
+matplotlib.use("Agg")
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="warn", backend=gs.cpu)
+
+dt = 1e-2
+horizon = 10
+substeps = 1
+grad_window = 5 #None
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window),
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ show_viewer=show_viewer,
+)
+
+ball = scene.add_entity(
+ gs.morphs.Sphere(
+ pos=(0, 0, 0.109), # small penetration with ground
+ radius=0.1,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+)
+
+ground = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(5.0, 5.0, 0.02),
+ #fixed=True,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.0, 0.9, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ gravity_compensation=1,
+ )
+)
+cam = scene.add_camera(
+ pos=(3.5, 0.5, 2.5),
+ lookat=(0.0, 0.0, 0.5),
+ fov=40,
+ GUI=False,
+)
+
+scene.build()
+
+num_iter = 10000
+lr = 1e-2
+
+force = gs.zeros((horizon, 6), requires_grad=True)
+optimizer = torch.optim.Adam([force], lr=lr)
+
+render_every = 100
+prev_loss = float('inf')
+losses = []
+for iter in range(num_iter):
+ scene.reset()
+
+ curr_losses = []
+ if iter % render_every == 0:
+ cam.start_recording()
+ for i in range(horizon):
+ curr_force = force[i]
+ ball.control_dofs_force(curr_force)
+ scene.step()
+
+ ball_state = ball.get_state()
+ ball_pos = ball_state.pos
+ curr_loss = -ball_pos[:, 2].sum() # make x, y, z larger
+ curr_losses.append(curr_loss)
+
+ if iter % render_every == 0:
+ cam.render()
+ if iter % render_every == 0:
+ cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30)
+
+ loss = sum(curr_losses) / len(curr_losses)
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0)
+
+ with torch.no_grad():
+ force.data[:, 6:] = 0.0
+
+ print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}")
+ prev_loss = loss.item()
+
+ losses.append(loss.item())
+
+ plt.plot(losses)
+ plt.savefig("loss.png")
+ plt.close()
\ No newline at end of file
diff --git a/examples/diffrigid/ant.py b/examples/diffrigid/ant.py
new file mode 100644
index 0000000000..9935066ea1
--- /dev/null
+++ b/examples/diffrigid/ant.py
@@ -0,0 +1,297 @@
+import argparse
+
+import numpy as np
+
+import genesis as gs
+
+import torch
+
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+import matplotlib
+matplotlib.use("Agg")
+
+class Controller(torch.nn.Module):
+ def __init__(self, obs_dim, n_dofs, hidden_dim=64):
+ super().__init__()
+ self.obs_dim = obs_dim
+ self.n_dofs = n_dofs
+
+ # Batch normalization layer
+ self.bn = torch.nn.BatchNorm1d(obs_dim)
+
+ # MLP layers (2-3 layers)
+ self.fc1 = torch.nn.Linear(obs_dim, hidden_dim)
+ self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
+ self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim)
+
+ # Output layers for mean and log_std
+ self.mean_layer = torch.nn.Linear(hidden_dim, n_dofs)
+ self.log_std_layer = torch.nn.Linear(hidden_dim, n_dofs)
+
+ # Initialize log_std to small values
+ self.log_std_layer.weight.data.fill_(0.0)
+ self.log_std_layer.bias.data.fill_(-0.5)
+
+ def forward(self, obs):
+ """
+ Args:
+ obs: observation tensor of shape (batch_size, obs_dim)
+ Returns:
+ mean: mean of action distribution, shape (batch_size, n_dofs)
+ std: standard deviation of action distribution, shape (batch_size, n_dofs)
+ """
+ # Batch normalization
+ if obs.shape[0] > 1:
+ x = self.bn(obs)
+ else:
+ x = obs
+
+ # MLP layers with ReLU activation
+ x = torch.relu(self.fc1(x))
+ x = torch.relu(self.fc2(x))
+ x = torch.relu(self.fc3(x))
+
+ # Output mean and log_std
+ mean = self.mean_layer(x)
+ log_std = self.log_std_layer(x)
+
+ # Clamp log_std to prevent extreme values
+ log_std = torch.clamp(log_std, min=-10, max=2)
+ std = torch.exp(log_std)
+
+ return mean, std
+
+ def sample_action(self, obs):
+ """
+ Sample action from the policy distribution.
+
+ Args:
+ obs: observation tensor of shape (batch_size, obs_dim)
+ Returns:
+ action: sampled action, shape (batch_size, n_dofs)
+ mean: mean of action distribution, shape (batch_size, n_dofs)
+ std: standard deviation of action distribution, shape (batch_size, n_dofs)
+ """
+ mean, std = self.forward(obs)
+ noise = torch.randn_like(mean)
+ action = mean + std * noise
+ return action, mean, std
+
+def observe_fn(state):
+ qpos = state.qpos
+ dofs_vel = state.dofs_vel
+ dofs_acc = state.dofs_acc
+
+ return torch.cat([qpos, dofs_vel, dofs_acc], dim=1).detach()
+
+def reward_fn(state, dt, prev_state=None):
+ pos = state.pos
+
+ height_clip = torch.clamp(pos[:, 2] - 0.8, -float('inf'), 1.0)
+ height_reward = torch.where(height_clip <= 0.0, -200 * (height_clip ** 2), height_clip)
+ forward_reward = (pos[:, 0] - prev_state.pos[:, 0].detach()) / dt if prev_state is not None else 0.0
+
+ reward = height_reward * 0.01 + forward_reward
+
+ return reward
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-v", "--vis", action="store_true", default=False)
+ parser.add_argument("-n", "--n_envs", type=int, default=49)
+ args = parser.parse_args()
+
+ args.vis = False
+ args.n_envs = 64
+
+ dt = 0.01
+ substeps = 1
+ horizon_steps = 128
+ window_substeps = 32
+ window_steps = int(window_substeps / substeps)
+ iteration = 10000
+ lr = 1e-4
+ render_every = 100
+
+ ########################## init ##########################
+ gs.init(backend=gs.gpu, logging_level="warn")
+
+ ########################## create a scene ##########################
+ viewer_options = gs.options.ViewerOptions(
+ camera_pos=(3, -1, 1.5),
+ camera_lookat=(0.0, 0.0, 0.0),
+ camera_fov=30,
+ max_FPS=60,
+ )
+
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=window_steps),
+ viewer_options=viewer_options,
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ vis_options=gs.options.VisOptions(
+ rendered_envs_idx=(0,),
+ show_world_frame=True,
+ ),
+ show_viewer=args.vis,
+ )
+
+ ########################## entities ##########################
+ # plane = scene.add_entity(
+ # gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True, pos=(0, 0, -0.5)),
+ # )
+ plane = scene.add_entity(
+ gs.morphs.Box(size=(10, 10, 0.6), pos=(0, 0, -0.3), fixed=True),
+ )
+ ant = scene.add_entity(
+ gs.morphs.MJCF(file="xml/walker_no_ground.xml"),
+ # vis_mode="collision"
+ )
+ cam = scene.add_camera(
+ pos=(3.5, 0.5, 2.5),
+ lookat=(0.0, 0.0, 0.5),
+ fov=40,
+ GUI=False,
+ env_idx=0,
+ )
+ cam.follow_entity(ant)
+
+ ########################## build ##########################
+ scene.build(n_envs=args.n_envs, env_spacing=(1, 1))
+
+ n_dofs = ant.n_dofs
+
+ # rand_force = torch.randn((n_dofs,), dtype=torch.float32)
+ # for i in range(10000):
+ # ant.control_dofs_force(rand_force)
+ # scene.step()
+
+ # if i % 1000 == 0 and i > 0:
+ # print("----------------------------------------------------------")
+ # rand_force = torch.randn((n_dofs,), dtype=torch.float32)
+
+ # # Reset env
+ # rigid_solver = scene.sim.rigid_solver
+ # qpos = rigid_solver._rigid_global_info.qpos.to_numpy()
+ # dofs_vel = rigid_solver.dofs_state.vel.to_numpy()
+ # dofs_acc = rigid_solver.dofs_state.acc.to_numpy()
+ # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy()
+ # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy()
+
+ # scene.reset()
+ # scene.sim.rigid_solver._rigid_global_info.qpos.from_numpy(qpos)
+ # rigid_solver.dofs_state.vel.from_numpy(dofs_vel)
+ # rigid_solver.dofs_state.acc.from_numpy(dofs_acc)
+ # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth)
+ # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws)
+ # rigid_solver.load_test()
+
+
+ # Initialize controller
+ # Get obs_dim by computing observation once
+ scene.reset()
+ state = ant.get_state()
+ obs = observe_fn(state)
+ obs_dim = obs.shape[1]
+
+ controller = Controller(obs_dim=obs_dim, n_dofs=n_dofs, hidden_dim=64)
+ optimizer = torch.optim.Adam(controller.parameters(), lr=lr)
+
+ rewards = []
+ for iter in range(iteration):
+ scene.reset()
+ acc_reward = None
+ prev_state = None
+
+ record = (iter % render_every == 0) or (iter == iteration - 1)
+ if record:
+ cam.start_recording()
+ print("running forward pass...")
+ for step in range(horizon_steps):
+ scene.step()
+ if record:
+ cam.render()
+
+ # Determine observation and reward
+ state = ant.get_state()
+ obs = observe_fn(state)
+ reward = reward_fn(state, dt, prev_state)
+ if acc_reward is None:
+ acc_reward = reward
+ else:
+ acc_reward += reward
+
+ prev_state = state
+
+ # Determine action
+ action, mean, std = controller.sample_action(obs)
+ # Apply action (assuming action is force/torque for dofs)
+ ant.control_dofs_force(action)
+
+ truncate = step == horizon_steps - 1 # step % window_steps == 0 or step == horizon_steps - 1
+ if truncate and step > 0:
+ print("running backward pass...")
+ acc_reward = acc_reward / (step + 1)
+ mean_reward = acc_reward.mean()
+ loss = -mean_reward
+
+ optimizer.zero_grad()
+ loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0)
+ optimizer.step()
+
+ print(f"[ITER {iter}] Mean Reward: {mean_reward.item():.4g} | Grad Norm: {grad_norm:.4g}")
+
+ rewards.append(mean_reward.detach().item())
+
+ plt.plot(rewards)
+ plt.savefig("rewards.png")
+ plt.close()
+
+ # Reset env
+ # scene.sim.reset_grad()
+ # scene._forward_ready = True
+ # scene._backward_ready = True
+ # scene._t = t
+ # scene.sim._cur_substep_global = substep_global
+
+ # rigid_solver._rigid_global_info.qpos.from_numpy(_qpos)
+ # rigid_solver.dofs_state.vel.from_numpy(_dofs_vel)
+ # rigid_solver.dofs_state.acc.from_numpy(_dofs_acc)
+ # rigid_solver.dofs_state.acc_smooth.from_numpy(_dofs_acc_smooth)
+ # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(_solver_qacc_ws)
+ # rigid_solver.load_test()
+
+ # rigid_solver = scene.sim.rigid_solver
+ # qpos = rigid_solver._rigid_global_info.qpos.to_numpy()
+ # dofs_vel = rigid_solver.dofs_state.vel.to_numpy()
+ # dofs_acc = rigid_solver.dofs_state.acc.to_numpy()
+ # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy()
+ # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy()
+ # scene_t = scene._t
+ # scene.reset()
+ # scene._t = scene_t
+ # rigid_solver._rigid_global_info.qpos.from_numpy(qpos)
+ # rigid_solver.dofs_state.vel.from_numpy(dofs_vel)
+ # rigid_solver.dofs_state.acc.from_numpy(dofs_acc)
+ # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth)
+ # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws)
+ # rigid_solver.load_test()
+
+ acc_reward = None
+
+ # if step // window_steps > 1:
+ break
+
+ if record:
+ cam.stop_recording(save_to_filename=f"ant_video_{iter:06d}.mp4", fps=30)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/diffrigid/debug_rot.py b/examples/diffrigid/debug_rot.py
new file mode 100644
index 0000000000..3ffba27e5f
--- /dev/null
+++ b/examples/diffrigid/debug_rot.py
@@ -0,0 +1,113 @@
+"""
+One step optimization for the basic debugging of differentiable rigid simulation.
+"""
+import torch
+import genesis as gs
+import matplotlib.pyplot as plt
+import matplotlib
+import numpy as np
+matplotlib.use("Agg")
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="warn", backend=gs.cpu)
+
+dt = 1e-2
+horizon = 10
+substeps = 1
+grad_window = None
+np.random.seed(0)
+goal_quat = np.random.randn(4)
+goal_quat = goal_quat / np.linalg.norm(goal_quat)
+goal_quat = gs.tensor(goal_quat)
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window),
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ show_viewer=show_viewer,
+)
+
+ground = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(5.0, 5.0, 0.02),
+ #fixed=True,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.0, 0.9, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ gravity_compensation=1,
+ )
+)
+ball = scene.add_entity(
+ gs.morphs.Sphere(
+ pos=(0, 0, 0.109), # small penetration with ground
+ radius=0.1,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+)
+cam = scene.add_camera(
+ pos=(3.5, 0.5, 2.5),
+ lookat=(0.0, 0.0, 0.5),
+ fov=40,
+ GUI=False,
+)
+
+scene.build()
+
+num_iter = 10000
+lr = 1e-2
+
+force = gs.zeros((horizon, 6), requires_grad=True)
+with torch.no_grad():
+ torch.manual_seed(0)
+ force.data[:, 3:] = torch.randn_like(force.data[:, 3:])
+optimizer = torch.optim.Adam([force], lr=lr)
+
+render_every = 100
+prev_loss = float('inf')
+losses = []
+for iter in range(num_iter):
+ scene.reset()
+
+ curr_losses = []
+ if iter % render_every == 0:
+ cam.start_recording()
+ for i in range(horizon):
+ curr_force = force[i]
+ ground.control_dofs_force(curr_force)
+ scene.step()
+
+ box_state = ground.get_state()
+ box_quat = box_state.quat
+ curr_loss = (box_quat - goal_quat).abs().sum()
+ curr_losses.append(curr_loss)
+
+ if iter % render_every == 0:
+ cam.render()
+ if iter % render_every == 0:
+ cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30)
+
+ loss = sum(curr_losses) / len(curr_losses)
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0)
+
+ with torch.no_grad():
+ force.data[:, :3] = 0.0
+
+ print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}")
+ prev_loss = loss.item()
+
+ losses.append(loss.item())
+
+ plt.plot(losses)
+ plt.savefig("loss.png")
+ plt.close()
\ No newline at end of file
diff --git a/examples/diffrigid/pingpong.py b/examples/diffrigid/pingpong.py
new file mode 100644
index 0000000000..e1751cf5bc
--- /dev/null
+++ b/examples/diffrigid/pingpong.py
@@ -0,0 +1,103 @@
+import torch
+import genesis as gs
+
+show_viewer = True
+
+gs.init(precision="32", logging_level="warn", backend=gs.cpu)
+
+dt = 1e-2
+horizon = 200
+substeps = 4
+goal_pos = gs.tensor([0.0, 0.1, -0.1])
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True),
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ show_viewer=show_viewer,
+)
+
+ball = scene.add_entity(
+ gs.morphs.Sphere(
+ pos=(0, 0, 0.5),
+ radius=0.1,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ rho=0.001,
+ )
+)
+# if show_viewer:
+# target = scene.add_entity(
+# gs.morphs.Sphere(
+# pos=goal_pos.cpu().numpy().tolist(),
+# radius=0.1,
+# ),
+# surface=gs.surfaces.Default(
+# color=(0.0, 0.9, 0.0, 0.5),
+# ),
+# )
+
+racket = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(5.0, 5.0, 0.01),
+ #fixed=True,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.0, 0.9, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ gravity_compensation=1,
+ )
+)
+
+scene.build()
+
+num_iter = 200
+lr = 1e-4
+
+init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True)
+init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
+
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
+
+prev_loss = float('inf')
+for iter in range(num_iter):
+ scene.reset()
+
+ racket.set_pos(init_pos)
+ racket.set_quat(init_quat)
+ #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0]))
+
+ losses = []
+ for i in range(horizon):
+ scene.step()
+ # ball_state = ball.get_state()
+ # ball_pos = ball_state.pos
+ # losses.append(torch.abs(ball_pos - goal_pos).sum())
+ # if show_viewer:
+ # target.set_pos(goal_pos)
+
+ ball_state = ball.get_state()
+ ball_pos = ball_state.pos
+ loss = torch.abs(ball_pos - goal_pos).sum()
+ # loss = sum(losses) / len(losses)
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ with torch.no_grad():
+ init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+
+ print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}")
+ prev_loss = loss.item()
+
+# assert_allclose(loss, 0.0, atol=1e-2)
diff --git a/examples/diffrigid/slide_ball.py b/examples/diffrigid/slide_ball.py
new file mode 100644
index 0000000000..a12c286ede
--- /dev/null
+++ b/examples/diffrigid/slide_ball.py
@@ -0,0 +1,129 @@
+import torch
+import genesis as gs
+import matplotlib.pyplot as plt
+
+show_viewer = False
+
+gs.init(precision="32", logging_level="warn", backend=gs.cpu)
+
+dt = 1e-2
+horizon = 100
+substeps = 4
+goal_pos = gs.tensor([0.2, 0.1, 0.1])
+render_every = 100
+
+scene = gs.Scene(
+ sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True),
+ rigid_options=gs.options.RigidOptions(
+ use_gjk_collision=True,
+ enable_joint_limit=False,
+ ),
+ show_viewer=show_viewer,
+)
+
+ball = scene.add_entity(
+ gs.morphs.Sphere(
+ pos=(0, 0, 0.11),
+ radius=0.1,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ rho=0.001,
+ )
+)
+# if show_viewer:
+# target = scene.add_entity(
+# gs.morphs.Sphere(
+# pos=goal_pos.cpu().numpy().tolist(),
+# radius=0.1,
+# ),
+# surface=gs.surfaces.Default(
+# color=(0.0, 0.9, 0.0, 0.5),
+# ),
+# )
+
+racket = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(5.0, 5.0, 0.02),
+ #fixed=True,
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.0, 0.9, 1.0),
+ ),
+ material=gs.materials.Rigid(
+ gravity_compensation=1,
+ )
+)
+
+cam = scene.add_camera(
+ pos=(3.5, 0.5, 2.5),
+ lookat=(0.0, 0.0, 0.5),
+ fov=40,
+ GUI=False,
+)
+
+scene.build()
+
+num_iter = 300
+lr = 1e-4
+
+init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True)
+init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
+
+scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-4)
+
+prev_loss = float('inf')
+losses = []
+for iter in range(num_iter):
+ scene.reset()
+
+ racket.set_pos(init_pos)
+ racket.set_quat(init_quat)
+ #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0]))
+
+ record = (iter % render_every == 0) or (iter == num_iter - 1)
+
+ if record:
+ cam.start_recording()
+ for i in range(horizon):
+ scene.step()
+ if record:
+ cam.render()
+
+ if record:
+ cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30)
+
+ ball_state = ball.get_state()
+ ball_pos = ball_state.pos
+ loss = torch.abs(ball_pos - goal_pos).sum()
+ # loss = sum(losses) / len(losses)
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ with torch.no_grad():
+ # init_quat.data[0] = 1.0
+ # init_quat.data[1] = 0.0
+ # init_quat.data[2] = 0.0
+ # init_quat.data[3] = 0.0
+ init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+ #init_pos.data = init_pos.data.clamp(0.0, 0.0)
+
+ print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Ball Pos: {ball_pos.detach().cpu().numpy().tolist()}")
+ prev_loss = loss.item()
+
+ losses.append(loss.item())
+
+ plt.plot(losses)
+ # set y axis to log scale
+ plt.yscale('log')
+ plt.savefig("loss.png")
+ plt.close()
+
+# assert_allclose(loss, 0.0, atol=1e-2)
diff --git a/genesis/assets/xml/ant_no_ground.xml b/genesis/assets/xml/ant_no_ground.xml
new file mode 100644
index 0000000000..7eb2522ec4
--- /dev/null
+++ b/genesis/assets/xml/ant_no_ground.xml
@@ -0,0 +1,93 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/walker_no_ground.xml b/genesis/assets/xml/walker_no_ground.xml
new file mode 100644
index 0000000000..827e776c7c
--- /dev/null
+++ b/genesis/assets/xml/walker_no_ground.xml
@@ -0,0 +1,80 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py
index 7297b71af9..47ff9c9a3a 100644
--- a/genesis/engine/entities/rigid_entity/rigid_entity.py
+++ b/genesis/engine/entities/rigid_entity/rigid_entity.py
@@ -117,7 +117,7 @@ def __init__(
self._load_model()
# Initialize target variables and checkpoint
- self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity")
+ self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force")
self._tgt = dict()
self._tgt_buffer = list()
self._ckpt = dict()
@@ -1661,6 +1661,8 @@ def process_input(self, in_backward=False):
self.set_quat(**data_kwargs)
case "set_dofs_velocity":
self.set_dofs_velocity(**data_kwargs)
+ case "control_dofs_force":
+ self.control_dofs_force(**data_kwargs)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
@@ -1693,6 +1695,16 @@ def process_input_grad(self):
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)
+
+ case "control_dofs_force":
+ force = data_kwargs.pop("force")
+ if force.requires_grad:
+ force._backward_from_ti(
+ self.control_dofs_force_grad,
+ data_kwargs["dofs_idx_local"],
+ data_kwargs["envs_idx"],
+ )
+
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
@@ -1715,10 +1727,16 @@ def get_state(self):
solver_state = self._solver.get_state()
pos = solver_state.links_pos[:, self.base_link_idx]
quat = solver_state.links_quat[:, self.base_link_idx]
+ qpos = solver_state.qpos[:, self._q_start:self._q_start + self.n_qs]
+ dofs_vel = solver_state.dofs_vel[:, self._dof_start:self._dof_start + self.n_dofs]
+ dofs_acc = solver_state.dofs_acc[:, self._dof_start:self._dof_start + self.n_dofs]
state._pos = pos
state._quat = quat
-
+ state._qpos = qpos
+ state._dofs_vel = dofs_vel
+ state._dofs_acc = dofs_acc
+
return state
def _get_global_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
@@ -2330,7 +2348,7 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
@gs.assert_built
def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad):
- dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
+ dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data)
@gs.assert_built
@@ -2355,6 +2373,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
self._solver.set_dofs_position(position, dofs_idx, envs_idx)
@gs.assert_built
+ @tracked
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None):
"""
Control the entity's dofs' motor force. This is used for force/torque control.
@@ -2371,6 +2390,14 @@ def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.control_dofs_force(force, dofs_idx, envs_idx)
+ @gs.assert_built
+ def control_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad):
+ dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
+ self._solver.control_dofs_force_grad(dofs_idx, envs_idx, force_grad.data)
+
+ pass
+
+
@gs.assert_built
def control_dofs_velocity(self, velocity, dofs_idx_local=None, envs_idx=None):
"""
diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py
index a269c998b4..80859ba86d 100644
--- a/genesis/engine/simulator.py
+++ b/genesis/engine/simulator.py
@@ -280,7 +280,9 @@ def step(self, in_backward=False):
self.save_ckpt()
if self.rigid_solver.is_active:
- self.rigid_solver.clear_external_force()
+ # In backward pass, we need to keep the external force for gradient computation
+ if not in_backward:
+ self.rigid_solver.clear_external_force()
if self._cur_substep_global % RATE_CHECK_ERRNO == 0:
self.rigid_solver.check_errno()
@@ -295,8 +297,16 @@ def _step_grad(self):
self.sub_step_grad(self.cur_substep_local)
+ if self.rigid_solver.is_active:
+ # Clear external force after gradient computation
+ self.rigid_solver.clear_external_force()
+
self.process_input_grad()
+ if self.options.grad_window_steps is not None and self.cur_step_global % self.options.grad_window_steps == 0:
+ # Truncate upstream gradient flow
+ self.rigid_solver.zero_grad()
+
def process_input(self, in_backward=False):
"""
setting _tgt state using external commands
diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py
index e7c35378af..a6aebb3d42 100644
--- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py
+++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py
@@ -174,6 +174,20 @@ def add_inequality_constraints(self):
static_rigid_sim_config=solver._static_rigid_sim_config,
)
+ def add_inequality_constraints_grad(self):
+ solver = self._solver
+ add_inequality_constraints.grad(
+ links_info=solver.links_info,
+ links_state=solver.links_state,
+ dofs_state=solver.dofs_state,
+ dofs_info=solver.dofs_info,
+ joints_info=solver.joints_info,
+ constraint_state=self.constraint_state,
+ collider_state=self._collider._collider_state,
+ rigid_global_info=solver._rigid_global_info,
+ static_rigid_sim_config=solver._static_rigid_sim_config,
+ )
+
def resolve(self):
solver = self._solver
@@ -437,80 +451,110 @@ def add_collision_constraints(
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
- EPS = rigid_global_info.EPS[None]
-
- _B = dofs_state.ctrl_mode.shape[1]
- n_dofs = dofs_state.ctrl_mode.shape[0]
-
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
- for i_b in range(_B):
- for i_col in range(collider_state.n_contacts[i_b]):
- contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b]
- contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b]
-
- contact_data_pos = collider_state.contact_data.pos[i_col, i_b]
- contact_data_normal = collider_state.contact_data.normal[i_col, i_b]
- contact_data_friction = collider_state.contact_data.friction[i_col, i_b]
- contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b]
- contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b]
-
- link_a = contact_data_link_a
- link_b = contact_data_link_b
- link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a
- link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b
-
- d1, d2 = gu.ti_orthogonals(contact_data_normal)
-
- invweight = links_info.invweight[link_a_maybe_batch][0]
- if link_b > -1:
- invweight = invweight + links_info.invweight[link_b_maybe_batch][0]
-
- for i in range(4):
+ for i_b, i_0, i_4 in (
+ ti.ndrange(dofs_state.ctrl_mode.shape[1], 1, 4)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.ndrange(dofs_state.ctrl_mode.shape[1], static_rigid_sim_config.max_contact_pairs, 4)
+ ):
+ EPS = rigid_global_info.EPS[None]
+ n_dofs = dofs_state.ctrl_mode.shape[0]
+
+ for i_1 in (
+ range(collider_state.n_contacts[i_b])
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(1))
+ ):
+ i_col = i_1 if ti.static(not static_rigid_sim_config.is_backward) else i_0
+ if i_col < collider_state.n_contacts[i_b]:
+ contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b]
+ contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b]
+
+ contact_data_pos = collider_state.contact_data.pos[i_col, i_b]
+ contact_data_normal = collider_state.contact_data.normal[i_col, i_b]
+ contact_data_friction = collider_state.contact_data.friction[i_col, i_b]
+ contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b]
+ contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b]
+
+ link_a = contact_data_link_a
+ link_b = contact_data_link_b
+ link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a
+ link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b
+
+ d1, d2 = gu.ti_orthogonals(contact_data_normal)
+
+ invweight = links_info.invweight[link_a_maybe_batch][0]
+ if link_b > -1:
+ invweight = invweight + links_info.invweight[link_b_maybe_batch][0]
+
+ #for i in range(4) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(4)):
+ i = i_4
d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2)
n = d * contact_data_friction - contact_data_normal
- n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1)
+ n_con = i_col * 4 + i # + constraint_state.n_constraints[i_b]
if ti.static(static_rigid_sim_config.sparse_solve):
for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]):
i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b]
constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0)
else:
- for i_d in range(n_dofs):
+ for i_d in (
+ range(n_dofs)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.n_dofs))
+ ):
constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0)
con_n_relevant_dofs = 0
jac_qvel = gs.ti_float(0.0)
- for i_ab in range(2):
+ for i_ab in range(2) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(2)):
sign = gs.ti_float(-1.0)
link = link_a
if i_ab == 1:
sign = gs.ti_float(1.0)
link = link_b
- while link > -1:
- link_maybe_batch = [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link
+ # FIXME: Set number of iterations to look for parent to certain value for autodiff
+ for i_parent in (
+ range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(5))
+ ):
+ if link > -1:
+ link_maybe_batch = (
+ [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link
+ )
+
+ # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending
+ for i_d_ in (
+ range(links_info.n_dofs[link_maybe_batch])
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link))
+ ):
+ if i_d_ < links_info.n_dofs[link_maybe_batch]:
+ i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_
- # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending
- for i_d_ in range(links_info.n_dofs[link_maybe_batch]):
- i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_
+ cdof_ang = dofs_state.cdof_ang[i_d, i_b]
+ cdot_vel = dofs_state.cdof_vel[i_d, i_b]
- cdof_ang = dofs_state.cdof_ang[i_d, i_b]
- cdot_vel = dofs_state.cdof_vel[i_d, i_b]
+ t_quat = gu.ti_identity_quat()
+ t_pos = contact_data_pos - links_state.root_COM[link, i_b]
+ _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat)
- t_quat = gu.ti_identity_quat()
- t_pos = contact_data_pos - links_state.root_COM[link, i_b]
- _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat)
+ diff = sign * vel
+ jac = diff @ n
+ jac_qvel += jac * dofs_state.vel[i_d, i_b]
+ constraint_state.jac[n_con, i_d, i_b] += jac
- diff = sign * vel
- jac = diff @ n
- jac_qvel = jac_qvel + jac * dofs_state.vel[i_d, i_b]
- constraint_state.jac[n_con, i_d, i_b] = constraint_state.jac[n_con, i_d, i_b] + jac
+ if ti.static(static_rigid_sim_config.sparse_solve):
+ constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d
+ con_n_relevant_dofs += 1
- if ti.static(static_rigid_sim_config.sparse_solve):
- constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d
- con_n_relevant_dofs += 1
+ link = links_info.parent_idx[link_maybe_batch]
- link = links_info.parent_idx[link_maybe_batch]
+ if ti.static(static_rigid_sim_config.is_backward):
+ if i_parent == 4 and link > -1:
+ print(
+ "Warning: Number of parents is too large for backward mode in add_collision_constraints"
+ )
if ti.static(static_rigid_sim_config.sparse_solve):
constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs
@@ -518,14 +562,18 @@ def add_collision_constraints(
contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration
)
- diag = invweight + contact_data_friction * contact_data_friction * invweight
- diag *= 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp
- diag = ti.max(diag, EPS)
+ diag_0 = invweight + contact_data_friction * contact_data_friction * invweight
+ diag_1 = diag_0 * 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp
+ diag = ti.max(diag_1, EPS)
constraint_state.diag[n_con, i_b] = diag
constraint_state.aref[n_con, i_b] = aref
constraint_state.efc_D[n_con, i_b] = 1 / diag
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_b in range(dofs_state.ctrl_mode.shape[1]):
+ constraint_state.n_constraints[i_b] += 4 * collider_state.n_contacts[i_b]
+
@ti.func
def func_equality_connect(
@@ -792,15 +840,15 @@ def add_inequality_constraints(
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
- add_frictionloss_constraints(
- links_info=links_info,
- joints_info=joints_info,
- dofs_info=dofs_info,
- dofs_state=dofs_state,
- rigid_global_info=rigid_global_info,
- constraint_state=constraint_state,
- static_rigid_sim_config=static_rigid_sim_config,
- )
+ # add_frictionloss_constraints(
+ # links_info=links_info,
+ # joints_info=joints_info,
+ # dofs_info=dofs_info,
+ # dofs_state=dofs_state,
+ # rigid_global_info=rigid_global_info,
+ # constraint_state=constraint_state,
+ # static_rigid_sim_config=static_rigid_sim_config,
+ # )
if ti.static(static_rigid_sim_config.enable_collision):
add_collision_constraints(
links_info=links_info,
@@ -811,16 +859,16 @@ def add_inequality_constraints(
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)
- if ti.static(static_rigid_sim_config.enable_joint_limit):
- add_joint_limit_constraints(
- links_info=links_info,
- joints_info=joints_info,
- dofs_info=dofs_info,
- dofs_state=dofs_state,
- rigid_global_info=rigid_global_info,
- constraint_state=constraint_state,
- static_rigid_sim_config=static_rigid_sim_config,
- )
+ # if ti.static(static_rigid_sim_config.enable_joint_limit):
+ # add_joint_limit_constraints(
+ # links_info=links_info,
+ # joints_info=joints_info,
+ # dofs_info=dofs_info,
+ # dofs_state=dofs_state,
+ # rigid_global_info=rigid_global_info,
+ # constraint_state=constraint_state,
+ # static_rigid_sim_config=static_rigid_sim_config,
+ # )
@ti.func
@@ -1026,52 +1074,69 @@ def add_joint_limit_constraints(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
- EPS = rigid_global_info.EPS[None]
-
- _B = constraint_state.jac.shape[2]
- n_links = links_info.root_idx.shape[0]
- n_dofs = dofs_state.ctrl_mode.shape[0]
-
# TODO: sparse mode
ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
- for i_b in range(_B):
- for i_l in range(n_links):
+ for i_b in range(constraint_state.jac.shape[2]):
+ EPS = rigid_global_info.EPS[None]
+ n_links = links_info.root_idx.shape[0]
+ n_dofs = dofs_state.ctrl_mode.shape[0]
+
+ for i_l in (
+ range(n_links)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.n_links))
+ ):
I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
- for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
- I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
-
- if joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC:
- i_q = joints_info.q_start[I_j]
- i_d = joints_info.dof_start[I_j]
- I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d
- pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0]
- pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b]
- pos_delta = min(pos_delta_min, pos_delta_max)
-
- if pos_delta < 0:
- jac = (pos_delta_min < pos_delta_max) * 2 - 1
- jac_qvel = jac * dofs_state.vel[i_d, i_b]
- imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta)
- diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS)
-
- n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1)
- constraint_state.diag[n_con, i_b] = diag
- constraint_state.aref[n_con, i_b] = aref
- constraint_state.efc_D[n_con, i_b] = 1 / diag
-
- if ti.static(static_rigid_sim_config.sparse_solve):
- for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]):
- i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b]
- constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0)
- else:
- for i_d2 in range(n_dofs):
- constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0)
- constraint_state.jac[n_con, i_d, i_b] = jac
+ for i_j_ in (
+ range(links_info.joint_start[I_l], links_info.joint_end[I_l])
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = (
+ i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l])
+ )
+ if i_j < links_info.joint_end[I_l]:
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+
+ if (
+ joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE
+ or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC
+ ):
+ i_q = joints_info.q_start[I_j]
+ i_d = joints_info.dof_start[I_j]
+ I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d
+ pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0]
+ pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b]
+ pos_delta = min(pos_delta_min, pos_delta_max)
+
+ if pos_delta < 0:
+ jac = (pos_delta_min < pos_delta_max) * 2 - 1
+ jac_qvel = jac * dofs_state.vel[i_d, i_b]
+ imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta)
+ diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS)
+
+ n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1)
+ constraint_state.diag[n_con, i_b] = diag
+ constraint_state.aref[n_con, i_b] = aref
+ constraint_state.efc_D[n_con, i_b] = 1 / diag
- if ti.static(static_rigid_sim_config.sparse_solve):
- constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1
- constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d
+ if ti.static(static_rigid_sim_config.sparse_solve):
+ for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]):
+ i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b]
+ constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0)
+ else:
+ for i_d2 in (
+ range(n_dofs)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.n_dofs))
+ ):
+ constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0)
+ constraint_state.jac[n_con, i_d, i_b] = jac
+
+ if ti.static(static_rigid_sim_config.sparse_solve):
+ constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1
+ constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d
@ti.func
@@ -1084,46 +1149,70 @@ def add_frictionloss_constraints(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
):
- EPS = rigid_global_info.EPS[None]
-
- _B = constraint_state.jac.shape[2]
- n_links = links_info.root_idx.shape[0]
- n_dofs = dofs_state.ctrl_mode.shape[0]
-
# TODO: sparse mode
# FIXME: The condition `if dofs_info.frictionloss[I_d] > EPS:` is not correctly evaluated on Apple Metal
# if `serialize=True`...
ti.loop_config(
serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL and gs.backend != gs.metal)
)
- for i_b in range(_B):
+ for i_b in range(constraint_state.jac.shape[2]):
constraint_state.n_constraints_frictionloss[i_b] = 0
- for i_l in range(n_links):
- I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
-
- for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
- I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
-
- for i_d in range(joints_info.dof_start[I_j], joints_info.dof_end[I_j]):
- I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d
+ EPS = rigid_global_info.EPS[None]
+ n_links = links_info.root_idx.shape[0]
+ n_dofs = dofs_state.ctrl_mode.shape[0]
- if dofs_info.frictionloss[I_d] > EPS:
- jac = 1.0
- jac_qvel = jac * dofs_state.vel[i_d, i_b]
- imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0)
- diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS)
-
- i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1)
- ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1)
+ for i_l in (
+ range(n_links)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.n_links))
+ ):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
- constraint_state.diag[i_con, i_b] = diag
- constraint_state.aref[i_con, i_b] = aref
- constraint_state.efc_D[i_con, i_b] = 1.0 / diag
- constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d]
- for i_d2 in range(n_dofs):
- constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0)
- constraint_state.jac[i_con, i_d, i_b] = jac
+ for i_j_ in (
+ range(links_info.joint_start[I_l], links_info.joint_end[I_l])
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = (
+ i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l])
+ )
+ if i_j < links_info.joint_end[I_l]:
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+
+ for i_d_ in (
+ range(joints_info.dof_start[I_j], joints_info.dof_end[I_j])
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint))
+ ):
+ i_d = (
+ i_d_
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else (i_d_ + joints_info.dof_start[I_j])
+ )
+ if i_d < joints_info.dof_end[I_j]:
+ I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d
+
+ if dofs_info.frictionloss[I_d] > EPS:
+ jac = 1.0
+ jac_qvel = jac * dofs_state.vel[i_d, i_b]
+ imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0)
+ diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS)
+
+ i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1)
+ ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1)
+
+ constraint_state.diag[i_con, i_b] = diag
+ constraint_state.aref[i_con, i_b] = aref
+ constraint_state.efc_D[i_con, i_b] = 1.0 / diag
+ constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d]
+ for i_d2 in (
+ range(n_dofs)
+ if ti.static(not static_rigid_sim_config.is_backward)
+ else ti.static(range(static_rigid_sim_config.n_dofs))
+ ):
+ constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0)
+ constraint_state.jac[i_con, i_d, i_b] = jac
@ti.func
diff --git a/genesis/engine/solvers/rigid/diff_gjk_decomp.py b/genesis/engine/solvers/rigid/diff_gjk_decomp.py
index 5d3e407e08..8f36df3e8c 100644
--- a/genesis/engine/solvers/rigid/diff_gjk_decomp.py
+++ b/genesis/engine/solvers/rigid/diff_gjk_decomp.py
@@ -77,7 +77,7 @@ def func_gjk_contact(
found_default_epa = False
# 4 (small) + 4 (large) perturbated configurations
- num_perturb = 8
+ num_perturb = 4
### Detect multiple possible contact points and gather the non-differentiable contact data.
for i in range(1 + num_perturb):
@@ -209,6 +209,8 @@ def func_gjk_contact(
)
found_default_epa = True
+ # Do not use extended EPA algorithm for now, practically it seems not needed.
+ break
# Break the loop if we found enough contact points for default configuration. As we can find at most
# 8 contact points for perturbed configurations, we can find at most max_contacts_per_pair - 8
diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py
index 48b1a693c8..2f28b7637e 100644
--- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py
+++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py
@@ -271,6 +271,9 @@ def build(self):
max_n_qs_per_link=max(link.n_qs for link in self.links) if self.links else 0,
n_links=self._n_links,
n_geoms=self._n_geoms,
+ n_dofs=self._n_dofs,
+ n_entities=self._n_entities,
+ # max_contact_pairs=self.collider._collider_state.contact_data.geom_a.shape[0],
)
self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig(**static_rigid_sim_config)
else:
@@ -299,6 +302,9 @@ def build(self):
if getattr(self._options, "noslip_iterations", 0) > 0:
gs.raise_exception("Noslip is not supported yet when requires_grad is True.")
+ if getattr(self._options, "sparse_solve", False):
+ gs.raise_exception("Sparse solve is not supported yet when requires_grad is True.")
+
# when the migration is finished, we will remove the about two lines
self._func_vel_at_point = func_vel_at_point
self._func_apply_coupling_force = func_apply_coupling_force
@@ -310,7 +316,6 @@ def build(self):
self._errno = self.data_manager.errno
self._rigid_global_info = self.data_manager.rigid_global_info
- self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache
if self._use_hibernation:
self.n_awake_dofs = self._rigid_global_info.n_awake_dofs
self.awake_dofs = self._rigid_global_info.awake_dofs
@@ -323,6 +328,8 @@ def build(self):
self.links_state_adjoint_cache = self.data_manager.links_state_adjoint_cache
self.joints_state_adjoint_cache = self.data_manager.joints_state_adjoint_cache
self.geoms_state_adjoint_cache = self.data_manager.geoms_state_adjoint_cache
+ self._rigid_adjoint_cache_fw = self.data_manager.rigid_adjoint_cache_fw
+ self._rigid_adjoint_cache_bw = self.data_manager.rigid_adjoint_cache_bw
self._init_mass_mat()
self._init_dof_fields()
@@ -816,6 +823,9 @@ def _init_sdf(self):
def _init_collider(self):
self.collider = Collider(self)
+ if self.sim.options.requires_grad:
+ self._static_rigid_sim_config.max_contact_pairs = self.collider._collider_state.contact_data.geom_a.shape[0]
+
if self.collider._collider_static_config.has_terrain:
link_idx_ = next(
i for i, _type in enumerate(ti_to_numpy(self.geoms_info.type)) if _type == gs.GEOM_TYPE.TERRAIN
@@ -859,8 +869,9 @@ def substep(self, f):
kernel_save_adjoint_cache(
f=f,
dofs_state=self.dofs_state,
+ constraint_state=self.constraint_solver.constraint_state,
rigid_global_info=self._rigid_global_info,
- rigid_adjoint_cache=self._rigid_adjoint_cache,
+ rigid_adjoint_cache=self._rigid_adjoint_cache_fw,
static_rigid_sim_config=self._static_rigid_sim_config,
)
@@ -908,8 +919,9 @@ def substep(self, f):
kernel_save_adjoint_cache(
f=f + 1,
dofs_state=self.dofs_state,
+ constraint_state=self.constraint_solver.constraint_state,
rigid_global_info=self._rigid_global_info,
- rigid_adjoint_cache=self._rigid_adjoint_cache,
+ rigid_adjoint_cache=self._rigid_adjoint_cache_fw,
static_rigid_sim_config=self._static_rigid_sim_config,
)
@@ -1148,9 +1160,6 @@ def substep_pre_coupling(self, f):
self.substep(f)
def substep_pre_coupling_grad(self, f):
- # Change to backward mode
- self._static_rigid_sim_config.is_backward = True
-
# Run forward substep again to restore this step's information, this is needed because we do not store info
# of every substep.
kernel_prepare_backward_substep(
@@ -1163,46 +1172,144 @@ def substep_pre_coupling_grad(self, f):
dofs_info=self.dofs_info,
geoms_state=self.geoms_state,
geoms_info=self.geoms_info,
+ constraint_state=self.constraint_solver.constraint_state,
entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
dofs_state_adjoint_cache=self.dofs_state_adjoint_cache,
links_state_adjoint_cache=self.links_state_adjoint_cache,
joints_state_adjoint_cache=self.joints_state_adjoint_cache,
geoms_state_adjoint_cache=self.geoms_state_adjoint_cache,
- rigid_adjoint_cache=self._rigid_adjoint_cache,
+ rigid_adjoint_cache=self._rigid_adjoint_cache_fw,
static_rigid_sim_config=self._static_rigid_sim_config,
)
self.substep(f)
# =================== Backward substep ======================
+ # Change to backward mode: Note that we use forward mode in the [substep] function right above. This is because
+ # we need to reproduce the same data as in the forward pass for the backward pass. The backward pass should be
+ # logically same as the forward pass, but it is tweaked to use autodiff, and thus numerical differences could
+ # arise. To prevent this, we change to backward mode here.
+ self._static_rigid_sim_config.is_backward = True
+
envs_idx = self._scene._sanitize_envs_idx(None)
if not self._enable_mujoco_compatibility:
- kernel_forward_velocity.grad(
- envs_idx=envs_idx,
- links_state=self.links_state,
- links_info=self.links_info,
- joints_info=self.joints_info,
- dofs_state=self.dofs_state,
+ # kernel_forward_velocity.grad(
+ # envs_idx=envs_idx,
+ # links_state=self.links_state,
+ # links_info=self.links_info,
+ # joints_info=self.joints_info,
+ # dofs_state=self.dofs_state,
+ # entities_info=self.entities_info,
+ # rigid_global_info=self._rigid_global_info,
+ # static_rigid_sim_config=self._static_rigid_sim_config,
+ # )
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_forward_velocity_ad.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_update_geoms.grad(
+ envs_idx,
entities_info=self.entities_info,
+ geoms_info=self.geoms_info,
+ geoms_state=self.geoms_state,
+ links_state=self.links_state,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
+ force_update_fixed_geoms=False,
)
- kernel_update_cartesian_space.grad(
+ # kernel_COM_links.grad(
+ # links_state=self.links_state,
+ # links_info=self.links_info,
+ # joints_state=self.joints_state,
+ # joints_info=self.joints_info,
+ # dofs_state=self.dofs_state,
+ # dofs_info=self.dofs_info,
+ # geoms_state=self.geoms_state,
+ # geoms_info=self.geoms_info,
+ # entities_info=self.entities_info,
+ # rigid_global_info=self._rigid_global_info,
+ # static_rigid_sim_config=self._static_rigid_sim_config,
+ # force_update_fixed_geoms=False,
+ # )
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_COM_links_ad_1.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_COM_links_ad_0.grad(
links_state=self.links_state,
links_info=self.links_info,
joints_state=self.joints_state,
joints_info=self.joints_info,
dofs_state=self.dofs_state,
dofs_info=self.dofs_info,
- geoms_state=self.geoms_state,
- geoms_info=self.geoms_info,
entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
- force_update_fixed_geoms=False,
)
-
- is_grad_valid = kernel_begin_backward_substep(
+ # kernel_forward_kinematics.grad(
+ # links_state=self.links_state,
+ # links_info=self.links_info,
+ # joints_state=self.joints_state,
+ # joints_info=self.joints_info,
+ # dofs_state=self.dofs_state,
+ # dofs_info=self.dofs_info,
+ # geoms_state=self.geoms_state,
+ # geoms_info=self.geoms_info,
+ # entities_info=self.entities_info,
+ # rigid_global_info=self._rigid_global_info,
+ # static_rigid_sim_config=self._static_rigid_sim_config,
+ # force_update_fixed_geoms=False,
+ # )
+
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_forward_kinematics_ad.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ geoms_state=self.geoms_state,
+ geoms_info=self.geoms_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ force_update_fixed_geoms=False,
+ )
+ # kernel_update_cartesian_space.grad(
+ # links_state=self.links_state,
+ # links_info=self.links_info,
+ # joints_state=self.joints_state,
+ # joints_info=self.joints_info,
+ # dofs_state=self.dofs_state,
+ # dofs_info=self.dofs_info,
+ # geoms_state=self.geoms_state,
+ # geoms_info=self.geoms_info,
+ # entities_info=self.entities_info,
+ # rigid_global_info=self._rigid_global_info,
+ # static_rigid_sim_config=self._static_rigid_sim_config,
+ # force_update_fixed_geoms=False,
+ # )
+
+ errno = kernel_begin_backward_substep(
f=f,
links_state=self.links_state,
links_info=self.links_info,
@@ -1218,11 +1325,20 @@ def substep_pre_coupling_grad(self, f):
links_state_adjoint_cache=self.links_state_adjoint_cache,
joints_state_adjoint_cache=self.joints_state_adjoint_cache,
geoms_state_adjoint_cache=self.geoms_state_adjoint_cache,
- rigid_adjoint_cache=self._rigid_adjoint_cache,
+ rigid_adjoint_cache_fw=self._rigid_adjoint_cache_fw,
+ rigid_adjoint_cache_bw=self._rigid_adjoint_cache_bw,
static_rigid_sim_config=self._static_rigid_sim_config,
)
- if not is_grad_valid:
- gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}")
+ match errno:
+ case 1:
+ gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}")
+ # case 2:
+ # qpos_diff = self._rigid_adjoint_cache_fw.qpos.to_numpy() - self._rigid_adjoint_cache_bw.qpos.to_numpy()
+ # vel_diff = self._rigid_adjoint_cache_fw.dofs_vel.to_numpy() - self._rigid_adjoint_cache_bw.dofs_vel.to_numpy()
+ # acc_diff = self._rigid_adjoint_cache_fw.dofs_acc.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc.to_numpy()
+ # acc_smooth_diff = self._rigid_adjoint_cache_fw.dofs_acc_smooth.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc_smooth.to_numpy()
+ # solver_qacc_ws_diff = self._rigid_adjoint_cache_fw.solver_qacc_ws.to_numpy() - self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy()
+ # gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}")
kernel_step_2.grad(
dofs_state=self.dofs_state,
@@ -1241,68 +1357,190 @@ def substep_pre_coupling_grad(self, f):
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
)
- # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
- # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules).
- # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc].
- # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through
- # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation
- # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel.
- # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a
- # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid
- # the data access violation. However, both of these require major changes.
- kernel_compute_qacc.grad(
- dofs_state=self.dofs_state,
- entities_info=self.entities_info,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- )
+ if not self._disable_constraint:
+ # Solver backward
+ dL_dqacc = self.dofs_state.acc.grad.to_numpy()
+ self.dofs_state.acc.grad.fill(0.0)
+
+ qacc_ws = self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy()[f]
+ self.constraint_solver.constraint_state.qacc_ws.from_numpy(qacc_ws)
+
+ self.constraint_solver.backward(dL_dqacc)
+ dL_dM = self.constraint_solver.constraint_state.dL_dM.to_numpy()
+ dL_djac = self.constraint_solver.constraint_state.dL_djac.to_numpy()
+ dL_daref = self.constraint_solver.constraint_state.dL_daref.to_numpy()
+ dL_defc_D = self.constraint_solver.constraint_state.dL_defc_D.to_numpy()
+ dL_dforce = self.constraint_solver.constraint_state.dL_dforce.to_numpy()
+
+ self._rigid_global_info.mass_mat.grad.from_numpy(dL_dM)
+ self.constraint_solver.constraint_state.jac.grad.from_numpy(dL_djac)
+ self.constraint_solver.constraint_state.aref.grad.from_numpy(dL_daref)
+ self.constraint_solver.constraint_state.efc_D.grad.from_numpy(dL_defc_D)
+ self.dofs_state.force.grad.from_numpy(dL_dforce)
+
+ self.constraint_solver.constraint_state.n_constraints.fill(0)
+ self.constraint_solver.add_inequality_constraints_grad()
+
+ # Collider backward
+ dL_dcontact_pos = self.collider._collider_state.contact_data.pos.grad.to_numpy()
+ dL_dcontact_normal = self.collider._collider_state.contact_data.normal.grad.to_numpy()
+ dL_dcontact_penetration = self.collider._collider_state.contact_data.penetration.grad.to_numpy()
+ self.collider.backward(dL_dcontact_pos, dL_dcontact_normal, dL_dcontact_penetration)
+ else:
+ # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
+ # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules).
+ # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc].
+ # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through
+ # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation
+ # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel.
+ # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a
+ # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid
+ # the data access violation. However, both of these require major changes.
+ kernel_compute_qacc.grad(
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
kernel_copy_acc(
f=f,
dofs_state=self.dofs_state,
- rigid_adjoint_cache=self._rigid_adjoint_cache,
+ rigid_adjoint_cache=self._rigid_adjoint_cache_fw,
static_rigid_sim_config=self._static_rigid_sim_config,
)
- kernel_forward_dynamics_without_qacc.grad(
+ kernel_bias_force.grad(
+ dofs_state=self.dofs_state,
links_state=self.links_state,
links_info=self.links_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_update_force.grad(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_update_acc.grad(
dofs_state=self.dofs_state,
- dofs_info=self.dofs_info,
- joints_info=self.joints_info,
+ links_info=self.links_info,
+ links_state=self.links_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_torque_and_passive_force.grad(
entities_state=self.entities_state,
entities_info=self.entities_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
geoms_state=self.geoms_state,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
)
+ kernel_factor_mass.grad(
+ implicit_damping=False,
+ entities_info=self.entities_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_compute_mass_matrix_ad.grad(
+ implicit_damping=self._static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+
+ # kernel_forward_dynamics_without_qacc.grad(
+ # links_state=self.links_state,
+ # links_info=self.links_info,
+ # dofs_state=self.dofs_state,
+ # dofs_info=self.dofs_info,
+ # joints_info=self.joints_info,
+ # entities_state=self.entities_state,
+ # entities_info=self.entities_info,
+ # geoms_state=self.geoms_state,
+ # rigid_global_info=self._rigid_global_info,
+ # static_rigid_sim_config=self._static_rigid_sim_config,
+ # contact_island_state=self.constraint_solver.contact_island.contact_island_state,
+ # )
# If it was the very first substep, we need to backpropagate through the initial update of the cartesian space
if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0:
- kernel_forward_velocity.grad(
- envs_idx=envs_idx,
- links_state=self.links_state,
- links_info=self.links_info,
- joints_info=self.joints_info,
- dofs_state=self.dofs_state,
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_forward_velocity_ad.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_update_geoms.grad(
+ envs_idx,
entities_info=self.entities_info,
+ geoms_info=self.geoms_info,
+ geoms_state=self.geoms_state,
+ links_state=self.links_state,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
+ force_update_fixed_geoms=False,
)
- kernel_update_cartesian_space.grad(
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_COM_links_ad_1.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ kernel_COM_links_ad_0.grad(
links_state=self.links_state,
links_info=self.links_info,
joints_state=self.joints_state,
joints_info=self.joints_info,
dofs_state=self.dofs_state,
dofs_info=self.dofs_info,
- geoms_state=self.geoms_state,
- geoms_info=self.geoms_info,
entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
- force_update_fixed_geoms=False,
)
+ for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity):
+ kernel_forward_kinematics_ad.grad(
+ i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ geoms_state=self.geoms_state,
+ geoms_info=self.geoms_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ force_update_fixed_geoms=False,
+ )
# Change back to forward mode
self._static_rigid_sim_config.is_backward = False
@@ -1392,6 +1630,27 @@ def reset_grad(self):
for entity in self._entities:
entity.reset_grad()
self._queried_states.clear()
+ self.zero_grad()
+
+ def zero_grad(self):
+ # zero grad
+ for state in [
+ self._rigid_global_info,
+ self.links_state,
+ self.dofs_state,
+ self.geoms_state,
+ self.joints_state,
+ self.entities_state,
+ self.constraint_solver.constraint_state,
+ self.collider._collider_state.diff_contact_input,
+ self.collider._collider_state.contact_data,
+ self._rigid_adjoint_cache_fw,
+ self._rigid_adjoint_cache_bw,
+ ]:
+ for attr in state.__dict__.values():
+ if hasattr(attr, 'grad') and attr.grad is not None:
+ attr.grad.fill(0.0)
+
def update_geoms_render_T(self):
kernel_update_geoms_render_T(
@@ -1494,20 +1753,32 @@ def save_ckpt(self, ckpt_name):
if ckpt_name not in self._ckpt:
self._ckpt[ckpt_name] = dict()
- self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache.qpos)
- self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_vel)
- self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_acc)
+ self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache_fw.qpos, copy=True)
+ self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_vel, copy=True)
+ self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc, copy=True)
+ self._ckpt[ckpt_name]["dofs_acc_smooth"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc_smooth, copy=True)
+ self._ckpt[ckpt_name]["solver_qacc_ws"] = ti_to_numpy(self._rigid_adjoint_cache_fw.solver_qacc_ws, copy=True)
for entity in self._entities:
entity.save_ckpt(ckpt_name)
def load_ckpt(self, ckpt_name):
+ # Load adjoint cache for backward pass
+ self._rigid_adjoint_cache_bw.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"])
+ self._rigid_adjoint_cache_bw.dofs_vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"])
+ self._rigid_adjoint_cache_bw.dofs_acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"])
+ self._rigid_adjoint_cache_bw.dofs_acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"])
+ self._rigid_adjoint_cache_bw.solver_qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"])
+
# Set first frame
self._rigid_global_info.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"][0])
self.dofs_state.vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"][0])
self.dofs_state.acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"][0])
+ self.dofs_state.acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"][0])
+ self.constraint_solver.constraint_state.qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"][0])
if not self._enable_mujoco_compatibility:
+ envs_idx = self._scene._sanitize_envs_idx(None)
kernel_update_cartesian_space(
links_state=self.links_state,
links_info=self.links_info,
@@ -1522,9 +1793,47 @@ def load_ckpt(self, ckpt_name):
static_rigid_sim_config=self._static_rigid_sim_config,
force_update_fixed_geoms=False,
)
+ kernel_forward_velocity(
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
for entity in self._entities:
entity.load_ckpt(ckpt_name)
+
+ def load_test(self):
+ if not self._enable_mujoco_compatibility:
+ envs_idx = self._scene._sanitize_envs_idx(None)
+ kernel_update_cartesian_space(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ geoms_state=self.geoms_state,
+ geoms_info=self.geoms_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ force_update_fixed_geoms=False,
+ )
+ kernel_forward_velocity(
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
@property
def is_active(self):
@@ -2056,6 +2365,14 @@ def control_dofs_force(self, force, dofs_idx=None, envs_idx=None):
kernel_control_dofs_force(force, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config)
+ def control_dofs_force_grad(self, dofs_idx, envs_idx, force_grad):
+ force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables(
+ force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True
+ )
+ if self.n_envs == 0:
+ force_grad_ = force_grad_.unsqueeze(0)
+ kernel_control_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config)
+
def control_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None):
if gs.use_zerocopy:
mask = (0, *indices_to_mask(dofs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, dofs_idx)
@@ -2433,11 +2750,11 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True
def clear_external_force(self):
if gs.use_zerocopy:
- for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel):
+ for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel, self.dofs_state.ctrl_force):
out = ti_to_torch(tensor, copy=False)
out.zero_()
else:
- kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config)
+ kernel_clear_external_force(self.links_state, self.dofs_state, self._rigid_global_info, self._static_rigid_sim_config)
def update_vgeoms(self):
kernel_update_vgeoms(self.vgeoms_info, self.vgeoms_state, self.links_state, self._static_rigid_sim_config)
@@ -3340,6 +3657,27 @@ def func_vel_at_point(pos_world, link_idx, i_b, links_state: array_class.LinksSt
vel_lin = links_state.cd_vel[link_idx, i_b]
return vel_rot + vel_lin
+@ti.kernel
+def kernel_compute_mass_matrix_ad(
+ implicit_damping: ti.template(),
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ func_compute_mass_matrix(
+ implicit_damping=implicit_damping,
+ links_state=links_state,
+ links_info=links_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
@ti.func
def func_compute_mass_matrix(
@@ -3593,6 +3931,23 @@ def func_compute_mass_matrix(
# qM += d qfrc_actuator / d qvel
rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None]
+@ti.kernel
+def kernel_factor_mass(
+ implicit_damping: ti.template(),
+ entities_info: array_class.EntitiesInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ func_factor_mass(
+ implicit_damping=implicit_damping,
+ entities_info=entities_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
@ti.func
def func_factor_mass(
@@ -3864,7 +4219,7 @@ def func_solve_mass_batched(
# Static inner loop for backward pass
ti.static(range(static_rigid_sim_config.max_n_awake_entities))
if ti.static(static_rigid_sim_config.use_hibernation)
- else ti.static(range(static_rigid_sim_config.max_n_links_per_entity))
+ else ti.static(range(static_rigid_sim_config.n_entities))
)
):
if func_check_index_range(i_0, 0, rigid_global_info.n_awake_entities[i_b], BW):
@@ -4125,11 +4480,13 @@ def kernel_forward_dynamics_without_qacc(
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_clear_external_force(
links_state: array_class.LinksState,
+ dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
func_clear_external_force(
links_state=links_state,
+ dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)
@@ -4169,9 +4526,8 @@ def kernel_update_cartesian_space(
)
-@ti.func
-def func_update_cartesian_space(
- i_b,
+@ti.kernel
+def kernel_forward_kinematics(
links_state: array_class.LinksState,
links_info: array_class.LinksInfo,
joints_state: array_class.JointsState,
@@ -4185,64 +4541,160 @@ def func_update_cartesian_space(
static_rigid_sim_config: ti.template(),
force_update_fixed_geoms: ti.template(),
):
- func_forward_kinematics(
- i_b,
- links_state=links_state,
- links_info=links_info,
- joints_state=joints_state,
- joints_info=joints_info,
- dofs_state=dofs_state,
- dofs_info=dofs_info,
- entities_info=entities_info,
- rigid_global_info=rigid_global_info,
- static_rigid_sim_config=static_rigid_sim_config,
- )
- func_COM_links(
- i_b,
- links_state=links_state,
- links_info=links_info,
- joints_state=joints_state,
- joints_info=joints_info,
- dofs_state=dofs_state,
- dofs_info=dofs_info,
- entities_info=entities_info,
- rigid_global_info=rigid_global_info,
- static_rigid_sim_config=static_rigid_sim_config,
- )
- func_update_geoms(
- i_b=i_b,
- entities_info=entities_info,
- geoms_info=geoms_info,
- geoms_state=geoms_state,
- links_state=links_state,
- rigid_global_info=rigid_global_info,
- static_rigid_sim_config=static_rigid_sim_config,
- force_update_fixed_geoms=force_update_fixed_geoms,
- )
-
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_b in range(links_state.pos.shape[1]):
+ func_forward_kinematics(
+ i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
-@ti.kernel(fastcache=gs.use_fastcache)
-def kernel_step_1(
+@ti.kernel
+def kernel_forward_kinematics_ad(
+ i_l_:ti.int32,
links_state: array_class.LinksState,
links_info: array_class.LinksInfo,
joints_state: array_class.JointsState,
joints_info: array_class.JointsInfo,
dofs_state: array_class.DofsState,
dofs_info: array_class.DofsInfo,
- geoms_state: array_class.GeomsState,
geoms_info: array_class.GeomsInfo,
- entities_state: array_class.EntitiesState,
+ geoms_state: array_class.GeomsState,
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
- contact_island_state: array_class.ContactIslandState,
+ force_update_fixed_geoms: ti.template(),
):
- if ti.static(static_rigid_sim_config.enable_mujoco_compatibility):
- _B = links_state.pos.shape[1]
- ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
- for i_b in range(_B):
- func_update_cartesian_space(
- i_b=i_b,
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]):
+ func_forward_kinematics_entity_ad(
+ i_e=i_e,
+ i_l_=i_l_,
+ i_b=i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+
+
+@ti.kernel
+def kernel_COM_links(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ geoms_info: array_class.GeomsInfo,
+ geoms_state: array_class.GeomsState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+ force_update_fixed_geoms: ti.template(),
+):
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_b in range(links_state.pos.shape[1]):
+ func_COM_links(
+ i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+
+
+@ti.func
+def func_update_cartesian_space(
+ i_b,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ geoms_info: array_class.GeomsInfo,
+ geoms_state: array_class.GeomsState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+ force_update_fixed_geoms: ti.template(),
+):
+ func_forward_kinematics(
+ i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+ func_COM_links(
+ i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+ func_update_geoms(
+ i_b=i_b,
+ entities_info=entities_info,
+ geoms_info=geoms_info,
+ geoms_state=geoms_state,
+ links_state=links_state,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ force_update_fixed_geoms=force_update_fixed_geoms,
+ )
+
+
+@ti.kernel(fastcache=gs.use_fastcache)
+def kernel_step_1(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ geoms_state: array_class.GeomsState,
+ geoms_info: array_class.GeomsInfo,
+ entities_state: array_class.EntitiesState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+ contact_island_state: array_class.ContactIslandState,
+):
+ if ti.static(static_rigid_sim_config.enable_mujoco_compatibility):
+ _B = links_state.pos.shape[1]
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_b in range(_B):
+ func_update_cartesian_space(
+ i_b=i_b,
links_state=links_state,
links_info=links_info,
joints_state=joints_state,
@@ -4530,6 +4982,34 @@ def kernel_forward_velocity(
)
+@ti.kernel(fastcache=gs.use_fastcache)
+def kernel_forward_velocity_ad(
+ i_l_:ti.int32,
+ envs_idx: ti.types.ndarray(),
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ for i_e, i_b_ in ti.ndrange(entities_info.n_links.shape[0], envs_idx.shape[0]):
+ i_b = envs_idx[i_b_]
+ func_forward_velocity_entity_ad(
+ i_e=i_e,
+ i_l_=i_l_,
+ i_b=i_b,
+ entities_info=entities_info,
+ links_info=links_info,
+ links_state=links_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+
+
@ti.func
def func_COM_links(
i_b,
@@ -4863,6 +5343,240 @@ def func_COM_links(
dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b]
)
+@ti.kernel
+def kernel_COM_links_ad_0(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ links_state.root_COM_bw[i_l, i_b].fill(0.0)
+ links_state.mass_sum[i_l, i_b] = 0.0
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+
+ mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b]
+ (
+ links_state.i_pos_bw[i_l, i_b],
+ links_state.i_quat[i_l, i_b],
+ ) = gu.ti_transform_pos_quat_by_trans_quat(
+ links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b],
+ links_info.inertial_quat[I_l],
+ links_state.pos[i_l, i_b],
+ links_state.quat[i_l, i_b],
+ )
+
+ i_r = links_info.root_idx[I_l]
+ links_state.mass_sum[i_r, i_b] += mass
+ links_state.root_COM_bw[i_r, i_b] += mass * links_state.i_pos_bw[i_l, i_b]
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+
+ i_r = links_info.root_idx[I_l]
+ if i_l == i_r:
+ links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b]
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+
+ i_r = links_info.root_idx[I_l]
+ links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b]
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+
+ links_state.i_pos[i_l, i_b] = links_state.i_pos_bw[i_l, i_b] - links_state.root_COM[i_l, i_b]
+
+ i_inertial = links_info.inertial_i[I_l]
+ i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b]
+ (
+ links_state.cinr_inertial[i_l, i_b],
+ links_state.cinr_pos[i_l, i_b],
+ links_state.cinr_quat[i_l, i_b],
+ links_state.cinr_mass[i_l, i_b],
+ ) = gu.ti_transform_inertia_by_trans_quat(
+ i_inertial,
+ i_mass,
+ links_state.i_pos[i_l, i_b],
+ links_state.i_quat[i_l, i_b],
+ rigid_global_info.EPS[None],
+ )
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]):
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+ BW = ti.static(static_rigid_sim_config.is_backward)
+
+ if links_info.n_dofs[I_l] > 0:
+ for i_j_ in (
+ range(links_info.joint_start[I_l], links_info.joint_end[I_l])
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = i_j_ if ti.static(not BW) else (i_j_ + links_info.joint_start[I_l])
+
+ if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW):
+ offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b]
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+ joint_type = joints_info.type[I_j]
+
+ dof_start = joints_info.dof_start[I_j]
+
+ EPS = rigid_global_info.EPS[None]
+ if joint_type == gs.JOINT_TYPE.REVOLUTE:
+ dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b]
+ dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos)
+ elif joint_type == gs.JOINT_TYPE.PRISMATIC:
+ dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3)
+ dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b]
+ elif joint_type == gs.JOINT_TYPE.SPHERICAL:
+ xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose()
+ for i in ti.static(range(3)):
+ dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :]
+ dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos)
+ elif joint_type == gs.JOINT_TYPE.FREE:
+ for i in ti.static(range(3)):
+ dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3)
+ dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3)
+ dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0
+
+ xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose()
+ for i in ti.static(range(3)):
+ dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :]
+ dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos)
+
+ for i_d_ in (
+ range(dof_start, joints_info.dof_end[I_j])
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint))
+ ):
+ i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start)
+ if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW):
+ dofs_state.cdofvel_ang[i_d, i_b] = (
+ dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b]
+ )
+ dofs_state.cdofvel_vel[i_d, i_b] = (
+ dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b]
+ )
+
+@ti.kernel
+def kernel_COM_links_ad_1(
+ i_l_:ti.int32,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]):
+ func_COM_links_ad_1(
+ i_e=i_e,
+ i_l_=i_l_,
+ i_b=i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+
+@ti.func
+def func_COM_links_ad_1(
+ i_e: ti.int32,
+ i_l_:ti.int32,
+ i_b: ti.int32,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ BW = ti.static(static_rigid_sim_config.is_backward)
+ i_l = i_l_ + entities_info.link_start[i_e]
+ if i_l < entities_info.link_end[i_e]:
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+
+ if links_info.n_dofs[I_l] > 0:
+ i_p = links_info.parent_idx[I_l]
+
+ _i_j = links_info.joint_start[I_l]
+ _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j
+ joint_type = joints_info.type[_I_j]
+
+ p_pos = ti.Vector.zero(gs.ti_float, 3)
+ p_quat = gu.ti_identity_quat()
+ if i_p != -1:
+ p_pos = links_state.pos[i_p, i_b]
+ p_quat = links_state.quat[i_p, i_b]
+
+ if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1):
+ links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b]
+ links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b]
+ else:
+ (
+ links_state.j_pos_bw[i_l, 0, i_b],
+ links_state.j_quat_bw[i_l, 0, i_b],
+ ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat)
+
+ n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]
+
+ for i_j_ in (
+ range(n_joints)
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = i_j_ + links_info.joint_start[I_l]
+
+ curr_i_j = 0 if ti.static(not BW) else i_j_
+ next_i_j = 0 if ti.static(not BW) else i_j_ + 1
+
+ if func_check_index_range(
+ i_j,
+ links_info.joint_start[I_l],
+ links_info.joint_end[I_l],
+ BW,
+ ):
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+
+ (
+ links_state.j_pos_bw[i_l, next_i_j, i_b],
+ links_state.j_quat_bw[i_l, next_i_j, i_b],
+ ) = gu.ti_transform_pos_quat_by_trans_quat(
+ joints_info.pos[I_j],
+ gu.ti_identity_quat(),
+ links_state.j_pos_bw[i_l, curr_i_j, i_b],
+ links_state.j_quat_bw[i_l, curr_i_j, i_b],
+ )
+
+ i_j_ = 0 if ti.static(not BW) else n_joints
+ links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b]
+ links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b]
@ti.func
def func_forward_kinematics(
@@ -4889,7 +5603,7 @@ def func_forward_kinematics(
else (
ti.static(range(static_rigid_sim_config.max_n_awake_entities))
if ti.static(static_rigid_sim_config.use_hibernation)
- else ti.static(range(static_rigid_sim_config.max_n_links_per_entity))
+ else ti.static(range(static_rigid_sim_config.n_entities))
)
):
if func_check_index_range(
@@ -4941,7 +5655,7 @@ def func_forward_velocity(
# Static inner loop for backward pass
ti.static(range(static_rigid_sim_config.max_n_awake_entities))
if ti.static(static_rigid_sim_config.use_hibernation)
- else ti.static(range(static_rigid_sim_config.max_n_links_per_entity))
+ else ti.static(range(static_rigid_sim_config.n_entities))
)
):
if func_check_index_range(
@@ -5155,6 +5869,160 @@ def func_forward_kinematics_entity(
links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW)
links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW)
+@ti.func
+def func_forward_kinematics_entity_ad(
+ i_e,
+ i_l_,
+ i_b,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ BW = ti.static(static_rigid_sim_config.is_backward)
+ W = ti.static(func_write_field_if)
+ R = ti.static(func_read_field_if)
+ WR = ti.static(func_write_and_read_field_if)
+
+ EPS = rigid_global_info.EPS[None]
+
+ i_l = i_l_ + entities_info.link_start[i_e]
+
+ if i_l < entities_info.link_end[i_e]:
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+ I_l0 = (i_l, 0, i_b)
+
+ pos = W(links_state.pos_bw, I_l0, links_info.pos[I_l], BW)
+ quat = W(links_state.quat_bw, I_l0, links_info.quat[I_l], BW)
+ if links_info.parent_idx[I_l] != -1:
+ parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b]
+ parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b]
+ pos_ = parent_pos + gu.ti_transform_by_quat(links_info.pos[I_l], parent_quat)
+ quat_ = gu.ti_transform_quat_by_quat(links_info.quat[I_l], parent_quat)
+
+ pos = W(links_state.pos_bw, I_l0, pos_, BW)
+ quat = W(links_state.quat_bw, I_l0, quat_, BW)
+
+ n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]
+
+ for i_j_ in (
+ range(n_joints)
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = i_j_ + links_info.joint_start[I_l]
+
+ curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b)
+ next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b)
+
+ if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW):
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+ joint_type = joints_info.type[I_j]
+ q_start = joints_info.q_start[I_j]
+ dof_start = joints_info.dof_start[I_j]
+ I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start
+
+ # compute axis and anchor
+ if joint_type == gs.JOINT_TYPE.FREE:
+ joints_state.xanchor[i_j, i_b] = ti.Vector(
+ [
+ rigid_global_info.qpos[q_start, i_b],
+ rigid_global_info.qpos[q_start + 1, i_b],
+ rigid_global_info.qpos[q_start + 2, i_b],
+ ]
+ )
+ joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0])
+ elif joint_type == gs.JOINT_TYPE.FIXED:
+ pass
+ else:
+ axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float)
+ if joint_type == gs.JOINT_TYPE.REVOLUTE:
+ axis = dofs_info.motion_ang[I_d]
+ elif joint_type == gs.JOINT_TYPE.PRISMATIC:
+ axis = dofs_info.motion_vel[I_d]
+
+ pos_ = R(links_state.pos_bw, curr_I, pos, BW)
+ quat_ = R(links_state.quat_bw, curr_I, quat, BW)
+
+ joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat_) + pos_
+ joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat_)
+
+ if joint_type == gs.JOINT_TYPE.FREE:
+ pos_ = ti.Vector(
+ [
+ rigid_global_info.qpos[q_start, i_b],
+ rigid_global_info.qpos[q_start + 1, i_b],
+ rigid_global_info.qpos[q_start + 2, i_b],
+ ],
+ dt=gs.ti_float,
+ )
+ quat_ = ti.Vector(
+ [
+ rigid_global_info.qpos[q_start + 3, i_b],
+ rigid_global_info.qpos[q_start + 4, i_b],
+ rigid_global_info.qpos[q_start + 5, i_b],
+ rigid_global_info.qpos[q_start + 6, i_b],
+ ],
+ dt=gs.ti_float,
+ )
+ pos = WR(links_state.pos_bw, next_I, pos_, BW)
+ quat = WR(links_state.quat_bw, next_I, quat_, BW)
+
+ xyz = gu.ti_quat_to_xyz(quat, EPS)
+ for j in ti.static(range(3)):
+ dofs_state.pos[dof_start + j, i_b] = pos[j]
+ dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j]
+ elif joint_type == gs.JOINT_TYPE.FIXED:
+ pass
+ elif joint_type == gs.JOINT_TYPE.SPHERICAL:
+ print("SPHERICAL")
+ qloc = ti.Vector(
+ [
+ rigid_global_info.qpos[q_start, i_b],
+ rigid_global_info.qpos[q_start + 1, i_b],
+ rigid_global_info.qpos[q_start + 2, i_b],
+ rigid_global_info.qpos[q_start + 3, i_b],
+ ],
+ dt=gs.ti_float,
+ )
+ xyz = gu.ti_quat_to_xyz(qloc, EPS)
+ for j in ti.static(range(3)):
+ dofs_state.pos[dof_start + j, i_b] = xyz[j]
+ quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW))
+ quat = WR(links_state.quat_bw, next_I, quat_, BW)
+ pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat)
+ pos = W(links_state.pos_bw, next_I, pos_, BW)
+ elif joint_type == gs.JOINT_TYPE.REVOLUTE:
+ axis = dofs_info.motion_ang[I_d]
+ dofs_state.pos[dof_start, i_b] = (
+ rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b]
+ )
+ qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS)
+ quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW))
+ quat = WR(links_state.quat_bw, next_I, quat_, BW)
+ pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat)
+ pos = W(links_state.pos_bw, next_I, pos_, BW)
+ else: # joint_type == gs.JOINT_TYPE.PRISMATIC:
+ dofs_state.pos[dof_start, i_b] = (
+ rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b]
+ )
+ pos_ = (
+ R(links_state.pos_bw, curr_I, pos, BW)
+ + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b]
+ )
+ pos = W(links_state.pos_bw, next_I, pos_, BW)
+
+ # Skip link pose update for fixed root links to let users manually overwrite them
+ I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b)
+ if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]):
+ links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW)
+ links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW)
+
@ti.func
def func_forward_velocity_entity(
@@ -5281,6 +6149,127 @@ def func_forward_velocity_entity(
links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW)
links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW)
+@ti.func
+def func_forward_velocity_entity_ad(
+ i_e,
+ i_l_,
+ i_b,
+ entities_info: array_class.EntitiesInfo,
+ links_info: array_class.LinksInfo,
+ links_state: array_class.LinksState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ BW = ti.static(static_rigid_sim_config.is_backward)
+ W = ti.static(func_write_field_if)
+ R = ti.static(func_read_field_if)
+ A = ti.static(func_atomic_add_if)
+
+ i_l = i_l_ + entities_info.link_start[i_e]
+
+ if i_l < entities_info.link_end[i_e]:
+ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l
+ n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]
+
+ I_j0 = (i_l, 0, i_b)
+ cvel_vel = W(links_state.cd_vel_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW)
+ cvel_ang = W(links_state.cd_ang_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW)
+
+ if links_info.parent_idx[I_l] != -1:
+ cvel_vel = W(links_state.cd_vel_bw, I_j0, links_state.cd_vel[links_info.parent_idx[I_l], i_b], BW)
+ cvel_ang = W(links_state.cd_ang_bw, I_j0, links_state.cd_ang[links_info.parent_idx[I_l], i_b], BW)
+
+ for i_j_ in (
+ range(n_joints)
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_joints_per_link))
+ ):
+ i_j = i_j_ + links_info.joint_start[I_l]
+
+ if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW):
+ I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j
+ joint_type = joints_info.type[I_j]
+ q_start = joints_info.q_start[I_j]
+ dof_start = joints_info.dof_start[I_j]
+
+ curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b)
+ next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b)
+
+ if joint_type == gs.JOINT_TYPE.FREE:
+ for i_3 in ti.static(range(3)):
+ _vel = dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b]
+ _ang = dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b]
+
+ cvel_vel = cvel_vel + A(links_state.cd_vel_bw, curr_I, _vel, BW)
+ cvel_ang = cvel_ang + A(links_state.cd_ang_bw, curr_I, _ang, BW)
+
+ for i_3 in ti.static(range(3)):
+ (
+ dofs_state.cdofd_ang[dof_start + i_3, i_b],
+ dofs_state.cdofd_vel[dof_start + i_3, i_b],
+ ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3)
+
+ (
+ dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b],
+ dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b],
+ ) = gu.motion_cross_motion(
+ R(links_state.cd_ang_bw, curr_I, cvel_ang, BW),
+ R(links_state.cd_vel_bw, curr_I, cvel_vel, BW),
+ dofs_state.cdof_ang[dof_start + i_3 + 3, i_b],
+ dofs_state.cdof_vel[dof_start + i_3 + 3, i_b],
+ )
+
+ if ti.static(BW):
+ links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I]
+ links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I]
+
+ for i_3 in ti.static(range(3)):
+ _vel = (
+ dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b]
+ )
+ _ang = (
+ dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b]
+ )
+ cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW)
+ cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW)
+
+ else:
+ for i_d_ in (
+ range(dof_start, joints_info.dof_end[I_j])
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint))
+ ):
+ i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start)
+ if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW):
+ dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion(
+ R(links_state.cd_ang_bw, curr_I, cvel_ang, BW),
+ R(links_state.cd_vel_bw, curr_I, cvel_vel, BW),
+ dofs_state.cdof_ang[i_d, i_b],
+ dofs_state.cdof_vel[i_d, i_b],
+ )
+
+ if ti.static(BW):
+ links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I]
+ links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I]
+
+ for i_d_ in (
+ range(dof_start, joints_info.dof_end[I_j])
+ if ti.static(not BW)
+ else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint))
+ ):
+ i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start)
+ if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW):
+ _vel = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b]
+ _ang = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b]
+ cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW)
+ cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW)
+
+ I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b)
+ links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW)
+ links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW)
+
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_update_geoms(
@@ -5358,7 +6347,9 @@ def func_update_geoms(
)
):
i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0
- if func_check_index_range(i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], BW):
+ if func_check_index_range(
+ i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], static_rigid_sim_config.use_hibernation
+ ):
if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]:
(
geoms_state.pos[i_g, i_b],
@@ -5717,6 +6708,7 @@ def func_apply_link_external_torque(
@ti.func
def func_clear_external_force(
links_state: array_class.LinksState,
+ dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
@@ -5735,7 +6727,38 @@ def func_clear_external_force(
i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0
links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3)
links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3)
+
+ ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
+ for I in ti.grouped(dofs_state.ctrl_force):
+ dofs_state.ctrl_force[I] = ti.Vector.zero(gs.ti_float, 3)
+@ti.kernel
+def kernel_torque_and_passive_force(
+ entities_state: array_class.EntitiesState,
+ entities_info: array_class.EntitiesInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_info: array_class.JointsInfo,
+ geoms_state: array_class.GeomsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+ contact_island_state: array_class.ContactIslandState,
+):
+ func_torque_and_passive_force(
+ entities_state=entities_state,
+ entities_info=entities_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ links_state=links_state,
+ links_info=links_info,
+ joints_info=joints_info,
+ geoms_state=geoms_state,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ contact_island_state=contact_island_state,
+ )
@ti.func
def func_torque_and_passive_force(
@@ -6055,6 +7078,21 @@ def func_update_acc(
BW,
)
+@ti.kernel
+def kernel_update_force(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ func_update_force(
+ links_state=links_state,
+ links_info=links_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
@ti.func
def func_update_force(
@@ -6195,6 +7233,22 @@ def func_actuation(self):
self.dofs_state.act_length[i_d, i_b] = 0.0
self.dofs_state.qf_actuator[i_d, i_b] = self.dofs_state.act_length[i_d, i_b]
+@ti.kernel
+def kernel_bias_force(
+ dofs_state: array_class.DofsState,
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ func_bias_force(
+ dofs_state=dofs_state,
+ links_state=links_state,
+ links_info=links_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+
@ti.func
def func_bias_force(
@@ -6525,17 +7579,19 @@ def func_copy_next_to_curr_grad(
def kernel_save_adjoint_cache(
f: ti.int32,
dofs_state: array_class.DofsState,
+ constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
rigid_adjoint_cache: array_class.RigidAdjointCache,
static_rigid_sim_config: ti.template(),
):
- func_save_adjoint_cache(f, dofs_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config)
+ func_save_adjoint_cache(f, dofs_state, constraint_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config)
@ti.func
def func_save_adjoint_cache(
f: ti.int32,
dofs_state: array_class.DofsState,
+ constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
rigid_adjoint_cache: array_class.RigidAdjointCache,
static_rigid_sim_config: ti.template(),
@@ -6548,6 +7604,8 @@ def func_save_adjoint_cache(
for i_d, i_b in ti.ndrange(n_dofs, _B):
rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b]
rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b]
+ rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]
+ rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_q, i_b in ti.ndrange(n_qs, _B):
@@ -6558,6 +7616,7 @@ def func_save_adjoint_cache(
def func_load_adjoint_cache(
f: ti.int32,
dofs_state: array_class.DofsState,
+ constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
rigid_adjoint_cache: array_class.RigidAdjointCache,
static_rigid_sim_config: ti.template(),
@@ -6570,6 +7629,8 @@ def func_load_adjoint_cache(
for i_d, i_b in ti.ndrange(n_dofs, _B):
dofs_state.vel[i_d, i_b] = rigid_adjoint_cache.dofs_vel[f, i_d, i_b]
dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b]
+ dofs_state.acc_smooth[i_d, i_b] = rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b]
+ constraint_state.qacc_ws[i_d, i_b] = rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_q, i_b in ti.ndrange(n_qs, _B):
@@ -6587,6 +7648,7 @@ def kernel_prepare_backward_substep(
dofs_info: array_class.DofsInfo,
geoms_state: array_class.GeomsState,
geoms_info: array_class.GeomsInfo,
+ constraint_state: array_class.ConstraintState,
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
dofs_state_adjoint_cache: array_class.DofsState,
@@ -6600,6 +7662,7 @@ def kernel_prepare_backward_substep(
func_load_adjoint_cache(
f=f,
dofs_state=dofs_state,
+ constraint_state=constraint_state,
rigid_global_info=rigid_global_info,
rigid_adjoint_cache=rigid_adjoint_cache,
static_rigid_sim_config=static_rigid_sim_config,
@@ -6625,6 +7688,16 @@ def kernel_prepare_backward_substep(
static_rigid_sim_config=static_rigid_sim_config,
force_update_fixed_geoms=False,
)
+ func_forward_velocity(
+ i_b=i_b,
+ entities_info=entities_info,
+ links_info=links_info,
+ links_state=links_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
# FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names.
# Save results of [update_cartesian_space] to adjoint cache
@@ -6658,20 +7731,29 @@ def kernel_begin_backward_substep(
links_state_adjoint_cache: array_class.LinksState,
joints_state_adjoint_cache: array_class.JointsState,
geoms_state_adjoint_cache: array_class.GeomsState,
- rigid_adjoint_cache: array_class.RigidAdjointCache,
+ rigid_adjoint_cache_fw: array_class.RigidAdjointCache,
+ rigid_adjoint_cache_bw: array_class.RigidAdjointCache,
static_rigid_sim_config: ti.template(),
) -> ti.i32:
+ errno = 0
is_grad_valid = func_is_grad_valid(
rigid_global_info=rigid_global_info,
dofs_state=dofs_state,
static_rigid_sim_config=static_rigid_sim_config,
)
- if is_grad_valid:
+ # Check the integrity of the next frame's adjoint cache as it is the computation result of the current frame
+ is_cache_valid = func_check_cache_integrity(
+ f=f + 1,
+ rigid_adjoint_cache_fw=rigid_adjoint_cache_fw,
+ rigid_adjoint_cache_bw=rigid_adjoint_cache_bw,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
+ if is_grad_valid and is_cache_valid:
func_copy_next_to_curr_grad(
f=f,
dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
- rigid_adjoint_cache=rigid_adjoint_cache,
+ rigid_adjoint_cache=rigid_adjoint_cache_fw,
static_rigid_sim_config=static_rigid_sim_config,
)
@@ -6689,8 +7771,12 @@ def kernel_begin_backward_substep(
geoms_state_adjoint_cache=geoms_state_adjoint_cache,
static_rigid_sim_config=static_rigid_sim_config,
)
+ elif not is_grad_valid:
+ errno = 1
+ elif not is_cache_valid:
+ errno = 2
- return is_grad_valid
+ return errno
@ti.func
@@ -6713,6 +7799,37 @@ def func_is_grad_valid(
return is_valid
+@ti.func
+def func_check_cache_integrity(
+ f: ti.int32,
+ rigid_adjoint_cache_fw: array_class.RigidAdjointCache,
+ rigid_adjoint_cache_bw: array_class.RigidAdjointCache,
+ static_rigid_sim_config: ti.template(),
+):
+ is_valid = True
+ n_qs = rigid_adjoint_cache_fw.qpos.shape[1]
+ n_dofs = rigid_adjoint_cache_fw.dofs_vel.shape[1]
+ _B = rigid_adjoint_cache_fw.qpos.shape[2]
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_q, i_b in ti.ndrange(n_qs, _B):
+ if rigid_adjoint_cache_fw.qpos[f, i_q, i_b] != rigid_adjoint_cache_bw.qpos[f, i_q, i_b]:
+ is_valid = False
+
+ ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_d, i_b in ti.ndrange(n_dofs, _B):
+ if rigid_adjoint_cache_fw.dofs_vel[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_vel[f, i_d, i_b]:
+ is_valid = False
+ if rigid_adjoint_cache_fw.dofs_acc[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc[f, i_d, i_b]:
+ is_valid = False
+ if rigid_adjoint_cache_fw.dofs_acc_smooth[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc_smooth[f, i_d, i_b]:
+ is_valid = False
+ if rigid_adjoint_cache_fw.solver_qacc_ws[f, i_d, i_b] != rigid_adjoint_cache_bw.solver_qacc_ws[f, i_d, i_b]:
+ is_valid = False
+
+ return is_valid
+
+
@ti.func
def func_copy_cartesian_space(
dofs_state: array_class.DofsState,
@@ -7586,6 +8703,19 @@ def kernel_control_dofs_force(
dofs_state.ctrl_mode[dofs_idx[i_d_], envs_idx[i_b_]] = gs.CTRL_MODE.FORCE
dofs_state.ctrl_force[dofs_idx[i_d_], envs_idx[i_b_]] = force[i_b_, i_d_]
+@ti.kernel(fastcache=gs.use_fastcache)
+def kernel_control_dofs_force_grad(
+ force_grad: ti.types.ndarray(),
+ dofs_idx: ti.types.ndarray(),
+ envs_idx: ti.types.ndarray(),
+ dofs_state: array_class.DofsState,
+ static_rigid_sim_config: ti.template(),
+):
+ ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
+ for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]):
+ force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]]
+ dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0
+
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_control_dofs_velocity(
@@ -7848,3 +8978,14 @@ def func_write_and_read_field_if(field: array_class.V_ANNOTATION, I, value, cond
def func_check_index_range(idx: ti.i32, min: ti.i32, max: ti.i32, cond: ti.template()):
# Conditionally check if the index is in the range [min, max) to save computational cost
return (idx >= min and idx < max) if ti.static(cond) else True
+
+
+@ti.kernel(fastcache=gs.use_fastcache)
+def kernel_zero_grad(
+ links_state: array_class.LinksState,
+ dofs_state: array_class.DofsState,
+ geoms_state: array_class.GeomsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: ti.template(),
+):
+ pass
\ No newline at end of file
diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py
index b6ee1f6dd1..06ce22a501 100644
--- a/genesis/engine/states/entities.py
+++ b/genesis/engine/states/entities.py
@@ -204,13 +204,19 @@ def __init__(self, entity, s_global):
scene = self._entity.scene
self._pos = gs.zeros((num_batch, 3), dtype=float, requires_grad=requires_grad, scene=scene)
self._quat = gs.zeros((num_batch, 4), dtype=float, requires_grad=requires_grad, scene=scene)
+ self._qpos = gs.zeros((num_batch, entity.n_qs), dtype=float, requires_grad=requires_grad, scene=scene)
+ self._dofs_vel = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene)
+ self._dofs_acc = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene)
def serializable(self):
self._entity = None
self._pos = self._pos.detach()
self._quat = self._quat.detach()
-
+ self._qpos = self._qpos.detach()
+ self._dofs_vel = self._dofs_vel.detach()
+ self._dofs_acc = self._dofs_acc.detach()
+
@property
def entity(self):
return self._entity
@@ -226,3 +232,15 @@ def pos(self):
@property
def quat(self):
return self._quat
+
+ @property
+ def qpos(self):
+ return self._qpos
+
+ @property
+ def dofs_vel(self):
+ return self._dofs_vel
+
+ @property
+ def dofs_acc(self):
+ return self._dofs_acc
\ No newline at end of file
diff --git a/genesis/options/solvers.py b/genesis/options/solvers.py
index 67da878c2c..5315b7b61c 100644
--- a/genesis/options/solvers.py
+++ b/genesis/options/solvers.py
@@ -36,6 +36,8 @@ class SimOptions(Options):
Height of the floor in meters. Defaults to 0.0.
requires_grad : bool, optional
Whether to enable differentiable mode. Defaults to False.
+ substeps_grad: int, optional
+ Number of steps that constitutes a window for gradient computation, defaults to None (window not used).
use_hydroelastic_contact : bool, optional
Whether to use hydroelastic contact. Defaults to False.
"""
@@ -46,6 +48,7 @@ class SimOptions(Options):
gravity: tuple = (0.0, 0.0, -9.81)
floor_height: float = 0.0
requires_grad: bool = False
+ grad_window_steps: Optional[int] = None
# not set by user
_steps_local: Optional[int] = None
diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py
index f0baaca6b0..3bc2eea24c 100644
--- a/genesis/utils/array_class.py
+++ b/genesis/utils/array_class.py
@@ -291,15 +291,15 @@ def get_constraint_state(constraint_solver, solver):
efc_AR=V(dtype=gs.ti_float, shape=efc_AR_shape),
active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)),
prev_active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)),
- diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
- aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
+ diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad),
+ aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad),
Jaref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
- efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
+ efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad),
efc_force=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
- efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
+ efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad),
jv=V(dtype=gs.ti_float, shape=(len_constraints_, _B)),
quad=V(dtype=gs.ti_float, shape=(len_constraints_, 3, _B)),
- jac=V(dtype=gs.ti_float, shape=jac_shape),
+ jac=V(dtype=gs.ti_float, shape=jac_shape, needs_grad=solver._requires_grad),
jac_relevant_dofs=V(dtype=gs.ti_int, shape=jac_relevant_dofs_shape),
jac_n_relevant_dofs=V(dtype=gs.ti_int, shape=jac_n_relevant_dofs_shape),
# Backward gradients
@@ -1776,6 +1776,11 @@ class StructRigidAdjointCache(metaclass=BASE_METACLASS):
qpos: V_ANNOTATION
dofs_vel: V_ANNOTATION
dofs_acc: V_ANNOTATION
+ # We also store the initial solutions (acc_smooth, qacc_ws) for the constraint solver to use in the backward pass.
+ # For [acc_smooth], even though it could be reproduced during the backward pass and thus we do not need to store it,
+ # we do it to compare the reproduced value with the stored one to ensure the integrity of the backward pass.
+ dofs_acc_smooth: V_ANNOTATION
+ solver_qacc_ws: V_ANNOTATION
def get_rigid_adjoint_cache(solver):
@@ -1786,6 +1791,8 @@ def get_rigid_adjoint_cache(solver):
qpos=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_qs_, solver._B), needs_grad=requires_grad),
dofs_vel=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad),
dofs_acc=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad),
+ dofs_acc_smooth=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)),
+ solver_qacc_ws=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)),
)
@@ -1818,6 +1825,9 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta):
max_n_geoms_per_entity: int = -1
n_links: int = -1
n_geoms: int = -1
+ n_dofs: int = -1
+ n_entities: int = -1
+ max_contact_pairs: int = -1
# =========================================== DataManager ===========================================
@@ -1862,7 +1872,10 @@ def __init__(self, solver):
self.joints_state_adjoint_cache = get_joints_state(solver)
self.geoms_state_adjoint_cache = get_geoms_state(solver)
- self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver)
+ # We use a pair of adjoint cache, one of which is used for the forward pass and the other is used for the
+ # backward pass to check the integrity of the backward pass.
+ self.rigid_adjoint_cache_fw = get_rigid_adjoint_cache(solver)
+ self.rigid_adjoint_cache_bw = get_rigid_adjoint_cache(solver)
self.errno = V_SCALAR_FROM(dtype=gs.ti_int, value=0)
diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py
index e1a052c8e3..20397402b4 100644
--- a/genesis/utils/geom.py
+++ b/genesis/utils/geom.py
@@ -82,6 +82,7 @@ def ti_rotvec_to_R(rotvec, eps):
@ti.func
def ti_rotvec_to_quat(rotvec, eps):
quat = ti.Vector.zero(gs.ti_float, 4)
+ res = ti.Vector.zero(gs.ti_float, 4)
# We need to use [norm_sqr] instead of [norm] to avoid nan gradients in the backward pass. Even when theta = 0,
# the gradient of [norm] operation is computed and used (note that the gradient becomes NaN when theta = 0). This
@@ -98,11 +99,12 @@ def ti_rotvec_to_quat(rotvec, eps):
quat[i + 1] = xyz[i]
# First order quaternion normalization is accurate enough yet necessary
- quat *= 0.5 * (3.0 - quat.norm_sqr())
+ # quat *= 0.5 * (3.0 - quat.norm_sqr())
+ res = quat * 0.5 * (3.0 - quat.norm_sqr())
else:
- quat[0] = 1.0
+ res[0] = 1.0
- return quat
+ return res
@ti.func
@@ -221,7 +223,12 @@ def ti_transform_quat_by_quat(v, u):
This is equivalent to quatmul(quat_u, quat_v) or R_u @ R_v
"""
vec = ti_quat_mul(u, v)
- return vec.normalized()
+ res = ti.Vector([1.0, 0.0, 0.0, 0.0], dt=gs.ti_float)
+ vec_norm_sqr = vec.norm_sqr()
+ if vec_norm_sqr > gs.EPS ** 2:
+ res = vec / (ti.sqrt(vec_norm_sqr) + gs.EPS)
+ return res
+ # return vec.normalized()
@ti.func
@@ -239,7 +246,7 @@ def ti_transform_by_quat(v, quat):
v.x * (-2.0 * q_wy + 2.0 * q_xz) + v.y * (2.0 * q_wx + 2.0 * q_yz) + v.z * (q_ww - q_xx - q_yy + q_zz),
],
dt=gs.ti_float,
- ) / (q_ww + q_xx + q_yy + q_zz)
+ ) / (q_ww + q_xx + q_yy + q_zz + gs.EPS)
@ti.func