diff --git a/examples/differentiable_rigid.py b/examples/differentiable_rigid.py new file mode 100644 index 0000000000..9ee90ac18e --- /dev/null +++ b/examples/differentiable_rigid.py @@ -0,0 +1,92 @@ +import torch +import genesis as gs + +show_viewer = False + +gs.init(precision="32", logging_level="info") + +dt = 1e-2 +horizon = 100 +substeps = 1 +goal_pos = gs.tensor([0.7, 1.0, 0.05]) +goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) +goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, +) + +box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + +scene.build() + +num_iter = 200 +lr = 1e-2 + +init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print("loss: ", loss.item()) + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/examples/differentiable_rigid_demo_1.py b/examples/differentiable_rigid_demo_1.py new file mode 100644 index 0000000000..9ee90ac18e --- /dev/null +++ b/examples/differentiable_rigid_demo_1.py @@ -0,0 +1,92 @@ +import torch +import genesis as gs + +show_viewer = False + +gs.init(precision="32", logging_level="info") + +dt = 1e-2 +horizon = 100 +substeps = 1 +goal_pos = gs.tensor([0.7, 1.0, 0.05]) +goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) +goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, +) + +box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + +scene.build() + +num_iter = 200 +lr = 1e-2 + +init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print("loss: ", loss.item()) + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/examples/diffrigid/1_one_step.py b/examples/diffrigid/1_one_step.py new file mode 100644 index 0000000000..a7afd6cbf7 --- /dev/null +++ b/examples/diffrigid/1_one_step.py @@ -0,0 +1,96 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 1 +substeps = 1 +goal_pos = gs.tensor([0.0, 0.0, 1.0]) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +scene.build() + +num_iter = 400 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +#init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + ground.set_pos(init_pos) + # ground.set_quat(init_quat) + + for i in range(horizon): + scene.step() + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + grad_norm = torch.nn.utils.clip_grad_norm_(init_pos.grad, 1.0) + + # with torch.no_grad(): + # init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + # with torch.no_grad(): + # init_pos.data[0] = 0.0 + # init_pos.data[1] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/2_lift_ball.py b/examples/diffrigid/2_lift_ball.py new file mode 100644 index 0000000000..8da89913d3 --- /dev/null +++ b/examples/diffrigid/2_lift_ball.py @@ -0,0 +1,106 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 10 +substeps = 1 +grad_window = 5 #None + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + +scene.build() + +num_iter = 10000 +lr = 1e-2 + +force = gs.zeros((horizon, 6), requires_grad=True) +optimizer = torch.optim.Adam([force], lr=lr) + +render_every = 100 +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + curr_losses = [] + if iter % render_every == 0: + cam.start_recording() + for i in range(horizon): + curr_force = force[i] + ball.control_dofs_force(curr_force) + scene.step() + + ball_state = ball.get_state() + ball_pos = ball_state.pos + curr_loss = -ball_pos[:, 2].sum() # make x, y, z larger + curr_losses.append(curr_loss) + + if iter % render_every == 0: + cam.render() + if iter % render_every == 0: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) + + loss = sum(curr_losses) / len(curr_losses) + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0) + + with torch.no_grad(): + force.data[:, 6:] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/ant.py b/examples/diffrigid/ant.py new file mode 100644 index 0000000000..9935066ea1 --- /dev/null +++ b/examples/diffrigid/ant.py @@ -0,0 +1,297 @@ +import argparse + +import numpy as np + +import genesis as gs + +import torch + +from copy import deepcopy + +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") + +class Controller(torch.nn.Module): + def __init__(self, obs_dim, n_dofs, hidden_dim=64): + super().__init__() + self.obs_dim = obs_dim + self.n_dofs = n_dofs + + # Batch normalization layer + self.bn = torch.nn.BatchNorm1d(obs_dim) + + # MLP layers (2-3 layers) + self.fc1 = torch.nn.Linear(obs_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim) + + # Output layers for mean and log_std + self.mean_layer = torch.nn.Linear(hidden_dim, n_dofs) + self.log_std_layer = torch.nn.Linear(hidden_dim, n_dofs) + + # Initialize log_std to small values + self.log_std_layer.weight.data.fill_(0.0) + self.log_std_layer.bias.data.fill_(-0.5) + + def forward(self, obs): + """ + Args: + obs: observation tensor of shape (batch_size, obs_dim) + Returns: + mean: mean of action distribution, shape (batch_size, n_dofs) + std: standard deviation of action distribution, shape (batch_size, n_dofs) + """ + # Batch normalization + if obs.shape[0] > 1: + x = self.bn(obs) + else: + x = obs + + # MLP layers with ReLU activation + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = torch.relu(self.fc3(x)) + + # Output mean and log_std + mean = self.mean_layer(x) + log_std = self.log_std_layer(x) + + # Clamp log_std to prevent extreme values + log_std = torch.clamp(log_std, min=-10, max=2) + std = torch.exp(log_std) + + return mean, std + + def sample_action(self, obs): + """ + Sample action from the policy distribution. + + Args: + obs: observation tensor of shape (batch_size, obs_dim) + Returns: + action: sampled action, shape (batch_size, n_dofs) + mean: mean of action distribution, shape (batch_size, n_dofs) + std: standard deviation of action distribution, shape (batch_size, n_dofs) + """ + mean, std = self.forward(obs) + noise = torch.randn_like(mean) + action = mean + std * noise + return action, mean, std + +def observe_fn(state): + qpos = state.qpos + dofs_vel = state.dofs_vel + dofs_acc = state.dofs_acc + + return torch.cat([qpos, dofs_vel, dofs_acc], dim=1).detach() + +def reward_fn(state, dt, prev_state=None): + pos = state.pos + + height_clip = torch.clamp(pos[:, 2] - 0.8, -float('inf'), 1.0) + height_reward = torch.where(height_clip <= 0.0, -200 * (height_clip ** 2), height_clip) + forward_reward = (pos[:, 0] - prev_state.pos[:, 0].detach()) / dt if prev_state is not None else 0.0 + + reward = height_reward * 0.01 + forward_reward + + return reward + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-v", "--vis", action="store_true", default=False) + parser.add_argument("-n", "--n_envs", type=int, default=49) + args = parser.parse_args() + + args.vis = False + args.n_envs = 64 + + dt = 0.01 + substeps = 1 + horizon_steps = 128 + window_substeps = 32 + window_steps = int(window_substeps / substeps) + iteration = 10000 + lr = 1e-4 + render_every = 100 + + ########################## init ########################## + gs.init(backend=gs.gpu, logging_level="warn") + + ########################## create a scene ########################## + viewer_options = gs.options.ViewerOptions( + camera_pos=(3, -1, 1.5), + camera_lookat=(0.0, 0.0, 0.0), + camera_fov=30, + max_FPS=60, + ) + + scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=window_steps), + viewer_options=viewer_options, + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + vis_options=gs.options.VisOptions( + rendered_envs_idx=(0,), + show_world_frame=True, + ), + show_viewer=args.vis, + ) + + ########################## entities ########################## + # plane = scene.add_entity( + # gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True, pos=(0, 0, -0.5)), + # ) + plane = scene.add_entity( + gs.morphs.Box(size=(10, 10, 0.6), pos=(0, 0, -0.3), fixed=True), + ) + ant = scene.add_entity( + gs.morphs.MJCF(file="xml/walker_no_ground.xml"), + # vis_mode="collision" + ) + cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, + env_idx=0, + ) + cam.follow_entity(ant) + + ########################## build ########################## + scene.build(n_envs=args.n_envs, env_spacing=(1, 1)) + + n_dofs = ant.n_dofs + + # rand_force = torch.randn((n_dofs,), dtype=torch.float32) + # for i in range(10000): + # ant.control_dofs_force(rand_force) + # scene.step() + + # if i % 1000 == 0 and i > 0: + # print("----------------------------------------------------------") + # rand_force = torch.randn((n_dofs,), dtype=torch.float32) + + # # Reset env + # rigid_solver = scene.sim.rigid_solver + # qpos = rigid_solver._rigid_global_info.qpos.to_numpy() + # dofs_vel = rigid_solver.dofs_state.vel.to_numpy() + # dofs_acc = rigid_solver.dofs_state.acc.to_numpy() + # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy() + # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy() + + # scene.reset() + # scene.sim.rigid_solver._rigid_global_info.qpos.from_numpy(qpos) + # rigid_solver.dofs_state.vel.from_numpy(dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws) + # rigid_solver.load_test() + + + # Initialize controller + # Get obs_dim by computing observation once + scene.reset() + state = ant.get_state() + obs = observe_fn(state) + obs_dim = obs.shape[1] + + controller = Controller(obs_dim=obs_dim, n_dofs=n_dofs, hidden_dim=64) + optimizer = torch.optim.Adam(controller.parameters(), lr=lr) + + rewards = [] + for iter in range(iteration): + scene.reset() + acc_reward = None + prev_state = None + + record = (iter % render_every == 0) or (iter == iteration - 1) + if record: + cam.start_recording() + print("running forward pass...") + for step in range(horizon_steps): + scene.step() + if record: + cam.render() + + # Determine observation and reward + state = ant.get_state() + obs = observe_fn(state) + reward = reward_fn(state, dt, prev_state) + if acc_reward is None: + acc_reward = reward + else: + acc_reward += reward + + prev_state = state + + # Determine action + action, mean, std = controller.sample_action(obs) + # Apply action (assuming action is force/torque for dofs) + ant.control_dofs_force(action) + + truncate = step == horizon_steps - 1 # step % window_steps == 0 or step == horizon_steps - 1 + if truncate and step > 0: + print("running backward pass...") + acc_reward = acc_reward / (step + 1) + mean_reward = acc_reward.mean() + loss = -mean_reward + + optimizer.zero_grad() + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 1.0) + optimizer.step() + + print(f"[ITER {iter}] Mean Reward: {mean_reward.item():.4g} | Grad Norm: {grad_norm:.4g}") + + rewards.append(mean_reward.detach().item()) + + plt.plot(rewards) + plt.savefig("rewards.png") + plt.close() + + # Reset env + # scene.sim.reset_grad() + # scene._forward_ready = True + # scene._backward_ready = True + # scene._t = t + # scene.sim._cur_substep_global = substep_global + + # rigid_solver._rigid_global_info.qpos.from_numpy(_qpos) + # rigid_solver.dofs_state.vel.from_numpy(_dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(_dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(_dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(_solver_qacc_ws) + # rigid_solver.load_test() + + # rigid_solver = scene.sim.rigid_solver + # qpos = rigid_solver._rigid_global_info.qpos.to_numpy() + # dofs_vel = rigid_solver.dofs_state.vel.to_numpy() + # dofs_acc = rigid_solver.dofs_state.acc.to_numpy() + # dofs_acc_smooth = rigid_solver.dofs_state.acc_smooth.to_numpy() + # solver_qacc_ws = rigid_solver.constraint_solver.constraint_state.qacc_ws.to_numpy() + # scene_t = scene._t + # scene.reset() + # scene._t = scene_t + # rigid_solver._rigid_global_info.qpos.from_numpy(qpos) + # rigid_solver.dofs_state.vel.from_numpy(dofs_vel) + # rigid_solver.dofs_state.acc.from_numpy(dofs_acc) + # rigid_solver.dofs_state.acc_smooth.from_numpy(dofs_acc_smooth) + # rigid_solver.constraint_solver.constraint_state.qacc_ws.from_numpy(solver_qacc_ws) + # rigid_solver.load_test() + + acc_reward = None + + # if step // window_steps > 1: + break + + if record: + cam.stop_recording(save_to_filename=f"ant_video_{iter:06d}.mp4", fps=30) + + +if __name__ == "__main__": + main() diff --git a/examples/diffrigid/debug_rot.py b/examples/diffrigid/debug_rot.py new file mode 100644 index 0000000000..3ffba27e5f --- /dev/null +++ b/examples/diffrigid/debug_rot.py @@ -0,0 +1,113 @@ +""" +One step optimization for the basic debugging of differentiable rigid simulation. +""" +import torch +import genesis as gs +import matplotlib.pyplot as plt +import matplotlib +import numpy as np +matplotlib.use("Agg") + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 10 +substeps = 1 +grad_window = None +np.random.seed(0) +goal_quat = np.random.randn(4) +goal_quat = goal_quat / np.linalg.norm(goal_quat) +goal_quat = gs.tensor(goal_quat) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, grad_window_steps=grad_window), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ground = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.109), # small penetration with ground + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), +) +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + +scene.build() + +num_iter = 10000 +lr = 1e-2 + +force = gs.zeros((horizon, 6), requires_grad=True) +with torch.no_grad(): + torch.manual_seed(0) + force.data[:, 3:] = torch.randn_like(force.data[:, 3:]) +optimizer = torch.optim.Adam([force], lr=lr) + +render_every = 100 +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + curr_losses = [] + if iter % render_every == 0: + cam.start_recording() + for i in range(horizon): + curr_force = force[i] + ground.control_dofs_force(curr_force) + scene.step() + + box_state = ground.get_state() + box_quat = box_state.quat + curr_loss = (box_quat - goal_quat).abs().sum() + curr_losses.append(curr_loss) + + if iter % render_every == 0: + cam.render() + if iter % render_every == 0: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) + + loss = sum(curr_losses) / len(curr_losses) + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(force.grad, 1.0) + + with torch.no_grad(): + force.data[:, :3] = 0.0 + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Grad Norm: {grad_norm:.6g} | Force: {force.data.mean(dim=0).cpu().numpy().tolist()}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + plt.savefig("loss.png") + plt.close() \ No newline at end of file diff --git a/examples/diffrigid/pingpong.py b/examples/diffrigid/pingpong.py new file mode 100644 index 0000000000..e1751cf5bc --- /dev/null +++ b/examples/diffrigid/pingpong.py @@ -0,0 +1,103 @@ +import torch +import genesis as gs + +show_viewer = True + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 200 +substeps = 4 +goal_pos = gs.tensor([0.0, 0.1, -0.1]) + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.5), + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + material=gs.materials.Rigid( + rho=0.001, + ) +) +# if show_viewer: +# target = scene.add_entity( +# gs.morphs.Sphere( +# pos=goal_pos.cpu().numpy().tolist(), +# radius=0.1, +# ), +# surface=gs.surfaces.Default( +# color=(0.0, 0.9, 0.0, 0.5), +# ), +# ) + +racket = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.01), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +scene.build() + +num_iter = 200 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + +prev_loss = float('inf') +for iter in range(num_iter): + scene.reset() + + racket.set_pos(init_pos) + racket.set_quat(init_quat) + #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0])) + + losses = [] + for i in range(horizon): + scene.step() + # ball_state = ball.get_state() + # ball_pos = ball_state.pos + # losses.append(torch.abs(ball_pos - goal_pos).sum()) + # if show_viewer: + # target.set_pos(goal_pos) + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + # loss = sum(losses) / len(losses) + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g}") + prev_loss = loss.item() + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/examples/diffrigid/slide_ball.py b/examples/diffrigid/slide_ball.py new file mode 100644 index 0000000000..a12c286ede --- /dev/null +++ b/examples/diffrigid/slide_ball.py @@ -0,0 +1,129 @@ +import torch +import genesis as gs +import matplotlib.pyplot as plt + +show_viewer = False + +gs.init(precision="32", logging_level="warn", backend=gs.cpu) + +dt = 1e-2 +horizon = 100 +substeps = 4 +goal_pos = gs.tensor([0.2, 0.1, 0.1]) +render_every = 100 + +scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + enable_joint_limit=False, + ), + show_viewer=show_viewer, +) + +ball = scene.add_entity( + gs.morphs.Sphere( + pos=(0, 0, 0.11), + radius=0.1, + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + material=gs.materials.Rigid( + rho=0.001, + ) +) +# if show_viewer: +# target = scene.add_entity( +# gs.morphs.Sphere( +# pos=goal_pos.cpu().numpy().tolist(), +# radius=0.1, +# ), +# surface=gs.surfaces.Default( +# color=(0.0, 0.9, 0.0, 0.5), +# ), +# ) + +racket = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(5.0, 5.0, 0.02), + #fixed=True, + ), + surface=gs.surfaces.Default( + color=(0.0, 0.0, 0.9, 1.0), + ), + material=gs.materials.Rigid( + gravity_compensation=1, + ) +) + +cam = scene.add_camera( + pos=(3.5, 0.5, 2.5), + lookat=(0.0, 0.0, 0.5), + fov=40, + GUI=False, +) + +scene.build() + +num_iter = 300 +lr = 1e-4 + +init_pos = gs.tensor([0.0, 0.0, 0.0], requires_grad=True) +init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) +optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-4) + +prev_loss = float('inf') +losses = [] +for iter in range(num_iter): + scene.reset() + + racket.set_pos(init_pos) + racket.set_quat(init_quat) + #ball.set_dofs_velocity(gs.tensor([0, 0, -2.0, 0, 0, 0])) + + record = (iter % render_every == 0) or (iter == num_iter - 1) + + if record: + cam.start_recording() + for i in range(horizon): + scene.step() + if record: + cam.render() + + if record: + cam.stop_recording(save_to_filename=f"video_{iter:06d}.mp4", fps=30) + + ball_state = ball.get_state() + ball_pos = ball_state.pos + loss = torch.abs(ball_pos - goal_pos).sum() + # loss = sum(losses) / len(losses) + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + # init_quat.data[0] = 1.0 + # init_quat.data[1] = 0.0 + # init_quat.data[2] = 0.0 + # init_quat.data[3] = 0.0 + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + #init_pos.data = init_pos.data.clamp(0.0, 0.0) + + print(f"Loss: {prev_loss:.6g} -> {loss.item():.6g} | Ball Pos: {ball_pos.detach().cpu().numpy().tolist()}") + prev_loss = loss.item() + + losses.append(loss.item()) + + plt.plot(losses) + # set y axis to log scale + plt.yscale('log') + plt.savefig("loss.png") + plt.close() + +# assert_allclose(loss, 0.0, atol=1e-2) diff --git a/genesis/assets/xml/ant_no_ground.xml b/genesis/assets/xml/ant_no_ground.xml new file mode 100644 index 0000000000..7eb2522ec4 --- /dev/null +++ b/genesis/assets/xml/ant_no_ground.xml @@ -0,0 +1,93 @@ + + + diff --git a/genesis/assets/xml/walker_no_ground.xml b/genesis/assets/xml/walker_no_ground.xml new file mode 100644 index 0000000000..827e776c7c --- /dev/null +++ b/genesis/assets/xml/walker_no_ground.xml @@ -0,0 +1,80 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 7297b71af9..47ff9c9a3a 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -117,7 +117,7 @@ def __init__( self._load_model() # Initialize target variables and checkpoint - self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity") + self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force") self._tgt = dict() self._tgt_buffer = list() self._ckpt = dict() @@ -1661,6 +1661,8 @@ def process_input(self, in_backward=False): self.set_quat(**data_kwargs) case "set_dofs_velocity": self.set_dofs_velocity(**data_kwargs) + case "control_dofs_force": + self.control_dofs_force(**data_kwargs) case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -1693,6 +1695,16 @@ def process_input_grad(self): data_kwargs["dofs_idx_local"], data_kwargs["envs_idx"], ) + + case "control_dofs_force": + force = data_kwargs.pop("force") + if force.requires_grad: + force._backward_from_ti( + self.control_dofs_force_grad, + data_kwargs["dofs_idx_local"], + data_kwargs["envs_idx"], + ) + case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -1715,10 +1727,16 @@ def get_state(self): solver_state = self._solver.get_state() pos = solver_state.links_pos[:, self.base_link_idx] quat = solver_state.links_quat[:, self.base_link_idx] + qpos = solver_state.qpos[:, self._q_start:self._q_start + self.n_qs] + dofs_vel = solver_state.dofs_vel[:, self._dof_start:self._dof_start + self.n_dofs] + dofs_acc = solver_state.dofs_acc[:, self._dof_start:self._dof_start + self.n_dofs] state._pos = pos state._quat = quat - + state._qpos = qpos + state._dofs_vel = dofs_vel + state._dofs_acc = dofs_acc + return state def _get_global_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False): @@ -2330,7 +2348,7 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * @gs.assert_built def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad): - dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data) @gs.assert_built @@ -2355,6 +2373,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer self._solver.set_dofs_position(position, dofs_idx, envs_idx) @gs.assert_built + @tracked def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None): """ Control the entity's dofs' motor force. This is used for force/torque control. @@ -2371,6 +2390,14 @@ def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None): dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.control_dofs_force(force, dofs_idx, envs_idx) + @gs.assert_built + def control_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad): + dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + self._solver.control_dofs_force_grad(dofs_idx, envs_idx, force_grad.data) + + pass + + @gs.assert_built def control_dofs_velocity(self, velocity, dofs_idx_local=None, envs_idx=None): """ diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index a269c998b4..80859ba86d 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -280,7 +280,9 @@ def step(self, in_backward=False): self.save_ckpt() if self.rigid_solver.is_active: - self.rigid_solver.clear_external_force() + # In backward pass, we need to keep the external force for gradient computation + if not in_backward: + self.rigid_solver.clear_external_force() if self._cur_substep_global % RATE_CHECK_ERRNO == 0: self.rigid_solver.check_errno() @@ -295,8 +297,16 @@ def _step_grad(self): self.sub_step_grad(self.cur_substep_local) + if self.rigid_solver.is_active: + # Clear external force after gradient computation + self.rigid_solver.clear_external_force() + self.process_input_grad() + if self.options.grad_window_steps is not None and self.cur_step_global % self.options.grad_window_steps == 0: + # Truncate upstream gradient flow + self.rigid_solver.zero_grad() + def process_input(self, in_backward=False): """ setting _tgt state using external commands diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index e7c35378af..a6aebb3d42 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -174,6 +174,20 @@ def add_inequality_constraints(self): static_rigid_sim_config=solver._static_rigid_sim_config, ) + def add_inequality_constraints_grad(self): + solver = self._solver + add_inequality_constraints.grad( + links_info=solver.links_info, + links_state=solver.links_state, + dofs_state=solver.dofs_state, + dofs_info=solver.dofs_info, + joints_info=solver.joints_info, + constraint_state=self.constraint_state, + collider_state=self._collider._collider_state, + rigid_global_info=solver._rigid_global_info, + static_rigid_sim_config=solver._static_rigid_sim_config, + ) + def resolve(self): solver = self._solver @@ -437,80 +451,110 @@ def add_collision_constraints( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_col in range(collider_state.n_contacts[i_b]): - contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b] - contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b] - - contact_data_pos = collider_state.contact_data.pos[i_col, i_b] - contact_data_normal = collider_state.contact_data.normal[i_col, i_b] - contact_data_friction = collider_state.contact_data.friction[i_col, i_b] - contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b] - contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b] - - link_a = contact_data_link_a - link_b = contact_data_link_b - link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a - link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b - - d1, d2 = gu.ti_orthogonals(contact_data_normal) - - invweight = links_info.invweight[link_a_maybe_batch][0] - if link_b > -1: - invweight = invweight + links_info.invweight[link_b_maybe_batch][0] - - for i in range(4): + for i_b, i_0, i_4 in ( + ti.ndrange(dofs_state.ctrl_mode.shape[1], 1, 4) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.ndrange(dofs_state.ctrl_mode.shape[1], static_rigid_sim_config.max_contact_pairs, 4) + ): + EPS = rigid_global_info.EPS[None] + n_dofs = dofs_state.ctrl_mode.shape[0] + + for i_1 in ( + range(collider_state.n_contacts[i_b]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(1)) + ): + i_col = i_1 if ti.static(not static_rigid_sim_config.is_backward) else i_0 + if i_col < collider_state.n_contacts[i_b]: + contact_data_link_a = collider_state.contact_data.link_a[i_col, i_b] + contact_data_link_b = collider_state.contact_data.link_b[i_col, i_b] + + contact_data_pos = collider_state.contact_data.pos[i_col, i_b] + contact_data_normal = collider_state.contact_data.normal[i_col, i_b] + contact_data_friction = collider_state.contact_data.friction[i_col, i_b] + contact_data_sol_params = collider_state.contact_data.sol_params[i_col, i_b] + contact_data_penetration = collider_state.contact_data.penetration[i_col, i_b] + + link_a = contact_data_link_a + link_b = contact_data_link_b + link_a_maybe_batch = [link_a, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_a + link_b_maybe_batch = [link_b, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link_b + + d1, d2 = gu.ti_orthogonals(contact_data_normal) + + invweight = links_info.invweight[link_a_maybe_batch][0] + if link_b > -1: + invweight = invweight + links_info.invweight[link_b_maybe_batch][0] + + #for i in range(4) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(4)): + i = i_4 d = (2 * (i % 2) - 1) * (d1 if i < 2 else d2) n = d * contact_data_friction - contact_data_normal - n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + n_con = i_col * 4 + i # + constraint_state.n_constraints[i_b] if ti.static(static_rigid_sim_config.sparse_solve): for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) else: - for i_d in range(n_dofs): + for i_d in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) con_n_relevant_dofs = 0 jac_qvel = gs.ti_float(0.0) - for i_ab in range(2): + for i_ab in range(2) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(2)): sign = gs.ti_float(-1.0) link = link_a if i_ab == 1: sign = gs.ti_float(1.0) link = link_b - while link > -1: - link_maybe_batch = [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link + # FIXME: Set number of iterations to look for parent to certain value for autodiff + for i_parent in ( + range(20) if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(5)) + ): + if link > -1: + link_maybe_batch = ( + [link, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else link + ) + + # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending + for i_d_ in ( + range(links_info.n_dofs[link_maybe_batch]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + if i_d_ < links_info.n_dofs[link_maybe_batch]: + i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - # reverse order to make sure dofs in each row of self.jac_relevant_dofs is strictly descending - for i_d_ in range(links_info.n_dofs[link_maybe_batch]): - i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ + cdof_ang = dofs_state.cdof_ang[i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_d, i_b] - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + t_quat = gu.ti_identity_quat() + t_pos = contact_data_pos - links_state.root_COM[link, i_b] + _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) - t_quat = gu.ti_identity_quat() - t_pos = contact_data_pos - links_state.root_COM[link, i_b] - _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) + diff = sign * vel + jac = diff @ n + jac_qvel += jac * dofs_state.vel[i_d, i_b] + constraint_state.jac[n_con, i_d, i_b] += jac - diff = sign * vel - jac = diff @ n - jac_qvel = jac_qvel + jac * dofs_state.vel[i_d, i_b] - constraint_state.jac[n_con, i_d, i_b] = constraint_state.jac[n_con, i_d, i_b] + jac + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d + con_n_relevant_dofs += 1 - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_relevant_dofs[n_con, con_n_relevant_dofs, i_b] = i_d - con_n_relevant_dofs += 1 + link = links_info.parent_idx[link_maybe_batch] - link = links_info.parent_idx[link_maybe_batch] + if ti.static(static_rigid_sim_config.is_backward): + if i_parent == 4 and link > -1: + print( + "Warning: Number of parents is too large for backward mode in add_collision_constraints" + ) if ti.static(static_rigid_sim_config.sparse_solve): constraint_state.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs @@ -518,14 +562,18 @@ def add_collision_constraints( contact_data_sol_params, -contact_data_penetration, jac_qvel, -contact_data_penetration ) - diag = invweight + contact_data_friction * contact_data_friction * invweight - diag *= 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp - diag = ti.max(diag, EPS) + diag_0 = invweight + contact_data_friction * contact_data_friction * invweight + diag_1 = diag_0 * 2 * contact_data_friction * contact_data_friction * (1 - imp) / imp + diag = ti.max(diag_1, EPS) constraint_state.diag[n_con, i_b] = diag constraint_state.aref[n_con, i_b] = aref constraint_state.efc_D[n_con, i_b] = 1 / diag + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(dofs_state.ctrl_mode.shape[1]): + constraint_state.n_constraints[i_b] += 4 * collider_state.n_contacts[i_b] + @ti.func def func_equality_connect( @@ -792,15 +840,15 @@ def add_inequality_constraints( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - add_frictionloss_constraints( - links_info=links_info, - joints_info=joints_info, - dofs_info=dofs_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) + # add_frictionloss_constraints( + # links_info=links_info, + # joints_info=joints_info, + # dofs_info=dofs_info, + # dofs_state=dofs_state, + # rigid_global_info=rigid_global_info, + # constraint_state=constraint_state, + # static_rigid_sim_config=static_rigid_sim_config, + # ) if ti.static(static_rigid_sim_config.enable_collision): add_collision_constraints( links_info=links_info, @@ -811,16 +859,16 @@ def add_inequality_constraints( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(static_rigid_sim_config.enable_joint_limit): - add_joint_limit_constraints( - links_info=links_info, - joints_info=joints_info, - dofs_info=dofs_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) + # if ti.static(static_rigid_sim_config.enable_joint_limit): + # add_joint_limit_constraints( + # links_info=links_info, + # joints_info=joints_info, + # dofs_info=dofs_info, + # dofs_state=dofs_state, + # rigid_global_info=rigid_global_info, + # constraint_state=constraint_state, + # static_rigid_sim_config=static_rigid_sim_config, + # ) @ti.func @@ -1026,52 +1074,69 @@ def add_joint_limit_constraints( constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = constraint_state.jac.shape[2] - n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] - # TODO: sparse mode ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_l in range(n_links): + for i_b in range(constraint_state.jac.shape[2]): + EPS = rigid_global_info.EPS[None] + n_links = links_info.root_idx.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[0] + + for i_l in ( + range(n_links) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_links)) + ): I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - if joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC: - i_q = joints_info.q_start[I_j] - i_d = joints_info.dof_start[I_j] - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] - pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] - pos_delta = min(pos_delta_min, pos_delta_max) - - if pos_delta < 0: - jac = (pos_delta_min < pos_delta_max) * 2 - 1 - jac_qvel = jac * dofs_state.vel[i_d, i_b] - imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta) - diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS) - - n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) - constraint_state.diag[n_con, i_b] = diag - constraint_state.aref[n_con, i_b] = aref - constraint_state.efc_D[n_con, i_b] = 1 / diag - - if ti.static(static_rigid_sim_config.sparse_solve): - for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): - i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b] - constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) - else: - for i_d2 in range(n_dofs): - constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) - constraint_state.jac[n_con, i_d, i_b] = jac + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = ( + i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l]) + ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + if ( + joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE + or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC + ): + i_q = joints_info.q_start[I_j] + i_d = joints_info.dof_start[I_j] + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] + pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] + pos_delta = min(pos_delta_min, pos_delta_max) + + if pos_delta < 0: + jac = (pos_delta_min < pos_delta_max) * 2 - 1 + jac_qvel = jac * dofs_state.vel[i_d, i_b] + imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta) + diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, EPS) + + n_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + constraint_state.diag[n_con, i_b] = diag + constraint_state.aref[n_con, i_b] = aref + constraint_state.efc_D[n_con, i_b] = 1 / diag - if ti.static(static_rigid_sim_config.sparse_solve): - constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1 - constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d + if ti.static(static_rigid_sim_config.sparse_solve): + for i_d2_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): + i_d2 = constraint_state.jac_relevant_dofs[n_con, i_d2_, i_b] + constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) + else: + for i_d2 in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[n_con, i_d2, i_b] = gs.ti_float(0.0) + constraint_state.jac[n_con, i_d, i_b] = jac + + if ti.static(static_rigid_sim_config.sparse_solve): + constraint_state.jac_n_relevant_dofs[n_con, i_b] = 1 + constraint_state.jac_relevant_dofs[n_con, 0, i_b] = i_d @ti.func @@ -1084,46 +1149,70 @@ def add_frictionloss_constraints( constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - _B = constraint_state.jac.shape[2] - n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] - # TODO: sparse mode # FIXME: The condition `if dofs_info.frictionloss[I_d] > EPS:` is not correctly evaluated on Apple Metal # if `serialize=True`... ti.loop_config( serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL and gs.backend != gs.metal) ) - for i_b in range(_B): + for i_b in range(constraint_state.jac.shape[2]): constraint_state.n_constraints_frictionloss[i_b] = 0 - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - for i_d in range(joints_info.dof_start[I_j], joints_info.dof_end[I_j]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + EPS = rigid_global_info.EPS[None] + n_links = links_info.root_idx.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[0] - if dofs_info.frictionloss[I_d] > EPS: - jac = 1.0 - jac_qvel = jac * dofs_state.vel[i_d, i_b] - imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0) - diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS) - - i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) - ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1) + for i_l in ( + range(n_links) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_links)) + ): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - constraint_state.diag[i_con, i_b] = diag - constraint_state.aref[i_con, i_b] = aref - constraint_state.efc_D[i_con, i_b] = 1.0 / diag - constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d] - for i_d2 in range(n_dofs): - constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0) - constraint_state.jac[i_con, i_d, i_b] = jac + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = ( + i_j_ if ti.static(not static_rigid_sim_config.is_backward) else (i_j_ + links_info.joint_start[I_l]) + ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + for i_d_ in ( + range(joints_info.dof_start[I_j], joints_info.dof_end[I_j]) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + joints_info.dof_start[I_j]) + ) + if i_d < joints_info.dof_end[I_j]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + + if dofs_info.frictionloss[I_d] > EPS: + jac = 1.0 + jac_qvel = jac * dofs_state.vel[i_d, i_b] + imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0) + diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, EPS) + + i_con = ti.atomic_add(constraint_state.n_constraints[i_b], 1) + ti.atomic_add(constraint_state.n_constraints_frictionloss[i_b], 1) + + constraint_state.diag[i_con, i_b] = diag + constraint_state.aref[i_con, i_b] = aref + constraint_state.efc_D[i_con, i_b] = 1.0 / diag + constraint_state.efc_frictionloss[i_con, i_b] = dofs_info.frictionloss[I_d] + for i_d2 in ( + range(n_dofs) + if ti.static(not static_rigid_sim_config.is_backward) + else ti.static(range(static_rigid_sim_config.n_dofs)) + ): + constraint_state.jac[i_con, i_d2, i_b] = gs.ti_float(0.0) + constraint_state.jac[i_con, i_d, i_b] = jac @ti.func diff --git a/genesis/engine/solvers/rigid/diff_gjk_decomp.py b/genesis/engine/solvers/rigid/diff_gjk_decomp.py index 5d3e407e08..8f36df3e8c 100644 --- a/genesis/engine/solvers/rigid/diff_gjk_decomp.py +++ b/genesis/engine/solvers/rigid/diff_gjk_decomp.py @@ -77,7 +77,7 @@ def func_gjk_contact( found_default_epa = False # 4 (small) + 4 (large) perturbated configurations - num_perturb = 8 + num_perturb = 4 ### Detect multiple possible contact points and gather the non-differentiable contact data. for i in range(1 + num_perturb): @@ -209,6 +209,8 @@ def func_gjk_contact( ) found_default_epa = True + # Do not use extended EPA algorithm for now, practically it seems not needed. + break # Break the loop if we found enough contact points for default configuration. As we can find at most # 8 contact points for perturbed configurations, we can find at most max_contacts_per_pair - 8 diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 48b1a693c8..2f28b7637e 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -271,6 +271,9 @@ def build(self): max_n_qs_per_link=max(link.n_qs for link in self.links) if self.links else 0, n_links=self._n_links, n_geoms=self._n_geoms, + n_dofs=self._n_dofs, + n_entities=self._n_entities, + # max_contact_pairs=self.collider._collider_state.contact_data.geom_a.shape[0], ) self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig(**static_rigid_sim_config) else: @@ -299,6 +302,9 @@ def build(self): if getattr(self._options, "noslip_iterations", 0) > 0: gs.raise_exception("Noslip is not supported yet when requires_grad is True.") + if getattr(self._options, "sparse_solve", False): + gs.raise_exception("Sparse solve is not supported yet when requires_grad is True.") + # when the migration is finished, we will remove the about two lines self._func_vel_at_point = func_vel_at_point self._func_apply_coupling_force = func_apply_coupling_force @@ -310,7 +316,6 @@ def build(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info - self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs self.awake_dofs = self._rigid_global_info.awake_dofs @@ -323,6 +328,8 @@ def build(self): self.links_state_adjoint_cache = self.data_manager.links_state_adjoint_cache self.joints_state_adjoint_cache = self.data_manager.joints_state_adjoint_cache self.geoms_state_adjoint_cache = self.data_manager.geoms_state_adjoint_cache + self._rigid_adjoint_cache_fw = self.data_manager.rigid_adjoint_cache_fw + self._rigid_adjoint_cache_bw = self.data_manager.rigid_adjoint_cache_bw self._init_mass_mat() self._init_dof_fields() @@ -816,6 +823,9 @@ def _init_sdf(self): def _init_collider(self): self.collider = Collider(self) + if self.sim.options.requires_grad: + self._static_rigid_sim_config.max_contact_pairs = self.collider._collider_state.contact_data.geom_a.shape[0] + if self.collider._collider_static_config.has_terrain: link_idx_ = next( i for i, _type in enumerate(ti_to_numpy(self.geoms_info.type)) if _type == gs.GEOM_TYPE.TERRAIN @@ -859,8 +869,9 @@ def substep(self, f): kernel_save_adjoint_cache( f=f, dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) @@ -908,8 +919,9 @@ def substep(self, f): kernel_save_adjoint_cache( f=f + 1, dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) @@ -1148,9 +1160,6 @@ def substep_pre_coupling(self, f): self.substep(f) def substep_pre_coupling_grad(self, f): - # Change to backward mode - self._static_rigid_sim_config.is_backward = True - # Run forward substep again to restore this step's information, this is needed because we do not store info # of every substep. kernel_prepare_backward_substep( @@ -1163,46 +1172,144 @@ def substep_pre_coupling_grad(self, f): dofs_info=self.dofs_info, geoms_state=self.geoms_state, geoms_info=self.geoms_info, + constraint_state=self.constraint_solver.constraint_state, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, dofs_state_adjoint_cache=self.dofs_state_adjoint_cache, links_state_adjoint_cache=self.links_state_adjoint_cache, joints_state_adjoint_cache=self.joints_state_adjoint_cache, geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) self.substep(f) # =================== Backward substep ====================== + # Change to backward mode: Note that we use forward mode in the [substep] function right above. This is because + # we need to reproduce the same data as in the forward pass for the backward pass. The backward pass should be + # logically same as the forward pass, but it is tweaked to use autodiff, and thus numerical differences could + # arise. To prevent this, we change to backward mode here. + self._static_rigid_sim_config.is_backward = True + envs_idx = self._scene._sanitize_envs_idx(None) if not self._enable_mujoco_compatibility: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, - dofs_state=self.dofs_state, + # kernel_forward_velocity.grad( + # envs_idx=envs_idx, + # links_state=self.links_state, + # links_info=self.links_info, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_velocity_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_geoms.grad( + envs_idx, entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + links_state=self.links_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, ) - kernel_update_cartesian_space.grad( + # kernel_COM_links.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_COM_links_ad_1.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_COM_links_ad_0.grad( links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, joints_info=self.joints_info, dofs_state=self.dofs_state, dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, ) - - is_grad_valid = kernel_begin_backward_substep( + # kernel_forward_kinematics.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) + + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_kinematics_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) + # kernel_update_cartesian_space.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # joints_state=self.joints_state, + # joints_info=self.joints_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # geoms_state=self.geoms_state, + # geoms_info=self.geoms_info, + # entities_info=self.entities_info, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # force_update_fixed_geoms=False, + # ) + + errno = kernel_begin_backward_substep( f=f, links_state=self.links_state, links_info=self.links_info, @@ -1218,11 +1325,20 @@ def substep_pre_coupling_grad(self, f): links_state_adjoint_cache=self.links_state_adjoint_cache, joints_state_adjoint_cache=self.joints_state_adjoint_cache, geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache_fw=self._rigid_adjoint_cache_fw, + rigid_adjoint_cache_bw=self._rigid_adjoint_cache_bw, static_rigid_sim_config=self._static_rigid_sim_config, ) - if not is_grad_valid: - gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + match errno: + case 1: + gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + # case 2: + # qpos_diff = self._rigid_adjoint_cache_fw.qpos.to_numpy() - self._rigid_adjoint_cache_bw.qpos.to_numpy() + # vel_diff = self._rigid_adjoint_cache_fw.dofs_vel.to_numpy() - self._rigid_adjoint_cache_bw.dofs_vel.to_numpy() + # acc_diff = self._rigid_adjoint_cache_fw.dofs_acc.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc.to_numpy() + # acc_smooth_diff = self._rigid_adjoint_cache_fw.dofs_acc_smooth.to_numpy() - self._rigid_adjoint_cache_bw.dofs_acc_smooth.to_numpy() + # solver_qacc_ws_diff = self._rigid_adjoint_cache_fw.solver_qacc_ws.to_numpy() - self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy() + # gs.raise_exception(f"The backward computation result does not match the forward computation result at step {self._sim.cur_step_global}") kernel_step_2.grad( dofs_state=self.dofs_state, @@ -1241,68 +1357,190 @@ def substep_pre_coupling_grad(self, f): contact_island_state=self.constraint_solver.contact_island.contact_island_state, ) - # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, - # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). - # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. - # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through - # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation - # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. - # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a - # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid - # the data access violation. However, both of these require major changes. - kernel_compute_qacc.grad( - dofs_state=self.dofs_state, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + if not self._disable_constraint: + # Solver backward + dL_dqacc = self.dofs_state.acc.grad.to_numpy() + self.dofs_state.acc.grad.fill(0.0) + + qacc_ws = self._rigid_adjoint_cache_bw.solver_qacc_ws.to_numpy()[f] + self.constraint_solver.constraint_state.qacc_ws.from_numpy(qacc_ws) + + self.constraint_solver.backward(dL_dqacc) + dL_dM = self.constraint_solver.constraint_state.dL_dM.to_numpy() + dL_djac = self.constraint_solver.constraint_state.dL_djac.to_numpy() + dL_daref = self.constraint_solver.constraint_state.dL_daref.to_numpy() + dL_defc_D = self.constraint_solver.constraint_state.dL_defc_D.to_numpy() + dL_dforce = self.constraint_solver.constraint_state.dL_dforce.to_numpy() + + self._rigid_global_info.mass_mat.grad.from_numpy(dL_dM) + self.constraint_solver.constraint_state.jac.grad.from_numpy(dL_djac) + self.constraint_solver.constraint_state.aref.grad.from_numpy(dL_daref) + self.constraint_solver.constraint_state.efc_D.grad.from_numpy(dL_defc_D) + self.dofs_state.force.grad.from_numpy(dL_dforce) + + self.constraint_solver.constraint_state.n_constraints.fill(0) + self.constraint_solver.add_inequality_constraints_grad() + + # Collider backward + dL_dcontact_pos = self.collider._collider_state.contact_data.pos.grad.to_numpy() + dL_dcontact_normal = self.collider._collider_state.contact_data.normal.grad.to_numpy() + dL_dcontact_penetration = self.collider._collider_state.contact_data.penetration.grad.to_numpy() + self.collider.backward(dL_dcontact_pos, dL_dcontact_normal, dL_dcontact_penetration) + else: + # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, + # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). + # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. + # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through + # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation + # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. + # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a + # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid + # the data access violation. However, both of these require major changes. + kernel_compute_qacc.grad( + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) kernel_copy_acc( f=f, dofs_state=self.dofs_state, - rigid_adjoint_cache=self._rigid_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache_fw, static_rigid_sim_config=self._static_rigid_sim_config, ) - kernel_forward_dynamics_without_qacc.grad( + kernel_bias_force.grad( + dofs_state=self.dofs_state, links_state=self.links_state, links_info=self.links_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_force.grad( + links_state=self.links_state, + links_info=self.links_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_acc.grad( dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - joints_info=self.joints_info, + links_info=self.links_info, + links_state=self.links_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_torque_and_passive_force.grad( entities_state=self.entities_state, entities_info=self.entities_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, geoms_state=self.geoms_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, ) + kernel_factor_mass.grad( + implicit_damping=False, + entities_info=self.entities_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_compute_mass_matrix_ad.grad( + implicit_damping=self._static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast, + links_state=self.links_state, + links_info=self.links_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + + # kernel_forward_dynamics_without_qacc.grad( + # links_state=self.links_state, + # links_info=self.links_info, + # dofs_state=self.dofs_state, + # dofs_info=self.dofs_info, + # joints_info=self.joints_info, + # entities_state=self.entities_state, + # entities_info=self.entities_info, + # geoms_state=self.geoms_state, + # rigid_global_info=self._rigid_global_info, + # static_rigid_sim_config=self._static_rigid_sim_config, + # contact_island_state=self.constraint_solver.contact_island.contact_island_state, + # ) # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, - dofs_state=self.dofs_state, + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_velocity_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_update_geoms.grad( + envs_idx, entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + links_state=self.links_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, ) - kernel_update_cartesian_space.grad( + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_COM_links_ad_1.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_COM_links_ad_0.grad( links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, joints_info=self.joints_info, dofs_state=self.dofs_state, dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, ) + for i_l_ in range(self._static_rigid_sim_config.max_n_links_per_entity): + kernel_forward_kinematics_ad.grad( + i_l_=self._static_rigid_sim_config.max_n_links_per_entity - 1 - i_l_, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) # Change back to forward mode self._static_rigid_sim_config.is_backward = False @@ -1392,6 +1630,27 @@ def reset_grad(self): for entity in self._entities: entity.reset_grad() self._queried_states.clear() + self.zero_grad() + + def zero_grad(self): + # zero grad + for state in [ + self._rigid_global_info, + self.links_state, + self.dofs_state, + self.geoms_state, + self.joints_state, + self.entities_state, + self.constraint_solver.constraint_state, + self.collider._collider_state.diff_contact_input, + self.collider._collider_state.contact_data, + self._rigid_adjoint_cache_fw, + self._rigid_adjoint_cache_bw, + ]: + for attr in state.__dict__.values(): + if hasattr(attr, 'grad') and attr.grad is not None: + attr.grad.fill(0.0) + def update_geoms_render_T(self): kernel_update_geoms_render_T( @@ -1494,20 +1753,32 @@ def save_ckpt(self, ckpt_name): if ckpt_name not in self._ckpt: self._ckpt[ckpt_name] = dict() - self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache.qpos) - self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_vel) - self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_acc) + self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache_fw.qpos, copy=True) + self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_vel, copy=True) + self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc, copy=True) + self._ckpt[ckpt_name]["dofs_acc_smooth"] = ti_to_numpy(self._rigid_adjoint_cache_fw.dofs_acc_smooth, copy=True) + self._ckpt[ckpt_name]["solver_qacc_ws"] = ti_to_numpy(self._rigid_adjoint_cache_fw.solver_qacc_ws, copy=True) for entity in self._entities: entity.save_ckpt(ckpt_name) def load_ckpt(self, ckpt_name): + # Load adjoint cache for backward pass + self._rigid_adjoint_cache_bw.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"]) + self._rigid_adjoint_cache_bw.dofs_vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"]) + self._rigid_adjoint_cache_bw.dofs_acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"]) + self._rigid_adjoint_cache_bw.dofs_acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"]) + self._rigid_adjoint_cache_bw.solver_qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"]) + # Set first frame self._rigid_global_info.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"][0]) self.dofs_state.vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"][0]) self.dofs_state.acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"][0]) + self.dofs_state.acc_smooth.from_numpy(self._ckpt[ckpt_name]["dofs_acc_smooth"][0]) + self.constraint_solver.constraint_state.qacc_ws.from_numpy(self._ckpt[ckpt_name]["solver_qacc_ws"][0]) if not self._enable_mujoco_compatibility: + envs_idx = self._scene._sanitize_envs_idx(None) kernel_update_cartesian_space( links_state=self.links_state, links_info=self.links_info, @@ -1522,9 +1793,47 @@ def load_ckpt(self, ckpt_name): static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, ) + kernel_forward_velocity( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) for entity in self._entities: entity.load_ckpt(ckpt_name) + + def load_test(self): + if not self._enable_mujoco_compatibility: + envs_idx = self._scene._sanitize_envs_idx(None) + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + ) + kernel_forward_velocity( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) @property def is_active(self): @@ -2056,6 +2365,14 @@ def control_dofs_force(self, force, dofs_idx=None, envs_idx=None): kernel_control_dofs_force(force, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + def control_dofs_force_grad(self, dofs_idx, envs_idx, force_grad): + force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables( + force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True + ) + if self.n_envs == 0: + force_grad_ = force_grad_.unsqueeze(0) + kernel_control_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + def control_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None): if gs.use_zerocopy: mask = (0, *indices_to_mask(dofs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, dofs_idx) @@ -2433,11 +2750,11 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True def clear_external_force(self): if gs.use_zerocopy: - for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel): + for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel, self.dofs_state.ctrl_force): out = ti_to_torch(tensor, copy=False) out.zero_() else: - kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config) + kernel_clear_external_force(self.links_state, self.dofs_state, self._rigid_global_info, self._static_rigid_sim_config) def update_vgeoms(self): kernel_update_vgeoms(self.vgeoms_info, self.vgeoms_state, self.links_state, self._static_rigid_sim_config) @@ -3340,6 +3657,27 @@ def func_vel_at_point(pos_world, link_idx, i_b, links_state: array_class.LinksSt vel_lin = links_state.cd_vel[link_idx, i_b] return vel_rot + vel_lin +@ti.kernel +def kernel_compute_mass_matrix_ad( + implicit_damping: ti.template(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_compute_mass_matrix( + implicit_damping=implicit_damping, + links_state=links_state, + links_info=links_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_compute_mass_matrix( @@ -3593,6 +3931,23 @@ def func_compute_mass_matrix( # qM += d qfrc_actuator / d qvel rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] +@ti.kernel +def kernel_factor_mass( + implicit_damping: ti.template(), + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_factor_mass( + implicit_damping=implicit_damping, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_factor_mass( @@ -3864,7 +4219,7 @@ def func_solve_mass_batched( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range(i_0, 0, rigid_global_info.n_awake_entities[i_b], BW): @@ -4125,11 +4480,13 @@ def kernel_forward_dynamics_without_qacc( @ti.kernel(fastcache=gs.use_fastcache) def kernel_clear_external_force( links_state: array_class.LinksState, + dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): func_clear_external_force( links_state=links_state, + dofs_state=dofs_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) @@ -4169,9 +4526,8 @@ def kernel_update_cartesian_space( ) -@ti.func -def func_update_cartesian_space( - i_b, +@ti.kernel +def kernel_forward_kinematics( links_state: array_class.LinksState, links_info: array_class.LinksInfo, joints_state: array_class.JointsState, @@ -4185,64 +4541,160 @@ def func_update_cartesian_space( static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), ): - func_forward_kinematics( - i_b, - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - func_COM_links( - i_b, - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - func_update_geoms( - i_b=i_b, - entities_info=entities_info, - geoms_info=geoms_info, - geoms_state=geoms_state, - links_state=links_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=force_update_fixed_geoms, - ) - + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_forward_kinematics( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) -@ti.kernel(fastcache=gs.use_fastcache) -def kernel_step_1( +@ti.kernel +def kernel_forward_kinematics_ad( + i_l_:ti.int32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, joints_state: array_class.JointsState, joints_info: array_class.JointsInfo, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, - geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, - entities_state: array_class.EntitiesState, + geoms_state: array_class.GeomsState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - contact_island_state: array_class.ContactIslandState, + force_update_fixed_geoms: ti.template(), ): - if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): - _B = links_state.pos.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - func_update_cartesian_space( - i_b=i_b, + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + func_forward_kinematics_entity_ad( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel +def kernel_COM_links( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_COM_links( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.func +def func_update_cartesian_space( + i_b, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), +): + func_forward_kinematics( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + func_COM_links( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + func_update_geoms( + i_b=i_b, + entities_info=entities_info, + geoms_info=geoms_info, + geoms_state=geoms_state, + links_state=links_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=force_update_fixed_geoms, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_step_1( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + contact_island_state: array_class.ContactIslandState, +): + if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): + _B = links_state.pos.shape[1] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + func_update_cartesian_space( + i_b=i_b, links_state=links_state, links_info=links_info, joints_state=joints_state, @@ -4530,6 +4982,34 @@ def kernel_forward_velocity( ) +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_forward_velocity_ad( + i_l_:ti.int32, + envs_idx: ti.types.ndarray(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + for i_e, i_b_ in ti.ndrange(entities_info.n_links.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + func_forward_velocity_entity_ad( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + @ti.func def func_COM_links( i_b, @@ -4863,6 +5343,240 @@ def func_COM_links( dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] ) +@ti.kernel +def kernel_COM_links_ad_0( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + links_state.root_COM_bw[i_l, i_b].fill(0.0) + links_state.mass_sum[i_l, i_b] = 0.0 + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.i_pos_bw[i_l, i_b], + links_state.i_quat[i_l, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], + links_info.inertial_quat[I_l], + links_state.pos[i_l, i_b], + links_state.quat[i_l, i_b], + ) + + i_r = links_info.root_idx[I_l] + links_state.mass_sum[i_r, i_b] += mass + links_state.root_COM_bw[i_r, i_b] += mass * links_state.i_pos_bw[i_l, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + i_r = links_info.root_idx[I_l] + if i_l == i_r: + links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + i_r = links_info.root_idx[I_l] + links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + links_state.i_pos[i_l, i_b] = links_state.i_pos_bw[i_l, i_b] - links_state.root_COM[i_l, i_b] + + i_inertial = links_info.inertial_i[I_l] + i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.cinr_inertial[i_l, i_b], + links_state.cinr_pos[i_l, i_b], + links_state.cinr_quat[i_l, i_b], + links_state.cinr_mass[i_l, i_b], + ) = gu.ti_transform_inertia_by_trans_quat( + i_inertial, + i_mass, + links_state.i_pos[i_l, i_b], + links_state.i_quat[i_l, i_b], + rigid_global_info.EPS[None], + ) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]): + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + BW = ti.static(static_rigid_sim_config.is_backward) + + if links_info.n_dofs[I_l] > 0: + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ if ti.static(not BW) else (i_j_ + links_info.joint_start[I_l]) + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + + dof_start = joints_info.dof_start[I_j] + + EPS = rigid_global_info.EPS[None] + if joint_type == gs.JOINT_TYPE.REVOLUTE: + dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.FREE: + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0 + + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + dofs_state.cdofvel_ang[i_d, i_b] = ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + dofs_state.cdofvel_vel[i_d, i_b] = ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + +@ti.kernel +def kernel_COM_links_ad_1( + i_l_:ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + func_COM_links_ad_1( + i_e=i_e, + i_l_=i_l_, + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + +@ti.func +def func_COM_links_ad_1( + i_e: ti.int32, + i_l_:ti.int32, + i_b: ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + i_l = i_l_ + entities_info.link_start[i_e] + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.n_dofs[I_l] > 0: + i_p = links_info.parent_idx[I_l] + + _i_j = links_info.joint_start[I_l] + _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j + joint_type = joints_info.type[_I_j] + + p_pos = ti.Vector.zero(gs.ti_float, 3) + p_quat = gu.ti_identity_quat() + if i_p != -1: + p_pos = links_state.pos[i_p, i_b] + p_quat = links_state.quat[i_p, i_b] + + if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): + links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] + links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] + else: + ( + links_state.j_pos_bw[i_l, 0, i_b], + links_state.j_quat_bw[i_l, 0, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + curr_i_j = 0 if ti.static(not BW) else i_j_ + next_i_j = 0 if ti.static(not BW) else i_j_ + 1 + + if func_check_index_range( + i_j, + links_info.joint_start[I_l], + links_info.joint_end[I_l], + BW, + ): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + ( + links_state.j_pos_bw[i_l, next_i_j, i_b], + links_state.j_quat_bw[i_l, next_i_j, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + joints_info.pos[I_j], + gu.ti_identity_quat(), + links_state.j_pos_bw[i_l, curr_i_j, i_b], + links_state.j_quat_bw[i_l, curr_i_j, i_b], + ) + + i_j_ = 0 if ti.static(not BW) else n_joints + links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b] + links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b] @ti.func def func_forward_kinematics( @@ -4889,7 +5603,7 @@ def func_forward_kinematics( else ( ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range( @@ -4941,7 +5655,7 @@ def func_forward_velocity( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) - else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + else ti.static(range(static_rigid_sim_config.n_entities)) ) ): if func_check_index_range( @@ -5155,6 +5869,160 @@ def func_forward_kinematics_entity( links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW) links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW) +@ti.func +def func_forward_kinematics_entity_ad( + i_e, + i_l_, + i_b, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + W = ti.static(func_write_field_if) + R = ti.static(func_read_field_if) + WR = ti.static(func_write_and_read_field_if) + + EPS = rigid_global_info.EPS[None] + + i_l = i_l_ + entities_info.link_start[i_e] + + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + I_l0 = (i_l, 0, i_b) + + pos = W(links_state.pos_bw, I_l0, links_info.pos[I_l], BW) + quat = W(links_state.quat_bw, I_l0, links_info.quat[I_l], BW) + if links_info.parent_idx[I_l] != -1: + parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] + parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] + pos_ = parent_pos + gu.ti_transform_by_quat(links_info.pos[I_l], parent_quat) + quat_ = gu.ti_transform_quat_by_quat(links_info.quat[I_l], parent_quat) + + pos = W(links_state.pos_bw, I_l0, pos_, BW) + quat = W(links_state.quat_bw, I_l0, quat_, BW) + + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b) + next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b) + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + + # compute axis and anchor + if joint_type == gs.JOINT_TYPE.FREE: + joints_state.xanchor[i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ] + ) + joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + else: + axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) + if joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + axis = dofs_info.motion_vel[I_d] + + pos_ = R(links_state.pos_bw, curr_I, pos, BW) + quat_ = R(links_state.quat_bw, curr_I, quat, BW) + + joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat_) + pos_ + joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat_) + + if joint_type == gs.JOINT_TYPE.FREE: + pos_ = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ], + dt=gs.ti_float, + ) + quat_ = ti.Vector( + [ + rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + 4, i_b], + rigid_global_info.qpos[q_start + 5, i_b], + rigid_global_info.qpos[q_start + 6, i_b], + ], + dt=gs.ti_float, + ) + pos = WR(links_state.pos_bw, next_I, pos_, BW) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + + xyz = gu.ti_quat_to_xyz(quat, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = pos[j] + dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + print("SPHERICAL") + qloc = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[q_start + 3, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(qloc, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = xyz[j] + quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW)) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = W(links_state.pos_bw, next_I, pos_, BW) + elif joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) + quat_ = gu.ti_transform_quat_by_quat(qloc, R(links_state.quat_bw, curr_I, quat, BW)) + quat = WR(links_state.quat_bw, next_I, quat_, BW) + pos_ = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = W(links_state.pos_bw, next_I, pos_, BW) + else: # joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + pos_ = ( + R(links_state.pos_bw, curr_I, pos, BW) + + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + ) + pos = W(links_state.pos_bw, next_I, pos_, BW) + + # Skip link pose update for fixed root links to let users manually overwrite them + I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b) + if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): + links_state.pos[i_l, i_b] = R(links_state.pos_bw, I_jf, pos, BW) + links_state.quat[i_l, i_b] = R(links_state.quat_bw, I_jf, quat, BW) + @ti.func def func_forward_velocity_entity( @@ -5281,6 +6149,127 @@ def func_forward_velocity_entity( links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW) links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW) +@ti.func +def func_forward_velocity_entity_ad( + i_e, + i_l_, + i_b, + entities_info: array_class.EntitiesInfo, + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + BW = ti.static(static_rigid_sim_config.is_backward) + W = ti.static(func_write_field_if) + R = ti.static(func_read_field_if) + A = ti.static(func_atomic_add_if) + + i_l = i_l_ + entities_info.link_start[i_e] + + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + I_j0 = (i_l, 0, i_b) + cvel_vel = W(links_state.cd_vel_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW) + cvel_ang = W(links_state.cd_ang_bw, I_j0, ti.Vector.zero(gs.ti_float, 3), BW) + + if links_info.parent_idx[I_l] != -1: + cvel_vel = W(links_state.cd_vel_bw, I_j0, links_state.cd_vel[links_info.parent_idx[I_l], i_b], BW) + cvel_ang = W(links_state.cd_ang_bw, I_j0, links_state.cd_ang[links_info.parent_idx[I_l], i_b], BW) + + for i_j_ in ( + range(n_joints) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + if func_check_index_range(i_j, links_info.joint_start[I_l], links_info.joint_end[I_l], BW): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + + curr_I = (i_l, 0 if ti.static(not BW) else i_j_, i_b) + next_I = (i_l, 0 if ti.static(not BW) else i_j_ + 1, i_b) + + if joint_type == gs.JOINT_TYPE.FREE: + for i_3 in ti.static(range(3)): + _vel = dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + _ang = dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, curr_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, curr_I, _ang, BW) + + for i_3 in ti.static(range(3)): + ( + dofs_state.cdofd_ang[dof_start + i_3, i_b], + dofs_state.cdofd_vel[dof_start + i_3, i_b], + ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + + ( + dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], + ) = gu.motion_cross_motion( + R(links_state.cd_ang_bw, curr_I, cvel_ang, BW), + R(links_state.cd_vel_bw, curr_I, cvel_vel, BW), + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], + ) + + if ti.static(BW): + links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I] + links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I] + + for i_3 in ti.static(range(3)): + _vel = ( + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + _ang = ( + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW) + + else: + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( + R(links_state.cd_ang_bw, curr_I, cvel_ang, BW), + R(links_state.cd_vel_bw, curr_I, cvel_vel, BW), + dofs_state.cdof_ang[i_d, i_b], + dofs_state.cdof_vel[i_d, i_b], + ) + + if ti.static(BW): + links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I] + links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I] + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not BW) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not BW) else (i_d_ + dof_start) + if func_check_index_range(i_d, dof_start, joints_info.dof_end[I_j], BW): + _vel = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + _ang = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW) + cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW) + + I_jf = (i_l, 0 if ti.static(not BW) else n_joints, i_b) + links_state.cd_vel[i_l, i_b] = R(links_state.cd_vel_bw, I_jf, cvel_vel, BW) + links_state.cd_ang[i_l, i_b] = R(links_state.cd_ang_bw, I_jf, cvel_ang, BW) + @ti.kernel(fastcache=gs.use_fastcache) def kernel_update_geoms( @@ -5358,7 +6347,9 @@ def func_update_geoms( ) ): i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - if func_check_index_range(i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], BW): + if func_check_index_range( + i_g, entities_info.geom_start[i_e], entities_info.geom_end[i_e], static_rigid_sim_config.use_hibernation + ): if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( geoms_state.pos[i_g, i_b], @@ -5717,6 +6708,7 @@ def func_apply_link_external_torque( @ti.func def func_clear_external_force( links_state: array_class.LinksState, + dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): @@ -5735,7 +6727,38 @@ def func_clear_external_force( i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for I in ti.grouped(dofs_state.ctrl_force): + dofs_state.ctrl_force[I] = ti.Vector.zero(gs.ti_float, 3) +@ti.kernel +def kernel_torque_and_passive_force( + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + contact_island_state: array_class.ContactIslandState, +): + func_torque_and_passive_force( + entities_state=entities_state, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + links_state=links_state, + links_info=links_info, + joints_info=joints_info, + geoms_state=geoms_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + contact_island_state=contact_island_state, + ) @ti.func def func_torque_and_passive_force( @@ -6055,6 +7078,21 @@ def func_update_acc( BW, ) +@ti.kernel +def kernel_update_force( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_update_force( + links_state=links_state, + links_info=links_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func def func_update_force( @@ -6195,6 +7233,22 @@ def func_actuation(self): self.dofs_state.act_length[i_d, i_b] = 0.0 self.dofs_state.qf_actuator[i_d, i_b] = self.dofs_state.act_length[i_d, i_b] +@ti.kernel +def kernel_bias_force( + dofs_state: array_class.DofsState, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_bias_force( + dofs_state=dofs_state, + links_state=links_state, + links_info=links_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + @ti.func def func_bias_force( @@ -6525,17 +7579,19 @@ def func_copy_next_to_curr_grad( def kernel_save_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), ): - func_save_adjoint_cache(f, dofs_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config) + func_save_adjoint_cache(f, dofs_state, constraint_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config) @ti.func def func_save_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), @@ -6548,6 +7604,8 @@ def func_save_adjoint_cache( for i_d, i_b in ti.ndrange(n_dofs, _B): rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b] rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b] + rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b in ti.ndrange(n_qs, _B): @@ -6558,6 +7616,7 @@ def func_save_adjoint_cache( def func_load_adjoint_cache( f: ti.int32, dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), @@ -6570,6 +7629,8 @@ def func_load_adjoint_cache( for i_d, i_b in ti.ndrange(n_dofs, _B): dofs_state.vel[i_d, i_b] = rigid_adjoint_cache.dofs_vel[f, i_d, i_b] dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b] + dofs_state.acc_smooth[i_d, i_b] = rigid_adjoint_cache.dofs_acc_smooth[f, i_d, i_b] + constraint_state.qacc_ws[i_d, i_b] = rigid_adjoint_cache.solver_qacc_ws[f, i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b in ti.ndrange(n_qs, _B): @@ -6587,6 +7648,7 @@ def kernel_prepare_backward_substep( dofs_info: array_class.DofsInfo, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, + constraint_state: array_class.ConstraintState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, dofs_state_adjoint_cache: array_class.DofsState, @@ -6600,6 +7662,7 @@ def kernel_prepare_backward_substep( func_load_adjoint_cache( f=f, dofs_state=dofs_state, + constraint_state=constraint_state, rigid_global_info=rigid_global_info, rigid_adjoint_cache=rigid_adjoint_cache, static_rigid_sim_config=static_rigid_sim_config, @@ -6625,6 +7688,16 @@ def kernel_prepare_backward_substep( static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, ) + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) # FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names. # Save results of [update_cartesian_space] to adjoint cache @@ -6658,20 +7731,29 @@ def kernel_begin_backward_substep( links_state_adjoint_cache: array_class.LinksState, joints_state_adjoint_cache: array_class.JointsState, geoms_state_adjoint_cache: array_class.GeomsState, - rigid_adjoint_cache: array_class.RigidAdjointCache, + rigid_adjoint_cache_fw: array_class.RigidAdjointCache, + rigid_adjoint_cache_bw: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), ) -> ti.i32: + errno = 0 is_grad_valid = func_is_grad_valid( rigid_global_info=rigid_global_info, dofs_state=dofs_state, static_rigid_sim_config=static_rigid_sim_config, ) - if is_grad_valid: + # Check the integrity of the next frame's adjoint cache as it is the computation result of the current frame + is_cache_valid = func_check_cache_integrity( + f=f + 1, + rigid_adjoint_cache_fw=rigid_adjoint_cache_fw, + rigid_adjoint_cache_bw=rigid_adjoint_cache_bw, + static_rigid_sim_config=static_rigid_sim_config, + ) + if is_grad_valid and is_cache_valid: func_copy_next_to_curr_grad( f=f, dofs_state=dofs_state, rigid_global_info=rigid_global_info, - rigid_adjoint_cache=rigid_adjoint_cache, + rigid_adjoint_cache=rigid_adjoint_cache_fw, static_rigid_sim_config=static_rigid_sim_config, ) @@ -6689,8 +7771,12 @@ def kernel_begin_backward_substep( geoms_state_adjoint_cache=geoms_state_adjoint_cache, static_rigid_sim_config=static_rigid_sim_config, ) + elif not is_grad_valid: + errno = 1 + elif not is_cache_valid: + errno = 2 - return is_grad_valid + return errno @ti.func @@ -6713,6 +7799,37 @@ def func_is_grad_valid( return is_valid +@ti.func +def func_check_cache_integrity( + f: ti.int32, + rigid_adjoint_cache_fw: array_class.RigidAdjointCache, + rigid_adjoint_cache_bw: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + is_valid = True + n_qs = rigid_adjoint_cache_fw.qpos.shape[1] + n_dofs = rigid_adjoint_cache_fw.dofs_vel.shape[1] + _B = rigid_adjoint_cache_fw.qpos.shape[2] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + if rigid_adjoint_cache_fw.qpos[f, i_q, i_b] != rigid_adjoint_cache_bw.qpos[f, i_q, i_b]: + is_valid = False + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + if rigid_adjoint_cache_fw.dofs_vel[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_vel[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.dofs_acc[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.dofs_acc_smooth[f, i_d, i_b] != rigid_adjoint_cache_bw.dofs_acc_smooth[f, i_d, i_b]: + is_valid = False + if rigid_adjoint_cache_fw.solver_qacc_ws[f, i_d, i_b] != rigid_adjoint_cache_bw.solver_qacc_ws[f, i_d, i_b]: + is_valid = False + + return is_valid + + @ti.func def func_copy_cartesian_space( dofs_state: array_class.DofsState, @@ -7586,6 +8703,19 @@ def kernel_control_dofs_force( dofs_state.ctrl_mode[dofs_idx[i_d_], envs_idx[i_b_]] = gs.CTRL_MODE.FORCE dofs_state.ctrl_force[dofs_idx[i_d_], envs_idx[i_b_]] = force[i_b_, i_d_] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_control_dofs_force_grad( + force_grad: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + dofs_state: array_class.DofsState, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] + dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + @ti.kernel(fastcache=gs.use_fastcache) def kernel_control_dofs_velocity( @@ -7848,3 +8978,14 @@ def func_write_and_read_field_if(field: array_class.V_ANNOTATION, I, value, cond def func_check_index_range(idx: ti.i32, min: ti.i32, max: ti.i32, cond: ti.template()): # Conditionally check if the index is in the range [min, max) to save computational cost return (idx >= min and idx < max) if ti.static(cond) else True + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_zero_grad( + links_state: array_class.LinksState, + dofs_state: array_class.DofsState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + pass \ No newline at end of file diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py index b6ee1f6dd1..06ce22a501 100644 --- a/genesis/engine/states/entities.py +++ b/genesis/engine/states/entities.py @@ -204,13 +204,19 @@ def __init__(self, entity, s_global): scene = self._entity.scene self._pos = gs.zeros((num_batch, 3), dtype=float, requires_grad=requires_grad, scene=scene) self._quat = gs.zeros((num_batch, 4), dtype=float, requires_grad=requires_grad, scene=scene) + self._qpos = gs.zeros((num_batch, entity.n_qs), dtype=float, requires_grad=requires_grad, scene=scene) + self._dofs_vel = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene) + self._dofs_acc = gs.zeros((num_batch, entity.n_dofs), dtype=float, requires_grad=requires_grad, scene=scene) def serializable(self): self._entity = None self._pos = self._pos.detach() self._quat = self._quat.detach() - + self._qpos = self._qpos.detach() + self._dofs_vel = self._dofs_vel.detach() + self._dofs_acc = self._dofs_acc.detach() + @property def entity(self): return self._entity @@ -226,3 +232,15 @@ def pos(self): @property def quat(self): return self._quat + + @property + def qpos(self): + return self._qpos + + @property + def dofs_vel(self): + return self._dofs_vel + + @property + def dofs_acc(self): + return self._dofs_acc \ No newline at end of file diff --git a/genesis/options/solvers.py b/genesis/options/solvers.py index 67da878c2c..5315b7b61c 100644 --- a/genesis/options/solvers.py +++ b/genesis/options/solvers.py @@ -36,6 +36,8 @@ class SimOptions(Options): Height of the floor in meters. Defaults to 0.0. requires_grad : bool, optional Whether to enable differentiable mode. Defaults to False. + substeps_grad: int, optional + Number of steps that constitutes a window for gradient computation, defaults to None (window not used). use_hydroelastic_contact : bool, optional Whether to use hydroelastic contact. Defaults to False. """ @@ -46,6 +48,7 @@ class SimOptions(Options): gravity: tuple = (0.0, 0.0, -9.81) floor_height: float = 0.0 requires_grad: bool = False + grad_window_steps: Optional[int] = None # not set by user _steps_local: Optional[int] = None diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index f0baaca6b0..3bc2eea24c 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -291,15 +291,15 @@ def get_constraint_state(constraint_solver, solver): efc_AR=V(dtype=gs.ti_float, shape=efc_AR_shape), active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)), prev_active=V(dtype=gs.ti_bool, shape=(len_constraints_, _B)), - diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + diag=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), + aref=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), Jaref=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + efc_frictionloss=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), efc_force=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), - efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), + efc_D=V(dtype=gs.ti_float, shape=(len_constraints_, _B), needs_grad=solver._requires_grad), jv=V(dtype=gs.ti_float, shape=(len_constraints_, _B)), quad=V(dtype=gs.ti_float, shape=(len_constraints_, 3, _B)), - jac=V(dtype=gs.ti_float, shape=jac_shape), + jac=V(dtype=gs.ti_float, shape=jac_shape, needs_grad=solver._requires_grad), jac_relevant_dofs=V(dtype=gs.ti_int, shape=jac_relevant_dofs_shape), jac_n_relevant_dofs=V(dtype=gs.ti_int, shape=jac_n_relevant_dofs_shape), # Backward gradients @@ -1776,6 +1776,11 @@ class StructRigidAdjointCache(metaclass=BASE_METACLASS): qpos: V_ANNOTATION dofs_vel: V_ANNOTATION dofs_acc: V_ANNOTATION + # We also store the initial solutions (acc_smooth, qacc_ws) for the constraint solver to use in the backward pass. + # For [acc_smooth], even though it could be reproduced during the backward pass and thus we do not need to store it, + # we do it to compare the reproduced value with the stored one to ensure the integrity of the backward pass. + dofs_acc_smooth: V_ANNOTATION + solver_qacc_ws: V_ANNOTATION def get_rigid_adjoint_cache(solver): @@ -1786,6 +1791,8 @@ def get_rigid_adjoint_cache(solver): qpos=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_qs_, solver._B), needs_grad=requires_grad), dofs_vel=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), dofs_acc=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + dofs_acc_smooth=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)), + solver_qacc_ws=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B)), ) @@ -1818,6 +1825,9 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta): max_n_geoms_per_entity: int = -1 n_links: int = -1 n_geoms: int = -1 + n_dofs: int = -1 + n_entities: int = -1 + max_contact_pairs: int = -1 # =========================================== DataManager =========================================== @@ -1862,7 +1872,10 @@ def __init__(self, solver): self.joints_state_adjoint_cache = get_joints_state(solver) self.geoms_state_adjoint_cache = get_geoms_state(solver) - self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver) + # We use a pair of adjoint cache, one of which is used for the forward pass and the other is used for the + # backward pass to check the integrity of the backward pass. + self.rigid_adjoint_cache_fw = get_rigid_adjoint_cache(solver) + self.rigid_adjoint_cache_bw = get_rigid_adjoint_cache(solver) self.errno = V_SCALAR_FROM(dtype=gs.ti_int, value=0) diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index e1a052c8e3..20397402b4 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -82,6 +82,7 @@ def ti_rotvec_to_R(rotvec, eps): @ti.func def ti_rotvec_to_quat(rotvec, eps): quat = ti.Vector.zero(gs.ti_float, 4) + res = ti.Vector.zero(gs.ti_float, 4) # We need to use [norm_sqr] instead of [norm] to avoid nan gradients in the backward pass. Even when theta = 0, # the gradient of [norm] operation is computed and used (note that the gradient becomes NaN when theta = 0). This @@ -98,11 +99,12 @@ def ti_rotvec_to_quat(rotvec, eps): quat[i + 1] = xyz[i] # First order quaternion normalization is accurate enough yet necessary - quat *= 0.5 * (3.0 - quat.norm_sqr()) + # quat *= 0.5 * (3.0 - quat.norm_sqr()) + res = quat * 0.5 * (3.0 - quat.norm_sqr()) else: - quat[0] = 1.0 + res[0] = 1.0 - return quat + return res @ti.func @@ -221,7 +223,12 @@ def ti_transform_quat_by_quat(v, u): This is equivalent to quatmul(quat_u, quat_v) or R_u @ R_v """ vec = ti_quat_mul(u, v) - return vec.normalized() + res = ti.Vector([1.0, 0.0, 0.0, 0.0], dt=gs.ti_float) + vec_norm_sqr = vec.norm_sqr() + if vec_norm_sqr > gs.EPS ** 2: + res = vec / (ti.sqrt(vec_norm_sqr) + gs.EPS) + return res + # return vec.normalized() @ti.func @@ -239,7 +246,7 @@ def ti_transform_by_quat(v, quat): v.x * (-2.0 * q_wy + 2.0 * q_xz) + v.y * (2.0 * q_wx + 2.0 * q_yz) + v.z * (q_ww - q_xx - q_yy + q_zz), ], dt=gs.ti_float, - ) / (q_ww + q_xx + q_yy + q_zz) + ) / (q_ww + q_xx + q_yy + q_zz + gs.EPS) @ti.func