diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9a387ae07..9db05ba90 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -336,10 +336,12 @@ def _add_prefetch_and_make_iterator( raise RuntimeError("InterleaveDatasetIterator has been garbage collected.") if isinstance(ds, dataset.MapDataset): # Prefetch is automatically added in `MapDataset.__iter__`. - return ds.__iter__() - iterator = prefetch.ThreadPrefetchIterDataset( - ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access - ).__iter__() + iter_dataset = ds.to_iter_dataset() + else: + iter_dataset = prefetch.ThreadPrefetchIterDataset( + ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access + ) + iterator = iter_dataset.__iter__() # Propagate options applied after InterleaveIterDataset to the iterators that # are being interleaved. iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access