From 2f28acf3a1c45656495febcd9ba73e27cd8081e9 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Tue, 3 Mar 2026 14:28:22 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 878126445 --- grain/_src/python/dataset/dataset.py | 11 +++++++++-- grain/_src/python/dataset/transformations/prefetch.py | 8 +++++++- .../dataset/transformations/process_prefetch.py | 9 +++++++-- grain/_src/python/dataset/transformations/repeat.py | 2 +- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index e506508b1..0ff646b53 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1739,17 +1739,24 @@ class _OutputIterDataset(IterDataset[T]): def __iter__(self) -> DatasetIterator[T]: """Performs any injection that needs to happen at the end of the pipeline.""" - # Loaded lazily due to a circular dependency (dataset <-> prefetch). + # Loaded lazily due to a circular dependency (dataset <-> prefetch and + # dataset <-> process_prefetch). # pylint: disable=g-import-not-at-top from grain._src.python.dataset.transformations import prefetch + from grain._src.python.dataset.transformations import process_prefetch # pylint: enable=g-import-not-at-top iterator = self._parent.__iter__() if ( is_thread_prefetch_injection_enabled() and not iterator._ctx.is_dataloader_pipeline # pylint: disable=protected-access + and not process_prefetch.is_in_worker_process() ): if not prefetch.is_prefetch_iterator(iterator): - iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1) + try: + iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1) + except AttributeError: + # Some legacy iterators do not implement the `get_state` method. + pass filter_mode = traceback_filter_mode() if filter_mode != "off": diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index b2c309a92..a9cde4843 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -454,7 +454,7 @@ def _put_iterator_elements_in_buffer( try: while not should_stop.is_set(): element = stats.record_bytes_consumed(iterator.__next__()) - state = iterator.get_state() + state = copy.deepcopy(iterator.get_state()) buffer.put((element, state, None)) except Exception as e: # pylint: disable=broad-except buffer.put((None, None, e)) @@ -736,11 +736,17 @@ def multithread_prefetch( def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool: """Returns whether the iterator is a prefetch iterator.""" + # Loaded lazily due to a circular dependency (prefetch <-> process_prefetch). + # pylint: disable=g-import-not-at-top + from grain._src.python.dataset.transformations import process_prefetch + # pylint: enable=g-import-not-at-top + return isinstance( it, ( PrefetchDatasetIterator, ThreadPrefetchDatasetIterator, interleave.InterleaveDatasetIterator, + process_prefetch.ProcessPrefetchDatasetIterator, ), ) diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 8d5a3fc27..3b033efcc 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -57,6 +57,11 @@ _is_in_worker_process = False +def is_in_worker_process() -> bool: + """Returns whether the current process is a worker process.""" + return _is_in_worker_process + + def _run_all(fns: Sequence[Callable[[], None]]): for fn in fns: fn() @@ -165,7 +170,7 @@ def __str__(self) -> str: return f"ProcessPrefetchIterDataset(buffer_size={self._buffer_size})" def __iter__(self) -> dataset.DatasetIterator[T]: - return _ProcessPrefetchDatasetIterator( + return ProcessPrefetchDatasetIterator( self._parent, self._buffer_size, self._worker_init_fn, @@ -291,7 +296,7 @@ def _close_if_alive( iterator.close() -class _ProcessPrefetchDatasetIterator(dataset.DatasetIterator[T]): +class ProcessPrefetchDatasetIterator(dataset.DatasetIterator[T]): """Iterator that performs prefetching using a background process.""" def __init__( diff --git a/grain/_src/python/dataset/transformations/repeat.py b/grain/_src/python/dataset/transformations/repeat.py index 4921cab46..e1d2cc550 100644 --- a/grain/_src/python/dataset/transformations/repeat.py +++ b/grain/_src/python/dataset/transformations/repeat.py @@ -102,7 +102,7 @@ def __init__( to_visit = [self] while to_visit: node = to_visit.pop(0) - if isinstance(node, process_prefetch._ProcessPrefetchDatasetIterator): # pylint: disable=protected-access + if isinstance(node, process_prefetch.ProcessPrefetchDatasetIterator): node.set_keep_workers_after_stop_iteration(True) if isinstance(node, interleave.InterleaveDatasetIterator): node.set_keep_iterators_after_stop_iteration(True)