File tree Expand file tree Collapse file tree
grain/_src/python/dataset Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1740,16 +1740,23 @@ def __iter__(self) -> DatasetIterator[T]:
17401740 """Performs any injection that needs to happen at the end of the pipeline."""
17411741
17421742 # Loaded lazily due to a circular dependency (dataset <-> prefetch).
1743+ # Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
17431744 # pylint: disable=g-import-not-at-top
17441745 from grain ._src .python .dataset .transformations import prefetch
1746+ from grain ._src .python .dataset .transformations import process_prefetch
17451747 # pylint: enable=g-import-not-at-top
17461748 iterator = self ._parent .__iter__ ()
17471749 if (
17481750 is_thread_prefetch_injection_enabled ()
17491751 and not iterator ._ctx .is_dataloader_pipeline # pylint: disable=protected-access
1752+ and not process_prefetch .is_in_worker_process ()
17501753 ):
17511754 if not prefetch .is_prefetch_iterator (iterator ):
1752- iterator = prefetch .ThreadPrefetchDatasetIterator (iterator , 1 )
1755+ try :
1756+ iterator = prefetch .ThreadPrefetchDatasetIterator (iterator , 1 )
1757+ except AttributeError :
1758+ # Some legacy iterators do not implement the `get_state` method.
1759+ pass
17531760
17541761 filter_mode = traceback_filter_mode ()
17551762 if filter_mode != "off" :
Original file line number Diff line number Diff line change @@ -454,7 +454,7 @@ def _put_iterator_elements_in_buffer(
454454 try :
455455 while not should_stop .is_set ():
456456 element = stats .record_bytes_consumed (iterator .__next__ ())
457- state = iterator .get_state ()
457+ state = copy . deepcopy ( iterator .get_state () )
458458 buffer .put ((element , state , None ))
459459 except Exception as e : # pylint: disable=broad-except
460460 buffer .put ((None , None , e ))
@@ -736,11 +736,17 @@ def multithread_prefetch(
736736
737737def is_prefetch_iterator (it : dataset .DatasetIterator ) -> bool :
738738 """Returns whether the iterator is a prefetch iterator."""
739+ # Loaded lazily due to a circular dependency (prefetch <-> process_prefetch).
740+ # pylint: disable=g-import-not-at-top
741+ from grain ._src .python .dataset .transformations import process_prefetch
742+ # pylint: enable=g-import-not-at-top
743+
739744 return isinstance (
740745 it ,
741746 (
742747 PrefetchDatasetIterator ,
743748 ThreadPrefetchDatasetIterator ,
744749 interleave .InterleaveDatasetIterator ,
750+ process_prefetch ._ProcessPrefetchDatasetIterator , # pylint: disable=protected-access
745751 ),
746752 )
Original file line number Diff line number Diff line change 5757_is_in_worker_process = False
5858
5959
60+ def is_in_worker_process () -> bool :
61+ """Returns whether the current process is a worker process."""
62+ return _is_in_worker_process
63+
64+
6065def _run_all (fns : Sequence [Callable [[], None ]]):
6166 for fn in fns :
6267 fn ()
You can’t perform that action at this time.
0 commit comments