Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,20 @@ def _create_dataset(self) -> dataset.IterDataset:
ds = ds.to_iter_dataset(self._read_options)
for operation in self._operations:
ds = _apply_transform_to_dataset(operation, ds)
from grain._src.python import execution_backend
backend = execution_backend.get_execution_backend()
if self.multiprocessing_options.num_workers > 0 and isinstance(
ds, dataset_base.SupportsSharedMemoryOutput
):
ds.enable_shared_memory_output()
if backend.is_multiprocess():
ds.enable_shared_memory_output()
else:
import warnings
warnings.warn(
"Shared memory fallback is active (ThreadingBackend). "
"This will result in severely degraded performance for large data volumes "
"because memory buffers must be repeatedly copied across thread boundaries."
)
ds = ds.map(lambda r: r.data)
if self.multiprocessing_options.num_workers > 0:
ds = ds.mp_prefetch(self.multiprocessing_options)
Expand Down
100 changes: 69 additions & 31 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from multiprocessing import synchronize
from multiprocessing import util
import queue
import sys
import threading
import time
from typing import Any, TypeVar
import weakref
Expand All @@ -39,6 +41,7 @@
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.ipc import queue as grain_queue
from grain._src.python.ipc import shared_memory_array
from grain._src.python import execution_backend


T = TypeVar("T")
Expand All @@ -54,7 +57,12 @@
# Timeout for getting an element from the worker process.
_QUEUE_WAIT_TIMEOUT_S = 1

_is_in_worker_process = False
_worker_state = threading.local()
_worker_state.is_in_worker_process = False

def _get_is_in_worker_process() -> bool:
return getattr(_worker_state, "is_in_worker_process", False)



def _run_all(fns: Sequence[Callable[[], None]]):
Expand Down Expand Up @@ -192,16 +200,17 @@ def _put_dataset_elements_in_buffer(
debug_flags: dict[str, Any],
):
"""Prefetches elements in a separate process."""
global _is_in_worker_process
_is_in_worker_process = True
_worker_state.is_in_worker_process = True
try:
parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn)
parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn) if isinstance(pickled_parse_debug_flags_fn, bytes) else pickled_parse_debug_flags_fn
parse_debug_flags_fn(debug_flags)
worker_init_fn = cloudpickle.loads(pickled_worker_init_fn)
worker_init_fn = cloudpickle.loads(pickled_worker_init_fn) if isinstance(pickled_worker_init_fn, bytes) else pickled_worker_init_fn
if worker_init_fn is not None:
worker_init_fn()
ds = cloudpickle.loads(pickled_ds)
if isinstance(ds, base.SupportsSharedMemoryOutput):
ds = cloudpickle.loads(pickled_ds) if isinstance(pickled_ds, bytes) else pickled_ds

backend = execution_backend.get_execution_backend()
if backend.is_multiprocess() and isinstance(ds, base.SupportsSharedMemoryOutput):
ds.enable_shared_memory_output()
it = ds.__iter__()
min_shm_size = it._ctx.dataset_options.min_shm_size # pylint: disable=protected-access
Expand All @@ -215,7 +224,14 @@ def _put_dataset_elements_in_buffer(
with set_state_request_count.get_lock():
if set_state_request_count.value > 0:
set_state_request_count.value -= 1
new_state_or_index = set_state_queue.get()
while not should_stop.is_set():
try:
new_state_or_index = set_state_queue.get(timeout=_QUEUE_WAIT_TIMEOUT_S)
break
except queue.Empty:
pass
if new_state_or_index is None:
continue
parent_exhausted = False
if new_state_or_index is not None:
if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -243,8 +259,10 @@ def _put_dataset_elements_in_buffer(
)
parent_exhausted = True
continue
element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size)
if backend.is_multiprocess():
element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size)
# If the node is prefetch, we already record the bytes produced in it's

# __next__ method.
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
Expand Down Expand Up @@ -310,22 +328,18 @@ def __init__(
# propagate them.
self._ctx.dataset_options = _get_dataset_options(parent)

self._process_ctx = mp.get_context("spawn")
self._backend = execution_backend.get_execution_backend()
self._state: StateT | None = None
self._prefetch_process: Any | None = None
self._prefetch_should_stop: synchronize.Event = self._process_ctx.Event()
self._set_state_request_count: sharedctypes.Synchronized[int] = (
self._process_ctx.Value("i", 0)
)
self._set_state_queue: queues.Queue[StateT | int] = self._process_ctx.Queue(
5
)
self._prefetch_should_stop: synchronize.Event | threading.Event = self._backend.Event()
self._set_state_request_count = self._backend.SynchronizedInt(0)
self._set_state_queue: queues.Queue[StateT | int] | queue.Queue = self._backend.Queue(5)
self._buffer: queues.Queue[
tuple[T, StateT | None, int | None, Exception | None]
] = self._process_ctx.Queue(maxsize=self._buffer_size)
self._stats_in_queue = self._process_ctx.Queue(maxsize=5)
self._start_profiling_event = self._process_ctx.Event()
self._stop_profiling_event = self._process_ctx.Event()
] | queue.Queue = self._backend.Queue(maxsize=self._buffer_size)
self._stats_in_queue = self._backend.Queue(maxsize=5)
self._start_profiling_event = self._backend.Event()
self._stop_profiling_event = self._backend.Event()
self._set_state_count = 0
self._exhausted = False
self._prefetch_ds_iter = None
Expand Down Expand Up @@ -392,12 +406,23 @@ def start_prefetch(self) -> None:
self._iter_parent,
options=self._ctx.dataset_options,
)
self._prefetch_process = self._process_ctx.Process(
import os
if self._backend.is_multiprocess() or os.environ.get("GRAIN_STRICT_PICKLING", "0") == "1":
pickled_ds = _serialize_dataset(ds)
pickled_parse = cloudpickle.dumps(_parse_debug_flags)
pickled_init = cloudpickle.dumps(self._worker_init_fn) if self._worker_init_fn is not None else None
else:
pickled_parse = _parse_debug_flags
pickled_init = self._worker_init_fn
pickled_ds = ds

self._prefetch_process = self._backend.Process(
target=_put_dataset_elements_in_buffer,
kwargs=dict(
pickled_parse_debug_flags_fn=cloudpickle.dumps(_parse_debug_flags),
pickled_worker_init_fn=cloudpickle.dumps(self._worker_init_fn),
pickled_ds=_serialize_dataset(ds),
pickled_parse_debug_flags_fn=pickled_parse,
pickled_worker_init_fn=pickled_init,
pickled_ds=pickled_ds,

buffer=self._buffer,
should_stop=self._prefetch_should_stop,
set_state_request_count=self._set_state_request_count,
Expand All @@ -412,7 +437,7 @@ def start_prefetch(self) -> None:
),
),
),
daemon=True,
daemon=True if self._backend.is_multiprocess() else False,
name=f"grain-process-prefetch-{str(self)}",
)
self._prefetch_process.start()
Expand All @@ -421,7 +446,11 @@ def start_prefetch(self) -> None:
def _process_failed(self) -> bool:
if self._prefetch_process is None:
return False
exit_code = self._prefetch_process.exitcode
# threading.Thread doesn't have an exitcode attribute generally, but we added it to our wrapper.
if hasattr(self._prefetch_process, 'exitcode'):
exit_code = self._prefetch_process.exitcode
else:
exit_code = getattr(self._prefetch_process, 'exitcode', 0)
return exit_code is not None and exit_code != 0

@dataset_stats.record_next_duration_if_output
Expand Down Expand Up @@ -467,7 +496,9 @@ def __next__(self):
self._next_index = next_index
with self._stats.record_self_time(offset_ns=timer.value()):
element = self._stats.record_bytes_produced(element)
return shared_memory_array.open_from_shm(element)
if self._backend.is_multiprocess():
return shared_memory_array.open_from_shm(element)
return element

def close(self):
"""Stops the iterator. No further calls to the iterator are expected."""
Expand Down Expand Up @@ -495,8 +526,15 @@ def _stop_prefetch(self):

# In case all our attempts to terminate the system fails, we forcefully
# kill the child processes.
if self._prefetch_process.is_alive():
self._prefetch_process.kill()
if getattr(self._prefetch_process, "is_alive", lambda: False)():
if hasattr(self._prefetch_process, "kill"):
self._prefetch_process.kill()
else:
import warnings
warnings.warn(
"Background thread failed to exit gracefully within timeout during shutdown. "
"This likely means a dataset transform is hanging indefinitely."
)
else:
_clear_queue_and_maybe_unlink_shm(self._buffer)
self._prefetch_process = None
Expand Down Expand Up @@ -583,7 +621,7 @@ def __init__(
self._sequential_slice = sequential_slice

def __iter__(self) -> dataset.DatasetIterator[T]:
if not _is_in_worker_process:
if not _get_is_in_worker_process():
return self._parent.__iter__()
prefetch._set_slice_iter_dataset(
self._parent, self._slice, self._sequential_slice
Expand Down
9 changes: 6 additions & 3 deletions grain/_src/python/dataset/transformations/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from grain._src.python.experimental.index_shuffle.python import (
index_shuffle_module as index_shuffle,
)
try:
from grain._src.python.experimental.index_shuffle.python import (
index_shuffle_module as index_shuffle,
)
except ImportError:
index_shuffle = None


T = TypeVar("T")
Expand Down
128 changes: 128 additions & 0 deletions grain/_src/python/execution_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ExecutionBackend layer for cross-platform process/thread management."""

import abc
import multiprocessing as mp
from multiprocessing import queues
from multiprocessing import sharedctypes
from multiprocessing import synchronize
import os
import platform
import queue
import threading
from typing import Any, Callable

class ExecutionBackend(abc.ABC):
"""Abstract mapping for concurrency primitives."""

@abc.abstractmethod
def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]":
"""Returns a Queue instance."""

@abc.abstractmethod
def Event(self) -> synchronize.Event | threading.Event:
"""Returns an Event instance."""

@abc.abstractmethod
def SynchronizedInt(self, initial_value: int) -> Any:
"""Returns a synchronized integer with .value and .get_lock()."""

@abc.abstractmethod
def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str) -> Any:
"""Returns a Process or Thread instance."""

@abc.abstractmethod
def is_multiprocess(self) -> bool:
"""Returns True if this backend runs tasks in separate processes."""


class MultiprocessingBackend(ExecutionBackend):
"""Execution backend utilizing Linux-optimized multiprocessing 'spawn'."""

def __init__(self):
start_method = os.environ.get("GRAIN_MP_START", "fork")
if start_method not in ["fork", "spawn", "forkserver"]:
raise ValueError(f"Invalid GRAIN_MP_START: {start_method}")
try:
self._ctx = mp.get_context(start_method)
except ValueError:
self._ctx = mp.get_context("spawn")

def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]":
return self._ctx.Queue(maxsize=maxsize)

def Event(self) -> synchronize.Event:
return self._ctx.Event()

def SynchronizedInt(self, initial_value: int) -> Any:
return self._ctx.Value("i", initial_value)

def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str):
return self._ctx.Process(target=target, kwargs=kwargs, daemon=daemon, name=name)

def is_multiprocess(self) -> bool:
return True


class _ThreadSynchronizedInt:
def __init__(self, initial_value: int):
self._value = initial_value
self._lock = threading.Lock()

@property
def value(self):
return self._value

@value.setter
def value(self, v):
self._value = v

def get_lock(self):
return self._lock


class ThreadingBackend(ExecutionBackend):
"""Execution backend utilizing threading for platforms lacking solid fork support or shared_memory (Windows/macOS)."""

def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]":
return queue.Queue(maxsize=maxsize)

def Event(self) -> threading.Event:
return threading.Event()

def SynchronizedInt(self, initial_value: int) -> _ThreadSynchronizedInt:
return _ThreadSynchronizedInt(initial_value)

def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str) -> threading.Thread:
return threading.Thread(target=target, kwargs=kwargs, daemon=daemon, name=name)

def is_multiprocess(self) -> bool:
return False


def get_execution_backend() -> ExecutionBackend:
"""Returns the optimal ExecutionBackend based on the platform and environment."""
env_backend = os.environ.get("GRAIN_EXECUTION_BACKEND", "").lower()
if env_backend and env_backend not in ["multiprocessing", "threading"]:
raise ValueError(f"Invalid GRAIN_EXECUTION_BACKEND: {env_backend}")
if env_backend == "multiprocessing":
return MultiprocessingBackend()
elif env_backend == "threading":
return ThreadingBackend()

if platform.system() == "Linux":
return MultiprocessingBackend()
else:
return ThreadingBackend()
9 changes: 6 additions & 3 deletions grain/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@
from grain._src.python.experimental.example_packing.packing import (
PackAndBatchOperation,
)
from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import (
index_shuffle,
)
try:
from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import (
index_shuffle,
)
except ImportError:
index_shuffle = None

# This should eventually live under grain.testing.
from grain._src.python.testing.experimental import (
Expand Down
Loading
Loading