diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index fc6f553bb..534c99cfc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -40,6 +40,7 @@ Guidelines for modifications: * Xinjie Wang * Xinying Guo * Yikai Tang +* Yixiao Huang * Yongce Liu * Yu Hong * Yuchen Huang diff --git a/roboverse_learn/il/dp/models/ddpm_image_policy.py b/roboverse_learn/il/dp/models/ddpm_image_policy.py index d1ec3a38d..d54abd275 100644 --- a/roboverse_learn/il/dp/models/ddpm_image_policy.py +++ b/roboverse_learn/il/dp/models/ddpm_image_policy.py @@ -101,6 +101,10 @@ def conditional_sample( # Set diffusion steps. scheduler.set_timesteps(self.num_inference_steps) + # Ensure timesteps are on the same device as trajectory + device = trajectory.device + scheduler.timesteps = scheduler.timesteps.to(device=device) + step_kwargs = dict(self.scheduler_step_kwargs) step_kwargs.update(kwargs) @@ -109,7 +113,7 @@ def conditional_sample( trajectory[condition_mask] = condition_data[condition_mask] # 2. Predict model output. - t = t.to(device=trajectory.device) + # t = t.to(device=trajectory.device) model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) # 3. Compute previous sample x_t -> x_{t-1}.