From 3b4b6a36264584f9c3e07234a846c402c0eb01c0 Mon Sep 17 00:00:00 2001 From: Kevin Zakka Date: Wed, 24 Apr 2024 18:08:55 -0700 Subject: [PATCH] Make deterministic an option. --- dlimp/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dlimp/dataset.py b/dlimp/dataset.py index 75dd5f3..2b9ea23 100644 --- a/dlimp/dataset.py +++ b/dlimp/dataset.py @@ -57,11 +57,11 @@ def __getattribute__(self, name): return _wrap(attr, None) return attr - def _apply_options(self): + def _apply_options(self, deterministic: bool): """Applies some default options for performance.""" options = tf.data.Options() options.autotune.enabled = True - options.deterministic = False + options.deterministic = deterministic options.experimental_optimization.apply_default_optimizations = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_and_filter_fusion = True @@ -130,6 +130,7 @@ def from_rlds( split: str = "train", shuffle: bool = True, num_parallel_reads: int = tf.data.AUTOTUNE, + deterministic: bool = False, ) -> "DLataset": """Creates a DLataset from the RLDS format (which is a special case of the TFDS format). @@ -150,7 +151,7 @@ def from_rlds( num_parallel_calls_for_interleave_files=num_parallel_reads, interleave_cycle_length=num_parallel_reads, ), - )._apply_options() + )._apply_options(deterministic=deterministic) dataset = dataset.enumerate().traj_map(_broadcast_metadata_rlds)