diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 930e3bbb0..bb74e50a2 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -83,6 +83,86 @@ class MessageEndpoints: Class for message endpoints. """ + @staticmethod + def shared_memory_info(): + """ + Endpoint for shared memory information. This endpoint is used to publish the shared memory information using + a messages.SharedMemAllocationInfo message. + + Returns: + EndpointInfo: Endpoint for shared memory information. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/info/" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationInfo, + message_op=MessageOp.SET_PUBLISH, + ) + + @staticmethod + def shared_memory_allocate(): + """ + Endpoint for shared memory allocation. This endpoint is used to request the allocation of a shared memory object using + a messages.SharedMemAllocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory allocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/allocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationRequest, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def shared_memory_deallocate(): + """ + Endpoint for shared memory deallocation. This endpoint is used to request the deallocation of a shared memory object using + a messages.SharedMemDeallocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory deallocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/deallocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemDeallocationRequest, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def shared_memory_slot_written(): + """ + Endpoint for writer-complete shared memory events. This endpoint is used when a writer has finished writing + one shared memory slot. + + Returns: + EndpointInfo: Endpoint for shared memory slot written events. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/slot_written" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemSlotWritten, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def shared_memory_slot_processed(): + """ + Endpoint for reader-result shared memory events. This endpoint is used when a reader has finished processing + one shared memory slot. + + Returns: + EndpointInfo: Endpoint for shared memory slot processed events. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/slot_processed" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemSlotProcessed, + message_op=MessageOp.STREAM, + ) + # devices feedback @staticmethod def device_status(device: str): diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index b94fc6938..fc03380ee 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -17,6 +17,9 @@ from bec_lib.metadata_schema import get_metadata_schema_for_scan +# TODO remove bec_server depencency.. +from bec_server.shared_memory.models import PayloadDescriptor, SharedMemInfo + class ProcedureWorkerStatus(Enum): RUNNING = auto() @@ -94,6 +97,60 @@ def __hash__(self) -> int: return self.model_dump_json().__hash__() +class SharedMemAllocationInfo(BECMessage): + """ + This message is published by the shared memory manager and contains a list of all currently allocated shared memory objects. + Once shared memory objects are created or destroyed, this message will publish the updated list of shared memory objects. + """ + + msg_type: ClassVar[str] = "shared_mem_allocation_info" + + # Consider structure with dict[str, SharedMemInfo]. signal dotted name as key, which allows to identify this directly + # Alternatively, dict[str, dict[str, SharedMemInfo]] with device name as key, and then signal name as nested key + info: dict[str, dict[str, SharedMemInfo]] + + +class SharedMemAllocationRequest(BECMessage): + """Message to request information about a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_allocation_request" + + client_id: str + slots: int + payload_desc: PayloadDescriptor + signal: str | None = None + + +class SharedMemDeallocationRequest(BECMessage): + """Message to request deallocation of a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_deallocation_request" + + client_id: str + signal: str | None = None + + +class SharedMemSlotWritten(BECMessage): + """Event emitted after a writer finished writing one shared-memory slot.""" + + msg_type: ClassVar[str] = "shared_mem_slot_written" + + client_id: str + signal: str + slot_index: int + + +class SharedMemSlotProcessed(BECMessage): + """Event emitted after a reader finished processing one shared-memory slot.""" + + msg_type: ClassVar[str] = "shared_mem_slot_processed" + + client_id: str + signal: str + slot_index: int + result: dict[str, Any] + + class BundleMessage(BECMessage): """Message type to send a bundle of BECMessages. diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 3a31df96e..c19124862 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -5,7 +5,14 @@ import pytest from bec_lib import messages +from bec_lib.endpoints import MessageEndpoints, MessageOp from bec_lib.serialization import MsgpackSerialization +from bec_server.shared_memory.models import ( + DTypeDescriptor, + PayloadDescriptor, + RingBufferDescriptor, + SharedMemInfo, +) @pytest.mark.parametrize("version", [1.0, 1.1, 1.2, None]) @@ -704,3 +711,67 @@ def test_feedback_message(): assert res_loaded == msg assert res_loaded.username == getpass.getuser() assert res_loaded.versions == messages.ServiceVersions._get_version_numbers() + + +def test_shared_memory_allocation_messages_round_trip(): + payload = PayloadDescriptor( + nbytes=16, shape=(4,), dtype=DTypeDescriptor(kind="float", itemsize=4, byte_order="little") + ) + descriptor = RingBufferDescriptor( + name="bec_psm_abcdef", + reader_count_name="bec_psm_abcdef_cnt", + data_lock_ids=("bec_psm_abcdef_d_0",), + reader_gate_ids=("bec_psm_abcdef_g_0",), + reader_count_lock_ids=("bec_psm_abcdef_c_0",), + slots=1, + payload=payload, + ) + info = SharedMemInfo(client_id="client", buffer_desc=descriptor, signal="detector.data") + + request = messages.SharedMemAllocationRequest( + client_id="client", slots=1, payload_desc=payload, signal="detector.data" + ) + allocation_info = messages.SharedMemAllocationInfo(info={"client": {"detector.data": info}}) + deallocation = messages.SharedMemDeallocationRequest(client_id="client", signal="detector.data") + + for msg in (request, allocation_info, deallocation): + assert MsgpackSerialization.loads(MsgpackSerialization.dumps(msg)) == msg + + +def test_shared_memory_endpoints_match_message_contracts(): + assert MessageEndpoints.shared_memory_info().message_type is messages.SharedMemAllocationInfo + assert MessageEndpoints.shared_memory_info().message_op is MessageOp.SET_PUBLISH + assert ( + MessageEndpoints.shared_memory_allocate().message_type + is messages.SharedMemAllocationRequest + ) + assert MessageEndpoints.shared_memory_allocate().message_op is MessageOp.STREAM + assert ( + MessageEndpoints.shared_memory_deallocate().message_type + is messages.SharedMemDeallocationRequest + ) + assert MessageEndpoints.shared_memory_deallocate().message_op is MessageOp.STREAM + + +def test_shared_memory_slot_event_messages_round_trip(): + written = messages.SharedMemSlotWritten( + client_id="client", signal="detector.data", slot_index=1 + ) + processed = messages.SharedMemSlotProcessed( + client_id="client", signal="detector.data", slot_index=1, result={"sum": 10.0} + ) + + for msg in (written, processed): + assert MsgpackSerialization.loads(MsgpackSerialization.dumps(msg)) == msg + + +def test_shared_memory_slot_event_endpoints_match_message_contracts(): + assert ( + MessageEndpoints.shared_memory_slot_written().message_type is messages.SharedMemSlotWritten + ) + assert MessageEndpoints.shared_memory_slot_written().message_op is MessageOp.STREAM + assert ( + MessageEndpoints.shared_memory_slot_processed().message_type + is messages.SharedMemSlotProcessed + ) + assert MessageEndpoints.shared_memory_slot_processed().message_op is MessageOp.STREAM diff --git a/bec_server/bec_server/shared_memory/README.md b/bec_server/bec_server/shared_memory/README.md new file mode 100644 index 000000000..e8f7c1e93 --- /dev/null +++ b/bec_server/bec_server/shared_memory/README.md @@ -0,0 +1,61 @@ +# Shared Memory Ring Buffer + +The shared-memory ring buffer keeps payload storage and control-plane policy separate. + +## Memory Layout + +The payload shared-memory object contains only slot bytes: + +```text +[ slot 0 payload ][ slot 1 payload ] ... [ slot N payload ] +``` + +The payload shape, dtype, slot count, and synchronization resource names are distributed through +`RingBufferDescriptor`. This keeps attachment explicit and avoids a mutable metadata header in the +payload memory. + +Reader counts live in a second, small shared-memory object: + +```text +[ reader_count[0] ][ reader_count[1] ] ... [ reader_count[N] ] +``` + +The counter memory stores only synchronization state. It does not store write position, slot +availability, processing state, or scheduling policy. + +## Locking + +Each slot has one logical readers/writer lock built from three named POSIX semaphores: + +- `data_lock`: held by a writer exclusively, or collectively by active readers. +- `reader_gate`: lets a waiting writer block new readers from entering the slot. +- `reader_count_lock`: protects updates to `reader_count[index]`. + +Readers briefly pass through the gate, increment the shared counter, copy the payload, and decrement +the counter. The first reader acquires the data lock, and the last reader releases it. + +Writers acquire the gate first, then the data lock. This allows existing readers to finish, prevents +new readers from entering while the writer waits, and guarantees that no reader observes a partial +write. + +## Ownership + +`RingBuffer` owns the operating-system resources. It creates and unlinks the payload memory, reader +counter memory, and all named semaphores. + +`RingBufferView` only attaches to existing resources. It closes local handles during shutdown and +never unlinks resources. + +## Write Position + +The ring buffer assumes one writer service per buffer. The writer handle keeps a local circular +cursor and returns the written slot index from `write_data(...)`. Shared memory does not contain a +global write cursor. + +Slot reuse, FIFO/LIFO ordering, release decisions, and processing results belong to the broker/event +control layer rather than the shared-memory implementation. + +## Timeout Behavior + +On macOS, positive semaphore timeouts are not reliable for this code path. Use `timeout=0` for a +non-blocking acquire or `timeout=None` to wait indefinitely. diff --git a/bec_server/bec_server/shared_memory/__init__.py b/bec_server/bec_server/shared_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bec_server/bec_server/shared_memory/cli/__init__.py b/bec_server/bec_server/shared_memory/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bec_server/bec_server/shared_memory/cli/launch.py b/bec_server/bec_server/shared_memory/cli/launch.py new file mode 100644 index 000000000..28bfd5f5a --- /dev/null +++ b/bec_server/bec_server/shared_memory/cli/launch.py @@ -0,0 +1,36 @@ +# Description: Launch the shared memory manager server. +# This script is the entry point for the Shared Memory Manager Server. It is called either +# by the bec-shared-mem-manager entry point or directly from the command line. +import threading + +from bec_lib.bec_service import parse_cmdline_args +from bec_lib.logger import bec_logger +from bec_lib.redis_connector import RedisConnector +from bec_server.shared_memory.manager import SharedMemoryManager + +logger = bec_logger.logger +bec_logger.level = bec_logger.LOGLEVEL.INFO + + +def main(): + """ + Launch the shared memory manager server. + """ + _, _, config = parse_cmdline_args() + + bec_server = SharedMemoryManager(config=config, connector_cls=RedisConnector) + bec_server.start() + + try: + event = threading.Event() + logger.success( + f"Started Shared Memory Manager server (id: {bec_server._service_id}). Press Ctrl+C to stop." + ) + event.wait() + except KeyboardInterrupt: + bec_server.shutdown() + event.set() + + +if __name__ == "__main__": + main() diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py new file mode 100644 index 000000000..674d86ba6 --- /dev/null +++ b/bec_server/bec_server/shared_memory/client.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import RingBufferView + +if TYPE_CHECKING: + import numpy as np + + from bec_lib.connector import MessageObject + from bec_lib.messages import SharedMemAllocationInfo + from bec_lib.redis_connector import RedisConnector + +logger = bec_logger.logger + + +# TODO one per service, or N per service. +class SharedMemoryClient: + """Client for interacting with shared memory objects managed by the SharedMemoryManager.""" + + def __init__(self, name: str, connector: RedisConnector): + self.name = name + self.connector = connector + # signal name to ring buffer view mapping + self._ring_buffer_views: dict[str, RingBufferView] = {} + self.start() + + def start(self): + """Start the client by subscribing to the shared memory object.""" + self.connector.register(MessageEndpoints.shared_memory_info(), cb=self._handle_info_update) + + def _handle_info_update(self, info: MessageObject) -> None: + """Handle updates to the shared memory information.""" + info: SharedMemAllocationInfo = info.value + # Any info update can potentially contain relevant information for creating or deleting ring buffer views. + info_updates = [] + client_info = info.info.get(self.name, {}) + + for signal, buff_info in client_info.items(): + info_updates.append(signal) + if signal not in self._ring_buffer_views: # + self._ring_buffer_views[signal] = RingBufferView(descriptor=buff_info.buffer_desc) + else: + logger.error( + f"Ring buffer view for signal {signal} already exists, should not happend. Received info update: {buff_info}" + ) + if len(client_info) < len(self._ring_buffer_views): + # Some shared memory objects have been deallocated. Remove them from the local dictionary. + to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) + for name in to_be_removed: + view = self._ring_buffer_views.pop(name) + view.close() + + def request_allocation( + self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict + ) -> None: + """Request the allocation of a shared memory object.""" + if isinstance(payload_desc, dict): + payload_desc = PayloadDescriptor.model_validate(payload_desc) + + self.connector.xadd( + MessageEndpoints.shared_memory_allocate(), + { + "client_id": self.name, + "slots": slots, + "payload_desc": payload_desc, + "signal": signal_name, + }, + max_size=1000, # Keep history of 1000 allocation requests + ) + + def request_deallocation(self, signal_name: str) -> None: + """Request the deallocation of a shared memory object.""" + self.connector.xadd( + MessageEndpoints.shared_memory_deallocate(), + {"client_id": self.name, "signal": signal_name}, + max_size=1000, # Keep history of 1000 deallocation requests + ) + + def read_from_buffer( + self, signal_name: str, index: int, timeout: float | None = None + ) -> np.ndarray: + """ + Read data from the shared memory buffer associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + read the data within that time frame. Please be aware, this is meant to block during write/read operations. + """ + # TODO add option to wait receiving an update on a specific signal in the buffer + # Also block until there is an update on the specific index in the buffer. + # Should there be a consume logic??? + buff = self._ring_buffer_views.get(signal_name) + if buff is None: + raise ValueError(f"No buffer found for signal name: {signal_name}") + return buff.copy_data(index, timeout) + + def write_to_buffer( + self, signal_name: str, data: np.ndarray, timeout: float | None = None + ) -> int: + """ + Write data to the next ring position associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + write the data within that time frame. Please be aware, this is meant to block during write/read operations. + + Returns: + int: The slot index containing the newly written payload. + """ + buff = self._ring_buffer_views.get(signal_name) + if buff is None: + raise ValueError(f"No buffer found for signal name: {signal_name}") + return buff.write_data(data=data, acquire_timeout=timeout) + + def shutdown(self) -> None: + """Clean up resources and all shared memory views.""" + for view in self._ring_buffer_views.values(): + view.close() + self._ring_buffer_views.clear() + self.connector.unregister( + MessageEndpoints.shared_memory_info(), cb=self._handle_info_update + ) + + +if __name__ == "__main__": + import time + + import numpy as np + + from bec_lib.redis_connector import RedisConnector + + array = np.random.rand(5, 5) + connector = RedisConnector(bootstrap="localhost:6379") + client = SharedMemoryClient(name="test_client", connector=connector) + client.request_allocation( + signal_name="test_signal", slots=10, payload_desc=PayloadDescriptor.from_numpy(array) + ) + time.sleep(1) # Wait for the allocation to be processed + print(client._ring_buffer_views) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py new file mode 100644 index 000000000..e61bd18f4 --- /dev/null +++ b/bec_server/bec_server/shared_memory/manager.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import threading +from collections import defaultdict +from typing import TYPE_CHECKING, Literal, Tuple + +from bec_lib import messages +from bec_lib.bec_service import BECService +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.shared_memory.models import SharedMemInfo +from bec_server.shared_memory.ring_buffer import RingBuffer + +SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] + +if TYPE_CHECKING: + from bec_lib.redis_connector import MessageObject, RedisConnector + +logger = bec_logger.logger + + +class SharedMemoryManager(BECService): + """ + Service to manage shared memory objects. It keeps track of all allocated shared memory objects and their descriptors. + It also handles the creation and destruction of shared memory objects, and publishes the updated list of shared memory objects + whenever a new shared memory object is created or destroyed. + """ + + def __init__(self, config, connector_cls: type[RedisConnector]) -> None: + super().__init__(config, connector_cls, unique_service=True) + # Shared memory objects are stored in a dictionary with the client_id and signal name tuple as key + # and the RingBuffer instance as value + self._shared_memory_objects: dict[Tuple[str, str], RingBuffer] = {} + self._shared_memory_info: dict[str, dict[str, SharedMemInfo]] = defaultdict( + dict + ) # Nested dict with client_id as key, and dict with signal name and ShareMemInfo as value + self.lock = threading.RLock() + + def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None: + """Callback function to handle shared memory allocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemAllocationRequest.model_validate(request) + if (request.client_id, request.signal) in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting." + ) + # TODO should this republish the info? + self._publish_allocation_info(self._shared_memory_info) + return + + buff = RingBuffer( + slots=request.slots, payload=request.payload_desc, name_suffix=request.signal + ) + with self.lock: + self._shared_memory_objects[(request.client_id, request.signal)] = buff + self._shared_memory_info[request.client_id][request.signal] = SharedMemInfo( + client_id=request.client_id, buffer_desc=buff.descriptor, signal=request.signal + ) + self._publish_allocation_info(self._shared_memory_info) + logger.info( + f"Allocated shared memory for client {request.client_id} and signal {request.signal} with descriptor {buff.descriptor}" + ) + + def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None: + """Callback function to handle shared memory deallocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemDeallocationRequest.model_validate(request) + if (request.client_id, request.signal) not in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate." + ) + # TODO should this republish the info? + self._publish_allocation_info(self._shared_memory_info) + return + + with self.lock: + buff = self._shared_memory_objects.pop((request.client_id, request.signal)) + buff.destroy() + self._shared_memory_info[request.client_id].pop(request.signal, None) + self._publish_allocation_info(self._shared_memory_info) + logger.info( + f"Deallocated shared memory for client {request.client_id} and signal {request.signal}" + ) + + def _publish_allocation_info(self, info: dict[str, dict[str, SharedMemInfo]]) -> None: + """Publish the updated list of allocated shared memory objects.""" + self.connector.set_and_publish( + MessageEndpoints.shared_memory_info(), messages.SharedMemAllocationInfo(info=info) + ) + + def start(self) -> None: + """start the shared memory manager server""" + self.connector.register(MessageEndpoints.shared_memory_allocate(), cb=self._allocate_memory) + self.connector.register( + MessageEndpoints.shared_memory_deallocate(), cb=self._deallocate_memory + ) + + def stop(self) -> None: + with self.lock: + for buff in self._shared_memory_objects.values(): + buff.destroy() + self._shared_memory_objects.clear() + self._shared_memory_info.clear() + self._publish_allocation_info({}) + logger.info("Stopped shared memory manager and cleared all shared memory objects.") + # Cleanup bec service related resources + + def shutdown(self) -> None: + """Shutdown the shared memory manager server and destroy all shared memory objects.""" + self.stop() + super().shutdown() diff --git a/bec_server/bec_server/shared_memory/models.py b/bec_server/bec_server/shared_memory/models.py new file mode 100644 index 000000000..5d2da3823 --- /dev/null +++ b/bec_server/bec_server/shared_memory/models.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import sys +from typing import Literal, Tuple + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class SharedMemInfo(BaseModel): + """ + Store information about the shared memory object. This message has the client_id, the buffer descriptor and + the potentially a list of devices for which this shared memory object is relevant. + """ + + model_config = ConfigDict(validate_assignment=True) + client_id: str + buffer_desc: RingBufferDescriptor + signal: str | None = None # dotted signal name, e.g. "eiger.preview" + + +class DTypeDescriptor(BaseModel): + kind: Literal["uint", "int", "float", "bool"] + itemsize: int + byte_order: Literal["little", "big"] = "little" + + @classmethod + def from_numpy(cls, dtype: np.dtype) -> DTypeDescriptor: + """Class method to create DTypeDescriptor from numpy dtype.""" + dtype = np.dtype(dtype) + kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} + if dtype.kind not in kind_map: + raise ValueError(f"Unsupported dtype kind: {dtype.kind!r}") + + byte_order = dtype.byteorder + if byte_order in ("=", "|"): + byte_order = sys.byteorder + elif byte_order == "<": + byte_order = "little" + elif byte_order == ">": + byte_order = "big" + else: + raise ValueError(f"Unsupported byte order: {dtype.byteorder!r}") + + return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) + + @property + def numpy_dtype(self) -> np.dtype: + """Return the corresponding numpy dtype for this DTypeDescriptor.""" + byte_order_char = {"little": "<", "big": ">"}[self.byte_order] + kind_char = {"uint": "u", "int": "i", "float": "f", "bool": "b"}[self.kind] + dtype_str = f"{byte_order_char}{kind_char}{self.itemsize}" + return np.dtype(dtype_str) + + +class PayloadDescriptor(BaseModel): + """Descriptor for the data payload stored in each slot of the ring buffer.""" + + nbytes: int + shape: Tuple[int, ...] + dtype: DTypeDescriptor + layout: Literal["C"] = "C" + + @classmethod + def from_numpy(cls, array: np.ndarray) -> PayloadDescriptor: + """Class method to create PayloadDescriptor from a numpy array.""" + return cls( + nbytes=array.nbytes, + shape=array.shape, + dtype=DTypeDescriptor.from_numpy(array.dtype), + layout="C" if array.flags.c_contiguous else "C", + ) + + +class RingBufferDescriptor(BaseModel): + """Information required to attach to a shared ring buffer.""" + + name: str + reader_count_name: str + data_lock_ids: Tuple[str, ...] + reader_gate_ids: Tuple[str, ...] + reader_count_lock_ids: Tuple[str, ...] + slots: int + payload: PayloadDescriptor + + +# class AvailableDataAnalysisMethods(messages.BECMessage): +# """Message published by the DAP server on which analysis methods are available.""" + +# methods: list[str] + + +# TODO maybe not needed to warm up, could automatically start a DAP worker once a shared memory object is created, +# Then DataAnalysisRegisterRequest is designed to register analysis methods for the shared memory object, and +# DataAnalysisTrigger is designed to trigger the analysis of the shared memory object. +# DataAnalysisResponse is designed to send the results back to the client. +# class DataAnalysisRequestWarmup(BECMessage): +# """Message to request a data analysis""" + +# shared_mem: SharedMemDescriptor + + +# class DataAnalysisRegisterRequest(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# methods: list[str] +# client_id: str +# device: str | None = None + + +# class DataAnalysisTrigger(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int + + +# class DataAnalysisResponse(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int +# results: dict +# client_id: str +# device: str | None = None diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py new file mode 100644 index 000000000..61387bc57 --- /dev/null +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +from contextlib import contextmanager +from functools import wraps +from multiprocessing import resource_tracker, shared_memory +from threading import RLock +from typing import Any, Callable, Iterator +from uuid import uuid4 + +import numpy as np +import posix_ipc + +from bec_server.shared_memory.models import PayloadDescriptor, RingBufferDescriptor + +# pylint: disable=c-extension-no-member + +MAX_SEMAPHORE_NAME_LENGTH = 30 +READER_COUNT_DTYPE = np.dtype(np.uint32) + + +def not_destroyed(method: Callable[..., Any]) -> Callable[..., Any]: + """Check that a shared-memory handle is still open before accessing it.""" + + @wraps(method) + def wrapper(self: RingBufferView, *args: Any, **kwargs: Any) -> Any: + if self.destroyed: + raise RuntimeError( + f"Cannot perform operation on a destroyed {self.__class__.__name__} object with name {self.name!r}." + ) + return method(self, *args, **kwargs) + + return wrapper + + +class RingBufferView: + """Attached handle for accessing a ring buffer without owning its resources.""" + + def __init__( + self, + descriptor: RingBufferDescriptor, + shm: shared_memory.SharedMemory | None = None, + reader_count_shm: shared_memory.SharedMemory | None = None, + *, + owns_memory: bool = False, + ): + self._validate_descriptor(descriptor) + self._descriptor = descriptor + self._shm = shm if shm is not None else shared_memory.SharedMemory(name=descriptor.name) + self._reader_count_shm = ( + reader_count_shm + if reader_count_shm is not None + else shared_memory.SharedMemory(name=descriptor.reader_count_name) + ) + self._owns_memory = owns_memory + if not owns_memory: + self._unregister_attached_shared_memory(self._shm) + self._unregister_attached_shared_memory(self._reader_count_shm) + self._data_locks = [ + posix_ipc.Semaphore(lock_id, flags=0) for lock_id in descriptor.data_lock_ids + ] + self._reader_gates = [ + posix_ipc.Semaphore(lock_id, flags=0) for lock_id in descriptor.reader_gate_ids + ] + self._reader_count_locks = [ + posix_ipc.Semaphore(lock_id, flags=0) for lock_id in descriptor.reader_count_lock_ids + ] + self._reader_counts = np.ndarray( + shape=(descriptor.slots,), dtype=READER_COUNT_DTYPE, buffer=self._reader_count_shm.buf + ) + self._next_write_position = 0 + self.__destroyed = False + self._lifecycle_lock = RLock() + + @staticmethod + def _validate_descriptor(descriptor: RingBufferDescriptor) -> None: + lock_lengths = { + "data_lock_ids": len(descriptor.data_lock_ids), + "reader_gate_ids": len(descriptor.reader_gate_ids), + "reader_count_lock_ids": len(descriptor.reader_count_lock_ids), + } + invalid = { + name: length for name, length in lock_lengths.items() if length != descriptor.slots + } + if invalid: + raise ValueError( + f"Ring buffer descriptor must provide exactly one lock per slot: {invalid}" + ) + + @staticmethod + def _unregister_attached_shared_memory(shm: shared_memory.SharedMemory) -> None: + """Let the owning manager unlink shared memory without local tracker warnings.""" + if not getattr(shared_memory, "_USE_POSIX", False): + return + resource_tracker.unregister(shm._name, "shared_memory") + + @contextmanager + def _acquire( + self, semaphore: posix_ipc.Semaphore, timeout: float | None, operation: str + ) -> Iterator[None]: + acquired = False + try: + semaphore.acquire(timeout=None if timeout is None else timeout) + acquired = True + yield + except posix_ipc.BusyError: + raise TimeoutError( + f"Could not acquire lock for {operation} buffer {self.name!r} within {timeout} seconds." + ) from None + finally: + if acquired: + semaphore.release() + + def _acquire_lock( + self, semaphore: posix_ipc.Semaphore, timeout: float | None, operation: str + ) -> bool: + try: + semaphore.acquire(timeout=None if timeout is None else timeout) + return True + except posix_ipc.BusyError: + raise TimeoutError( + f"Could not acquire lock for {operation} buffer {self.name!r} within {timeout} seconds." + ) from None + + def _validate_index(self, index: int) -> None: + if index < 0 or index >= self.slots: + raise IndexError( + f"Index {index} is out of bounds for ring buffer with {self.slots} slots." + ) + + def _validate_payload(self, data: np.ndarray) -> None: + descriptor = PayloadDescriptor.from_numpy(data) + if descriptor != self.payload_descriptor: + raise ValueError( + f"Data shape/dtype {descriptor.shape}/{descriptor.dtype} does not match expected " + f"shape/dtype {self.payload_descriptor.shape}/{self.payload_descriptor.dtype}" + ) + + def _array_for_slot(self, index: int) -> np.ndarray: + return np.ndarray( + shape=self.payload_descriptor.shape, + dtype=self.payload_descriptor.dtype.numpy_dtype, + buffer=self._shm.buf, + offset=index * self.bytes_per_slot, + ) + + @contextmanager + def _read_slot_lock(self, index: int, acquire_timeout: float | None) -> Iterator[None]: + gate_acquired = False + count_lock_acquired = False + try: + self._acquire_lock( + self._reader_gates[index], acquire_timeout, "entering reader gate for" + ) + gate_acquired = True + self._acquire_lock( + self._reader_count_locks[index], acquire_timeout, "updating reader count for" + ) + count_lock_acquired = True + if self._reader_counts[index] == 0: + self._acquire_lock(self._data_locks[index], acquire_timeout, "reading from") + self._reader_counts[index] += 1 + finally: + if count_lock_acquired: + self._reader_count_locks[index].release() + if gate_acquired: + self._reader_gates[index].release() + + try: + yield + finally: + with self._acquire( + self._reader_count_locks[index], acquire_timeout, "updating reader count for" + ): + if self._reader_counts[index] == 0: + raise RuntimeError("Reader count underflow while releasing ring buffer slot.") + self._reader_counts[index] -= 1 + if self._reader_counts[index] == 0: + self._data_locks[index].release() + + @contextmanager + def _write_slot_lock(self, index: int, acquire_timeout: float | None) -> Iterator[None]: + gate_acquired = False + data_lock_acquired = False + try: + self._acquire_lock( + self._reader_gates[index], acquire_timeout, "entering writer gate for" + ) + gate_acquired = True + self._acquire_lock(self._data_locks[index], acquire_timeout, "writing to") + data_lock_acquired = True + yield + finally: + if data_lock_acquired: + self._data_locks[index].release() + if gate_acquired: + self._reader_gates[index].release() + + @not_destroyed + def copy_data(self, index: int, acquire_timeout: float | None = 0) -> np.ndarray: + """Copy one identified payload slot while allowing concurrent readers.""" + self._validate_index(index) + with self._read_slot_lock(index, acquire_timeout): + return self._array_for_slot(index).copy() + + @not_destroyed + def write_data(self, data: np.ndarray, acquire_timeout: float | None = 0) -> int: + """Write using this writer handle's local circular slot cursor.""" + index = self._next_write_position + self.write_data_at(index, data, acquire_timeout) + self._next_write_position = (index + 1) % self.slots + return index + + @not_destroyed + def write_data_at( + self, index: int, data: np.ndarray, acquire_timeout: float | None = 0 + ) -> None: + """Write directly to an identified slot using the slot writer lock.""" + self._validate_index(index) + self._validate_payload(data) + with self._write_slot_lock(index, acquire_timeout): + np.copyto(self._array_for_slot(index), data) + + @property + def descriptor(self) -> RingBufferDescriptor: + return self._descriptor + + @property + def destroyed(self) -> bool: + return self.__destroyed + + @property + def name(self) -> str: + return self._descriptor.name + + @property + def reader_count_name(self) -> str: + return self._descriptor.reader_count_name + + @property + def slots(self) -> int: + return self._descriptor.slots + + @property + def bytes_per_slot(self) -> int: + return self._descriptor.payload.nbytes + + @property + def payload_descriptor(self) -> PayloadDescriptor: + return self._descriptor.payload + + @property + def next_write_position(self) -> int: + return self._next_write_position + + def _close_handles(self) -> None: + for lock in (*self._data_locks, *self._reader_gates, *self._reader_count_locks): + lock.close() + self._reader_count_shm.close() + self._shm.close() + + def close(self) -> None: + """Close local handles without unlinking owner-managed resources.""" + if self.destroyed: + return + with self._lifecycle_lock: + if self.destroyed: + return + self._close_handles() + self.__destroyed = True + + def destroy(self) -> None: + """Compatibility alias for attached clients; attached handles only close resources.""" + self.close() + + +class RingBuffer(RingBufferView): + """Owner of a shared ring buffer and its semaphore resources.""" + + @staticmethod + def _semaphore_name(name: str, suffix: str) -> str: + semaphore_name = f"{name}{suffix}" + if len(semaphore_name) > MAX_SEMAPHORE_NAME_LENGTH: + raise ValueError( + f"Semaphore name {semaphore_name!r} exceeds the platform limit of " + f"{MAX_SEMAPHORE_NAME_LENGTH} characters." + ) + return semaphore_name + + def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = ""): + if not 0 < slots: + raise ValueError("Ring buffer must contain at least one slot.") + name = f"bec_psm_{uuid4().hex[:6]}" + reader_count_name = f"{name}_cnt" + data_lock_names = tuple(self._semaphore_name(name, f"_d_{index}") for index in range(slots)) + reader_gate_names = tuple( + self._semaphore_name(name, f"_g_{index}") for index in range(slots) + ) + reader_count_lock_names = tuple( + self._semaphore_name(name, f"_c_{index}") for index in range(slots) + ) + shm = shared_memory.SharedMemory(create=True, size=slots * payload.nbytes, name=name) + reader_count_shm = shared_memory.SharedMemory( + create=True, size=slots * READER_COUNT_DTYPE.itemsize, name=reader_count_name + ) + reader_counts = np.ndarray( + shape=(slots,), dtype=READER_COUNT_DTYPE, buffer=reader_count_shm.buf + ) + reader_counts[:] = 0 + lock_names = (*data_lock_names, *reader_gate_names, *reader_count_lock_names) + created_locks: list[posix_ipc.Semaphore] = [] + try: + created_locks.extend( + posix_ipc.Semaphore( + lock_name, flags=posix_ipc.O_CREAT | posix_ipc.O_EXCL, initial_value=1 + ) + for lock_name in lock_names + ) + for lock in created_locks: + lock.close() + descriptor = RingBufferDescriptor( + name=shm.name, + reader_count_name=reader_count_shm.name, + data_lock_ids=data_lock_names, + reader_gate_ids=reader_gate_names, + reader_count_lock_ids=reader_count_lock_names, + slots=slots, + payload=payload, + ) + super().__init__( + descriptor=descriptor, shm=shm, reader_count_shm=reader_count_shm, owns_memory=True + ) + except Exception: + for lock in created_locks: + try: + lock.close() + except OSError: + pass + try: + posix_ipc.unlink_semaphore(lock.name) + except posix_ipc.ExistentialError: + pass + reader_count_shm.close() + reader_count_shm.unlink() + shm.close() + shm.unlink() + raise + + def destroy(self) -> None: + """Close and unlink all resources created for this owned ring buffer.""" + if self.destroyed: + return + descriptor = self.descriptor + self.close() + self._reader_count_shm.unlink() + self._shm.unlink() + for lock_id in ( + *descriptor.data_lock_ids, + *descriptor.reader_gate_ids, + *descriptor.reader_count_lock_ids, + ): + try: + posix_ipc.unlink_semaphore(lock_id) + except posix_ipc.ExistentialError: + pass + + @classmethod + def _name_suffix(cls, name: str, suffix: str, max_length: int = 63) -> str: + if suffix: + name = f"{name}_{suffix}" + return name[:max_length] diff --git a/bec_server/pyproject.toml b/bec_server/pyproject.toml index 2598be6ef..f0cfa79b8 100644 --- a/bec_server/pyproject.toml +++ b/bec_server/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyyaml~=6.0", "python-dotenv~=1.0", "rich>=13.7,<16.0", + "posix_ipc~=1.0", ] [project.optional-dependencies] diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer.py b/bec_server/tests/tests_shared_memory/test_ring_buffer.py new file mode 100644 index 000000000..06a1c6400 --- /dev/null +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer.py @@ -0,0 +1,248 @@ +from multiprocessing import shared_memory + +import numpy as np +import posix_ipc +import pytest + +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import ( + MAX_SEMAPHORE_NAME_LENGTH, + READER_COUNT_DTYPE, + RingBuffer, + RingBufferView, +) + + +@pytest.fixture +def payload() -> PayloadDescriptor: + return PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + + +@pytest.fixture +def ring_buffer(payload: PayloadDescriptor): + buffer = RingBuffer(slots=2, payload=payload) + yield buffer + buffer.destroy() + + +def test_descriptor_exposes_payload_counter_resources_and_rw_locks( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + assert ring_buffer.descriptor.name == ring_buffer.name + assert ring_buffer.descriptor.reader_count_name == ring_buffer.reader_count_name + assert ring_buffer.descriptor.slots == 2 + assert ring_buffer.descriptor.payload == payload + assert len(ring_buffer.descriptor.data_lock_ids) == 2 + assert len(ring_buffer.descriptor.reader_gate_ids) == 2 + assert len(ring_buffer.descriptor.reader_count_lock_ids) == 2 + assert ( + len( + { + *ring_buffer.descriptor.data_lock_ids, + *ring_buffer.descriptor.reader_gate_ids, + *ring_buffer.descriptor.reader_count_lock_ids, + } + ) + == 6 + ) + + +def test_attached_view_uses_descriptor_payload_and_counter_memory( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + view = RingBufferView(ring_buffer.descriptor) + try: + assert view.slots == ring_buffer.descriptor.slots + assert view.bytes_per_slot == payload.nbytes + assert view.payload_descriptor == payload + assert view._reader_counts.shape == (2,) + assert view._reader_counts.dtype == READER_COUNT_DTYPE + finally: + view.close() + + +def test_attached_view_rejects_incomplete_lock_descriptor(ring_buffer: RingBuffer): + descriptor = ring_buffer.descriptor.model_copy(update={"data_lock_ids": ("only-one",)}) + + with pytest.raises(ValueError, match="exactly one lock per slot"): + RingBufferView(descriptor) + + +def test_write_data_uses_local_circular_position_and_returns_written_slot(ring_buffer: RingBuffer): + first = np.array([1, 2, 3, 4], dtype=np.float64) + second = np.array([5, 6, 7, 8], dtype=np.float64) + third = np.array([9, 10, 11, 12], dtype=np.float64) + + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(first) == 0 + assert ring_buffer.next_write_position == 1 + assert ring_buffer.write_data(second) == 1 + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(third) == 0 + assert ring_buffer.next_write_position == 1 + np.testing.assert_array_equal(ring_buffer.copy_data(0), third) + np.testing.assert_array_equal(ring_buffer.copy_data(1), second) + + +def test_explicit_write_uses_payload_only_slot_offset_without_advancing_cursor( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + data = np.arange(4, dtype=np.float64) + + ring_buffer.write_data_at(1, data) + + raw_payload = np.ndarray( + payload.shape, + dtype=payload.dtype.numpy_dtype, + buffer=ring_buffer._shm.buf, + offset=payload.nbytes, + ) + np.testing.assert_array_equal(raw_payload, data) + assert ring_buffer.next_write_position == 0 + + +def test_attached_view_has_independent_local_write_cursor(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + written_from_view = np.array([1, 2, 3, 4], dtype=np.float64) + written_from_owner = np.array([5, 6, 7, 8], dtype=np.float64) + + assert view.write_data(written_from_view) == 0 + assert view.next_write_position == 1 + assert ring_buffer.next_write_position == 0 + np.testing.assert_array_equal(ring_buffer.copy_data(0), written_from_view) + + assert ring_buffer.write_data(written_from_owner) == 0 + np.testing.assert_array_equal(view.copy_data(0), written_from_owner) + finally: + view.close() + + +def test_each_buffer_has_distinct_shared_memory_and_semaphore_names(payload: PayloadDescriptor): + first = RingBuffer(slots=2, payload=payload) + second = RingBuffer(slots=2, payload=payload) + try: + assert first.name != second.name + assert first.descriptor.reader_count_name != second.descriptor.reader_count_name + assert first.descriptor.data_lock_ids != second.descriptor.data_lock_ids + assert first.descriptor.reader_gate_ids != second.descriptor.reader_gate_ids + assert first.descriptor.reader_count_lock_ids != second.descriptor.reader_count_lock_ids + finally: + first.destroy() + second.destroy() + + +def test_slot_semaphore_name_supports_largest_header_slot_index(): + name = "bec_psm_abcdef" + lock_name = RingBuffer._semaphore_name(name, f"_d_{(2**32) - 1}") + + assert len(lock_name) <= MAX_SEMAPHORE_NAME_LENGTH + assert lock_name.endswith("_d_4294967295") + + +def test_multiple_readers_share_one_slot_lock(ring_buffer: RingBuffer): + with ring_buffer._read_slot_lock(0, acquire_timeout=0): + assert ring_buffer._reader_counts[0] == 1 + with ring_buffer._read_slot_lock(0, acquire_timeout=0): + assert ring_buffer._reader_counts[0] == 2 + assert ring_buffer._reader_counts[0] == 1 + + assert ring_buffer._reader_counts[0] == 0 + + +def test_writer_waits_while_reader_is_attached_to_same_slot(ring_buffer: RingBuffer): + with ring_buffer._read_slot_lock(0, acquire_timeout=0): + with pytest.raises(TimeoutError, match="writing to"): + ring_buffer.write_data_at(0, np.arange(4, dtype=np.float64), acquire_timeout=0) + + ring_buffer.write_data_at(0, np.arange(4, dtype=np.float64), acquire_timeout=0) + + +def test_waiting_writer_gate_blocks_new_readers(ring_buffer: RingBuffer): + reader_gate = posix_ipc.Semaphore(ring_buffer.descriptor.reader_gate_ids[0]) + try: + reader_gate.acquire() + with pytest.raises(TimeoutError, match="reader gate"): + ring_buffer.copy_data(0, acquire_timeout=0) + finally: + reader_gate.release() + reader_gate.close() + + +def test_writer_on_one_slot_does_not_block_reader_on_other_slot(ring_buffer: RingBuffer): + data_lock = posix_ipc.Semaphore(ring_buffer.descriptor.data_lock_ids[0]) + try: + data_lock.acquire() + ring_buffer.copy_data(1, acquire_timeout=0) + finally: + data_lock.release() + data_lock.close() + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_copy_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.copy_data(index) + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_write_data_at_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.write_data_at(index, np.zeros((4,), dtype=np.float64)) + + +@pytest.mark.parametrize( + "data", [np.zeros((2,), dtype=np.float64), np.zeros((4,), dtype=np.float32)] +) +def test_write_data_rejects_payload_shape_or_dtype_mismatch( + ring_buffer: RingBuffer, data: np.ndarray +): + with pytest.raises(ValueError, match="does not match expected"): + ring_buffer.write_data(data) + + +def test_destroy_is_idempotent_and_rejects_further_operations( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + ring_buffer.destroy() + ring_buffer.destroy() + + with pytest.raises(RuntimeError, match="destroyed"): + ring_buffer.write_data(np.zeros(payload.shape, dtype=payload.dtype.numpy_dtype)) + + +def test_only_creator_owns_shared_memory_resources(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + assert ring_buffer._owns_memory is True + assert view._owns_memory is False + finally: + view.close() + + +def test_closing_view_keeps_owner_resources_attachable(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + view.close() + + attached = RingBufferView(ring_buffer.descriptor) + attached.close() + assert ring_buffer.next_write_position == 0 + + +def test_destroying_owner_unlinks_shared_memory_counter_memory_and_semaphores( + ring_buffer: RingBuffer, +): + descriptor = ring_buffer.descriptor + + ring_buffer.destroy() + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=descriptor.name) + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=descriptor.reader_count_name) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.data_lock_ids[0]) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.reader_gate_ids[0]) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.reader_count_lock_ids[0]) diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer_event_flow.py b/bec_server/tests/tests_shared_memory/test_ring_buffer_event_flow.py new file mode 100644 index 000000000..d7ec43bf2 --- /dev/null +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer_event_flow.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import multiprocessing +import threading +import time +from typing import Any + +import fakeredis +import numpy as np +import pytest + +from bec_lib import messages +from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import RedisConnector +from bec_server.shared_memory.models import PayloadDescriptor, RingBufferDescriptor +from bec_server.shared_memory.ring_buffer import RingBuffer, RingBufferView + + +class SharedMemorySumWorker: + """Small subprocess worker that reacts to slot-written events and publishes sums.""" + + def __init__( + self, + bootstrap: str, + descriptor: dict[str, Any], + *, + delay: float = 0, + expected_events: int = 1, + ): + self.bootstrap = bootstrap + self.descriptor = descriptor + self.delay = delay + self.expected_events = expected_events + + def run(self) -> None: + connector = RedisConnector(self.bootstrap, name="SharedMemorySumWorker RedisConnector") + view = RingBufferView(RingBufferDescriptor.model_validate(self.descriptor)) + try: + processed = 0 + while processed < self.expected_events: + records = connector.xread( + MessageEndpoints.shared_memory_slot_written(), + block=1000, + count=1, + from_start=processed == 0, + ) + if not records: + continue + event = records[0]["data"] + if not isinstance(event, messages.SharedMemSlotWritten): + continue + data = view.copy_data(event.slot_index) + if self.delay: + time.sleep(self.delay) + connector.xadd( + MessageEndpoints.shared_memory_slot_processed(), + { + "data": messages.SharedMemSlotProcessed( + client_id=event.client_id, + signal=event.signal, + slot_index=event.slot_index, + result={"sum": float(np.sum(data))}, + ) + }, + ) + processed += 1 + finally: + view.close() + connector.shutdown(per_thread_timeout_s=1) + + +@pytest.fixture +def fake_redis_tcp_server(): + server = fakeredis.TcpFakeServer(("127.0.0.1", 0)) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + yield f"{server.server_address[0]}:{server.server_address[1]}" + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1) + + +def test_two_slot_ring_buffer_event_processing_flow(fake_redis_tcp_server): + payload = PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + ring_buffer = RingBuffer(slots=2, payload=payload) + connector = RedisConnector(fake_redis_tcp_server, name="SharedMemoryEventTest RedisConnector") + worker = SharedMemorySumWorker( + fake_redis_tcp_server, ring_buffer.descriptor.model_dump(), delay=0.01, expected_events=2 + ) + ctx = multiprocessing.get_context("spawn") + process = ctx.Process(target=worker.run) + process.start() + + try: + for data in ( + np.array([1, 2, 3, 4], dtype=np.float64), + np.array([5, 6, 7, 8], dtype=np.float64), + ): + slot_index = ring_buffer.write_data(data) + connector.xadd( + MessageEndpoints.shared_memory_slot_written(), + { + "data": messages.SharedMemSlotWritten( + client_id="writer", signal="detector.data", slot_index=slot_index + ) + }, + ) + + results = [] + deadline = time.monotonic() + 5 + while len(results) < 2 and time.monotonic() < deadline: + records = connector.xread( + MessageEndpoints.shared_memory_slot_processed(), block=100, count=1 + ) + if records: + results.append(records[0]["data"]) + + process.join(timeout=5) + assert process.exitcode == 0 + assert [result.slot_index for result in results] == [0, 1] + assert [result.result["sum"] for result in results] == [10.0, 26.0] + finally: + if process.is_alive(): + process.terminate() + process.join(timeout=5) + connector.shutdown(per_thread_timeout_s=1) + ring_buffer.destroy()