diff --git a/src/openpi/training/droid_rlds_dataset.py b/src/openpi/training/droid_rlds_dataset.py index ebe863acf6..d0f7f63995 100644 --- a/src/openpi/training/droid_rlds_dataset.py +++ b/src/openpi/training/droid_rlds_dataset.py @@ -203,7 +203,8 @@ def decode_images(traj): dataset = dataset.frame_map(decode_images, num_parallel_calls) # Shuffle, batch - dataset = dataset.shuffle(shuffle_buffer_size) + if shuffle: + dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.batch(batch_size) # Note =>> Seems to reduce memory usage without affecting speed? dataset = dataset.with_ram_budget(1)