Skip to content

Commit 23faebb

Browse files
Internal
PiperOrigin-RevId: 878126445
1 parent 5080952 commit 23faebb

3 files changed

Lines changed: 20 additions & 2 deletions

File tree

grain/_src/python/dataset/dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,16 +1740,23 @@ def __iter__(self) -> DatasetIterator[T]:
17401740
"""Performs any injection that needs to happen at the end of the pipeline."""
17411741

17421742
# Loaded lazily due to a circular dependency (dataset <-> prefetch).
1743+
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
17431744
# pylint: disable=g-import-not-at-top
17441745
from grain._src.python.dataset.transformations import prefetch
1746+
from grain._src.python.dataset.transformations import process_prefetch
17451747
# pylint: enable=g-import-not-at-top
17461748
iterator = self._parent.__iter__()
17471749
if (
17481750
is_thread_prefetch_injection_enabled()
17491751
and not iterator._ctx.is_dataloader_pipeline # pylint: disable=protected-access
1752+
and not process_prefetch.is_in_worker_process()
17501753
):
17511754
if not prefetch.is_prefetch_iterator(iterator):
1752-
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)
1755+
try:
1756+
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)
1757+
except AttributeError:
1758+
# Some legacy iterators do not implement the `get_state` method.
1759+
pass
17531760

17541761
filter_mode = traceback_filter_mode()
17551762
if filter_mode != "off":

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def _put_iterator_elements_in_buffer(
454454
try:
455455
while not should_stop.is_set():
456456
element = stats.record_bytes_consumed(iterator.__next__())
457-
state = iterator.get_state()
457+
state = copy.deepcopy(iterator.get_state())
458458
buffer.put((element, state, None))
459459
except Exception as e: # pylint: disable=broad-except
460460
buffer.put((None, None, e))
@@ -736,11 +736,17 @@ def multithread_prefetch(
736736

737737
def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:
738738
"""Returns whether the iterator is a prefetch iterator."""
739+
# Loaded lazily due to a circular dependency (prefetch <-> process_prefetch).
740+
# pylint: disable=g-import-not-at-top
741+
from grain._src.python.dataset.transformations import process_prefetch
742+
# pylint: enable=g-import-not-at-top
743+
739744
return isinstance(
740745
it,
741746
(
742747
PrefetchDatasetIterator,
743748
ThreadPrefetchDatasetIterator,
744749
interleave.InterleaveDatasetIterator,
750+
process_prefetch._ProcessPrefetchDatasetIterator, # pylint: disable=protected-access
745751
),
746752
)

grain/_src/python/dataset/transformations/process_prefetch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
_is_in_worker_process = False
5858

5959

60+
def is_in_worker_process() -> bool:
61+
"""Returns whether the current process is a worker process."""
62+
return _is_in_worker_process
63+
64+
6065
def _run_all(fns: Sequence[Callable[[], None]]):
6166
for fn in fns:
6267
fn()

0 commit comments

Comments
 (0)