Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ py_library(
"//grain/_src/python/checkpoint:base",
"//grain/_src/python/ipc:queue",
"//grain/_src/python/ipc:shared_memory_array",
"//grain/_src/python/ipc:variable_size_queue",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/flags",
"@abseil-py//absl/logging",
Expand Down
84 changes: 60 additions & 24 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from grain._src.python.dataset.transformations import filter as filter_dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import source
from grain._src.python.ipc import variable_size_queue

T = TypeVar("T")

Expand Down Expand Up @@ -289,7 +290,7 @@ def _set_num_threads(self, num_threads: int) -> None:
self._executor = futures.ThreadPoolExecutor(
self._target_num_threads, thread_name_prefix="grain-prefetch"
)
else:
elif hasattr(self, "_executor"):
delattr(self, "_executor")
if old_executor is not None:
# Allows the old executor to finish running the tasks it was already
Expand Down Expand Up @@ -411,13 +412,14 @@ def __init__(
self,
parent: dataset.IterDataset[T],
*,
prefetch_buffer_size: int,
prefetch_buffer_size: int | bindings.AutotuneParameter,
):
super().__init__(parent)
if prefetch_buffer_size < 0:
target_prefetch_buffer_size = prefetch_buffer_size
if target_prefetch_buffer_size < 0:
raise ValueError(
"`prefetch_buffer_size` must be greater than or equal to 0, got "
f"{prefetch_buffer_size}."
f"{target_prefetch_buffer_size}."
)
self._prefetch_buffer_size = prefetch_buffer_size

Expand All @@ -429,8 +431,6 @@ def __str__(self) -> str:

def __iter__(self) -> dataset.DatasetIterator[T]:
parent_iter = self._parent.__iter__()
if self._prefetch_buffer_size == 0:
return parent_iter
return ThreadPrefetchDatasetIterator(
parent_iter, self._prefetch_buffer_size
)
Expand All @@ -442,11 +442,13 @@ def _element_spec(self) -> Any:

# Type for the iterator state.
StateT = dict[str, Any]
# Type for the buffer elements.
BufferElementT = tuple[T, StateT, Exception | None]


def _put_iterator_elements_in_buffer(
iterator: dataset.DatasetIterator[T],
buffer: queue.Queue[tuple[T, StateT, Exception | None]],
buffer: queue.Queue[BufferElementT],
should_stop: threading.Event,
stats: dataset_stats.Stats,
):
Expand Down Expand Up @@ -478,25 +480,36 @@ class ThreadPrefetchDatasetIterator(dataset.DatasetIterator[T]):
def __init__(
self,
parent: CheckpointableIterator[T],
prefetch_buffer_size: int,
prefetch_buffer_size: int | bindings.AutotuneParameter,
):
if isinstance(parent, dataset.DatasetIterator):
super().__init__(parent)
else:
super().__init__()
self._maybe_nonnative_parent = parent

assert prefetch_buffer_size > 0, prefetch_buffer_size
self._prefetch_buffer_size = prefetch_buffer_size
target_prefetch_buffer_size = prefetch_buffer_size
autotune_buffer_size = None

assert target_prefetch_buffer_size >= 0, target_prefetch_buffer_size
self._target_prefetch_buffer_size = target_prefetch_buffer_size
self.autotune_buffer_size = autotune_buffer_size
self._step_zero_state: StateT = parent.get_state()
self._state: StateT | None = self._step_zero_state
self._next_index: int | None = 0

self._prefetch_thread: threading.Thread | None = None
self._prefetch_should_stop: threading.Event = threading.Event()
self._buffer: queue.Queue[tuple[T, StateT, Exception | None]] = queue.Queue(
maxsize=self._prefetch_buffer_size
)
if self.autotune_buffer_size is not None:
self._buffer: (
variable_size_queue.VariableSizeQueue | queue.Queue[BufferElementT]
) = variable_size_queue.VariableSizeQueue(
max_size=self._target_prefetch_buffer_size
)
else:
self._buffer: (
variable_size_queue.VariableSizeQueue | queue.Queue[BufferElementT]
) = queue.Queue(maxsize=self._target_prefetch_buffer_size)

# pytype: disable=attribute-error
# pylint: disable=protected-access
Expand Down Expand Up @@ -530,7 +543,10 @@ def start_prefetch(self):
"""
if self._closed:
raise ValueError("Attempting to use a closed iterator.")
if self._prefetch_thread is not None:
if (
self._prefetch_thread is not None
or self._target_prefetch_buffer_size == 0
):
return

self._prefetch_should_stop.clear()
Expand All @@ -552,10 +568,22 @@ def start_prefetch(self):
stage_category=dataset_stats.IPL_CAT_PREFETCH
)
def __next__(self):

timer = dataset_stats.Timer()
with timer:
self.start_prefetch()
element, state, err = self._buffer.get()
if self._target_prefetch_buffer_size > 0:
self.start_prefetch()
element, state, err = self._buffer.get()
else:
try:
# In case of 0 prefetch buffer size, we still try to get from the
# buffer as it could have been populated when the prefetch buffer size
# was greater than 0.
element, state, err = self._buffer.get_nowait()
except queue.Empty:
element = self._maybe_nonnative_parent.__next__()
state = copy.deepcopy(self._maybe_nonnative_parent.get_state())
err = None

if err is not None:
self._stop_prefetch()
Expand All @@ -581,25 +609,33 @@ def _clear_buffer(self):
except queue.Empty:
return

def _stop_prefetch(self):
def _stop_prefetch(self, clear_buffer: bool = True):
"""Stops the prefetching thread if it's currently running."""
if self._prefetch_thread is None:
return

self._prefetch_should_stop.set()
# Remove entries from the buffer to unblock the producer, so that it checks
# producer_running.is_set() and exits.
self._clear_buffer()
if clear_buffer:
# Remove entries from the buffer to unblock the producer, so that it
# checks producer_running.is_set() and exits.
self._clear_buffer()
else:
assert isinstance(self._buffer, variable_size_queue.VariableSizeQueue)
# Increase the buffer size by 1 to unblock the producer.
self._buffer.set_max_size(self._target_prefetch_buffer_size + 1) # pytype: disable=attribute-error

if not sys.is_finalizing():
# Joining the worker thread is not necessary when the Python interpreter
# is shutting down. Attempting to join can lead to hanging in Python
# 3.13 as daemon threads can hang during interpreter shutdown. See
# https://github.com/python/cpython/issues/123940#issuecomment-2976446309
self._prefetch_thread.join()
self._prefetch_thread = None
# Clear the buffer again in case the prefetch loop added more elements on
# exit.
self._clear_buffer()

if clear_buffer:
# Clear the buffer again in case the prefetch loop added more elements
# on exit.
self._clear_buffer()

def get_state(self) -> StateT:
if self._state is not None:
Expand Down Expand Up @@ -648,7 +684,7 @@ def _set_next_index(self, next_index: int):
def __str__(self) -> str:
return (
"ThreadPrefetchDatasetIterator("
f"prefetch_buffer_size={self._prefetch_buffer_size})"
f"prefetch_buffer_size={self._target_prefetch_buffer_size})"
)


Expand Down
Loading