diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 8fcd0bbc..0b1aa92e 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -57,6 +57,7 @@ py_library( "//grain/_src/python/checkpoint:base", "//grain/_src/python/ipc:queue", "//grain/_src/python/ipc:shared_memory_array", + "//grain/_src/python/ipc:variable_size_queue", "//grain/proto:execution_summary_py_pb2", "@abseil-py//absl/flags", "@abseil-py//absl/logging", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index a9cde484..008a5160 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -36,6 +36,7 @@ from grain._src.python.dataset.transformations import filter as filter_dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import source +from grain._src.python.ipc import variable_size_queue T = TypeVar("T") @@ -289,7 +290,7 @@ def _set_num_threads(self, num_threads: int) -> None: self._executor = futures.ThreadPoolExecutor( self._target_num_threads, thread_name_prefix="grain-prefetch" ) - else: + elif hasattr(self, "_executor"): delattr(self, "_executor") if old_executor is not None: # Allows the old executor to finish running the tasks it was already @@ -411,13 +412,14 @@ def __init__( self, parent: dataset.IterDataset[T], *, - prefetch_buffer_size: int, + prefetch_buffer_size: int | bindings.AutotuneParameter, ): super().__init__(parent) - if prefetch_buffer_size < 0: + target_prefetch_buffer_size = prefetch_buffer_size + if target_prefetch_buffer_size < 0: raise ValueError( "`prefetch_buffer_size` must be greater than or equal to 0, got " - f"{prefetch_buffer_size}." + f"{target_prefetch_buffer_size}." ) self._prefetch_buffer_size = prefetch_buffer_size @@ -429,8 +431,6 @@ def __str__(self) -> str: def __iter__(self) -> dataset.DatasetIterator[T]: parent_iter = self._parent.__iter__() - if self._prefetch_buffer_size == 0: - return parent_iter return ThreadPrefetchDatasetIterator( parent_iter, self._prefetch_buffer_size ) @@ -442,11 +442,13 @@ def _element_spec(self) -> Any: # Type for the iterator state. StateT = dict[str, Any] +# Type for the buffer elements. +BufferElementT = tuple[T, StateT, Exception | None] def _put_iterator_elements_in_buffer( iterator: dataset.DatasetIterator[T], - buffer: queue.Queue[tuple[T, StateT, Exception | None]], + buffer: queue.Queue[BufferElementT], should_stop: threading.Event, stats: dataset_stats.Stats, ): @@ -478,7 +480,7 @@ class ThreadPrefetchDatasetIterator(dataset.DatasetIterator[T]): def __init__( self, parent: CheckpointableIterator[T], - prefetch_buffer_size: int, + prefetch_buffer_size: int | bindings.AutotuneParameter, ): if isinstance(parent, dataset.DatasetIterator): super().__init__(parent) @@ -486,17 +488,28 @@ def __init__( super().__init__() self._maybe_nonnative_parent = parent - assert prefetch_buffer_size > 0, prefetch_buffer_size - self._prefetch_buffer_size = prefetch_buffer_size + target_prefetch_buffer_size = prefetch_buffer_size + autotune_buffer_size = None + + assert target_prefetch_buffer_size >= 0, target_prefetch_buffer_size + self._target_prefetch_buffer_size = target_prefetch_buffer_size + self.autotune_buffer_size = autotune_buffer_size self._step_zero_state: StateT = parent.get_state() self._state: StateT | None = self._step_zero_state self._next_index: int | None = 0 self._prefetch_thread: threading.Thread | None = None self._prefetch_should_stop: threading.Event = threading.Event() - self._buffer: queue.Queue[tuple[T, StateT, Exception | None]] = queue.Queue( - maxsize=self._prefetch_buffer_size - ) + if self.autotune_buffer_size is not None: + self._buffer: ( + variable_size_queue.VariableSizeQueue | queue.Queue[BufferElementT] + ) = variable_size_queue.VariableSizeQueue( + max_size=self._target_prefetch_buffer_size + ) + else: + self._buffer: ( + variable_size_queue.VariableSizeQueue | queue.Queue[BufferElementT] + ) = queue.Queue(maxsize=self._target_prefetch_buffer_size) # pytype: disable=attribute-error # pylint: disable=protected-access @@ -530,7 +543,10 @@ def start_prefetch(self): """ if self._closed: raise ValueError("Attempting to use a closed iterator.") - if self._prefetch_thread is not None: + if ( + self._prefetch_thread is not None + or self._target_prefetch_buffer_size == 0 + ): return self._prefetch_should_stop.clear() @@ -552,10 +568,22 @@ def start_prefetch(self): stage_category=dataset_stats.IPL_CAT_PREFETCH ) def __next__(self): + timer = dataset_stats.Timer() with timer: - self.start_prefetch() - element, state, err = self._buffer.get() + if self._target_prefetch_buffer_size > 0: + self.start_prefetch() + element, state, err = self._buffer.get() + else: + try: + # In case of 0 prefetch buffer size, we still try to get from the + # buffer as it could have been populated when the prefetch buffer size + # was greater than 0. + element, state, err = self._buffer.get_nowait() + except queue.Empty: + element = self._maybe_nonnative_parent.__next__() + state = copy.deepcopy(self._maybe_nonnative_parent.get_state()) + err = None if err is not None: self._stop_prefetch() @@ -581,15 +609,21 @@ def _clear_buffer(self): except queue.Empty: return - def _stop_prefetch(self): + def _stop_prefetch(self, clear_buffer: bool = True): """Stops the prefetching thread if it's currently running.""" if self._prefetch_thread is None: return self._prefetch_should_stop.set() - # Remove entries from the buffer to unblock the producer, so that it checks - # producer_running.is_set() and exits. - self._clear_buffer() + if clear_buffer: + # Remove entries from the buffer to unblock the producer, so that it + # checks producer_running.is_set() and exits. + self._clear_buffer() + else: + assert isinstance(self._buffer, variable_size_queue.VariableSizeQueue) + # Increase the buffer size by 1 to unblock the producer. + self._buffer.set_max_size(self._target_prefetch_buffer_size + 1) # pytype: disable=attribute-error + if not sys.is_finalizing(): # Joining the worker thread is not necessary when the Python interpreter # is shutting down. Attempting to join can lead to hanging in Python @@ -597,9 +631,11 @@ def _stop_prefetch(self): # https://github.com/python/cpython/issues/123940#issuecomment-2976446309 self._prefetch_thread.join() self._prefetch_thread = None - # Clear the buffer again in case the prefetch loop added more elements on - # exit. - self._clear_buffer() + + if clear_buffer: + # Clear the buffer again in case the prefetch loop added more elements + # on exit. + self._clear_buffer() def get_state(self) -> StateT: if self._state is not None: @@ -648,7 +684,7 @@ def _set_next_index(self, next_index: int): def __str__(self) -> str: return ( "ThreadPrefetchDatasetIterator(" - f"prefetch_buffer_size={self._prefetch_buffer_size})" + f"prefetch_buffer_size={self._target_prefetch_buffer_size})" )