From 8a97d38cd0fa497a522480a225f000f05e2d92e6 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Tue, 10 Mar 2026 11:36:52 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 881542746 --- .../_src/python/dataset/transformations/interleave.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9a387ae07..9db05ba90 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -336,10 +336,12 @@ def _add_prefetch_and_make_iterator( raise RuntimeError("InterleaveDatasetIterator has been garbage collected.") if isinstance(ds, dataset.MapDataset): # Prefetch is automatically added in `MapDataset.__iter__`. - return ds.__iter__() - iterator = prefetch.ThreadPrefetchIterDataset( - ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access - ).__iter__() + iter_dataset = ds.to_iter_dataset() + else: + iter_dataset = prefetch.ThreadPrefetchIterDataset( + ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access + ) + iterator = iter_dataset.__iter__() # Propagate options applied after InterleaveIterDataset to the iterators that # are being interleaved. iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access