From b4b45b82c4095fba29a87a145151987f0646077f Mon Sep 17 00:00:00 2001 From: the-shadow-0 Date: Fri, 20 Mar 2026 22:55:46 +0000 Subject: [PATCH] Add cross-platform execution backend (Windows/macOS support) --- grain/_src/python/data_loader.py | 12 +- .../transformations/process_prefetch.py | 100 +++++++++----- .../python/dataset/transformations/shuffle.py | 9 +- grain/_src/python/execution_backend.py | 128 ++++++++++++++++++ grain/experimental.py | 9 +- grain/proto/execution_summary_pb2.py | 34 +++++ 6 files changed, 254 insertions(+), 38 deletions(-) create mode 100644 grain/_src/python/execution_backend.py create mode 100644 grain/proto/execution_summary_pb2.py diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 65851a50e..5ac7a774d 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -456,10 +456,20 @@ def _create_dataset(self) -> dataset.IterDataset: ds = ds.to_iter_dataset(self._read_options) for operation in self._operations: ds = _apply_transform_to_dataset(operation, ds) + from grain._src.python import execution_backend + backend = execution_backend.get_execution_backend() if self.multiprocessing_options.num_workers > 0 and isinstance( ds, dataset_base.SupportsSharedMemoryOutput ): - ds.enable_shared_memory_output() + if backend.is_multiprocess(): + ds.enable_shared_memory_output() + else: + import warnings + warnings.warn( + "Shared memory fallback is active (ThreadingBackend). " + "This will result in severely degraded performance for large data volumes " + "because memory buffers must be repeatedly copied across thread boundaries." + ) ds = ds.map(lambda r: r.data) if self.multiprocessing_options.num_workers > 0: ds = ds.mp_prefetch(self.multiprocessing_options) diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 8d5a3fc27..c2b1da32e 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -22,6 +22,8 @@ from multiprocessing import synchronize from multiprocessing import util import queue +import sys +import threading import time from typing import Any, TypeVar import weakref @@ -39,6 +41,7 @@ from grain._src.python.dataset.transformations import prefetch from grain._src.python.ipc import queue as grain_queue from grain._src.python.ipc import shared_memory_array +from grain._src.python import execution_backend T = TypeVar("T") @@ -54,7 +57,12 @@ # Timeout for getting an element from the worker process. _QUEUE_WAIT_TIMEOUT_S = 1 -_is_in_worker_process = False +_worker_state = threading.local() +_worker_state.is_in_worker_process = False + +def _get_is_in_worker_process() -> bool: + return getattr(_worker_state, "is_in_worker_process", False) + def _run_all(fns: Sequence[Callable[[], None]]): @@ -192,16 +200,17 @@ def _put_dataset_elements_in_buffer( debug_flags: dict[str, Any], ): """Prefetches elements in a separate process.""" - global _is_in_worker_process - _is_in_worker_process = True + _worker_state.is_in_worker_process = True try: - parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn) + parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn) if isinstance(pickled_parse_debug_flags_fn, bytes) else pickled_parse_debug_flags_fn parse_debug_flags_fn(debug_flags) - worker_init_fn = cloudpickle.loads(pickled_worker_init_fn) + worker_init_fn = cloudpickle.loads(pickled_worker_init_fn) if isinstance(pickled_worker_init_fn, bytes) else pickled_worker_init_fn if worker_init_fn is not None: worker_init_fn() - ds = cloudpickle.loads(pickled_ds) - if isinstance(ds, base.SupportsSharedMemoryOutput): + ds = cloudpickle.loads(pickled_ds) if isinstance(pickled_ds, bytes) else pickled_ds + + backend = execution_backend.get_execution_backend() + if backend.is_multiprocess() and isinstance(ds, base.SupportsSharedMemoryOutput): ds.enable_shared_memory_output() it = ds.__iter__() min_shm_size = it._ctx.dataset_options.min_shm_size # pylint: disable=protected-access @@ -215,7 +224,14 @@ def _put_dataset_elements_in_buffer( with set_state_request_count.get_lock(): if set_state_request_count.value > 0: set_state_request_count.value -= 1 - new_state_or_index = set_state_queue.get() + while not should_stop.is_set(): + try: + new_state_or_index = set_state_queue.get(timeout=_QUEUE_WAIT_TIMEOUT_S) + break + except queue.Empty: + pass + if new_state_or_index is None: + continue parent_exhausted = False if new_state_or_index is not None: if not grain_queue.add_element_to_queue( # pytype: disable=wrong-arg-types @@ -243,8 +259,10 @@ def _put_dataset_elements_in_buffer( ) parent_exhausted = True continue - element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size) + if backend.is_multiprocess(): + element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size) # If the node is prefetch, we already record the bytes produced in it's + # __next__ method. if not it._stats._config.is_prefetch: # pylint: disable=protected-access it._stats.record_bytes_produced(element) # pylint: disable=protected-access @@ -310,22 +328,18 @@ def __init__( # propagate them. self._ctx.dataset_options = _get_dataset_options(parent) - self._process_ctx = mp.get_context("spawn") + self._backend = execution_backend.get_execution_backend() self._state: StateT | None = None self._prefetch_process: Any | None = None - self._prefetch_should_stop: synchronize.Event = self._process_ctx.Event() - self._set_state_request_count: sharedctypes.Synchronized[int] = ( - self._process_ctx.Value("i", 0) - ) - self._set_state_queue: queues.Queue[StateT | int] = self._process_ctx.Queue( - 5 - ) + self._prefetch_should_stop: synchronize.Event | threading.Event = self._backend.Event() + self._set_state_request_count = self._backend.SynchronizedInt(0) + self._set_state_queue: queues.Queue[StateT | int] | queue.Queue = self._backend.Queue(5) self._buffer: queues.Queue[ tuple[T, StateT | None, int | None, Exception | None] - ] = self._process_ctx.Queue(maxsize=self._buffer_size) - self._stats_in_queue = self._process_ctx.Queue(maxsize=5) - self._start_profiling_event = self._process_ctx.Event() - self._stop_profiling_event = self._process_ctx.Event() + ] | queue.Queue = self._backend.Queue(maxsize=self._buffer_size) + self._stats_in_queue = self._backend.Queue(maxsize=5) + self._start_profiling_event = self._backend.Event() + self._stop_profiling_event = self._backend.Event() self._set_state_count = 0 self._exhausted = False self._prefetch_ds_iter = None @@ -392,12 +406,23 @@ def start_prefetch(self) -> None: self._iter_parent, options=self._ctx.dataset_options, ) - self._prefetch_process = self._process_ctx.Process( + import os + if self._backend.is_multiprocess() or os.environ.get("GRAIN_STRICT_PICKLING", "0") == "1": + pickled_ds = _serialize_dataset(ds) + pickled_parse = cloudpickle.dumps(_parse_debug_flags) + pickled_init = cloudpickle.dumps(self._worker_init_fn) if self._worker_init_fn is not None else None + else: + pickled_parse = _parse_debug_flags + pickled_init = self._worker_init_fn + pickled_ds = ds + + self._prefetch_process = self._backend.Process( target=_put_dataset_elements_in_buffer, kwargs=dict( - pickled_parse_debug_flags_fn=cloudpickle.dumps(_parse_debug_flags), - pickled_worker_init_fn=cloudpickle.dumps(self._worker_init_fn), - pickled_ds=_serialize_dataset(ds), + pickled_parse_debug_flags_fn=pickled_parse, + pickled_worker_init_fn=pickled_init, + pickled_ds=pickled_ds, + buffer=self._buffer, should_stop=self._prefetch_should_stop, set_state_request_count=self._set_state_request_count, @@ -412,7 +437,7 @@ def start_prefetch(self) -> None: ), ), ), - daemon=True, + daemon=True if self._backend.is_multiprocess() else False, name=f"grain-process-prefetch-{str(self)}", ) self._prefetch_process.start() @@ -421,7 +446,11 @@ def start_prefetch(self) -> None: def _process_failed(self) -> bool: if self._prefetch_process is None: return False - exit_code = self._prefetch_process.exitcode + # threading.Thread doesn't have an exitcode attribute generally, but we added it to our wrapper. + if hasattr(self._prefetch_process, 'exitcode'): + exit_code = self._prefetch_process.exitcode + else: + exit_code = getattr(self._prefetch_process, 'exitcode', 0) return exit_code is not None and exit_code != 0 @dataset_stats.record_next_duration_if_output @@ -467,7 +496,9 @@ def __next__(self): self._next_index = next_index with self._stats.record_self_time(offset_ns=timer.value()): element = self._stats.record_bytes_produced(element) - return shared_memory_array.open_from_shm(element) + if self._backend.is_multiprocess(): + return shared_memory_array.open_from_shm(element) + return element def close(self): """Stops the iterator. No further calls to the iterator are expected.""" @@ -495,8 +526,15 @@ def _stop_prefetch(self): # In case all our attempts to terminate the system fails, we forcefully # kill the child processes. - if self._prefetch_process.is_alive(): - self._prefetch_process.kill() + if getattr(self._prefetch_process, "is_alive", lambda: False)(): + if hasattr(self._prefetch_process, "kill"): + self._prefetch_process.kill() + else: + import warnings + warnings.warn( + "Background thread failed to exit gracefully within timeout during shutdown. " + "This likely means a dataset transform is hanging indefinitely." + ) else: _clear_queue_and_maybe_unlink_shm(self._buffer) self._prefetch_process = None @@ -583,7 +621,7 @@ def __init__( self._sequential_slice = sequential_slice def __iter__(self) -> dataset.DatasetIterator[T]: - if not _is_in_worker_process: + if not _get_is_in_worker_process(): return self._parent.__iter__() prefetch._set_slice_iter_dataset( self._parent, self._slice, self._sequential_slice diff --git a/grain/_src/python/dataset/transformations/shuffle.py b/grain/_src/python/dataset/transformations/shuffle.py index 23e56c4e2..004a08ed2 100644 --- a/grain/_src/python/dataset/transformations/shuffle.py +++ b/grain/_src/python/dataset/transformations/shuffle.py @@ -20,9 +20,12 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset import stats -from grain._src.python.experimental.index_shuffle.python import ( - index_shuffle_module as index_shuffle, -) +try: + from grain._src.python.experimental.index_shuffle.python import ( + index_shuffle_module as index_shuffle, + ) +except ImportError: + index_shuffle = None T = TypeVar("T") diff --git a/grain/_src/python/execution_backend.py b/grain/_src/python/execution_backend.py new file mode 100644 index 000000000..510195417 --- /dev/null +++ b/grain/_src/python/execution_backend.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ExecutionBackend layer for cross-platform process/thread management.""" + +import abc +import multiprocessing as mp +from multiprocessing import queues +from multiprocessing import sharedctypes +from multiprocessing import synchronize +import os +import platform +import queue +import threading +from typing import Any, Callable + +class ExecutionBackend(abc.ABC): + """Abstract mapping for concurrency primitives.""" + + @abc.abstractmethod + def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]": + """Returns a Queue instance.""" + + @abc.abstractmethod + def Event(self) -> synchronize.Event | threading.Event: + """Returns an Event instance.""" + + @abc.abstractmethod + def SynchronizedInt(self, initial_value: int) -> Any: + """Returns a synchronized integer with .value and .get_lock().""" + + @abc.abstractmethod + def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str) -> Any: + """Returns a Process or Thread instance.""" + + @abc.abstractmethod + def is_multiprocess(self) -> bool: + """Returns True if this backend runs tasks in separate processes.""" + + +class MultiprocessingBackend(ExecutionBackend): + """Execution backend utilizing Linux-optimized multiprocessing 'spawn'.""" + + def __init__(self): + start_method = os.environ.get("GRAIN_MP_START", "fork") + if start_method not in ["fork", "spawn", "forkserver"]: + raise ValueError(f"Invalid GRAIN_MP_START: {start_method}") + try: + self._ctx = mp.get_context(start_method) + except ValueError: + self._ctx = mp.get_context("spawn") + + def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]": + return self._ctx.Queue(maxsize=maxsize) + + def Event(self) -> synchronize.Event: + return self._ctx.Event() + + def SynchronizedInt(self, initial_value: int) -> Any: + return self._ctx.Value("i", initial_value) + + def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str): + return self._ctx.Process(target=target, kwargs=kwargs, daemon=daemon, name=name) + + def is_multiprocess(self) -> bool: + return True + + +class _ThreadSynchronizedInt: + def __init__(self, initial_value: int): + self._value = initial_value + self._lock = threading.Lock() + + @property + def value(self): + return self._value + + @value.setter + def value(self, v): + self._value = v + + def get_lock(self): + return self._lock + + +class ThreadingBackend(ExecutionBackend): + """Execution backend utilizing threading for platforms lacking solid fork support or shared_memory (Windows/macOS).""" + + def Queue(self, maxsize: int = 0) -> "queue.Queue[Any]": + return queue.Queue(maxsize=maxsize) + + def Event(self) -> threading.Event: + return threading.Event() + + def SynchronizedInt(self, initial_value: int) -> _ThreadSynchronizedInt: + return _ThreadSynchronizedInt(initial_value) + + def Process(self, target: Callable, kwargs: dict, daemon: bool, name: str) -> threading.Thread: + return threading.Thread(target=target, kwargs=kwargs, daemon=daemon, name=name) + + def is_multiprocess(self) -> bool: + return False + + +def get_execution_backend() -> ExecutionBackend: + """Returns the optimal ExecutionBackend based on the platform and environment.""" + env_backend = os.environ.get("GRAIN_EXECUTION_BACKEND", "").lower() + if env_backend and env_backend not in ["multiprocessing", "threading"]: + raise ValueError(f"Invalid GRAIN_EXECUTION_BACKEND: {env_backend}") + if env_backend == "multiprocessing": + return MultiprocessingBackend() + elif env_backend == "threading": + return ThreadingBackend() + + if platform.system() == "Linux": + return MultiprocessingBackend() + else: + return ThreadingBackend() diff --git a/grain/experimental.py b/grain/experimental.py index 297c40b56..c70772fef 100644 --- a/grain/experimental.py +++ b/grain/experimental.py @@ -73,9 +73,12 @@ from grain._src.python.experimental.example_packing.packing import ( PackAndBatchOperation, ) -from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import ( - index_shuffle, -) +try: + from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import ( + index_shuffle, + ) +except ImportError: + index_shuffle = None # This should eventually live under grain.testing. from grain._src.python.testing.experimental import ( diff --git a/grain/proto/execution_summary_pb2.py b/grain/proto/execution_summary_pb2.py new file mode 100644 index 000000000..c034b41f6 --- /dev/null +++ b/grain/proto/execution_summary_pb2.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: grain/proto/execution_summary.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#grain/proto/execution_summary.proto\x12\x1egrain.python.execution_summary\"\xfd\x03\n\x10\x45xecutionSummary\x12J\n\x05nodes\x18\x01 \x03(\x0b\x32;.grain.python.execution_summary.ExecutionSummary.NodesEntry\x1a\xb7\x02\n\x04Node\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06inputs\x18\x04 \x03(\x05\x12\x17\n\x0fwait_time_ratio\x18\x05 \x01(\x01\x12 \n\x18total_processing_time_ns\x18\x06 \x01(\x03\x12\x1e\n\x16min_processing_time_ns\x18\x07 \x01(\x03\x12\x1e\n\x16max_processing_time_ns\x18\x08 \x01(\x03\x12\x1d\n\x15num_produced_elements\x18\t \x01(\x03\x12\x13\n\x0boutput_spec\x18\n \x01(\t\x12\x11\n\tis_output\x18\x0b \x01(\x08\x12\x13\n\x0bis_prefetch\x18\x0c \x01(\x08\x12\x16\n\x0e\x62ytes_consumed\x18\r \x01(\x03\x12\x16\n\x0e\x62ytes_produced\x18\x0e \x01(\x03\x1a\x63\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x44\n\x05value\x18\x02 \x01(\x0b\x32\x35.grain.python.execution_summary.ExecutionSummary.Node:\x02\x38\x01\"c\n\x12\x45xecutionSummaries\x12M\n\x13\x65xecution_summaries\x18\x01 \x03(\x0b\x32\x30.grain.python.execution_summary.ExecutionSummaryb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grain.proto.execution_summary_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_EXECUTIONSUMMARY_NODESENTRY']._options = None + _globals['_EXECUTIONSUMMARY_NODESENTRY']._serialized_options = b'8\001' + _globals['_EXECUTIONSUMMARY']._serialized_start=72 + _globals['_EXECUTIONSUMMARY']._serialized_end=581 + _globals['_EXECUTIONSUMMARY_NODE']._serialized_start=169 + _globals['_EXECUTIONSUMMARY_NODE']._serialized_end=480 + _globals['_EXECUTIONSUMMARY_NODESENTRY']._serialized_start=482 + _globals['_EXECUTIONSUMMARY_NODESENTRY']._serialized_end=581 + _globals['_EXECUTIONSUMMARIES']._serialized_start=583 + _globals['_EXECUTIONSUMMARIES']._serialized_end=682 +# @@protoc_insertion_point(module_scope)