diff --git a/get_started/obj_layout/object_layout_task.py b/get_started/obj_layout/object_layout_task.py index fdc75cad9..909e13961 100644 --- a/get_started/obj_layout/object_layout_task.py +++ b/get_started/obj_layout/object_layout_task.py @@ -427,6 +427,9 @@ class Args: ## Step timing min_step_time: float = 0.001 + ## Device + device: str = "cuda" + def __post_init__(self): log.info(f"Args: {self}") @@ -476,7 +479,7 @@ def __post_init__(self): # Create task environment tic = time.time() - device = torch.device("cuda") + device = torch.device(args.device) env = task_cls(scenario, device=device) # Optionally wrap with Viser visualization