Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,17 +1739,24 @@ class _OutputIterDataset(IterDataset[T]):
def __iter__(self) -> DatasetIterator[T]:
"""Performs any injection that needs to happen at the end of the pipeline."""

# Loaded lazily due to a circular dependency (dataset <-> prefetch).
# Loaded lazily due to a circular dependency (dataset <-> prefetch and
# dataset <-> process_prefetch).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.dataset.transformations import process_prefetch
# pylint: enable=g-import-not-at-top
iterator = self._parent.__iter__()
if (
is_thread_prefetch_injection_enabled()
and not iterator._ctx.is_dataloader_pipeline # pylint: disable=protected-access
and not process_prefetch.is_in_worker_process()
):
if not prefetch.is_prefetch_iterator(iterator):
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)
try:
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)
except AttributeError:
# Some legacy iterators do not implement the `get_state` method.
pass

filter_mode = traceback_filter_mode()
if filter_mode != "off":
Expand Down
8 changes: 7 additions & 1 deletion grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _put_iterator_elements_in_buffer(
try:
while not should_stop.is_set():
element = stats.record_bytes_consumed(iterator.__next__())
state = iterator.get_state()
state = copy.deepcopy(iterator.get_state())
buffer.put((element, state, None))
except Exception as e: # pylint: disable=broad-except
buffer.put((None, None, e))
Expand Down Expand Up @@ -736,11 +736,17 @@ def multithread_prefetch(

def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:
"""Returns whether the iterator is a prefetch iterator."""
# Loaded lazily due to a circular dependency (prefetch <-> process_prefetch).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import process_prefetch
# pylint: enable=g-import-not-at-top

return isinstance(
it,
(
PrefetchDatasetIterator,
ThreadPrefetchDatasetIterator,
interleave.InterleaveDatasetIterator,
process_prefetch.ProcessPrefetchDatasetIterator,
),
)
9 changes: 7 additions & 2 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
_is_in_worker_process = False


def is_in_worker_process() -> bool:
"""Returns whether the current process is a worker process."""
return _is_in_worker_process


def _run_all(fns: Sequence[Callable[[], None]]):
for fn in fns:
fn()
Expand Down Expand Up @@ -165,7 +170,7 @@ def __str__(self) -> str:
return f"ProcessPrefetchIterDataset(buffer_size={self._buffer_size})"

def __iter__(self) -> dataset.DatasetIterator[T]:
return _ProcessPrefetchDatasetIterator(
return ProcessPrefetchDatasetIterator(
self._parent,
self._buffer_size,
self._worker_init_fn,
Expand Down Expand Up @@ -291,7 +296,7 @@ def _close_if_alive(
iterator.close()


class _ProcessPrefetchDatasetIterator(dataset.DatasetIterator[T]):
class ProcessPrefetchDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator that performs prefetching using a background process."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/transformations/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
to_visit = [self]
while to_visit:
node = to_visit.pop(0)
if isinstance(node, process_prefetch._ProcessPrefetchDatasetIterator): # pylint: disable=protected-access
if isinstance(node, process_prefetch.ProcessPrefetchDatasetIterator):
node.set_keep_workers_after_stop_iteration(True)
if isinstance(node, interleave.InterleaveDatasetIterator):
node.set_keep_iterators_after_stop_iteration(True)
Expand Down
Loading