diff --git a/dlimp/dataset.py b/dlimp/dataset.py index 75dd5f3..2b9ea23 100644 --- a/dlimp/dataset.py +++ b/dlimp/dataset.py @@ -57,11 +57,11 @@ def __getattribute__(self, name): return _wrap(attr, None) return attr - def _apply_options(self): + def _apply_options(self, deterministic: bool): """Applies some default options for performance.""" options = tf.data.Options() options.autotune.enabled = True - options.deterministic = False + options.deterministic = deterministic options.experimental_optimization.apply_default_optimizations = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_and_filter_fusion = True @@ -130,6 +130,7 @@ def from_rlds( split: str = "train", shuffle: bool = True, num_parallel_reads: int = tf.data.AUTOTUNE, + deterministic: bool = False, ) -> "DLataset": """Creates a DLataset from the RLDS format (which is a special case of the TFDS format). @@ -150,7 +151,7 @@ def from_rlds( num_parallel_calls_for_interleave_files=num_parallel_reads, interleave_cycle_length=num_parallel_reads, ), - )._apply_options() + )._apply_options(deterministic=deterministic) dataset = dataset.enumerate().traj_map(_broadcast_metadata_rlds)