From 251720c59cab63af656c033bb5400e5f2e233337 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 12 Dec 2024 23:43:56 -0800 Subject: [PATCH 1/8] fix literal \w being interpreted as escape sequence --- mio/sources/__init__.py | 0 mio/types.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 mio/sources/__init__.py diff --git a/mio/sources/__init__.py b/mio/sources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mio/types.py b/mio/types.py index a77af047..871004ba 100644 --- a/mio/types.py +++ b/mio/types.py @@ -17,7 +17,7 @@ CONFIG_ID_PATTERN = r"[\w\-\/#]+" """ -Any alphanumeric string (\w), as well as +Any alphanumeric string (\\w), as well as - ``-`` - ``/`` - ``#`` From 40a368cf84e15a118e44464a46bc8c859f5a804b Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 13 Jan 2025 19:47:14 -0800 Subject: [PATCH 2/8] SDFileSource and continued splitting up of wirefree io --- mio/devices/device.py | 12 ++ mio/devices/wirefree.py | 123 +---------------- mio/models/pipeline.py | 34 ++++- mio/models/sdcard.py | 53 ++++---- mio/sources/file.py | 263 ++++++++++++++++++++++++++++++++++++- mio/transforms/__init__.py | 0 mio/transforms/frame.py | 44 +++++++ 7 files changed, 377 insertions(+), 152 deletions(-) create mode 100644 mio/transforms/__init__.py create mode 100644 mio/transforms/frame.py diff --git a/mio/devices/device.py b/mio/devices/device.py index 34324259..b74f9107 100644 --- a/mio/devices/device.py +++ b/mio/devices/device.py @@ -2,6 +2,7 @@ ABC for """ +import sys from abc import abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union @@ -11,6 +12,11 @@ if TYPE_CHECKING: from mio.models.pipeline import Sink, Source, Transform +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + class DeviceConfig(MiniscopeConfig): """ @@ -140,3 +146,9 @@ def transforms(self) -> dict[str, "Transform"]: def sinks(self) -> dict[str, "Sink"]: """Convenience method to access :attr:`.Pipeline.sinks`""" return self.pipeline.sinks + + @classmethod + def from_config(cls, config: DeviceConfig) -> Self: + """ + Instantiate a device from its (yaml) configuration. + """ diff --git a/mio/devices/wirefree.py b/mio/devices/wirefree.py index 253926ff..d167828a 100644 --- a/mio/devices/wirefree.py +++ b/mio/devices/wirefree.py @@ -16,6 +16,7 @@ from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.data import Frame from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout +from mio.models.pipeline import PipelineConfig from mio.types import ConfigSource, Resolution @@ -25,6 +26,12 @@ class WireFreeConfig(DeviceConfig): pass +class WireFreePipeline(PipelineConfig): + required_nodes = { + "sdcard": "sd-file-source", + } + + @dataclass(kw_only=True) class WireFreeMiniscope(Miniscope, RecordingCameraMixin): """ @@ -62,7 +69,6 @@ def __post_init__(self) -> None: # Private attributes used when the file reading context is entered self._config = None # type: Optional[SDConfig] - self._f = None # type: Optional[BinaryIO] self._frame = None # type: Optional[int] self._frame_count = None # type: Optional[int] self._array = None # type: Optional[np.ndarray] @@ -158,47 +164,6 @@ def frame(self, frame: int) -> None: for _ in range(frame - self.frame): self.skip() - @property - def frame_count(self) -> int: - """ - Total number of frames in recording. - - Inferred from :class:`~.sdcard.SDConfig.n_buffers_recorded` and - reading a single frame to get the number of buffers per frame. - """ - if self._frame_count is None: - if self._f is None: - with self as self_open: - frame = self_open.read(return_header=True) - headers = frame.headers - - else: - # If we're already open, great, just return to the last frame - last_frame = self.frame - # Go one frame back in case we are at the end of the data - self.frame = max(last_frame - 1, 0) - frame = self.read(return_header=True) - headers = frame.headers - self.frame = last_frame - - self._frame_count = int( - np.ceil( - (self.config.n_buffers_recorded + self.config.n_buffers_dropped) / len(headers) - ) - ) - - # if we have since read more frames than should be there, we update the - # frame count with a warning - max_pos = np.max(list(self.positions.keys())) - if max_pos > self._frame_count: - self.logger.warning( - "Got more frames than indicated in card header, expected " - f"{self._frame_count} but got {max_pos}" - ) - self._frame_count = int(max_pos) - - return self._frame_count - # -------------------------------------------------- # Context Manager methods # -------------------------------------------------- @@ -227,80 +192,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 self._f = None self._frame = 0 - # -------------------------------------------------- - # read methods - # -------------------------------------------------- - def _read_data_header(self, sd: BinaryIO) -> SDBufferHeader: - """ - Given an already open file buffer opened in bytes mode, - seeked to the start of a frame, read the data header - """ - - # read one word first, I think to get the size of the rest of the header, - # that sort of breaks the abstraction - # (it assumes the buffer length is always at position 0) - # but we'll roll with it for now - dataHeader = np.frombuffer(sd.read(self.layout.word_size), dtype=np.uint32) - dataHeader = np.append( - dataHeader, - np.frombuffer( - sd.read((dataHeader[self.layout.buffer.length] - 1) * self.layout.word_size), - dtype=np.uint32, - ), - ) - - # use construct because we're already sure these are ints from the numpy casting - # https://docs.pydantic.dev/latest/usage/models/#creating-models-without-validation - try: - header = SDBufferHeader.from_format(dataHeader, self.layout.buffer, construct=True) - except IndexError as e: - raise ReadHeaderException( - "Could not read header, expected header to have " - f"{len(self.layout.buffer.model_dump().keys())} fields, " - f"got {len(dataHeader)}. Likely mismatch between specified " - "and actual SD Card layout or reached end of data.\n" - f"Header Data: {dataHeader}" - ) from e - - return header - - def _n_frame_blocks(self, header: SDBufferHeader) -> int: - """ - Compute the number of blocks for a given frame buffer - - Not sure how this works! - """ - n_blocks = int( - ( - header.data_length - + (header.length * self.layout.word_size) - + (self.layout.sectors.size - 1) - ) - / self.layout.sectors.size - ) - return n_blocks - - def _read_size(self, header: SDBufferHeader) -> int: - """ - Compute the number of bytes to read for a given buffer - - Not sure how this works with :meth:`._n_frame_blocks`, but keeping - them separate in case they are separable actions for now - """ - n_blocks = self._n_frame_blocks(header) - read_size = (n_blocks * self.layout.sectors.size) - (header.length * self.layout.word_size) - return read_size - - def _read_buffer(self, sd: BinaryIO, header: SDBufferHeader) -> np.ndarray: - """ - Read a single buffer from a frame. - - Each frame has several buffers, so for a given frame we read them until we - get another that's zero! - """ - data = np.frombuffer(sd.read(self._read_size(header)), dtype=np.uint8) - return data - def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: """ Trim or pad an array to match an expected size diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index 0152a1bf..a25dd6ca 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,9 +4,9 @@ import sys from abc import abstractmethod -from typing import ClassVar, Final, Generic, TypeVar, Union, final +from typing import ClassVar, Final, Generic, TypeVar, Union, final, TypedDict, Optional -from pydantic import Field +from pydantic import Field, model_validator from mio.exceptions import ConfigurationMismatchError from mio.models.models import MiniscopeConfig, PipelineModel @@ -42,6 +42,8 @@ class NodeConfig(MiniscopeConfig): """List of Node IDs to be used as input""" outputs: list[str] = Field(default_factory=list) """List of Node IDs to be used as output""" + config: dict + """Additional configuration for this node, parameterized by a TypedDict for the class""" class PipelineConfig(MiniscopeConfig): @@ -49,9 +51,23 @@ class PipelineConfig(MiniscopeConfig): Configuration for the nodes within a pipeline """ + required_nodes: ClassVar[Optional[dict[str, str]]] = None + """ + id: type mapping that a subclass can use to require a set of node types with specific IDs be present + """ + nodes: dict[str, NodeConfig] = Field(default_factory=dict) """The nodes that this pipeline configures""" + @model_validator(mode="after") + def validate_required_nodes(self) -> Self: + """Ensure required nodes are present, if any""" + if self.required_nodes is not None: + for id_, type_ in self.required_nodes.items(): + assert id_ in self.nodes, f"Node ID {id_} not in {self.nodes.keys()}" + assert self.nodes[id_].type_ == type_, f"Node ID {id_} is not of type {type_}" + return self + class Node(PipelineModel, Generic[T, U]): """A node within a processing pipeline""" @@ -63,24 +79,30 @@ class Node(PipelineModel, Generic[T, U]): id: str """Unique identifier of the node""" - config: NodeConfig + config: Optional[NodeConfig] = None input_type: ClassVar[type[T]] inputs: dict[str, Union["Source", "Transform"]] = Field(default_factory=dict) output_type: ClassVar[type[U]] outputs: dict[str, Union["Sink", "Transform"]] = Field(default_factory=dict) - @abstractmethod def start(self) -> None: """ - Start producing, processing, or receiving data + Start producing, processing, or receiving data. + + Default is a no-op. + Subclasses do not need to override if they have no initialization logic. """ + pass - @abstractmethod def stop(self) -> None: """ Stop producing, processing, or receiving data + + Default is a no-op. + Subclasses do not need to override if they have no deinit logic. """ + pass @classmethod def from_config(cls, config: NodeConfig) -> Self: diff --git a/mio/models/sdcard.py b/mio/models/sdcard.py index 76638c57..5e0f7f49 100644 --- a/mio/models/sdcard.py +++ b/mio/models/sdcard.py @@ -4,7 +4,9 @@ for consuming code to use a consistent, introspectable API """ -from typing import Optional +from typing import Optional, TYPE_CHECKING, Literal + +from pydantic import computed_field from mio.models import MiniscopeConfig from mio.models.buffer import BufferHeader, BufferHeaderFormat @@ -46,17 +48,20 @@ class SectorConfig(MiniscopeConfig): The size of an individual sector """ - def __getattr__(self, item: str) -> int: - """ - Get positions by multiplying by sector size - (__getattr__ is only called if the name can't be found, so we don't need to handle - the base case of the existing attributes) - """ - split = item.split("_") - if len(split) == 2 and split[1] == "pos": - return getattr(self, split[0]) * self.size - else: - raise AttributeError() + @property + def header_pos(self) -> int: + """header * sector size""" + return self.header * self.size + + @property + def config_pos(self) -> int: + """config * sector size""" + return self.config * self.size + + @property + def data_pos(self) -> int: + """data * sector size""" + return self.data * self.size class ConfigPositions(MiniscopeConfig): @@ -94,7 +99,7 @@ class SDBufferHeaderFormat(BufferHeaderFormat): id: str = "sd-buffer-header" - length: int = 0 + length: Literal[0] = 0 linked_list: int = 1 frame_num: int = 2 buffer_count: int = 3 @@ -114,25 +119,23 @@ class SDLayout(MiniscopeConfig, ConfigYAMLMixin): Used by the :class:`.io.WireFreeMiniscope` class to tell it how data on the SD card is laid out. """ - sectors: SectorConfig - write_key0: int = 0x0D7CBA17 - write_key1: int = 0x0D7CBA17 - write_key2: int = 0x0D7CBA17 - write_key3: int = 0x0D7CBA17 - """ - These don't seem to actually be used in the existing reading/writing code, - but we will leave them here for continuity's sake :) - """ word_size: int = 4 """ - I'm actually not sure what this is, but 4 is hardcoded a few times in the - existing notebook and it appears to be used as a word size when - reading from the SD card. + Size of each header word in bytes """ + sectors: SectorConfig header: SDHeaderPositions = SDHeaderPositions() config: ConfigPositions = ConfigPositions() buffer: SDBufferHeaderFormat = SDBufferHeaderFormat() + header_dtype: str = "uint32" + """ + String form of the numpy dtype that the global and buffer headers are encoded in + """ + buffer_dtype: str = "uint8" + """ + String form of the numpy dtype that each frame buffer is encoded in + """ class SDConfig(MiniscopeConfig): diff --git a/mio/sources/file.py b/mio/sources/file.py index 0011e9ee..2727c7d9 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -2,7 +2,17 @@ File-based data sources """ -from mio.models.pipeline import Source +from io import BufferedReader, BytesIO +from pathlib import Path +from typing import BinaryIO, ClassVar, Optional + +from pydantic import Field +import numpy as np + +from mio.models.buffer import BufferHeader, BufferHeaderFormat +from mio.models.sdcard import SDLayout, SDConfig, SDBufferHeader +from mio.models.pipeline import Source, U +from mio.exceptions import EndOfRecordingException, ReadHeaderException class FileSource(Source): @@ -11,13 +21,47 @@ class FileSource(Source): """ -class BinaryLayout: - """Layout for binary files""" +class BinaryFileSource(FileSource): + """ + A FileSource that yields blocks of binary data + """ - pass + output_type: ClassVar[bytes] + path: Path + offset: int = 0 + """ + The offset position from the start of the file from which to consider the "zero point" + """ + block_size: int + """ + Number of bytes to read per processing loop + """ -class BinaryFileSource(FileSource): + _f: Optional[BinaryIO] = None + + def start(self) -> None: + """Open the file, seek to the offset""" + self._f = open(self.path, "rb") # noqa: SIM115 + self._f.seek(self.offset, 0) + + def stop(self): + """Close the file, remove the reference""" + self._f.close() + self._f = None + + def tell(self) -> int: + """Return the current position in the file""" + if self._f is None: + raise RuntimeError("File has not yet been opened with start") + return self._f.tell() + + def process(self) -> bytes: + """Return a block of data""" + return self._f.read(self.block_size) + + +class SDFileSource(FileSource): """ Structured binary file that has @@ -31,4 +75,213 @@ class BinaryFileSource(FileSource): * the ``config`` - getter and setter for the actual configuration values of the source * the ``layout`` - how the configuration and data are laid out within the file. + + """ + + _type = "sd-file-source" + output_type = tuple[SDBufferHeader, np.ndarray] + + path: Path + layout: SDLayout + + _f: Optional[BinaryIO] = None + _config: SDConfig = None + _positions: dict[int, int] = Field(default_factory=dict) + """ + A mapping between frame number and byte position in the video that makes for + faster seeking :) + + As we read, we store the locations of each frame before reading it. Later, + we can assign to `frame` to seek back to those positions. Assigning to `frame` + works without caching position, but has to manually iterate through each frame. """ + _last_buffer: int = None + _frame: int = 0 + + @property + def config(self) -> SDConfig: + """ + Global configuration of the whole SD card + """ + if self._config is None: + with open(self.path, "rb") as sd: + sd.seek(self.layout.sectors.config_pos, 0) + configSectorData = np.frombuffer( + sd.read(self.layout.sectors.size), dtype=np.dtype(self.layout.header_dtype) + ) + + self._config = SDConfig( + **{ + k: configSectorData[v] + for k, v in self.layout.config.model_dump().items() + if v is not None + } + ) + + return self._config + + @property + def offset(self) -> int: + """Start point of the data sector""" + return self.layout.sectors.data_pos + + @property + def header_size(self) -> int: + """ + Number of bytes in a buffer header + + .. note:: + + This isn't guaranteed to be accurate, see: + https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/issues/64 + """ + return ( + max([v for v in self.layout.buffer.model_dump().values() if v is not None]) + 1 + ) * self.layout.word_size + + @property + def buffers_per_frame(self) -> int: + """ + Number of buffers per frame! + + References: + https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/blob/786663781a4bece89c89e00fc3ac9d95912faba4/Miniscope-v4-Wire-Free-MCU-Firmware/Miniscope-v4-Wire-Free/Miniscope-v4-Wire-Free/main.c#L680 + """ + n_pix = self.config.width * self.config.height + return int(np.ceil((n_pix + self.header_size) / (self.config.buffer_size))) + + @property + def frame_count(self) -> int: + """ + Total number of frames in the recording + """ + return int( + np.ceil( + (self.config.n_buffers_recorded + self.config.n_buffers_dropped) + / self.buffers_per_frame + ) + ) + + def start(self) -> None: + """Open the file, seek to the offset""" + self._last_buffer = 0 + + self._f = open(self.path, "rb") # noqa: SIM115 + self._f.seek(self.offset, 0) + + def stop(self): + """Close the file, remove the reference""" + self._f.close() + self._f = None + + def tell(self) -> int: + """Return the current position in the file""" + if self._f is None: + raise RuntimeError("File has not yet been opened with start") + return self._f.tell() + + def process(self) -> tuple[SDBufferHeader, np.ndarray]: + """ + Read a single data buffer, parsing its header and splitting it from the data + """ + start_position = self.tell() + + header = self._read_header(self._f) + self._last_buffer = header.buffer_count + self._frame = header.frame_num + + if header.frame_num not in self._positions: + self._positions[header.frame_num] = start_position + + buffer = self._read_buffer(self._f, header) + buffer = self._trim(buffer, header.data_length) + return header, buffer + + def _read_header(self, sd: BinaryIO) -> SDBufferHeader: + """ + Given an already open file buffer opened in bytes mode, + seeked to the start of a frame, read the data header + """ + # Get the length of the header from the first word + try: + dataHeader = np.frombuffer( + sd.read(self.layout.word_size), dtype=np.dtype(self.layout.header_dtype) + ) + except IndexError as e: + if "index 0 is out of bounds for axis 0 with size 0" in str(e): + # end of file if we are reading from a disk image without any + # additional space on disk + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise e + + # Get the rest of the values in the header + try: + dataHeader = np.append( + dataHeader, + np.frombuffer( + sd.read(int(dataHeader[0]) * self.layout.word_size), + dtype=np.dtype(self.layout.header_dtype), + ), + ) + except ValueError as e: + if "read length must be non-negative" in str(e): + # end of file! Value error thrown because the dataHeader will be + # blank, and thus have a value of 0 for the header size, and we + # can't read 0 from the card. + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise e + + # use construct because we're already sure these are ints from the numpy casting + # https://docs.pydantic.dev/latest/usage/models/#creating-models-without-validation + try: + return SDBufferHeader.from_format(dataHeader, self.layout.buffer, construct=True) + except IndexError as e: + if ( + self._last_buffer + >= self.config.n_buffers_recorded + self.config.n_buffers_dropped - 1 + ): + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise ReadHeaderException( + "Could not read header, expected header to have " + f"{len(self.layout.buffer.model_dump().keys())} fields, " + f"got {len(dataHeader)}. Likely mismatch between specified " + "and actual SD Card layout or reached end of data.\n" + f"Header Data: {dataHeader}" + ) from e + + def _read_buffer(self, sd: BinaryIO, header: SDBufferHeader) -> np.ndarray: + return np.frombuffer(sd.read(self._data_read_size(header)), dtype=np.uint8) + + def _data_read_size(self, header: SDBufferHeader) -> int: + """ + After the header, how many bytes to read for the data in a buffer + """ + # blocks are quantized by sector size, so get min number of blocks that cover the data + n_blocks = np.ceil( + (header.data_length + (header.length * self.layout.word_size)) + / self.layout.sectors.size + ) + # expand back to n bytes + sector_size = n_blocks * self.layout.sectors.size + # subtract length of header + return int(sector_size - (header.length * self.layout.word_size)) + + def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: + """ + Trim or pad an array to match an expected size. + + This should be the case most of the time - + number of bytes in a memory sector won't match bytes in a buffer + """ + if data.shape[0] != expected_size: + # trim if too long + if data.shape[0] > expected_size: + data = data[0:expected_size] + # pad if too short + else: + data = np.pad(data, (0, expected_size - data.shape[0])) + + return data diff --git a/mio/transforms/__init__.py b/mio/transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mio/transforms/frame.py b/mio/transforms/frame.py new file mode 100644 index 00000000..2a769109 --- /dev/null +++ b/mio/transforms/frame.py @@ -0,0 +1,44 @@ +from typing import ClassVar, TypedDict, Optional +from mio.models.pipeline import Transform, T, U +from mio.models.buffer import BufferHeader +from mio.models.data import Frame +import numpy as np + + +class MergeBuffersConfig(TypedDict): + width: int + height: int + + +class MergeBuffers(Transform): + """ + Merge sequential frame buffers into a single frame + """ + + type_ = "merge-buffers" + input_type = tuple[BufferHeader, np.ndarray] + output_type = Optional[Frame] + config: MergeBuffersConfig + + _headers: Optional[list[BufferHeader]] = None + _buffers: Optional[list[np.ndarray]] = None + _last_buffer_n: Optional[int] = None + + def start(self) -> None: + """Init private containers""" + self._headers = [] + self._buffers = [] + self._last_buffer_n = 0 + + def process(self, header: BufferHeader, buffer: np.ndarray) -> Optional[Frame]: + if header.frame_buffer_count == 0 and self._last_buffer_n >= 0: + frame = np.concat(self._buffers).reshape((self.config["width"], self.config["height"])) + headers = self._headers.copy() + self._headers = [] + self._buffers = [] + return Frame.model_construct(frame=frame, headers=headers) + else: + self._last_buffer_n = header.frame_buffer_count + self._headers.append(header) + self._buffers.append(buffer) + return None From 614b561dc01f6e6014a36697d820c01ee50133ae Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 16 Jan 2025 23:19:11 -0800 Subject: [PATCH 3/8] Split pipeline into pipeline runners, finish pieces of wirefree pipeline before actually working on instantiating, add target wirefree pipeline config --- .../config/wirefree/wirefree-pipeline.yaml | 28 ++++ mio/devices/wirefree.py | 6 +- mio/models/pipeline.py | 132 ++++++++++++------ mio/models/sdcard.py | 4 +- mio/pipeline/__init__.py | 3 + mio/pipeline/runners.py | 57 ++++++++ mio/sinks/__init__.py | 3 + mio/sinks/return.py | 41 ++++++ mio/sources/__init__.py | 3 + mio/sources/file.py | 42 ++++-- mio/transforms/__init__.py | 3 + mio/transforms/frame.py | 36 +++-- 12 files changed, 291 insertions(+), 67 deletions(-) create mode 100644 mio/data/config/wirefree/wirefree-pipeline.yaml create mode 100644 mio/pipeline/__init__.py create mode 100644 mio/pipeline/runners.py create mode 100644 mio/sinks/__init__.py create mode 100644 mio/sinks/return.py diff --git a/mio/data/config/wirefree/wirefree-pipeline.yaml b/mio/data/config/wirefree/wirefree-pipeline.yaml new file mode 100644 index 00000000..6e430b73 --- /dev/null +++ b/mio/data/config/wirefree/wirefree-pipeline.yaml @@ -0,0 +1,28 @@ +id: wirefree-pipeline +mio_model: mio.devices.wirefree.WireFreePipeline +mio_version: v0.6.0 + +nodes: + file: + type: "sd-file-source" + config: + layout: "wirefree-sd-layout" + passed: + - path + outputs: + - source: header + target: merge + - source: buffer + target: merge + merge: + type: "merge-buffers" + fill: + width: file.width + height: file.height + outputs: + - source: frame + target: data + return: + config: + key: frame + type: "return" diff --git a/mio/devices/wirefree.py b/mio/devices/wirefree.py index d167828a..a33b4126 100644 --- a/mio/devices/wirefree.py +++ b/mio/devices/wirefree.py @@ -5,7 +5,7 @@ import contextlib from dataclasses import dataclass, field from pathlib import Path -from typing import Any, BinaryIO, Literal, Optional, Union, overload +from typing import Any, Literal, Optional, Union, overload import cv2 import numpy as np @@ -15,8 +15,8 @@ from mio.devices import DeviceConfig, Miniscope, RecordingCameraMixin from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.data import Frame -from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout from mio.models.pipeline import PipelineConfig +from mio.models.sdcard import SDConfig, SDLayout from mio.types import ConfigSource, Resolution @@ -27,6 +27,8 @@ class WireFreeConfig(DeviceConfig): class WireFreePipeline(PipelineConfig): + """Base skeleton pipeline for the wirefree miniscope""" + required_nodes = { "sdcard": "sd-file-source", } diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index a25dd6ca..b134d156 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,9 +4,9 @@ import sys from abc import abstractmethod -from typing import ClassVar, Final, Generic, TypeVar, Union, final, TypedDict, Optional +from typing import ClassVar, Final, Generic, Optional, TypedDict, TypeVar, Union, final -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from mio.exceptions import ConfigurationMismatchError from mio.models.models import MiniscopeConfig, PipelineModel @@ -26,8 +26,24 @@ """ -class NodeConfig(MiniscopeConfig): - """Configuration for a single processing node""" +class _NodeMap(TypedDict): + source: str + target: str + + +class NodeConfig(TypedDict): + """ + Abstract parent TypedDict that each node inherits from to define + what fields it needs to be configured. + """ + + +class NodeSpecification(MiniscopeConfig): + """ + Specification for a single processing node within a pipeline .yaml file. + Distinct from a :class:`.NodeConfig`, which is a generic TypedDict that each + node defines to declare its parameterization. + """ type_: str = Field(..., alias="type") """ @@ -35,15 +51,55 @@ class NodeConfig(MiniscopeConfig): Subclasses should override this with a default. """ - id: str """The unique identifier of the node""" - inputs: list[str] = Field(default_factory=list) - """List of Node IDs to be used as input""" - outputs: list[str] = Field(default_factory=list) + outputs: Optional[list[_NodeMap]] = None """List of Node IDs to be used as output""" - config: dict + config: Optional[NodeConfig] = None """Additional configuration for this node, parameterized by a TypedDict for the class""" + passed: Optional[list[str]] = None + """ + List of config values that must be passed when the pipeline is instantiated + """ + fill: Optional[dict[str, str]] = None + """ + Values in the node config that should be dynamically filled from other nodes in the pipeline. + + Specified as {node_id}.{attribute}, these specify attributes and properties + on the instantiated node class, not the config values for that node. + + This is useful for accessing some properties that might not be known until runtime + like width and height of an input image. + + Examples: + + For a node class `camera` that has property `frame_width`, + and node class `process` that has config value `width`, + we would fill the config value like this: + + .. code-block:: yaml + + nodes: + cam: + type: camera + proc: + type: process + fill: + width: cam.frame_width + + The Pipeline class will then do something like this on instantiation: + + .. code-block:: python + + pipeline = PipelineConfig(**the_above_values) + + cam = CameraNode(config=pipeline.nodes['cam'].config) + + proc_config = pipeline.nodes['proc'].config + proc_config['width'] = cam.frame_width + proc = ProcessingNode(config=proc_config) + + """ class PipelineConfig(MiniscopeConfig): @@ -53,10 +109,11 @@ class PipelineConfig(MiniscopeConfig): required_nodes: ClassVar[Optional[dict[str, str]]] = None """ - id: type mapping that a subclass can use to require a set of node types with specific IDs be present + id: type mapping that a subclass can use to require a set of node types + with specific IDs be present """ - nodes: dict[str, NodeConfig] = Field(default_factory=dict) + nodes: dict[str, NodeSpecification] = Field(default_factory=dict) """The nodes that this pipeline configures""" @model_validator(mode="after") @@ -68,11 +125,23 @@ def validate_required_nodes(self) -> Self: assert self.nodes[id_].type_ == type_, f"Node ID {id_} is not of type {type_}" return self + @field_validator("nodes", mode="before") + @classmethod + def fill_node_ids(cls, value: dict[str, dict]) -> dict[str, dict]: + """ + Roll down the `id` from the key in the `nodes` dictionary into the node config + """ + assert isinstance(value, dict) + for id, node in value.items(): + if "id" not in node: + node["id"] = id + return value + class Node(PipelineModel, Generic[T, U]): """A node within a processing pipeline""" - type_: ClassVar[str] + name: ClassVar[str] """ Shortname for this type of node to match configs to node types """ @@ -105,7 +174,7 @@ def stop(self) -> None: pass @classmethod - def from_config(cls, config: NodeConfig) -> Self: + def from_config(cls, config: NodeSpecification) -> Self: """ Create a node from its config """ @@ -115,18 +184,18 @@ def from_config(cls, config: NodeConfig) -> Self: @final def node_types(cls) -> dict[str, type["Node"]]: """ - Map of all imported :attr:`.Node.type_` names to node classes + Map of all imported :attr:`.Node.name` names to node classes """ node_types = {} to_check = cls.__subclasses__() while to_check: node = to_check.pop() - if node.type_ in node_types: + if node.name in node_types: raise ValueError( - f"Repeated node type_ identifier: {node.type_}, found in:\n" - f"- {node_types[node.type_]}\n- {node}" + f"Repeated node name identifier: {node.name}, found in:\n" + f"- {node_types[node.name]}\n- {node}" ) - node_types[node.type_] = node + node_types[node.name] = node to_check.extend(node.__subclasses__()) return node_types @@ -193,6 +262,10 @@ def process(self, data: T) -> U: class Pipeline(PipelineModel): """ A graph of nodes transforming some input source(s) to some output sink(s) + + The Pipeline model is a container for a set of nodes that are fully instantiated + (e.g. have their "passed" and "fill" keys processed) and connected. + It does not handle running the pipeline -- that is handled by a PipelineRunner. """ nodes: dict[str, Node] = Field(default_factory=dict) @@ -215,29 +288,6 @@ def sinks(self) -> dict[str, "Sink"]: """All :class:`.Sink` nodes in the processing graph""" return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)} - @abstractmethod - def process(self) -> None: - """ - Process one step of data from each of the sources, - passing intermediate data to any subscribed nodes in a chain. - - The `process` method should not return anything except a to-be-implemented - result/status object, as any data intended to be received/processed by - downstream objects should be accessed via a :class:`.Sink` . - """ - - @abstractmethod - def start(self) -> None: - """ - Start processing data with the pipeline graph - """ - - @abstractmethod - def stop(self) -> None: - """ - Stop processing data with the pipeline graph - """ - @classmethod def from_config(cls, config: PipelineConfig) -> Self: """ diff --git a/mio/models/sdcard.py b/mio/models/sdcard.py index 5e0f7f49..d308163b 100644 --- a/mio/models/sdcard.py +++ b/mio/models/sdcard.py @@ -4,9 +4,7 @@ for consuming code to use a consistent, introspectable API """ -from typing import Optional, TYPE_CHECKING, Literal - -from pydantic import computed_field +from typing import Literal, Optional from mio.models import MiniscopeConfig from mio.models.buffer import BufferHeader, BufferHeaderFormat diff --git a/mio/pipeline/__init__.py b/mio/pipeline/__init__.py new file mode 100644 index 00000000..d2612c56 --- /dev/null +++ b/mio/pipeline/__init__.py @@ -0,0 +1,3 @@ +""" +Runtime classes for pipelines +""" diff --git a/mio/pipeline/runners.py b/mio/pipeline/runners.py new file mode 100644 index 00000000..fd5de4eb --- /dev/null +++ b/mio/pipeline/runners.py @@ -0,0 +1,57 @@ +""" +Pipeline runners for running pipelines +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + +from mio.models import Pipeline + + +@dataclass +class PipelineRunner(ABC): + """ + Abstract parent class for pipeline runners. + + Pipeline runners handle calling the nodes and passing the + events returned by them to each other. Each runner may do so + however it needs to (synchronously, asynchronously, alone or as part of a cluster, etc.) + as long as it satisfies this abstract interface. + """ + + pipeline: Pipeline + + @abstractmethod + def process(self) -> Optional[dict[str, Any]]: + """ + Process one step of data from each of the sources, + passing intermediate data to any subscribed nodes in a chain. + + The `process` method normally does not return anything, + except when using the special :class:`.ReturnSink` node - + if there are :class:`.ReturnSink` nodes in a :class:`.Pipeline` graph, + then each call to `process` will return a dictionary with one key + (from the :class:`.ReturnSink`'s `key` config value) and one value for each + :class:`.ReturnSink`. + """ + + @abstractmethod + def start(self) -> None: + """ + Start processing data with the pipeline graph + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop processing data with the pipeline graph + """ + + +class SynchronousRunner(PipelineRunner): + """ + Simple, synchronous pipeline runner. + + Just run the nodes in topological order and return from return nodes. + """ diff --git a/mio/sinks/__init__.py b/mio/sinks/__init__.py new file mode 100644 index 00000000..b5ac9db1 --- /dev/null +++ b/mio/sinks/__init__.py @@ -0,0 +1,3 @@ +""" +Sink pipeline nodes that receive but do not emit events +""" diff --git a/mio/sinks/return.py b/mio/sinks/return.py new file mode 100644 index 00000000..4891553a --- /dev/null +++ b/mio/sinks/return.py @@ -0,0 +1,41 @@ +""" +Special Return sink that pipeline runners use to return values from :meth:`.PipelineRunner.process` +""" + +from typing import Any, TypedDict + +from mio.models.pipeline import Sink, T + + +class ReturnConfig(TypedDict): + """ + Config for return nodes + """ + + key: str + """The key to use in the returned dictionary""" + + +class Return(Sink): + """ + Special sink node that returns values from a pipeline runner's `process` method + """ + + name = "return" + input_type = Any + + config: ReturnConfig + + _value: Any = None + + def process(self, data: T) -> None: + """ + Store the incoming value to retrieve later with :meth:`.get` + """ + self._value = data + + def get(self) -> dict[str, T]: + """ + Get the stored value from the process call + """ + return {self.config["key"]: self._value} diff --git a/mio/sources/__init__.py b/mio/sources/__init__.py index e69de29b..a8191faa 100644 --- a/mio/sources/__init__.py +++ b/mio/sources/__init__.py @@ -0,0 +1,3 @@ +""" +Source pipeline nodes that emit but do not receive events +""" diff --git a/mio/sources/file.py b/mio/sources/file.py index 2727c7d9..c2636053 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -2,17 +2,15 @@ File-based data sources """ -from io import BufferedReader, BytesIO from pathlib import Path -from typing import BinaryIO, ClassVar, Optional +from typing import BinaryIO, ClassVar, Optional, TypedDict -from pydantic import Field import numpy as np +from pydantic import Field -from mio.models.buffer import BufferHeader, BufferHeaderFormat -from mio.models.sdcard import SDLayout, SDConfig, SDBufferHeader -from mio.models.pipeline import Source, U from mio.exceptions import EndOfRecordingException, ReadHeaderException +from mio.models.pipeline import Source +from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout class FileSource(Source): @@ -20,12 +18,15 @@ class FileSource(Source): Generic parent class for file sources """ + name = "file-source" + class BinaryFileSource(FileSource): """ A FileSource that yields blocks of binary data """ + name = "binary-file-source" output_type: ClassVar[bytes] path: Path @@ -45,7 +46,7 @@ def start(self) -> None: self._f = open(self.path, "rb") # noqa: SIM115 self._f.seek(self.offset, 0) - def stop(self): + def stop(self) -> None: """Close the file, remove the reference""" self._f.close() self._f = None @@ -61,6 +62,13 @@ def process(self) -> bytes: return self._f.read(self.block_size) +class SDFileSourceOutput(TypedDict): + """Output types returned by :meth:`.SDFileSource.process`""" + + header: SDBufferHeader + buffer: np.ndarray + + class SDFileSource(FileSource): """ Structured binary file that has @@ -78,8 +86,8 @@ class SDFileSource(FileSource): """ - _type = "sd-file-source" - output_type = tuple[SDBufferHeader, np.ndarray] + name = "sd-file-source" + output_type = SDFileSourceOutput path: Path layout: SDLayout @@ -120,6 +128,16 @@ def config(self) -> SDConfig: return self._config + @property + def width(self) -> int: + """width of the captured video in pixels""" + return self.config.width + + @property + def height(self) -> int: + """height of the captured video in pixels""" + return self.config.height + @property def offset(self) -> int: """Start point of the data sector""" @@ -169,7 +187,7 @@ def start(self) -> None: self._f = open(self.path, "rb") # noqa: SIM115 self._f.seek(self.offset, 0) - def stop(self): + def stop(self) -> None: """Close the file, remove the reference""" self._f.close() self._f = None @@ -180,7 +198,7 @@ def tell(self) -> int: raise RuntimeError("File has not yet been opened with start") return self._f.tell() - def process(self) -> tuple[SDBufferHeader, np.ndarray]: + def process(self) -> SDFileSourceOutput: """ Read a single data buffer, parsing its header and splitting it from the data """ @@ -195,7 +213,7 @@ def process(self) -> tuple[SDBufferHeader, np.ndarray]: buffer = self._read_buffer(self._f, header) buffer = self._trim(buffer, header.data_length) - return header, buffer + return {"header": header, "buffer": buffer} def _read_header(self, sd: BinaryIO) -> SDBufferHeader: """ diff --git a/mio/transforms/__init__.py b/mio/transforms/__init__.py index e69de29b..3a1cc155 100644 --- a/mio/transforms/__init__.py +++ b/mio/transforms/__init__.py @@ -0,0 +1,3 @@ +""" +Transform pipeline nodes that both receive and emit events +""" diff --git a/mio/transforms/frame.py b/mio/transforms/frame.py index 2a769109..ffaec9ba 100644 --- a/mio/transforms/frame.py +++ b/mio/transforms/frame.py @@ -1,23 +1,37 @@ -from typing import ClassVar, TypedDict, Optional -from mio.models.pipeline import Transform, T, U +""" +Nodes that receive and emit frames +""" + +from typing import Optional, TypedDict + +import numpy as np + from mio.models.buffer import BufferHeader from mio.models.data import Frame -import numpy as np +from mio.models.pipeline import Transform class MergeBuffersConfig(TypedDict): + """Configuration for :class:`.MergeBuffers`""" + width: int height: int +class MergeBuffersOutput(TypedDict): + """Output returned by :meth:`.MergeBuffers.process`""" + + frame: Frame + + class MergeBuffers(Transform): """ Merge sequential frame buffers into a single frame """ - type_ = "merge-buffers" + name = "merge-buffers" input_type = tuple[BufferHeader, np.ndarray] - output_type = Optional[Frame] + output_type = MergeBuffersOutput config: MergeBuffersConfig _headers: Optional[list[BufferHeader]] = None @@ -30,13 +44,17 @@ def start(self) -> None: self._buffers = [] self._last_buffer_n = 0 - def process(self, header: BufferHeader, buffer: np.ndarray) -> Optional[Frame]: + def process(self, header: BufferHeader, buffer: np.ndarray) -> Optional[MergeBuffersOutput]: + """ + Receive a header/buffer pair. If the frame buffer count has cycled back to zero, + merge into a completed frame + """ if header.frame_buffer_count == 0 and self._last_buffer_n >= 0: frame = np.concat(self._buffers).reshape((self.config["width"], self.config["height"])) headers = self._headers.copy() - self._headers = [] - self._buffers = [] - return Frame.model_construct(frame=frame, headers=headers) + self._headers = [header] + self._buffers = [buffer] + return {"frame": Frame.model_construct(frame=frame, headers=headers)} else: self._last_buffer_n = header.frame_buffer_count self._headers.append(header) From 643a7f0ec0b76b1073f6b325a726fbe237775399 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 17 Jan 2025 01:19:28 -0800 Subject: [PATCH 4/8] Split pipeline into pipeline runners, finish pieces of wirefree pipeline before actually working on instantiating, add target wirefree pipeline config --- .../config/wirefree/wirefree-pipeline.yaml | 2 +- mio/models/pipeline.py | 191 +++++++++++++++++- 2 files changed, 183 insertions(+), 10 deletions(-) diff --git a/mio/data/config/wirefree/wirefree-pipeline.yaml b/mio/data/config/wirefree/wirefree-pipeline.yaml index 6e430b73..852cc738 100644 --- a/mio/data/config/wirefree/wirefree-pipeline.yaml +++ b/mio/data/config/wirefree/wirefree-pipeline.yaml @@ -8,7 +8,7 @@ nodes: config: layout: "wirefree-sd-layout" passed: - - path + path: sd_path outputs: - source: header target: merge diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index b134d156..81372daa 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,7 +4,8 @@ import sys from abc import abstractmethod -from typing import ClassVar, Final, Generic, Optional, TypedDict, TypeVar, Union, final +from graphlib import TopologicalSorter +from typing import Any, ClassVar, Final, Generic, Optional, TypedDict, TypeVar, Union, final from pydantic import Field, field_validator, model_validator @@ -57,9 +58,36 @@ class NodeSpecification(MiniscopeConfig): """List of Node IDs to be used as output""" config: Optional[NodeConfig] = None """Additional configuration for this node, parameterized by a TypedDict for the class""" - passed: Optional[list[str]] = None + passed: Optional[dict[str, str]] = None """ - List of config values that must be passed when the pipeline is instantiated + Mapping of config values that must be passed when the pipeline is instantiated. + + Keys are the key in the config dictionary to be filled by passing, and values are a key that + those values should be passed as. + + Examples: + + For a node with config field `height` , one can specify that it must be passed + on instantiation like this: + + .. code-block:: yaml + + nodes: + node1: + type: a_node + passed: + height: height_1 + node2: + type: a_node + passed: + height: height_2 + + The pipeline should then be instantiated like: + + .. code-block:: python + + Pipeline.from_config(above_config, passed={'height_1': 1, 'height_2': 2}) + """ fill: Optional[dict[str, str]] = None """ @@ -137,6 +165,58 @@ def fill_node_ids(cls, value: dict[str, dict]) -> dict[str, dict]: node["id"] = id return value + # TODO: Implement these validators + # @field_validator("nodes", mode="after") + # @classmethod + # def valid_passed_and_fill_keys( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # Passed and fill keys refer to values within the node's config type + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def unique_passed_values( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # All passed values ( + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def fill_sources_present( + # cls, value: dict[str, NodeSpecification] + # # ) -> dict[str, NodeSpecification]: + # """ + # Fill values refer to nodes that are present in the node graph + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def fill_values_dotted( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # Fill values refer to a property or attribute of a node (i.e. have at least one dot) + # """ + + def graph(self) -> TopologicalSorter: + """ + For the node specifications in :attr:`.PipelineConfig.nodes`, + produce a :class:`.TopologicalSorter` that accounts for the dependencies between nodes + induced by :attr:`.NodeSpecification.fill` + """ + sorter = TopologicalSorter() + for node_id, node in self.nodes.items(): + if node.fill is None: + sorter.add(node_id) + else: + dependents = {v.split(".")[0] for v in node.fill} + sorter.add(node_id, *dependents) + return sorter + class Node(PipelineModel, Generic[T, U]): """A node within a processing pipeline""" @@ -174,11 +254,11 @@ def stop(self) -> None: pass @classmethod - def from_config(cls, config: NodeSpecification) -> Self: + def from_specification(cls, config: NodeSpecification) -> Self: """ Create a node from its config """ - return cls(id=config.id, config=config) + return cls(id=config.id, config=config.config) @classmethod @final @@ -289,16 +369,109 @@ def sinks(self) -> dict[str, "Sink"]: return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)} @classmethod - def from_config(cls, config: PipelineConfig) -> Self: + def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None) -> Self: """ Instantiate a pipeline model from its configuration + + Args: + config (PipelineConfig): the pipeline config to instantiate + passed (dict[str, Any]): If any nodes in the """ - types = Node.node_types() + cls._validate_passed(config, passed) + + nodes = cls._init_nodes(config, passed) - nodes = {k: types[v.type_].from_config(v) for k, v in config.nodes.items()} - nodes = connect_nodes(nodes) return cls(nodes=nodes) + @classmethod + def passed_values(cls, config: PipelineConfig) -> dict[str, type]: + """ + Dictionary containing the keys that must be passed as specified by the `passed` field + of a node specification and their types. + + Args: + config (:class:`.PipelineConfig`): Pipeline configuration to get passed values for + """ + types = Node.node_types() + passed = {} + for node_id, node in config.nodes.items(): + if not node.passed: + continue + + for cfg_key, pass_key in node.passed.items(): + # get type of config key that needs to be passed + config_type = types[node.type_].model_fields["config"].annotation + try: + passed[pass_key] = config_type.__annotations__[cfg_key] + except KeyError as e: + raise ConfigurationMismatchError( + f"Node {node_id} requested {cfg_key} be passed as {pass_key}, " + f"but node type {node.type_}'s config has no field {cfg_key}! " + f"Possible keys: {list(config_type.__annotations__.keys())}" + ) from e + + return passed + + @classmethod + def _init_nodes( + cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None + ) -> dict[str, Node]: + """ + Initialize nodes, filling in any computed values from `fill` and `passed` + """ + if passed is None: + passed = {} + + types = Node.node_types() + graph = config.graph() + graph.prepare() + + nodes = {} + while graph.is_active(): + for node in graph.get_ready(): + complete_cfg = cls._complete_node(config.nodes[node], nodes, passed) + nodes[node] = types[complete_cfg.type_].from_specification(complete_cfg) + graph.done(node) + + return nodes + + @classmethod + def _complete_node( + cls, node: NodeSpecification, context: dict[str, Node], passed: dict + ) -> NodeSpecification: + """ + Given the context of already-instantiated nodes and passed values, + complete the configuration. + """ + if node.passed: + for cfg_key, passed_key in node.passed.items(): + node.config[cfg_key] = passed[passed_key] + if node.fill: + for cfg_key, fill_key in node.fill.items(): + parts = fill_key.split(".") + val = context[parts[0]] + for part in parts[1:]: + val = getattr(val, part) + node.config[cfg_key] = val + return node + + @classmethod + def _validate_passed(cls, config: PipelineConfig, passed: dict[str, Any]) -> None: + """ + Ensure that the passed values required by the pipeline config are in fact passed + + Raise ConfigurationMismatchError if missing keys, otherwise do nothing + """ + required = cls.passed_values(config) + for key in required: + if key not in passed: + raise ConfigurationMismatchError( + f"Pipeline config requires these values to be passed:\n" + f"{required}\n" + f"But received passed values:\n" + f"{passed}" + ) + def connect_nodes(nodes: dict[str, Node]) -> dict[str, Node]: """ From 209aaed78d2867b6fabf5a9c13b179584933c036 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 24 Jan 2025 00:22:37 -0800 Subject: [PATCH 5/8] initial draft of synchronous pipeline runner --- .../config/wirefree/wirefree-pipeline.yaml | 6 +- mio/models/pipeline.py | 142 ++++++++--- mio/pipeline/runner.py | 241 ++++++++++++++++++ mio/pipeline/runners.py | 57 ----- mio/sinks/__init__.py | 4 + mio/sinks/{return.py => return_.py} | 19 +- 6 files changed, 373 insertions(+), 96 deletions(-) create mode 100644 mio/pipeline/runner.py delete mode 100644 mio/pipeline/runners.py rename mio/sinks/{return.py => return_.py} (58%) diff --git a/mio/data/config/wirefree/wirefree-pipeline.yaml b/mio/data/config/wirefree/wirefree-pipeline.yaml index 852cc738..9cf5a335 100644 --- a/mio/data/config/wirefree/wirefree-pipeline.yaml +++ b/mio/data/config/wirefree/wirefree-pipeline.yaml @@ -11,9 +11,9 @@ nodes: path: sd_path outputs: - source: header - target: merge + target: merge.header - source: buffer - target: merge + target: merge.buffer merge: type: "merge-buffers" fill: @@ -21,7 +21,7 @@ nodes: height: file.height outputs: - source: frame - target: data + target: return return: config: key: frame diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index 81372daa..b1331260 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,8 +4,9 @@ import sys from abc import abstractmethod +from datetime import datetime from graphlib import TopologicalSorter -from typing import Any, ClassVar, Final, Generic, Optional, TypedDict, TypeVar, Union, final +from typing import Any, ClassVar, Generic, Optional, TypedDict, TypeVar, Union, Unpack, final from pydantic import Field, field_validator, model_validator @@ -21,12 +22,29 @@ """ Input Type typevar """ -U = TypeVar("U") +U = TypeVar("U", bound=dict[str, Any]) """ Output Type typevar """ +class Event(TypedDict, Generic[U]): + """ + Container for a single value returned from a single :meth:`.Node.process` call + """ + + id: int + """Unique ID for each event""" + timestamp: datetime + """Timestamp of when the event was received by the :class:`.PipelineRunner""" + node_id: str + """ID of node that emitted the event""" + slot: str + """name of the slot that emitted the event""" + value: Any + """Value emitted by the processing node""" + + class _NodeMap(TypedDict): source: str target: str @@ -231,9 +249,7 @@ class Node(PipelineModel, Generic[T, U]): config: Optional[NodeConfig] = None input_type: ClassVar[type[T]] - inputs: dict[str, Union["Source", "Transform"]] = Field(default_factory=dict) output_type: ClassVar[type[U]] - outputs: dict[str, Union["Sink", "Transform"]] = Field(default_factory=dict) def start(self) -> None: """ @@ -253,6 +269,11 @@ def stop(self) -> None: """ pass + @abstractmethod + def process(self, **kwargs: Unpack[T]) -> Optional[U]: + """Process some input, emitting it. See subclasses for details""" + pass + @classmethod def from_specification(cls, config: NodeSpecification) -> Self: """ @@ -283,7 +304,6 @@ def node_types(cls) -> dict[str, type["Node"]]: class Source(Node, Generic[T, U]): """A source of data in a processing pipeline""" - inputs: Final[None] = None input_type: ClassVar[None] = None @abstractmethod @@ -305,10 +325,9 @@ class Sink(Node, Generic[T, U]): """A sink of data in a processing pipeline""" output_type: ClassVar[None] = None - outputs: Final[None] = None @abstractmethod - def process(self, data: T) -> None: + def process(self, **kwargs: Unpack[T]) -> None: """ Process some incoming data, returning None @@ -326,7 +345,7 @@ class Transform(Node, Generic[T, U]): """ @abstractmethod - def process(self, data: T) -> U: + def process(self, **kwargs: Unpack[T]) -> U: """ Process some incoming data, yielding a transformed output @@ -339,6 +358,17 @@ def process(self, data: T) -> U: """ +class Edge(PipelineModel): + """ + Directed connection between an output slot a node and an input slot in another node + """ + + source_node: Node + source_slot: Optional[str] = None + target_node: Node + target_slot: Optional[str] = None + + class Pipeline(PipelineModel): """ A graph of nodes transforming some input source(s) to some output sink(s) @@ -352,6 +382,14 @@ class Pipeline(PipelineModel): """ Dictionary mapping all nodes from their ID to the instantiated node. """ + edges: list[Edge] = Field(default_factory=list) + """ + Edges connecting slots within nodes. + + The nodes within :attr:`.Edge.source_node` and :attr:`.Edge.target_node` must + be the same objects as those in :attr:`.Pipeline.nodes` + (i.e. ``edges[0].source_node is nodes[node_id]`` ). + """ @property def sources(self) -> dict[str, "Source"]: @@ -368,6 +406,45 @@ def sinks(self) -> dict[str, "Sink"]: """All :class:`.Sink` nodes in the processing graph""" return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)} + def graph(self) -> TopologicalSorter: + """ + Produce a :class:`.TopologicalSorter` based on the graph induced by + :attr:`.Pipeline.nodes` and :attr:`.Pipeline.edges` that yields node ids + """ + sorter = TopologicalSorter() + for node_id, node in self.nodes.items(): + in_edges = [e.target_node.id for e in self.edges if e.target_node is node] + sorter.add(node_id, *in_edges) + return sorter + + def in_edges(self, node: Union[Node, str]) -> list[Edge]: + """ + Edges going towards the given node (i.e. the node is the edge's ``target`` ) + + Args: + node (:class:`.Node`, str): Either a node or its id + + Returns: + list[:class:`.Edge`] + """ + if isinstance(node, Node): + node = node.id + return [e for e in self.edges if e.target_node.id == node] + + def out_edges(self, node: Union[Node, str]) -> list[Edge]: + """ + Edges going away from the given node (i.e. the node is the edge's ``source`` ) + + Args: + node (:class:`.Node`, str): Either a node or its id + + Returns: + list[:class:`.Edge`] + """ + if isinstance(node, Node): + node = node.id + return [e for e in self.edges if e.source_node.id == node] + @classmethod def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None) -> Self: """ @@ -380,8 +457,9 @@ def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = cls._validate_passed(config, passed) nodes = cls._init_nodes(config, passed) + edges = cls._init_edges(nodes, config.nodes) - return cls(nodes=nodes) + return cls(nodes=nodes, edges=edges) @classmethod def passed_values(cls, config: PipelineConfig) -> dict[str, type]: @@ -435,6 +513,30 @@ def _init_nodes( return nodes + @classmethod + def _init_edges(cls, nodes: dict[str, Node], spec: dict[str, NodeSpecification]) -> list[Edge]: + edges = [] + for node_id, node_spec in spec.items(): + if not node_spec.outputs: + continue + for output in node_spec.outputs: + # FIXME: Ugly and not DRY + target_parts = output["target"].split(".") + target_id, target_slot = ( + (target_parts[0], target_parts[1]) + if len(target_parts) == 2 + else (target_parts[0], None) + ) + edges.append( + Edge( + source_node=nodes[node_id], + target_node=nodes[target_id], + source_slot=output["source"], + target_slot=target_slot, + ) + ) + return edges + @classmethod def _complete_node( cls, node: NodeSpecification, context: dict[str, Node], passed: dict @@ -471,25 +573,3 @@ def _validate_passed(cls, config: PipelineConfig, passed: dict[str, Any]) -> Non f"But received passed values:\n" f"{passed}" ) - - -def connect_nodes(nodes: dict[str, Node]) -> dict[str, Node]: - """ - Provide references to instantiated nodes - """ - - for node in nodes.values(): - if node.config.inputs and node.inputs is None: - raise ConfigurationMismatchError( - "inputs found in node configuration, but node type allows no inputs!\n" - f"node: {node.model_dump()}" - ) - if node.config.outputs and not hasattr(node, "outputs"): - raise ConfigurationMismatchError( - "outputs found in node configuration, but node type allows no outputs!\n" - f"node: {node.model_dump()}" - ) - - node.inputs.update({id: nodes[id] for id in node.config.inputs}) - node.outputs.update({id: nodes[id] for id in node.config.outputs}) - return nodes diff --git a/mio/pipeline/runner.py b/mio/pipeline/runner.py new file mode 100644 index 00000000..365e07a1 --- /dev/null +++ b/mio/pipeline/runner.py @@ -0,0 +1,241 @@ +""" +Pipeline runners for running pipelines +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator, MutableSequence +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import datetime +from itertools import count +from logging import Logger +from typing import TYPE_CHECKING, Any, Optional, Self + +from mio import init_logger +from mio.models import Pipeline +from mio.models.pipeline import Edge, Event, Node, Source + +if TYPE_CHECKING: + from mio.sinks import Return + + +@dataclass +class EventStore: + """ + Container class for storing and retrieving events by node and slot + """ + + events: MutableSequence = field(default_factory=list) + counter: count = field(default_factory=count) + + def add(self, values: dict[str, Any], node_id: str) -> None: + """ + Add the result of a :meth:`.Node.process` call to the event store. + + Split the dictionary of values into separate :class:`.Event` s, + store along with current timestamp + + Args: + values (dict): Dict emitted by a :meth:`.Node.process` call + node_id (str): ID of the node that emitted the events + """ + if values is None: + return + timestamp = datetime.now() + for slot, value in values.items(): + self.events.append( + Event( + id=next(self.counter), + timestamp=timestamp, + node_id=node_id, + slot=slot, + value=value, + ) + ) + + def get(self, node_id: str, slot: str) -> Optional[Event]: + """ + Get the event with the matching node_id and slot name + + Returns the most recent matching event, as for now we assume that + each combination of `node_id` and `slot` is emitted only once per processing cycle, + and we assume processing cycles are independent (and thus our events are cleared) + + ``None`` in the case that the event has not been emitted + """ + event = [e for e in self.events if e["node_id"] == node_id and e["slot"] == slot] + return None if len(event) == 0 else event[-1] + + def gather(self, edges: list[Edge]) -> Optional[dict]: + """ + Gather events into a form that can be consumed by a :meth:`.Node.process` method, + given the collection of inbound edges (usually from :meth:`.Pipeline.in_edges` ). + + If none of the requested events have been emitted, return ``None``. + + If all of the requested events have been emitted, return a kwarg-like dict + + If some of the requested events are missing but others are present, + return ``None`` for any missing events. + + .. todo:: + + Add an example + + """ + ret = {} + for edge in edges: + event = self.get(edge.source_node.id, edge.source_slot) + value = None if event is None else event["value"] + ret[edge.target_slot] = value + + return None if not ret or all(val is None for val in ret.values()) else ret + + def clear(self) -> None: + """ + Clear events for this round of processing. + + Does not reset the counter (to continue giving unique ids to the next round's events) + """ + self.events = [] + + +@dataclass +class PipelineRunner(ABC): + """ + Abstract parent class for pipeline runners. + + Pipeline runners handle calling the nodes and passing the + events returned by them to each other. Each runner may do so + however it needs to (synchronously, asynchronously, alone or as part of a cluster, etc.) + as long as it satisfies this abstract interface. + """ + + pipeline: Pipeline + store: EventStore = field(default_factory=EventStore) + + _logger: Logger = field(default_factory=lambda: init_logger("pipeline.runner")) + + @abstractmethod + def process(self) -> Optional[dict[str, Any]]: + """ + Process one step of data from each of the sources, + passing intermediate data to any subscribed nodes in a chain. + + The `process` method normally does not return anything, + except when using the special :class:`.ReturnSink` node - + if there are :class:`.ReturnSink` nodes in a :class:`.Pipeline` graph, + then each call to `process` will return a dictionary with one key + (from the :class:`.ReturnSink`'s `key` config value) and one value for each + :class:`.ReturnSink`. + """ + + @abstractmethod + def start(self) -> None: + """ + Start processing data with the pipeline graph + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop processing data with the pipeline graph + """ + + def gather_input(self, node: Node) -> Optional[dict[str, Any]]: + """ + Gather input to give to the passed Node from the :attr:`.PipelineRunner.store` + + Returns: + dict: kwargs to pass to :meth:`.Node.process` if matching events are present + dict: empty dict if Node is a :class:`.Source` + None: if no input is available + """ + if isinstance(node, Source): + return {} + + edges = self.pipeline.in_edges(node) + return self.store.gather(edges) + + def gather_return(self) -> Optional[dict]: + """ + If any :class:`.Return` nodes are in the pipeline, + gather their return values to return from :meth:`.PipelineRunner.process` + + Returns: + dict: of the Return sink's key mapped to the returned value, + None: if there are no :class:`.Return` sinks in the pipeline + """ + ret = {} + for sink in self.pipeline.sinks.values(): + if sink.name != "return": + continue + sink: Return + val = sink.get(keep=False) + ret.update(val) + + if not ret: + return None + else: + return ret + + +class SynchronousRunner(PipelineRunner): + """ + Simple, synchronous pipeline runner. + + Just run the nodes in topological order and return from return nodes. + """ + + @contextmanager + def start(self) -> Generator[Self, None, None]: + """ + Start processing data with the pipeline graph. + + Returns a contextmanager that should be used like this: + + .. code-block:: python + + with sync_runner.start() as runner: + output = runner.process() + # do something... + + """ + # TODO: lock for re-entry + try: + for node in self.pipeline.nodes.values(): + node.start() + yield self + finally: + self.stop() + + def stop(self) -> None: + """Stop all nodes processing""" + # TODO: lock to ensure we've been started + for node in self.pipeline.nodes.values(): + node.stop() + + def process(self) -> Optional[dict[str, Any]]: + """ + Iterate through nodes in topological order, + calling their process method and passing events as they are emitted. + """ + self.store.clear() + + graph = self.pipeline.graph() + graph.prepare() + + while graph.is_active(): + for node_id in graph.get_ready(): + node = self.pipeline.nodes[node_id] + node_input = self.gather_input(node) + if node_input is None: + graph.done(node_id) + self._logger.debug(f"Node {node_id} received no input, skipping") + continue + value = node.process(**node_input) + self.store.add(value, node_id) + graph.done(node_id) + self._logger.debug(f"Node {node_id} emitted %s", value) + + return self.gather_return() diff --git a/mio/pipeline/runners.py b/mio/pipeline/runners.py deleted file mode 100644 index fd5de4eb..00000000 --- a/mio/pipeline/runners.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Pipeline runners for running pipelines -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Optional - -from mio.models import Pipeline - - -@dataclass -class PipelineRunner(ABC): - """ - Abstract parent class for pipeline runners. - - Pipeline runners handle calling the nodes and passing the - events returned by them to each other. Each runner may do so - however it needs to (synchronously, asynchronously, alone or as part of a cluster, etc.) - as long as it satisfies this abstract interface. - """ - - pipeline: Pipeline - - @abstractmethod - def process(self) -> Optional[dict[str, Any]]: - """ - Process one step of data from each of the sources, - passing intermediate data to any subscribed nodes in a chain. - - The `process` method normally does not return anything, - except when using the special :class:`.ReturnSink` node - - if there are :class:`.ReturnSink` nodes in a :class:`.Pipeline` graph, - then each call to `process` will return a dictionary with one key - (from the :class:`.ReturnSink`'s `key` config value) and one value for each - :class:`.ReturnSink`. - """ - - @abstractmethod - def start(self) -> None: - """ - Start processing data with the pipeline graph - """ - - @abstractmethod - def stop(self) -> None: - """ - Stop processing data with the pipeline graph - """ - - -class SynchronousRunner(PipelineRunner): - """ - Simple, synchronous pipeline runner. - - Just run the nodes in topological order and return from return nodes. - """ diff --git a/mio/sinks/__init__.py b/mio/sinks/__init__.py index b5ac9db1..f8d4cd7c 100644 --- a/mio/sinks/__init__.py +++ b/mio/sinks/__init__.py @@ -1,3 +1,7 @@ """ Sink pipeline nodes that receive but do not emit events """ + +from mio.sinks.return_ import Return + +__all__ = ["Return"] diff --git a/mio/sinks/return.py b/mio/sinks/return_.py similarity index 58% rename from mio/sinks/return.py rename to mio/sinks/return_.py index 4891553a..7c623ea7 100644 --- a/mio/sinks/return.py +++ b/mio/sinks/return_.py @@ -2,7 +2,7 @@ Special Return sink that pipeline runners use to return values from :meth:`.PipelineRunner.process` """ -from typing import Any, TypedDict +from typing import Any, Optional, TypedDict from mio.models.pipeline import Sink, T @@ -28,14 +28,23 @@ class Return(Sink): _value: Any = None - def process(self, data: T) -> None: + def process(self, value: T) -> None: """ Store the incoming value to retrieve later with :meth:`.get` """ - self._value = data + self._value = value - def get(self) -> dict[str, T]: + def get(self, keep: bool = False) -> Optional[dict[str, T]]: """ Get the stored value from the process call + + Args: + keep (bool): If ``True``, keep the stored value, otherwise clear it, consume it """ - return {self.config["key"]: self._value} + if self._value is None: + return None + else: + val = {self.config["key"]: self._value} + if not keep: + self._value = None + return val From 24a1a5214151764a4a7ce61da9fa3186e3aab8bf Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 24 Jan 2025 17:33:37 -0800 Subject: [PATCH 6/8] TypedDict from typing_extensions --- mio/models/pipeline.py | 10 +++++++--- mio/sinks/return_.py | 8 +++++++- mio/sources/file.py | 9 +++++++-- mio/transforms/frame.py | 9 +++++++-- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index b1331260..16ec5033 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -6,7 +6,7 @@ from abc import abstractmethod from datetime import datetime from graphlib import TopologicalSorter -from typing import Any, ClassVar, Generic, Optional, TypedDict, TypeVar, Union, Unpack, final +from typing import Any, ClassVar, Generic, Optional, TypeVar, Union, Unpack, final from pydantic import Field, field_validator, model_validator @@ -14,10 +14,14 @@ from mio.models.models import MiniscopeConfig, PipelineModel if sys.version_info < (3, 11): - from typing_extensions import Self -else: + from typing_extensions import Self, TypedDict +elif sys.version_info < (3, 12): from typing import Self + from typing_extensions import TypedDict +else: + from typing import Self, TypedDict + T = TypeVar("T") """ Input Type typevar diff --git a/mio/sinks/return_.py b/mio/sinks/return_.py index 7c623ea7..c97a603e 100644 --- a/mio/sinks/return_.py +++ b/mio/sinks/return_.py @@ -2,10 +2,16 @@ Special Return sink that pipeline runners use to return values from :meth:`.PipelineRunner.process` """ -from typing import Any, Optional, TypedDict +import sys +from typing import Any, Optional from mio.models.pipeline import Sink, T +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + class ReturnConfig(TypedDict): """ diff --git a/mio/sources/file.py b/mio/sources/file.py index c2636053..21f3801c 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -1,9 +1,9 @@ """ File-based data sources """ - +import sys from pathlib import Path -from typing import BinaryIO, ClassVar, Optional, TypedDict +from typing import BinaryIO, ClassVar, Optional import numpy as np from pydantic import Field @@ -12,6 +12,11 @@ from mio.models.pipeline import Source from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + class FileSource(Source): """ diff --git a/mio/transforms/frame.py b/mio/transforms/frame.py index ffaec9ba..ba2dfb4e 100644 --- a/mio/transforms/frame.py +++ b/mio/transforms/frame.py @@ -1,8 +1,8 @@ """ Nodes that receive and emit frames """ - -from typing import Optional, TypedDict +import sys +from typing import Optional import numpy as np @@ -10,6 +10,11 @@ from mio.models.data import Frame from mio.models.pipeline import Transform +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + class MergeBuffersConfig(TypedDict): """Configuration for :class:`.MergeBuffers`""" From 42f8c9701f3e2a770493410b8669b17b60115abc Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 24 Jan 2025 21:55:17 -0800 Subject: [PATCH 7/8] renaming config to metadata and tidying up wirefree --- mio/data/config/wirefree/wirefree.yaml | 3 + mio/devices/device.py | 19 +++-- mio/devices/wirefree.py | 103 ++++++++++++++----------- mio/models/mixins.py | 36 ++++++++- mio/models/sdcard.py | 20 ++--- mio/sources/__init__.py | 4 + mio/sources/file.py | 53 +------------ mio/transforms/frame.py | 1 + notebooks/Wire-Free-DAQ.ipynb | 2 +- tests/data/config/reference_a.yaml | 3 + tests/data/config/reference_b.yaml | 2 + tests/test_devices/test_wirefree.py | 10 +-- tests/test_mixins.py | 22 ++++++ tests/test_sdcard.py | 4 +- 14 files changed, 161 insertions(+), 121 deletions(-) create mode 100644 mio/data/config/wirefree/wirefree.yaml create mode 100644 tests/data/config/reference_a.yaml create mode 100644 tests/data/config/reference_b.yaml diff --git a/mio/data/config/wirefree/wirefree.yaml b/mio/data/config/wirefree/wirefree.yaml new file mode 100644 index 00000000..04f904b1 --- /dev/null +++ b/mio/data/config/wirefree/wirefree.yaml @@ -0,0 +1,3 @@ +id: wirefree-default +pipeline: wirefree-pipeline +layout: wirefree-sd-layout \ No newline at end of file diff --git a/mio/devices/device.py b/mio/devices/device.py index b74f9107..8ebd28f4 100644 --- a/mio/devices/device.py +++ b/mio/devices/device.py @@ -5,9 +5,10 @@ import sys from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional from mio.models import MiniscopeConfig, Pipeline, PipelineConfig +from mio.models.mixins import ConfigYAMLMixin if TYPE_CHECKING: from mio.models.pipeline import Sink, Source, Transform @@ -18,13 +19,11 @@ from typing import Self -class DeviceConfig(MiniscopeConfig): +class DeviceConfig(MiniscopeConfig, ConfigYAMLMixin): """ Abstract base class for device configuration """ - id: Union[str, int] - """(Locally) unique identifier for this device""" pipeline: PipelineConfig = PipelineConfig() @@ -36,8 +35,9 @@ class Device: Currently a placeholder to allow room for expansion/renaming in the future """ - pipeline: Optional[Pipeline] = None - # config: Optional[DeviceConfig] = None + config: Optional[DeviceConfig] = None + + _pipeline: Optional[Pipeline] = None @abstractmethod def init(self) -> None: @@ -147,6 +147,13 @@ def sinks(self) -> dict[str, "Sink"]: """Convenience method to access :attr:`.Pipeline.sinks`""" return self.pipeline.sinks + @property + def pipeline(self) -> Pipeline: + """Instantiated Pipeline from pipeline config""" + if self._pipeline is None: + self._pipeline = Pipeline.from_config(self.config.pipeline) + return self._pipeline + @classmethod def from_config(cls, config: DeviceConfig) -> Self: """ diff --git a/mio/devices/wirefree.py b/mio/devices/wirefree.py index a33b4126..c31b337a 100644 --- a/mio/devices/wirefree.py +++ b/mio/devices/wirefree.py @@ -16,14 +16,9 @@ from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.data import Frame from mio.models.pipeline import PipelineConfig -from mio.models.sdcard import SDConfig, SDLayout -from mio.types import ConfigSource, Resolution - - -class WireFreeConfig(DeviceConfig): - """Configuration for wire free miniscope""" - - pass +from mio.models.sdcard import SDLayout, SDMetadata +from mio.sources import SDFileSource +from mio.types import Resolution class WireFreePipeline(PipelineConfig): @@ -34,6 +29,13 @@ class WireFreePipeline(PipelineConfig): } +class WireFreeConfig(DeviceConfig): + """Configuration for wire free miniscope""" + + pipeline: WireFreePipeline = "wirefree-pipeline" + layout: SDLayout = "wirefree-sd-layout" + + @dataclass(kw_only=True) class WireFreeMiniscope(Miniscope, RecordingCameraMixin): """ @@ -45,15 +47,15 @@ class WireFreeMiniscope(Miniscope, RecordingCameraMixin): Args: drive (str, :class:`pathlib.Path`): Path to the SD card drive - layout (:class:`.sdcard.SDLayout`): A layout configuration for an SD card + config (:class:`.WireFreeConfig`): Configuration, + including data layout and pipeline configs """ drive: Path """The path to the SD card drive""" - # config: WireFreeConfig - # """Configuration """ - layout: Union[SDLayout, ConfigSource] = "wirefree-sd-layout" + config: WireFreeConfig = field(default_factory=WireFreeConfig) + positions: dict[int, int] = field(default_factory=dict) """ A mapping between frame number and byte position in the video that makes for @@ -66,11 +68,11 @@ class WireFreeMiniscope(Miniscope, RecordingCameraMixin): def __post_init__(self) -> None: """post-init create private vars""" - self.layout = SDLayout.from_any(self.layout) + self.layout = SDLayout.from_any(self.config.layout) self.logger = init_logger("WireFreeMiniscope") # Private attributes used when the file reading context is entered - self._config = None # type: Optional[SDConfig] + self._metadata = None # type: Optional[SDMetadata] self._frame = None # type: Optional[int] self._frame_count = None # type: Optional[int] self._array = None # type: Optional[np.ndarray] @@ -83,24 +85,24 @@ def __post_init__(self) -> None: # -------------------------------------------------- @property - def config(self) -> SDConfig: + def metadata(self) -> SDMetadata: """ - Read configuration from SD Card + Read metadata from SD Card """ - if self._config is None: + if self._metadata is None: with open(self.drive, "rb") as sd: sd.seek(self.layout.sectors.config_pos, 0) configSectorData = np.frombuffer(sd.read(self.layout.sectors.size), dtype=np.uint32) - self._config = SDConfig( + self._metadata = SDMetadata( **{ k: configSectorData[v] - for k, v in self.layout.config.model_dump().items() + for k, v in self.layout.metadata.model_dump().items() if v is not None } ) - return self._config + return self._metadata @classmethod def configure(cls, drive: Union[str, Path], config: WireFreeConfig) -> None: @@ -166,6 +168,34 @@ def frame(self, frame: int) -> None: for _ in range(frame - self.frame): self.skip() + @property + def buffers_per_frame(self) -> int: + """ + Number of buffers per frame! + + References: + https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/blob/786663781a4bece89c89e00fc3ac9d95912faba4/Miniscope-v4-Wire-Free-MCU-Firmware/Miniscope-v4-Wire-Free/Miniscope-v4-Wire-Free/main.c#L680 + """ + n_pix = self.metadata.width * self.metadata.height + return int(np.ceil((n_pix + self._source.header_size) / (self.metadata.buffer_size))) + + @property + def frame_count(self) -> int: + """ + Total number of frames in the recording + """ + return int( + np.ceil( + (self.metadata.n_buffers_recorded + self.metadata.n_buffers_dropped) + / self.buffers_per_frame + ) + ) + + @property + def _source(self) -> SDFileSource: + """The SDFileSource node in the pipeline""" + return self.pipeline.nodes["sdcard"] + # -------------------------------------------------- # Context Manager methods # -------------------------------------------------- @@ -176,7 +206,7 @@ def __enter__(self) -> "WireFreeMiniscope": # init private attrs # create an empty frame to hold our data! - self._array = np.zeros((self.config.width * self.config.height, 1), dtype=np.uint8) + self._array = np.zeros((self.metadata.width * self.metadata.height, 1), dtype=np.uint8) self._pixel_count = 0 self._last_buffer_n = 0 self._frame = 0 @@ -194,27 +224,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 self._f = None self._frame = 0 - def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: - """ - Trim or pad an array to match an expected size - """ - if data.shape[0] != expected_size: - self.logger.warning( - f"Frame: {self._frame}: Expected buffer data length: {expected_size}, " - f"got data with shape {data.shape}. " - "Padding to expected length", - stacklevel=1, - ) - - # trim if too long - if data.shape[0] > expected_size: - data = data[0:expected_size] - # pad if too short - else: - data = np.pad(data, (0, expected_size - data.shape[0])) - - return data - @overload def read(self, return_header: Literal[True] = True) -> Frame: ... @@ -279,7 +288,7 @@ def read(self, return_header: bool = False) -> Union[np.ndarray, Frame]: self._f.seek(last_position, 0) self._frame += 1 self.positions[self._frame] = last_position - frame = np.reshape(self._array, (self.config.width, self.config.height)) + frame = np.reshape(self._array, (self.metadata.width, self.metadata.height)) if return_header: return Frame.model_construct(frame=frame, headers=headers) else: @@ -344,8 +353,8 @@ def to_video( writer = cv2.VideoWriter( str(path), cv2.VideoWriter_fourcc(*fourcc), - self.config.fs, - (self.config.width, self.config.height), + self.metadata.fs, + (self.metadata.width, self.metadata.height), isColor=isColor, ) @@ -542,12 +551,12 @@ def excitation(self) -> float: @property def fps(self) -> int: """FPS""" - return self.config.fs + return self.metadata.fs @property def resolution(self) -> Resolution: """Resolution of recorded video""" - return Resolution(self.config.width, self.config.height) + return Resolution(self.metadata.width, self.metadata.height) def get(self, key: str) -> Any: """get a configuration value by its name""" diff --git a/mio/models/mixins.py b/mio/models/mixins.py index fd09a240..0cd27133 100644 --- a/mio/models/mixins.py +++ b/mio/models/mixins.py @@ -11,11 +11,20 @@ from typing import Any, ClassVar, List, Literal, Optional, Type, TypeVar, Union, overload import yaml -from pydantic import BaseModel, Field, ValidationError, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + GetCoreSchemaHandler, + ValidationError, + field_validator, +) +from pydantic_core import core_schema from mio.types import ConfigID, ConfigSource, PythonIdentifier, valid_config_id T = TypeVar("T") +"""Generic type of ConfigYamlMixin subclass""" class YamlDumper(yaml.SafeDumper): @@ -80,6 +89,8 @@ class ConfigYAMLMixin(BaseModel, YAMLMixin): at the top of the file. """ + model_config = ConfigDict(validate_default=True) + id: ConfigID mio_model: PythonIdentifier = Field(None, validate_default=True) mio_version: str = version("mio") @@ -258,6 +269,29 @@ def _complete_header(cls: Type[T], data: dict, file_path: Union[str, Path]) -> d return data + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """ + Add before_validator to allow instantiation from id + """ + + def _from_id(value: Any) -> cls: + if isinstance(value, str): + return cls.from_id(value) + else: + return value + + return core_schema.no_info_before_validator_function( + _from_id, + handler(source_type), + # TODO: add this when updating pydantic floor to 2.10 + # json_schema_input_schema=core_schema.union_schema( + # [handler(source_type), handler(ConfigID)] + # ), + ) + @overload def yaml_peek( diff --git a/mio/models/sdcard.py b/mio/models/sdcard.py index d308163b..2f40370e 100644 --- a/mio/models/sdcard.py +++ b/mio/models/sdcard.py @@ -19,7 +19,7 @@ class SectorConfig(MiniscopeConfig): Examples: - >>> sectors = SectorConfig(header=1023, config=1024, data=1025, size=512) + >>> sectors = SectorConfig(header=1023, metadata=1024, data=1025, size=512) >>> sectors.header 1023 >>> # should be 1023 * 512 @@ -33,7 +33,7 @@ class SectorConfig(MiniscopeConfig): """ Holds user settings to configure Miniscope and recording """ - config: int = 1024 + metadata: int = 1024 """ Holds final settings of the actual recording """ @@ -53,8 +53,8 @@ def header_pos(self) -> int: @property def config_pos(self) -> int: - """config * sector size""" - return self.config * self.size + """metadata * sector size""" + return self.metadata * self.size @property def data_pos(self) -> int: @@ -62,9 +62,9 @@ def data_pos(self) -> int: return self.data * self.size -class ConfigPositions(MiniscopeConfig): +class MetadataPositions(MiniscopeConfig): """ - Image acquisition configuration positions + Image acquisition metadata positions """ width: int = 0 @@ -124,7 +124,7 @@ class SDLayout(MiniscopeConfig, ConfigYAMLMixin): sectors: SectorConfig header: SDHeaderPositions = SDHeaderPositions() - config: ConfigPositions = ConfigPositions() + metadata: MetadataPositions = MetadataPositions() buffer: SDBufferHeaderFormat = SDBufferHeaderFormat() header_dtype: str = "uint32" """ @@ -136,11 +136,11 @@ class SDLayout(MiniscopeConfig, ConfigYAMLMixin): """ -class SDConfig(MiniscopeConfig): +class SDMetadata(MiniscopeConfig): """ - The configuration of a recording taken on this SD card. + The metadata of a recording taken on this SD card. - Read from the locations given in :class:`.ConfigPositions` + Read from the locations given in :class:`.MetadataPositions` for an SD card with a given :class:`.SDLayout` """ diff --git a/mio/sources/__init__.py b/mio/sources/__init__.py index a8191faa..61edb55f 100644 --- a/mio/sources/__init__.py +++ b/mio/sources/__init__.py @@ -1,3 +1,7 @@ """ Source pipeline nodes that emit but do not receive events """ + +from mio.sources.file import SDFileSource + +__all__ = ["SDFileSource"] diff --git a/mio/sources/file.py b/mio/sources/file.py index 21f3801c..772fb077 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -1,6 +1,7 @@ """ File-based data sources """ + import sys from pathlib import Path from typing import BinaryIO, ClassVar, Optional @@ -10,7 +11,7 @@ from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.pipeline import Source -from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout +from mio.models.sdcard import SDBufferHeader, SDLayout if sys.version_info < (3, 12): from typing_extensions import TypedDict @@ -78,7 +79,7 @@ class SDFileSource(FileSource): """ Structured binary file that has - * a global header with config values + * a global header with metadata values * a series of buffers, each containing a * buffer header - with metadata for that buffer and @@ -86,7 +87,7 @@ class SDFileSource(FileSource): The source thus has two configurations - * the ``config`` - getter and setter for the actual configuration values of the source + * the ``metadata`` - getter and setter for the actual configuration values of the source * the ``layout`` - how the configuration and data are laid out within the file. """ @@ -98,7 +99,6 @@ class SDFileSource(FileSource): layout: SDLayout _f: Optional[BinaryIO] = None - _config: SDConfig = None _positions: dict[int, int] = Field(default_factory=dict) """ A mapping between frame number and byte position in the video that makes for @@ -111,28 +111,6 @@ class SDFileSource(FileSource): _last_buffer: int = None _frame: int = 0 - @property - def config(self) -> SDConfig: - """ - Global configuration of the whole SD card - """ - if self._config is None: - with open(self.path, "rb") as sd: - sd.seek(self.layout.sectors.config_pos, 0) - configSectorData = np.frombuffer( - sd.read(self.layout.sectors.size), dtype=np.dtype(self.layout.header_dtype) - ) - - self._config = SDConfig( - **{ - k: configSectorData[v] - for k, v in self.layout.config.model_dump().items() - if v is not None - } - ) - - return self._config - @property def width(self) -> int: """width of the captured video in pixels""" @@ -162,29 +140,6 @@ def header_size(self) -> int: max([v for v in self.layout.buffer.model_dump().values() if v is not None]) + 1 ) * self.layout.word_size - @property - def buffers_per_frame(self) -> int: - """ - Number of buffers per frame! - - References: - https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/blob/786663781a4bece89c89e00fc3ac9d95912faba4/Miniscope-v4-Wire-Free-MCU-Firmware/Miniscope-v4-Wire-Free/Miniscope-v4-Wire-Free/main.c#L680 - """ - n_pix = self.config.width * self.config.height - return int(np.ceil((n_pix + self.header_size) / (self.config.buffer_size))) - - @property - def frame_count(self) -> int: - """ - Total number of frames in the recording - """ - return int( - np.ceil( - (self.config.n_buffers_recorded + self.config.n_buffers_dropped) - / self.buffers_per_frame - ) - ) - def start(self) -> None: """Open the file, seek to the offset""" self._last_buffer = 0 diff --git a/mio/transforms/frame.py b/mio/transforms/frame.py index ba2dfb4e..c72cc8aa 100644 --- a/mio/transforms/frame.py +++ b/mio/transforms/frame.py @@ -1,6 +1,7 @@ """ Nodes that receive and emit frames """ + import sys from typing import Optional diff --git a/notebooks/Wire-Free-DAQ.ipynb b/notebooks/Wire-Free-DAQ.ipynb index 11c06957..04211735 100644 --- a/notebooks/Wire-Free-DAQ.ipynb +++ b/notebooks/Wire-Free-DAQ.ipynb @@ -97,7 +97,7 @@ "f = open(driveName, \"rb\") # Open drive\n", "\n", "# Make sure this is the correct drive\n", - "# Read SD Card header and config sectors\n", + "# Read SD Card header and metadata sectors\n", "f.seek(headerSector * sectorSize, 0) # Move to correct sector\n", "headerSectorData = np.fromstring(f.read(sectorSize), dtype=np.uint32)\n", "if ((WRITE_KEY0 == headerSectorData[0]) and (WRITE_KEY1 == headerSectorData[1]) and (WRITE_KEY2 == headerSectorData[2]) and (WRITE_KEY3 == headerSectorData[3])):\n", diff --git a/tests/data/config/reference_a.yaml b/tests/data/config/reference_a.yaml new file mode 100644 index 00000000..0828ee26 --- /dev/null +++ b/tests/data/config/reference_a.yaml @@ -0,0 +1,3 @@ +id: reference-a +value: ["hey", "sup"] +config: reference-b \ No newline at end of file diff --git a/tests/data/config/reference_b.yaml b/tests/data/config/reference_b.yaml new file mode 100644 index 00000000..34d2cf46 --- /dev/null +++ b/tests/data/config/reference_b.yaml @@ -0,0 +1,2 @@ +id: reference-b +value: [1,2,3] \ No newline at end of file diff --git a/tests/test_devices/test_wirefree.py b/tests/test_devices/test_wirefree.py index 73074cd8..0ce3adeb 100644 --- a/tests/test_devices/test_wirefree.py +++ b/tests/test_devices/test_wirefree.py @@ -39,8 +39,8 @@ def test_read(wirefree): # the frame is the right shape assert len(frame.shape) == 2 - assert frame.shape[0] == wirefree.config.height - assert frame.shape[1] == wirefree.config.width + assert frame.shape[0] == wirefree.metadata.height + assert frame.shape[1] == wirefree.metadata.width # assert they're not all zeros - ie. we read some data assert frame.any() @@ -99,8 +99,8 @@ def test_relative_path(): assert not rel_path.is_absolute() sdcard = WireFreeMiniscope(drive=rel_path, layout="wirefree-sd-layout") - # check we can do something basic like read config - assert sdcard.config is not None + # check we can do something basic like read metadata + assert sdcard.metadata is not None # check it remains relative after init assert not sdcard.drive.is_absolute() @@ -109,7 +109,7 @@ def test_relative_path(): abs_path = rel_path.resolve() assert abs_path.is_absolute() sdcard_abs = WireFreeMiniscope(drive=abs_path, layout="wirefree-sd-layout") - assert sdcard_abs.config is not None + assert sdcard_abs.metadata is not None assert sdcard_abs.drive.is_absolute() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 7976ecbe..1260a7d4 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -155,3 +155,25 @@ def test_peek_yaml(key, expected, root, first, yaml_config): _ = yaml_peek(key, yaml_file, root=root, first=first) else: assert yaml_peek(key, yaml_file, root=root, first=first) == expected + + +def test_yamlmixin_core_schema(): + """ + The __get_pydantic_core_schema__ method in the ConfigYamlMixin + lets us use ids for keys everywhere + """ + + class B(ConfigYAMLMixin): + value: list[int] + + class A(ConfigYAMLMixin): + value: list[str] + config: B + + class Container(BaseModel): + model: A + + instance = Container(model="reference-a") + assert isinstance(instance.model.config, B) + assert instance.model.value == ["hey", "sup"] + assert instance.model.config.value == [1, 2, 3] diff --git a/tests/test_sdcard.py b/tests/test_sdcard.py index 7a3a8338..2854d3dc 100644 --- a/tests/test_sdcard.py +++ b/tests/test_sdcard.py @@ -8,7 +8,7 @@ def random_sectorconfig(): return SectorConfig( header=np.random.randint(0, 2048), - config=np.random.randint(0, 2048), + metadata=np.random.randint(0, 2048), data=np.random.randint(0, 2048), size=np.random.randint(0, 2048), ) @@ -20,7 +20,7 @@ def test_get_sector_position(random_sectorconfig): """ sectors = random_sectorconfig assert sectors.header_pos == sectors.header * sectors.size - assert sectors.config_pos == sectors.config * sectors.size + assert sectors.config_pos == sectors.metadata * sectors.size assert sectors.data_pos == sectors.data * sectors.size # We should raise an attribute error if we try and get a nonexistent one From a8ec180bb138dcfb765f19bda0ab143607d0a358 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 30 Jan 2025 02:28:32 -0800 Subject: [PATCH 8/8] separate pipeline mixin from device class, jittering towards runnable wirefree class, cleaning up as i go. --- .../config/wirefree/wirefree-pipeline.yaml | 2 +- mio/data/config/wirefree/wirefree.yaml | 4 +- mio/devices/device.py | 28 +--- mio/devices/wirefree.py | 133 ++++-------------- mio/exceptions.py | 10 ++ mio/models/__init__.py | 2 + mio/models/mixins.py | 23 +-- mio/models/pipeline.py | 125 +++++++++++++--- mio/pipeline/runner.py | 56 +++++--- mio/sources/__init__.py | 5 +- mio/sources/file.py | 108 ++++++++------ mio/transforms/__init__.py | 4 + tests/fixtures.py | 6 +- tests/test_devices/test_wirefree.py | 4 +- 14 files changed, 290 insertions(+), 220 deletions(-) diff --git a/mio/data/config/wirefree/wirefree-pipeline.yaml b/mio/data/config/wirefree/wirefree-pipeline.yaml index 9cf5a335..f5020709 100644 --- a/mio/data/config/wirefree/wirefree-pipeline.yaml +++ b/mio/data/config/wirefree/wirefree-pipeline.yaml @@ -3,7 +3,7 @@ mio_model: mio.devices.wirefree.WireFreePipeline mio_version: v0.6.0 nodes: - file: + sdcard: type: "sd-file-source" config: layout: "wirefree-sd-layout" diff --git a/mio/data/config/wirefree/wirefree.yaml b/mio/data/config/wirefree/wirefree.yaml index 04f904b1..df468490 100644 --- a/mio/data/config/wirefree/wirefree.yaml +++ b/mio/data/config/wirefree/wirefree.yaml @@ -1,3 +1,5 @@ id: wirefree-default +mio_model: mio.devices.wirefree.WireFreeConfig +mio_version: 0.6.1 pipeline: wirefree-pipeline -layout: wirefree-sd-layout \ No newline at end of file +layout: wirefree-sd-layout diff --git a/mio/devices/device.py b/mio/devices/device.py index 8ebd28f4..d6ceb381 100644 --- a/mio/devices/device.py +++ b/mio/devices/device.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional -from mio.models import MiniscopeConfig, Pipeline, PipelineConfig +from mio.models import MiniscopeConfig, Pipeline, PipelineConfig, PipelineMixin from mio.models.mixins import ConfigYAMLMixin if TYPE_CHECKING: @@ -28,7 +28,7 @@ class DeviceConfig(MiniscopeConfig, ConfigYAMLMixin): @dataclass(kw_only=True) -class Device: +class Device(PipelineMixin): """ Abstract base class for devices. @@ -37,8 +37,6 @@ class Device: config: Optional[DeviceConfig] = None - _pipeline: Optional[Pipeline] = None - @abstractmethod def init(self) -> None: """ @@ -132,28 +130,6 @@ def set(self, key: str, value: Any) -> None: key (str): The name of the value to get """ - @property - def sources(self) -> dict[str, "Source"]: - """Convenience method to access :attr:`.Pipeline.sources`""" - return self.pipeline.sources - - @property - def transforms(self) -> dict[str, "Transform"]: - """Convenience method to access :attr:`.Pipeline.transforms`""" - return self.pipeline.transforms - - @property - def sinks(self) -> dict[str, "Sink"]: - """Convenience method to access :attr:`.Pipeline.sinks`""" - return self.pipeline.sinks - - @property - def pipeline(self) -> Pipeline: - """Instantiated Pipeline from pipeline config""" - if self._pipeline is None: - self._pipeline = Pipeline.from_config(self.config.pipeline) - return self._pipeline - @classmethod def from_config(cls, config: DeviceConfig) -> Self: """ diff --git a/mio/devices/wirefree.py b/mio/devices/wirefree.py index c31b337a..6833a14d 100644 --- a/mio/devices/wirefree.py +++ b/mio/devices/wirefree.py @@ -14,6 +14,7 @@ from mio import init_logger from mio.devices import DeviceConfig, Miniscope, RecordingCameraMixin from mio.exceptions import EndOfRecordingException, ReadHeaderException +from mio.models import Pipeline from mio.models.data import Frame from mio.models.pipeline import PipelineConfig from mio.models.sdcard import SDLayout, SDMetadata @@ -25,7 +26,8 @@ class WireFreePipeline(PipelineConfig): """Base skeleton pipeline for the wirefree miniscope""" required_nodes = { - "sdcard": "sd-file-source", + "sdcard": {"type": "sd-file-source"}, + "return": {"type": "return", "config": {"key": "frame"}}, } @@ -68,6 +70,9 @@ class WireFreeMiniscope(Miniscope, RecordingCameraMixin): def __post_init__(self) -> None: """post-init create private vars""" + if isinstance(self.config, str): + self.config = WireFreeConfig.from_any(self.config) + self.layout = SDLayout.from_any(self.config.layout) self.logger = init_logger("WireFreeMiniscope") @@ -117,10 +122,7 @@ def position(self) -> Optional[int]: When entered as context manager, the current position of the internal file descriptor """ - if self._f is None: - return None - - return self._f.tell() + return self._source.position @property def frame(self) -> Optional[int]: @@ -201,106 +203,24 @@ def _source(self) -> SDFileSource: # -------------------------------------------------- def __enter__(self) -> "WireFreeMiniscope": - if self._f is not None: - raise RuntimeError("Cant enter context, and open the file twice!") - - # init private attrs - # create an empty frame to hold our data! - self._array = np.zeros((self.metadata.width * self.metadata.height, 1), dtype=np.uint8) - self._pixel_count = 0 - self._last_buffer_n = 0 - self._frame = 0 - - self._f = open(self.drive, "rb") # noqa: SIM115 - this is a context handler - # seek to the start of the data - self._f.seek(self.layout.sectors.data_pos, 0) - # store the 0th frame position - self.positions[0] = self.layout.sectors.data_pos - + self.runner.start() return self def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 - self._f.close() - self._f = None - self._frame = 0 - - @overload - def read(self, return_header: Literal[True] = True) -> Frame: ... + self.runner.stop() - @overload - def read(self, return_header: Literal[False] = False) -> np.ndarray: ... - - def read(self, return_header: bool = False) -> Union[np.ndarray, Frame]: + def read(self) -> Frame: """ - Read a single frame - - Arguments: - return_header (bool): If `True`, return headers from individual buffers - (default `False`) - - Return: - :class:`numpy.ndarray` , - or a tuple(ndarray, List[:class:`~.SDBufferHeader`]) if `return_header` - is `True` + Read a single :class:`.Frame` """ - if self._f is None: - raise RuntimeError( - "File is not open! Try entering the reader context by using it like " - "`with sdcard:`" - ) + if not self.runner.running: + self.logger.debug("Starting runner implicitly by calling read") + self.runner.start() - self._array[:] = 0 - pixel_count = 0 - last_buffer_n = 0 - headers = [] - while True: - # stash position before reading header - last_position = self._f.tell() - try: - header = self._read_data_header(self._f) - except ValueError as e: - if "read length must be non-negative" in str(e): - # end of file! Value error thrown because the dataHeader will be - # blank, and thus have a value of 0 for the header size, and we - # can't read 0 from the card. - self._f.seek(last_position, 0) - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - except IndexError as e: - if "index 0 is out of bounds for axis 0 with size 0" in str(e): - # end of file if we are reading from a disk image without any - # additional space on disk - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - except ReadHeaderException as e: - # if we are on the last frame, normal! signal end of iteration - if self._frame == self.frame_count - 1: - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - - if header.frame_buffer_count == 0 and last_buffer_n > 0: - # we are in the next frame! - # rewind to the beginning of the header, and return - # the last_position is the start of the header for this frame - self._f.seek(last_position, 0) - self._frame += 1 - self.positions[self._frame] = last_position - frame = np.reshape(self._array, (self.metadata.width, self.metadata.height)) - if return_header: - return Frame.model_construct(frame=frame, headers=headers) - else: - return frame - - # grab buffer data and stash - headers.append(header) - data = self._read_buffer(self._f, header) - data = self._trim(data, header.data_length) - self._array[pixel_count : pixel_count + header.data_length, 0] = data - pixel_count += header.data_length - last_buffer_n = header.frame_buffer_count + res = None + while res is None: + res: Optional[dict[Literal["frame"], Frame]] = self.runner.process() + return res["frame"] # -------------------------------------------------- # Write methods @@ -367,8 +287,8 @@ def to_video( while True: # this is sort of an awkward stack, should probably make a # generator version of `read` - frame = self.read(return_header=False) - writer.write(frame) + frame = self.read() + writer.write(frame.frame) if progress: pbar.update() @@ -533,16 +453,25 @@ def deinit(self) -> None: def start(self) -> None: """start pipeline""" - raise NotImplementedError() + self.runner.start() def stop(self) -> None: """stop pipeline""" - raise NotImplementedError() + self.runner.stop() def join(self) -> None: """join pipeline""" raise NotImplementedError() + @property + def pipeline(self) -> Pipeline: + """Create pipeline, passing needed values""" + if self._pipeline is None: + self._pipeline = Pipeline.from_config( + self.config.pipeline, passed={"sd_path": self.drive} + ) + return self._pipeline + @property def excitation(self) -> float: """LED Excitation""" diff --git a/mio/exceptions.py b/mio/exceptions.py index 1c7343a9..767a8392 100644 --- a/mio/exceptions.py +++ b/mio/exceptions.py @@ -66,3 +66,13 @@ class ConfigurationMismatchError(ConfigurationError): """ Mismatch between the fields in some config model and the fields in the model it is configuring """ + +class PipelineError(Exception): + """ + Base exception type for pipeline errors + """ + +class PipelineRunningError(PipelineError, RuntimeError): + """ + A pipeline has been requested to start, but it is already started! + """ \ No newline at end of file diff --git a/mio/models/__init__.py b/mio/models/__init__.py index 815b6156..2d57b5b5 100644 --- a/mio/models/__init__.py +++ b/mio/models/__init__.py @@ -12,6 +12,7 @@ Node, Pipeline, PipelineConfig, + PipelineMixin, Sink, Source, Transform, @@ -24,6 +25,7 @@ "Node", "Pipeline", "PipelineConfig", + "PipelineMixin", "PipelineModel", "Transform", "Sink", diff --git a/mio/models/mixins.py b/mio/models/mixins.py index 0cd27133..3031eb4b 100644 --- a/mio/models/mixins.py +++ b/mio/models/mixins.py @@ -3,6 +3,7 @@ to use composition for functionality and inheritance for semantics. """ +import pdb import re import shutil from importlib.metadata import version @@ -91,7 +92,7 @@ class ConfigYAMLMixin(BaseModel, YAMLMixin): model_config = ConfigDict(validate_default=True) - id: ConfigID + id: Optional[ConfigID] = None mio_model: PythonIdentifier = Field(None, validate_default=True) mio_version: str = version("mio") @@ -131,22 +132,26 @@ def from_id(cls: Type[T], id: ConfigID) -> T: """ globs = [src.rglob("*.y*ml") for src in cls.config_sources] + for config_file in chain(*globs): + # if id == "wirefree-pipeline": + # pdb.set_trace() try: file_id = yaml_peek("id", config_file) - if file_id == id: - from mio.logging import init_logger - - init_logger("config").debug( - "Model for %s found at %s", cls._model_name(), config_file - ) - return cls.from_yaml(config_file) except KeyError: continue + if file_id == id: + from mio.logging import init_logger + + init_logger("config").debug( + "Model for %s found at %s", cls._model_name(), config_file + ) + return cls.from_yaml(config_file) + from mio import Config - raise KeyError(f"No config with id {id} found in {Config().config_dir}") + raise KeyError(f"No config with id {id} found in {cls.config_sources}") @classmethod def from_any(cls: Type[T], source: Union[ConfigSource, T]) -> T: diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index 16ec5033..21b1ac08 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,23 +4,39 @@ import sys from abc import abstractmethod +from dataclasses import dataclass from datetime import datetime from graphlib import TopologicalSorter -from typing import Any, ClassVar, Generic, Optional, TypeVar, Union, Unpack, final - -from pydantic import Field, field_validator, model_validator +from typing import ( + Any, + ClassVar, + Generic, + Optional, + TypeVar, + Union, + Unpack, + final, + Protocol, + TYPE_CHECKING, +) + +from pydantic import BaseModel, Field, field_validator, model_validator from mio.exceptions import ConfigurationMismatchError from mio.models.models import MiniscopeConfig, PipelineModel +from mio.models.mixins import ConfigYAMLMixin if sys.version_info < (3, 11): - from typing_extensions import Self, TypedDict + from typing_extensions import Self, TypedDict, NotRequired elif sys.version_info < (3, 12): from typing import Self - from typing_extensions import TypedDict + from typing_extensions import TypedDict, NotRequired else: - from typing import Self, TypedDict + from typing import Self, TypedDict, NotRequired + +if TYPE_CHECKING: + from mio.pipeline.runner import PipelineRunner T = TypeVar("T") """ @@ -54,7 +70,7 @@ class _NodeMap(TypedDict): target: str -class NodeConfig(TypedDict): +class NodeConfig(TypedDict, total=False): """ Abstract parent TypedDict that each node inherits from to define what fields it needs to be configured. @@ -78,7 +94,7 @@ class NodeSpecification(MiniscopeConfig): """The unique identifier of the node""" outputs: Optional[list[_NodeMap]] = None """List of Node IDs to be used as output""" - config: Optional[NodeConfig] = None + config: Optional[dict] = None """Additional configuration for this node, parameterized by a TypedDict for the class""" passed: Optional[dict[str, str]] = None """ @@ -152,12 +168,19 @@ class NodeSpecification(MiniscopeConfig): """ -class PipelineConfig(MiniscopeConfig): +class _RequiredNode(TypedDict): + type: str + """Node type (as determined by its :attr:`.Node.name` attr""" + config: NotRequired[dict[str, Any]] + """Require that config values must be set to these values""" + + +class PipelineConfig(MiniscopeConfig, ConfigYAMLMixin): """ Configuration for the nodes within a pipeline """ - required_nodes: ClassVar[Optional[dict[str, str]]] = None + required_nodes: ClassVar[Optional[dict[str, _RequiredNode]]] = None """ id: type mapping that a subclass can use to require a set of node types with specific IDs be present @@ -170,9 +193,17 @@ class PipelineConfig(MiniscopeConfig): def validate_required_nodes(self) -> Self: """Ensure required nodes are present, if any""" if self.required_nodes is not None: - for id_, type_ in self.required_nodes.items(): - assert id_ in self.nodes, f"Node ID {id_} not in {self.nodes.keys()}" - assert self.nodes[id_].type_ == type_, f"Node ID {id_} is not of type {type_}" + for id_, required in self.required_nodes.items(): + assert id_ in self.nodes, f"Required node id {id_} not in {self.nodes.keys()}" + assert ( + self.nodes[id_].type_ == required["type"] + ), f"Node ID {id_} is not of type {required['type']}" + if "config" in required: + for key, val in required["config"].items(): + assert ( + self.nodes[id_].config[key] == val + ), f"Required node {id_} must have config value {key} set to {val}, " + f"got {self.nodes[id_].config[key]} instead." return self @field_validator("nodes", mode="before") @@ -291,17 +322,21 @@ def node_types(cls) -> dict[str, type["Node"]]: """ Map of all imported :attr:`.Node.name` names to node classes """ + from mio import sinks, sources, transforms + node_types = {} to_check = cls.__subclasses__() while to_check: node = to_check.pop() - if node.name in node_types: + if node not in (Sink, Source, Transform) and node.name in node_types: raise ValueError( f"Repeated node name identifier: {node.name}, found in:\n" f"- {node_types[node.name]}\n- {node}" ) - node_types[node.name] = node + to_check.extend(node.__subclasses__()) + if node not in (Sink, Source, Transform): + node_types[node.name] = node return node_types @@ -569,11 +604,69 @@ def _validate_passed(cls, config: PipelineConfig, passed: dict[str, Any]) -> Non Raise ConfigurationMismatchError if missing keys, otherwise do nothing """ required = cls.passed_values(config) + for key in required: - if key not in passed: + if passed is None or key not in passed: raise ConfigurationMismatchError( f"Pipeline config requires these values to be passed:\n" f"{required}\n" f"But received passed values:\n" f"{passed}" ) + + +class _ConfigProtocol(Protocol): + """ + Abstract protocol type to specify that classes consuming the PipelineMixin + must have some config attribute that specifies a pipeline + (without prescribing what that config object must be) + """ + + pipeline: Optional[PipelineConfig] = None + + +@dataclass(kw_only=True) +class PipelineMixin: + """Mixin for use with models that have pipelines!""" + + config: _ConfigProtocol + + _pipeline: Optional[Pipeline] = None + _runner: Optional["PipelineRunner"] = None + + @property + def sources(self) -> dict[str, "Source"]: + """Convenience method to access :attr:`.Pipeline.sources`""" + return self.pipeline.sources + + @property + def transforms(self) -> dict[str, "Transform"]: + """Convenience method to access :attr:`.Pipeline.transforms`""" + return self.pipeline.transforms + + @property + def sinks(self) -> dict[str, "Sink"]: + """Convenience method to access :attr:`.Pipeline.sinks`""" + return self.pipeline.sinks + + @property + def pipeline(self) -> Pipeline: + """Instantiated Pipeline from pipeline config""" + if self._pipeline is None: + self._pipeline = Pipeline.from_config(self.config.pipeline) + return self._pipeline + + @property + def runner(self) -> "PipelineRunner": + """ + A :class:`.PipelineRunner` that ... runs the :attr:`~.PipelineMixin.pipeline` ! + + By default, creates a :class:`.SynchronousRunner`, + and this property should be overridden by subclasses that want to specialize + runner instantiation. + """ + from mio.pipeline.runner import SynchronousRunner + + if self._runner is None: + self._runner = SynchronousRunner(self.pipeline) + return self._runner diff --git a/mio/pipeline/runner.py b/mio/pipeline/runner.py index 365e07a1..e66f216c 100644 --- a/mio/pipeline/runner.py +++ b/mio/pipeline/runner.py @@ -10,8 +10,10 @@ from itertools import count from logging import Logger from typing import TYPE_CHECKING, Any, Optional, Self +from threading import Event from mio import init_logger +from mio.exceptions import PipelineRunningError from mio.models import Pipeline from mio.models.pipeline import Edge, Event, Node, Source @@ -133,7 +135,12 @@ def process(self) -> Optional[dict[str, Any]]: @abstractmethod def start(self) -> None: """ - Start processing data with the pipeline graph + Start processing data with the pipeline graph. + + Implementations of this method must raise a :class:`.PipelineRunningError` + if the pipeline has already been started and is running, + (i.e. :meth:`.stop` has not been called, + or the pipeline has not exhausted itself) """ @abstractmethod @@ -142,6 +149,14 @@ def stop(self) -> None: Stop processing data with the pipeline graph """ + @property + @abstractmethod + def running(self) -> bool: + """ + Whether the pipeline is currently running + """ + pass + def gather_input(self, node: Node) -> Optional[dict[str, Any]]: """ Gather input to give to the passed Node from the :attr:`.PipelineRunner.store` @@ -187,33 +202,40 @@ class SynchronousRunner(PipelineRunner): Just run the nodes in topological order and return from return nodes. """ - @contextmanager - def start(self) -> Generator[Self, None, None]: - """ - Start processing data with the pipeline graph. - - Returns a contextmanager that should be used like this: + def __init__(self): + self._running = Event() - .. code-block:: python + def __enter__(self) -> Self: + self.start() + return self - with sync_runner.start() as runner: - output = runner.process() - # do something... + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + self.stop() + def start(self) -> Self: + """ + Start processing data with the pipeline graph. """ # TODO: lock for re-entry - try: - for node in self.pipeline.nodes.values(): - node.start() - yield self - finally: - self.stop() + if self._running.is_set(): + raise PipelineRunningError("Pipeline is already running!") + + self._running.set() + for node in self.pipeline.nodes.values(): + node.start() + return self def stop(self) -> None: """Stop all nodes processing""" # TODO: lock to ensure we've been started for node in self.pipeline.nodes.values(): node.stop() + self._running.clear() + + @property + def running(self) -> bool: + """Whether the pipeline is currently running""" + return self._running.is_set() def process(self) -> Optional[dict[str, Any]]: """ diff --git a/mio/sources/__init__.py b/mio/sources/__init__.py index 61edb55f..0e84b150 100644 --- a/mio/sources/__init__.py +++ b/mio/sources/__init__.py @@ -2,6 +2,7 @@ Source pipeline nodes that emit but do not receive events """ -from mio.sources.file import SDFileSource +from mio.sources.file import FileSource, BinaryFileSource, SDFileSource +from mio.sources.opalkelly import okDev -__all__ = ["SDFileSource"] +__all__ = ["BinaryFileSource", "FileSource", "SDFileSource", "okDev"] diff --git a/mio/sources/file.py b/mio/sources/file.py index 772fb077..854089e2 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -3,11 +3,13 @@ """ import sys +import os from pathlib import Path from typing import BinaryIO, ClassVar, Optional +from threading import Lock import numpy as np -from pydantic import Field +from pydantic import Field, PrivateAttr from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.pipeline import Source @@ -63,6 +65,11 @@ def tell(self) -> int: raise RuntimeError("File has not yet been opened with start") return self._f.tell() + @property + def position(self) -> int: + """Property alias for :meth:`.tell` - the current position in the file""" + return self.tell() + def process(self) -> bytes: """Return a block of data""" return self._f.read(self.block_size) @@ -75,6 +82,13 @@ class SDFileSourceOutput(TypedDict): buffer: np.ndarray +class SDFileSourceConfig(TypedDict): + """Config for :class:`.SDFileSource`""" + + path: Path + layout: SDLayout + + class SDFileSource(FileSource): """ Structured binary file that has @@ -95,11 +109,10 @@ class SDFileSource(FileSource): name = "sd-file-source" output_type = SDFileSourceOutput - path: Path - layout: SDLayout + config: SDFileSourceConfig _f: Optional[BinaryIO] = None - _positions: dict[int, int] = Field(default_factory=dict) + _positions: dict[int, int] = PrivateAttr(default_factory=dict) """ A mapping between frame number and byte position in the video that makes for faster seeking :) @@ -110,21 +123,12 @@ class SDFileSource(FileSource): """ _last_buffer: int = None _frame: int = 0 - - @property - def width(self) -> int: - """width of the captured video in pixels""" - return self.config.width - - @property - def height(self) -> int: - """height of the captured video in pixels""" - return self.config.height + _lock: Lock = PrivateAttr(default_factory=Lock) @property def offset(self) -> int: """Start point of the data sector""" - return self.layout.sectors.data_pos + return self.config["layout"].sectors.data_pos @property def header_size(self) -> int: @@ -137,15 +141,16 @@ def header_size(self) -> int: https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/issues/64 """ return ( - max([v for v in self.layout.buffer.model_dump().values() if v is not None]) + 1 - ) * self.layout.word_size + max([v for v in self.config["layout"].buffer.model_dump().values() if v is not None]) + + 1 + ) * self.config["layout"].word_size def start(self) -> None: """Open the file, seek to the offset""" self._last_buffer = 0 - self._f = open(self.path, "rb") # noqa: SIM115 - self._f.seek(self.offset, 0) + self._f = open(self.config["path"], "rb") # noqa: SIM115 + self._f.seek(self.offset, os.SEEK_SET) def stop(self) -> None: """Close the file, remove the reference""" @@ -156,7 +161,25 @@ def tell(self) -> int: """Return the current position in the file""" if self._f is None: raise RuntimeError("File has not yet been opened with start") - return self._f.tell() + with self._lock: + return self._f.tell() + + @property + def size(self) -> int: + """ + Size of the file (in bytes) + """ + current_pos = self.tell() + with self._lock: + self._f.seek(self.offset, os.SEEK_END) + end = self._f.tell() + self._f.seek(current_pos, os.SEEK_SET) + return end + + @property + def position(self) -> int: + """Property alias for :meth:`.tell` - the current position in the file""" + return self.tell() def process(self) -> SDFileSourceOutput: """ @@ -182,9 +205,11 @@ def _read_header(self, sd: BinaryIO) -> SDBufferHeader: """ # Get the length of the header from the first word try: - dataHeader = np.frombuffer( - sd.read(self.layout.word_size), dtype=np.dtype(self.layout.header_dtype) - ) + with self._lock: + dataHeader = np.frombuffer( + sd.read(self.config["layout"].word_size), + dtype=np.dtype(self.config["layout"].header_dtype), + ) except IndexError as e: if "index 0 is out of bounds for axis 0 with size 0" in str(e): # end of file if we are reading from a disk image without any @@ -195,13 +220,14 @@ def _read_header(self, sd: BinaryIO) -> SDBufferHeader: # Get the rest of the values in the header try: - dataHeader = np.append( - dataHeader, - np.frombuffer( - sd.read(int(dataHeader[0]) * self.layout.word_size), - dtype=np.dtype(self.layout.header_dtype), - ), - ) + with self._lock: + dataHeader = np.append( + dataHeader, + np.frombuffer( + sd.read(int(dataHeader[0]) * self.config["layout"].word_size), + dtype=np.dtype(self.config["layout"].header_dtype), + ), + ) except ValueError as e: if "read length must be non-negative" in str(e): # end of file! Value error thrown because the dataHeader will be @@ -214,24 +240,24 @@ def _read_header(self, sd: BinaryIO) -> SDBufferHeader: # use construct because we're already sure these are ints from the numpy casting # https://docs.pydantic.dev/latest/usage/models/#creating-models-without-validation try: - return SDBufferHeader.from_format(dataHeader, self.layout.buffer, construct=True) + return SDBufferHeader.from_format( + dataHeader, self.config["layout"].buffer, construct=True + ) except IndexError as e: - if ( - self._last_buffer - >= self.config.n_buffers_recorded + self.config.n_buffers_dropped - 1 - ): + if self.tell() >= self.size: raise EndOfRecordingException("Reached the end of the video!") from None else: raise ReadHeaderException( "Could not read header, expected header to have " - f"{len(self.layout.buffer.model_dump().keys())} fields, " + f"{len(self.config['layout'].buffer.model_dump().keys())} fields, " f"got {len(dataHeader)}. Likely mismatch between specified " "and actual SD Card layout or reached end of data.\n" f"Header Data: {dataHeader}" ) from e def _read_buffer(self, sd: BinaryIO, header: SDBufferHeader) -> np.ndarray: - return np.frombuffer(sd.read(self._data_read_size(header)), dtype=np.uint8) + with self._lock: + return np.frombuffer(sd.read(self._data_read_size(header)), dtype=np.uint8) def _data_read_size(self, header: SDBufferHeader) -> int: """ @@ -239,13 +265,13 @@ def _data_read_size(self, header: SDBufferHeader) -> int: """ # blocks are quantized by sector size, so get min number of blocks that cover the data n_blocks = np.ceil( - (header.data_length + (header.length * self.layout.word_size)) - / self.layout.sectors.size + (header.data_length + (header.length * self.config["layout"].word_size)) + / self.config["layout"].sectors.size ) # expand back to n bytes - sector_size = n_blocks * self.layout.sectors.size + sector_size = n_blocks * self.config["layout"].sectors.size # subtract length of header - return int(sector_size - (header.length * self.layout.word_size)) + return int(sector_size - (header.length * self.config["layout"].word_size)) def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: """ diff --git a/mio/transforms/__init__.py b/mio/transforms/__init__.py index 3a1cc155..4f14cf6d 100644 --- a/mio/transforms/__init__.py +++ b/mio/transforms/__init__.py @@ -1,3 +1,7 @@ """ Transform pipeline nodes that both receive and emit events """ + +from mio.transforms.frame import MergeBuffers + +__all__ = ["MergeBuffers"] diff --git a/tests/fixtures.py b/tests/fixtures.py index 8f162e13..f5e9cc63 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -14,20 +14,20 @@ @pytest.fixture -def wirefree() ->WireFreeMiniscope: +def wirefree() -> WireFreeMiniscope: """ SDCard with wirefree layout pointing to the sample data file """ sd_path = Path(__file__).parent.parent / "data" / "wirefree_example.img" - sdcard = WireFreeMiniscope(drive=sd_path, layout="wirefree-sd-layout") + sdcard = WireFreeMiniscope(drive=sd_path, config="wirefree-default") return sdcard @pytest.fixture def wirefree_battery() -> WireFreeMiniscope: sd_path = Path(__file__).parent.parent / "data" / "wirefree_battery_sample.img" - sdcard = WireFreeMiniscope(drive=sd_path, layout="wirefree-sd-layout-battery") + sdcard = WireFreeMiniscope(drive=sd_path, config="wirefree-sd-layout-battery") return sdcard diff --git a/tests/test_devices/test_wirefree.py b/tests/test_devices/test_wirefree.py index 0ce3adeb..d1415d35 100644 --- a/tests/test_devices/test_wirefree.py +++ b/tests/test_devices/test_wirefree.py @@ -97,7 +97,7 @@ def test_relative_path(): rel_path = abs_child.relative_to(abs_cwd) assert not rel_path.is_absolute() - sdcard = WireFreeMiniscope(drive=rel_path, layout="wirefree-sd-layout") + sdcard = WireFreeMiniscope(drive=rel_path) # check we can do something basic like read metadata assert sdcard.metadata is not None @@ -108,7 +108,7 @@ def test_relative_path(): # now try with an absolute path abs_path = rel_path.resolve() assert abs_path.is_absolute() - sdcard_abs = WireFreeMiniscope(drive=abs_path, layout="wirefree-sd-layout") + sdcard_abs = WireFreeMiniscope(drive=abs_path) assert sdcard_abs.metadata is not None assert sdcard_abs.drive.is_absolute()