diff --git a/mio/data/config/wirefree/wirefree-pipeline.yaml b/mio/data/config/wirefree/wirefree-pipeline.yaml new file mode 100644 index 00000000..f5020709 --- /dev/null +++ b/mio/data/config/wirefree/wirefree-pipeline.yaml @@ -0,0 +1,28 @@ +id: wirefree-pipeline +mio_model: mio.devices.wirefree.WireFreePipeline +mio_version: v0.6.0 + +nodes: + sdcard: + type: "sd-file-source" + config: + layout: "wirefree-sd-layout" + passed: + path: sd_path + outputs: + - source: header + target: merge.header + - source: buffer + target: merge.buffer + merge: + type: "merge-buffers" + fill: + width: file.width + height: file.height + outputs: + - source: frame + target: return + return: + config: + key: frame + type: "return" diff --git a/mio/data/config/wirefree/wirefree.yaml b/mio/data/config/wirefree/wirefree.yaml new file mode 100644 index 00000000..df468490 --- /dev/null +++ b/mio/data/config/wirefree/wirefree.yaml @@ -0,0 +1,5 @@ +id: wirefree-default +mio_model: mio.devices.wirefree.WireFreeConfig +mio_version: 0.6.1 +pipeline: wirefree-pipeline +layout: wirefree-sd-layout diff --git a/mio/devices/device.py b/mio/devices/device.py index 34324259..d6ceb381 100644 --- a/mio/devices/device.py +++ b/mio/devices/device.py @@ -2,36 +2,40 @@ ABC for """ +import sys from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional -from mio.models import MiniscopeConfig, Pipeline, PipelineConfig +from mio.models import MiniscopeConfig, Pipeline, PipelineConfig, PipelineMixin +from mio.models.mixins import ConfigYAMLMixin if TYPE_CHECKING: from mio.models.pipeline import Sink, Source, Transform +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self -class DeviceConfig(MiniscopeConfig): + +class DeviceConfig(MiniscopeConfig, ConfigYAMLMixin): """ Abstract base class for device configuration """ - id: Union[str, int] - """(Locally) unique identifier for this device""" pipeline: PipelineConfig = PipelineConfig() @dataclass(kw_only=True) -class Device: +class Device(PipelineMixin): """ Abstract base class for devices. Currently a placeholder to allow room for expansion/renaming in the future """ - pipeline: Optional[Pipeline] = None - # config: Optional[DeviceConfig] = None + config: Optional[DeviceConfig] = None @abstractmethod def init(self) -> None: @@ -126,17 +130,8 @@ def set(self, key: str, value: Any) -> None: key (str): The name of the value to get """ - @property - def sources(self) -> dict[str, "Source"]: - """Convenience method to access :attr:`.Pipeline.sources`""" - return self.pipeline.sources - - @property - def transforms(self) -> dict[str, "Transform"]: - """Convenience method to access :attr:`.Pipeline.transforms`""" - return self.pipeline.transforms - - @property - def sinks(self) -> dict[str, "Sink"]: - """Convenience method to access :attr:`.Pipeline.sinks`""" - return self.pipeline.sinks + @classmethod + def from_config(cls, config: DeviceConfig) -> Self: + """ + Instantiate a device from its (yaml) configuration. + """ diff --git a/mio/devices/wirefree.py b/mio/devices/wirefree.py index 253926ff..6833a14d 100644 --- a/mio/devices/wirefree.py +++ b/mio/devices/wirefree.py @@ -5,7 +5,7 @@ import contextlib from dataclasses import dataclass, field from pathlib import Path -from typing import Any, BinaryIO, Literal, Optional, Union, overload +from typing import Any, Literal, Optional, Union, overload import cv2 import numpy as np @@ -14,15 +14,28 @@ from mio import init_logger from mio.devices import DeviceConfig, Miniscope, RecordingCameraMixin from mio.exceptions import EndOfRecordingException, ReadHeaderException +from mio.models import Pipeline from mio.models.data import Frame -from mio.models.sdcard import SDBufferHeader, SDConfig, SDLayout -from mio.types import ConfigSource, Resolution +from mio.models.pipeline import PipelineConfig +from mio.models.sdcard import SDLayout, SDMetadata +from mio.sources import SDFileSource +from mio.types import Resolution + + +class WireFreePipeline(PipelineConfig): + """Base skeleton pipeline for the wirefree miniscope""" + + required_nodes = { + "sdcard": {"type": "sd-file-source"}, + "return": {"type": "return", "config": {"key": "frame"}}, + } class WireFreeConfig(DeviceConfig): """Configuration for wire free miniscope""" - pass + pipeline: WireFreePipeline = "wirefree-pipeline" + layout: SDLayout = "wirefree-sd-layout" @dataclass(kw_only=True) @@ -36,15 +49,15 @@ class WireFreeMiniscope(Miniscope, RecordingCameraMixin): Args: drive (str, :class:`pathlib.Path`): Path to the SD card drive - layout (:class:`.sdcard.SDLayout`): A layout configuration for an SD card + config (:class:`.WireFreeConfig`): Configuration, + including data layout and pipeline configs """ drive: Path """The path to the SD card drive""" - # config: WireFreeConfig - # """Configuration """ - layout: Union[SDLayout, ConfigSource] = "wirefree-sd-layout" + config: WireFreeConfig = field(default_factory=WireFreeConfig) + positions: dict[int, int] = field(default_factory=dict) """ A mapping between frame number and byte position in the video that makes for @@ -57,12 +70,14 @@ class WireFreeMiniscope(Miniscope, RecordingCameraMixin): def __post_init__(self) -> None: """post-init create private vars""" - self.layout = SDLayout.from_any(self.layout) + if isinstance(self.config, str): + self.config = WireFreeConfig.from_any(self.config) + + self.layout = SDLayout.from_any(self.config.layout) self.logger = init_logger("WireFreeMiniscope") # Private attributes used when the file reading context is entered - self._config = None # type: Optional[SDConfig] - self._f = None # type: Optional[BinaryIO] + self._metadata = None # type: Optional[SDMetadata] self._frame = None # type: Optional[int] self._frame_count = None # type: Optional[int] self._array = None # type: Optional[np.ndarray] @@ -75,24 +90,24 @@ def __post_init__(self) -> None: # -------------------------------------------------- @property - def config(self) -> SDConfig: + def metadata(self) -> SDMetadata: """ - Read configuration from SD Card + Read metadata from SD Card """ - if self._config is None: + if self._metadata is None: with open(self.drive, "rb") as sd: sd.seek(self.layout.sectors.config_pos, 0) configSectorData = np.frombuffer(sd.read(self.layout.sectors.size), dtype=np.uint32) - self._config = SDConfig( + self._metadata = SDMetadata( **{ k: configSectorData[v] - for k, v in self.layout.config.model_dump().items() + for k, v in self.layout.metadata.model_dump().items() if v is not None } ) - return self._config + return self._metadata @classmethod def configure(cls, drive: Union[str, Path], config: WireFreeConfig) -> None: @@ -107,10 +122,7 @@ def position(self) -> Optional[int]: When entered as context manager, the current position of the internal file descriptor """ - if self._f is None: - return None - - return self._f.tell() + return self._source.position @property def frame(self) -> Optional[int]: @@ -159,246 +171,56 @@ def frame(self, frame: int) -> None: self.skip() @property - def frame_count(self) -> int: + def buffers_per_frame(self) -> int: """ - Total number of frames in recording. + Number of buffers per frame! - Inferred from :class:`~.sdcard.SDConfig.n_buffers_recorded` and - reading a single frame to get the number of buffers per frame. + References: + https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/blob/786663781a4bece89c89e00fc3ac9d95912faba4/Miniscope-v4-Wire-Free-MCU-Firmware/Miniscope-v4-Wire-Free/Miniscope-v4-Wire-Free/main.c#L680 """ - if self._frame_count is None: - if self._f is None: - with self as self_open: - frame = self_open.read(return_header=True) - headers = frame.headers - - else: - # If we're already open, great, just return to the last frame - last_frame = self.frame - # Go one frame back in case we are at the end of the data - self.frame = max(last_frame - 1, 0) - frame = self.read(return_header=True) - headers = frame.headers - self.frame = last_frame - - self._frame_count = int( - np.ceil( - (self.config.n_buffers_recorded + self.config.n_buffers_dropped) / len(headers) - ) - ) + n_pix = self.metadata.width * self.metadata.height + return int(np.ceil((n_pix + self._source.header_size) / (self.metadata.buffer_size))) - # if we have since read more frames than should be there, we update the - # frame count with a warning - max_pos = np.max(list(self.positions.keys())) - if max_pos > self._frame_count: - self.logger.warning( - "Got more frames than indicated in card header, expected " - f"{self._frame_count} but got {max_pos}" + @property + def frame_count(self) -> int: + """ + Total number of frames in the recording + """ + return int( + np.ceil( + (self.metadata.n_buffers_recorded + self.metadata.n_buffers_dropped) + / self.buffers_per_frame ) - self._frame_count = int(max_pos) + ) - return self._frame_count + @property + def _source(self) -> SDFileSource: + """The SDFileSource node in the pipeline""" + return self.pipeline.nodes["sdcard"] # -------------------------------------------------- # Context Manager methods # -------------------------------------------------- def __enter__(self) -> "WireFreeMiniscope": - if self._f is not None: - raise RuntimeError("Cant enter context, and open the file twice!") - - # init private attrs - # create an empty frame to hold our data! - self._array = np.zeros((self.config.width * self.config.height, 1), dtype=np.uint8) - self._pixel_count = 0 - self._last_buffer_n = 0 - self._frame = 0 - - self._f = open(self.drive, "rb") # noqa: SIM115 - this is a context handler - # seek to the start of the data - self._f.seek(self.layout.sectors.data_pos, 0) - # store the 0th frame position - self.positions[0] = self.layout.sectors.data_pos - + self.runner.start() return self def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 - self._f.close() - self._f = None - self._frame = 0 - - # -------------------------------------------------- - # read methods - # -------------------------------------------------- - def _read_data_header(self, sd: BinaryIO) -> SDBufferHeader: - """ - Given an already open file buffer opened in bytes mode, - seeked to the start of a frame, read the data header - """ - - # read one word first, I think to get the size of the rest of the header, - # that sort of breaks the abstraction - # (it assumes the buffer length is always at position 0) - # but we'll roll with it for now - dataHeader = np.frombuffer(sd.read(self.layout.word_size), dtype=np.uint32) - dataHeader = np.append( - dataHeader, - np.frombuffer( - sd.read((dataHeader[self.layout.buffer.length] - 1) * self.layout.word_size), - dtype=np.uint32, - ), - ) - - # use construct because we're already sure these are ints from the numpy casting - # https://docs.pydantic.dev/latest/usage/models/#creating-models-without-validation - try: - header = SDBufferHeader.from_format(dataHeader, self.layout.buffer, construct=True) - except IndexError as e: - raise ReadHeaderException( - "Could not read header, expected header to have " - f"{len(self.layout.buffer.model_dump().keys())} fields, " - f"got {len(dataHeader)}. Likely mismatch between specified " - "and actual SD Card layout or reached end of data.\n" - f"Header Data: {dataHeader}" - ) from e - - return header - - def _n_frame_blocks(self, header: SDBufferHeader) -> int: - """ - Compute the number of blocks for a given frame buffer - - Not sure how this works! - """ - n_blocks = int( - ( - header.data_length - + (header.length * self.layout.word_size) - + (self.layout.sectors.size - 1) - ) - / self.layout.sectors.size - ) - return n_blocks - - def _read_size(self, header: SDBufferHeader) -> int: - """ - Compute the number of bytes to read for a given buffer - - Not sure how this works with :meth:`._n_frame_blocks`, but keeping - them separate in case they are separable actions for now - """ - n_blocks = self._n_frame_blocks(header) - read_size = (n_blocks * self.layout.sectors.size) - (header.length * self.layout.word_size) - return read_size - - def _read_buffer(self, sd: BinaryIO, header: SDBufferHeader) -> np.ndarray: - """ - Read a single buffer from a frame. - - Each frame has several buffers, so for a given frame we read them until we - get another that's zero! - """ - data = np.frombuffer(sd.read(self._read_size(header)), dtype=np.uint8) - return data - - def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: - """ - Trim or pad an array to match an expected size - """ - if data.shape[0] != expected_size: - self.logger.warning( - f"Frame: {self._frame}: Expected buffer data length: {expected_size}, " - f"got data with shape {data.shape}. " - "Padding to expected length", - stacklevel=1, - ) - - # trim if too long - if data.shape[0] > expected_size: - data = data[0:expected_size] - # pad if too short - else: - data = np.pad(data, (0, expected_size - data.shape[0])) + self.runner.stop() - return data - - @overload - def read(self, return_header: Literal[True] = True) -> Frame: ... - - @overload - def read(self, return_header: Literal[False] = False) -> np.ndarray: ... - - def read(self, return_header: bool = False) -> Union[np.ndarray, Frame]: + def read(self) -> Frame: """ - Read a single frame - - Arguments: - return_header (bool): If `True`, return headers from individual buffers - (default `False`) - - Return: - :class:`numpy.ndarray` , - or a tuple(ndarray, List[:class:`~.SDBufferHeader`]) if `return_header` - is `True` + Read a single :class:`.Frame` """ - if self._f is None: - raise RuntimeError( - "File is not open! Try entering the reader context by using it like " - "`with sdcard:`" - ) + if not self.runner.running: + self.logger.debug("Starting runner implicitly by calling read") + self.runner.start() - self._array[:] = 0 - pixel_count = 0 - last_buffer_n = 0 - headers = [] - while True: - # stash position before reading header - last_position = self._f.tell() - try: - header = self._read_data_header(self._f) - except ValueError as e: - if "read length must be non-negative" in str(e): - # end of file! Value error thrown because the dataHeader will be - # blank, and thus have a value of 0 for the header size, and we - # can't read 0 from the card. - self._f.seek(last_position, 0) - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - except IndexError as e: - if "index 0 is out of bounds for axis 0 with size 0" in str(e): - # end of file if we are reading from a disk image without any - # additional space on disk - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - except ReadHeaderException as e: - # if we are on the last frame, normal! signal end of iteration - if self._frame == self.frame_count - 1: - raise EndOfRecordingException("Reached the end of the video!") from None - else: - raise e - - if header.frame_buffer_count == 0 and last_buffer_n > 0: - # we are in the next frame! - # rewind to the beginning of the header, and return - # the last_position is the start of the header for this frame - self._f.seek(last_position, 0) - self._frame += 1 - self.positions[self._frame] = last_position - frame = np.reshape(self._array, (self.config.width, self.config.height)) - if return_header: - return Frame.model_construct(frame=frame, headers=headers) - else: - return frame - - # grab buffer data and stash - headers.append(header) - data = self._read_buffer(self._f, header) - data = self._trim(data, header.data_length) - self._array[pixel_count : pixel_count + header.data_length, 0] = data - pixel_count += header.data_length - last_buffer_n = header.frame_buffer_count + res = None + while res is None: + res: Optional[dict[Literal["frame"], Frame]] = self.runner.process() + return res["frame"] # -------------------------------------------------- # Write methods @@ -451,8 +273,8 @@ def to_video( writer = cv2.VideoWriter( str(path), cv2.VideoWriter_fourcc(*fourcc), - self.config.fs, - (self.config.width, self.config.height), + self.metadata.fs, + (self.metadata.width, self.metadata.height), isColor=isColor, ) @@ -465,8 +287,8 @@ def to_video( while True: # this is sort of an awkward stack, should probably make a # generator version of `read` - frame = self.read(return_header=False) - writer.write(frame) + frame = self.read() + writer.write(frame.frame) if progress: pbar.update() @@ -631,16 +453,25 @@ def deinit(self) -> None: def start(self) -> None: """start pipeline""" - raise NotImplementedError() + self.runner.start() def stop(self) -> None: """stop pipeline""" - raise NotImplementedError() + self.runner.stop() def join(self) -> None: """join pipeline""" raise NotImplementedError() + @property + def pipeline(self) -> Pipeline: + """Create pipeline, passing needed values""" + if self._pipeline is None: + self._pipeline = Pipeline.from_config( + self.config.pipeline, passed={"sd_path": self.drive} + ) + return self._pipeline + @property def excitation(self) -> float: """LED Excitation""" @@ -649,12 +480,12 @@ def excitation(self) -> float: @property def fps(self) -> int: """FPS""" - return self.config.fs + return self.metadata.fs @property def resolution(self) -> Resolution: """Resolution of recorded video""" - return Resolution(self.config.width, self.config.height) + return Resolution(self.metadata.width, self.metadata.height) def get(self, key: str) -> Any: """get a configuration value by its name""" diff --git a/mio/exceptions.py b/mio/exceptions.py index 1c7343a9..767a8392 100644 --- a/mio/exceptions.py +++ b/mio/exceptions.py @@ -66,3 +66,13 @@ class ConfigurationMismatchError(ConfigurationError): """ Mismatch between the fields in some config model and the fields in the model it is configuring """ + +class PipelineError(Exception): + """ + Base exception type for pipeline errors + """ + +class PipelineRunningError(PipelineError, RuntimeError): + """ + A pipeline has been requested to start, but it is already started! + """ \ No newline at end of file diff --git a/mio/models/__init__.py b/mio/models/__init__.py index 815b6156..2d57b5b5 100644 --- a/mio/models/__init__.py +++ b/mio/models/__init__.py @@ -12,6 +12,7 @@ Node, Pipeline, PipelineConfig, + PipelineMixin, Sink, Source, Transform, @@ -24,6 +25,7 @@ "Node", "Pipeline", "PipelineConfig", + "PipelineMixin", "PipelineModel", "Transform", "Sink", diff --git a/mio/models/mixins.py b/mio/models/mixins.py index fd09a240..3031eb4b 100644 --- a/mio/models/mixins.py +++ b/mio/models/mixins.py @@ -3,6 +3,7 @@ to use composition for functionality and inheritance for semantics. """ +import pdb import re import shutil from importlib.metadata import version @@ -11,11 +12,20 @@ from typing import Any, ClassVar, List, Literal, Optional, Type, TypeVar, Union, overload import yaml -from pydantic import BaseModel, Field, ValidationError, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + GetCoreSchemaHandler, + ValidationError, + field_validator, +) +from pydantic_core import core_schema from mio.types import ConfigID, ConfigSource, PythonIdentifier, valid_config_id T = TypeVar("T") +"""Generic type of ConfigYamlMixin subclass""" class YamlDumper(yaml.SafeDumper): @@ -80,7 +90,9 @@ class ConfigYAMLMixin(BaseModel, YAMLMixin): at the top of the file. """ - id: ConfigID + model_config = ConfigDict(validate_default=True) + + id: Optional[ConfigID] = None mio_model: PythonIdentifier = Field(None, validate_default=True) mio_version: str = version("mio") @@ -120,22 +132,26 @@ def from_id(cls: Type[T], id: ConfigID) -> T: """ globs = [src.rglob("*.y*ml") for src in cls.config_sources] + for config_file in chain(*globs): + # if id == "wirefree-pipeline": + # pdb.set_trace() try: file_id = yaml_peek("id", config_file) - if file_id == id: - from mio.logging import init_logger - - init_logger("config").debug( - "Model for %s found at %s", cls._model_name(), config_file - ) - return cls.from_yaml(config_file) except KeyError: continue + if file_id == id: + from mio.logging import init_logger + + init_logger("config").debug( + "Model for %s found at %s", cls._model_name(), config_file + ) + return cls.from_yaml(config_file) + from mio import Config - raise KeyError(f"No config with id {id} found in {Config().config_dir}") + raise KeyError(f"No config with id {id} found in {cls.config_sources}") @classmethod def from_any(cls: Type[T], source: Union[ConfigSource, T]) -> T: @@ -258,6 +274,29 @@ def _complete_header(cls: Type[T], data: dict, file_path: Union[str, Path]) -> d return data + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """ + Add before_validator to allow instantiation from id + """ + + def _from_id(value: Any) -> cls: + if isinstance(value, str): + return cls.from_id(value) + else: + return value + + return core_schema.no_info_before_validator_function( + _from_id, + handler(source_type), + # TODO: add this when updating pydantic floor to 2.10 + # json_schema_input_schema=core_schema.union_schema( + # [handler(source_type), handler(ConfigID)] + # ), + ) + @overload def yaml_peek( diff --git a/mio/models/pipeline.py b/mio/models/pipeline.py index 0152a1bf..21b1ac08 100644 --- a/mio/models/pipeline.py +++ b/mio/models/pipeline.py @@ -4,30 +4,85 @@ import sys from abc import abstractmethod -from typing import ClassVar, Final, Generic, TypeVar, Union, final - -from pydantic import Field +from dataclasses import dataclass +from datetime import datetime +from graphlib import TopologicalSorter +from typing import ( + Any, + ClassVar, + Generic, + Optional, + TypeVar, + Union, + Unpack, + final, + Protocol, + TYPE_CHECKING, +) + +from pydantic import BaseModel, Field, field_validator, model_validator from mio.exceptions import ConfigurationMismatchError from mio.models.models import MiniscopeConfig, PipelineModel +from mio.models.mixins import ConfigYAMLMixin if sys.version_info < (3, 11): - from typing_extensions import Self -else: + from typing_extensions import Self, TypedDict, NotRequired +elif sys.version_info < (3, 12): from typing import Self + from typing_extensions import TypedDict, NotRequired +else: + from typing import Self, TypedDict, NotRequired + +if TYPE_CHECKING: + from mio.pipeline.runner import PipelineRunner + T = TypeVar("T") """ Input Type typevar """ -U = TypeVar("U") +U = TypeVar("U", bound=dict[str, Any]) """ Output Type typevar """ -class NodeConfig(MiniscopeConfig): - """Configuration for a single processing node""" +class Event(TypedDict, Generic[U]): + """ + Container for a single value returned from a single :meth:`.Node.process` call + """ + + id: int + """Unique ID for each event""" + timestamp: datetime + """Timestamp of when the event was received by the :class:`.PipelineRunner""" + node_id: str + """ID of node that emitted the event""" + slot: str + """name of the slot that emitted the event""" + value: Any + """Value emitted by the processing node""" + + +class _NodeMap(TypedDict): + source: str + target: str + + +class NodeConfig(TypedDict, total=False): + """ + Abstract parent TypedDict that each node inherits from to define + what fields it needs to be configured. + """ + + +class NodeSpecification(MiniscopeConfig): + """ + Specification for a single processing node within a pipeline .yaml file. + Distinct from a :class:`.NodeConfig`, which is a generic TypedDict that each + node defines to declare its parameterization. + """ type_: str = Field(..., alias="type") """ @@ -35,84 +90,259 @@ class NodeConfig(MiniscopeConfig): Subclasses should override this with a default. """ - id: str """The unique identifier of the node""" - inputs: list[str] = Field(default_factory=list) - """List of Node IDs to be used as input""" - outputs: list[str] = Field(default_factory=list) + outputs: Optional[list[_NodeMap]] = None """List of Node IDs to be used as output""" + config: Optional[dict] = None + """Additional configuration for this node, parameterized by a TypedDict for the class""" + passed: Optional[dict[str, str]] = None + """ + Mapping of config values that must be passed when the pipeline is instantiated. + + Keys are the key in the config dictionary to be filled by passing, and values are a key that + those values should be passed as. + + Examples: + + For a node with config field `height` , one can specify that it must be passed + on instantiation like this: + + .. code-block:: yaml + + nodes: + node1: + type: a_node + passed: + height: height_1 + node2: + type: a_node + passed: + height: height_2 + + The pipeline should then be instantiated like: + + .. code-block:: python + + Pipeline.from_config(above_config, passed={'height_1': 1, 'height_2': 2}) + + """ + fill: Optional[dict[str, str]] = None + """ + Values in the node config that should be dynamically filled from other nodes in the pipeline. + + Specified as {node_id}.{attribute}, these specify attributes and properties + on the instantiated node class, not the config values for that node. + + This is useful for accessing some properties that might not be known until runtime + like width and height of an input image. + + Examples: + + For a node class `camera` that has property `frame_width`, + and node class `process` that has config value `width`, + we would fill the config value like this: + + .. code-block:: yaml + + nodes: + cam: + type: camera + proc: + type: process + fill: + width: cam.frame_width + + The Pipeline class will then do something like this on instantiation: + + .. code-block:: python + + pipeline = PipelineConfig(**the_above_values) + + cam = CameraNode(config=pipeline.nodes['cam'].config) + + proc_config = pipeline.nodes['proc'].config + proc_config['width'] = cam.frame_width + proc = ProcessingNode(config=proc_config) + + """ + + +class _RequiredNode(TypedDict): + type: str + """Node type (as determined by its :attr:`.Node.name` attr""" + config: NotRequired[dict[str, Any]] + """Require that config values must be set to these values""" -class PipelineConfig(MiniscopeConfig): +class PipelineConfig(MiniscopeConfig, ConfigYAMLMixin): """ Configuration for the nodes within a pipeline """ - nodes: dict[str, NodeConfig] = Field(default_factory=dict) + required_nodes: ClassVar[Optional[dict[str, _RequiredNode]]] = None + """ + id: type mapping that a subclass can use to require a set of node types + with specific IDs be present + """ + + nodes: dict[str, NodeSpecification] = Field(default_factory=dict) """The nodes that this pipeline configures""" + @model_validator(mode="after") + def validate_required_nodes(self) -> Self: + """Ensure required nodes are present, if any""" + if self.required_nodes is not None: + for id_, required in self.required_nodes.items(): + assert id_ in self.nodes, f"Required node id {id_} not in {self.nodes.keys()}" + assert ( + self.nodes[id_].type_ == required["type"] + ), f"Node ID {id_} is not of type {required['type']}" + if "config" in required: + for key, val in required["config"].items(): + assert ( + self.nodes[id_].config[key] == val + ), f"Required node {id_} must have config value {key} set to {val}, " + f"got {self.nodes[id_].config[key]} instead." + return self + + @field_validator("nodes", mode="before") + @classmethod + def fill_node_ids(cls, value: dict[str, dict]) -> dict[str, dict]: + """ + Roll down the `id` from the key in the `nodes` dictionary into the node config + """ + assert isinstance(value, dict) + for id, node in value.items(): + if "id" not in node: + node["id"] = id + return value + + # TODO: Implement these validators + # @field_validator("nodes", mode="after") + # @classmethod + # def valid_passed_and_fill_keys( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # Passed and fill keys refer to values within the node's config type + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def unique_passed_values( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # All passed values ( + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def fill_sources_present( + # cls, value: dict[str, NodeSpecification] + # # ) -> dict[str, NodeSpecification]: + # """ + # Fill values refer to nodes that are present in the node graph + # """ + # + # @field_validator("nodes", mode="after") + # @classmethod + # def fill_values_dotted( + # cls, value: dict[str, NodeSpecification] + # ) -> dict[str, NodeSpecification]: + # """ + # Fill values refer to a property or attribute of a node (i.e. have at least one dot) + # """ + + def graph(self) -> TopologicalSorter: + """ + For the node specifications in :attr:`.PipelineConfig.nodes`, + produce a :class:`.TopologicalSorter` that accounts for the dependencies between nodes + induced by :attr:`.NodeSpecification.fill` + """ + sorter = TopologicalSorter() + for node_id, node in self.nodes.items(): + if node.fill is None: + sorter.add(node_id) + else: + dependents = {v.split(".")[0] for v in node.fill} + sorter.add(node_id, *dependents) + return sorter + class Node(PipelineModel, Generic[T, U]): """A node within a processing pipeline""" - type_: ClassVar[str] + name: ClassVar[str] """ Shortname for this type of node to match configs to node types """ id: str """Unique identifier of the node""" - config: NodeConfig + config: Optional[NodeConfig] = None input_type: ClassVar[type[T]] - inputs: dict[str, Union["Source", "Transform"]] = Field(default_factory=dict) output_type: ClassVar[type[U]] - outputs: dict[str, Union["Sink", "Transform"]] = Field(default_factory=dict) - @abstractmethod def start(self) -> None: """ - Start producing, processing, or receiving data + Start producing, processing, or receiving data. + + Default is a no-op. + Subclasses do not need to override if they have no initialization logic. """ + pass - @abstractmethod def stop(self) -> None: """ Stop producing, processing, or receiving data + + Default is a no-op. + Subclasses do not need to override if they have no deinit logic. """ + pass + + @abstractmethod + def process(self, **kwargs: Unpack[T]) -> Optional[U]: + """Process some input, emitting it. See subclasses for details""" + pass @classmethod - def from_config(cls, config: NodeConfig) -> Self: + def from_specification(cls, config: NodeSpecification) -> Self: """ Create a node from its config """ - return cls(id=config.id, config=config) + return cls(id=config.id, config=config.config) @classmethod @final def node_types(cls) -> dict[str, type["Node"]]: """ - Map of all imported :attr:`.Node.type_` names to node classes + Map of all imported :attr:`.Node.name` names to node classes """ + from mio import sinks, sources, transforms + node_types = {} to_check = cls.__subclasses__() while to_check: node = to_check.pop() - if node.type_ in node_types: + if node not in (Sink, Source, Transform) and node.name in node_types: raise ValueError( - f"Repeated node type_ identifier: {node.type_}, found in:\n" - f"- {node_types[node.type_]}\n- {node}" + f"Repeated node name identifier: {node.name}, found in:\n" + f"- {node_types[node.name]}\n- {node}" ) - node_types[node.type_] = node + to_check.extend(node.__subclasses__()) + if node not in (Sink, Source, Transform): + node_types[node.name] = node return node_types class Source(Node, Generic[T, U]): """A source of data in a processing pipeline""" - inputs: Final[None] = None input_type: ClassVar[None] = None @abstractmethod @@ -134,10 +364,9 @@ class Sink(Node, Generic[T, U]): """A sink of data in a processing pipeline""" output_type: ClassVar[None] = None - outputs: Final[None] = None @abstractmethod - def process(self, data: T) -> None: + def process(self, **kwargs: Unpack[T]) -> None: """ Process some incoming data, returning None @@ -155,7 +384,7 @@ class Transform(Node, Generic[T, U]): """ @abstractmethod - def process(self, data: T) -> U: + def process(self, **kwargs: Unpack[T]) -> U: """ Process some incoming data, yielding a transformed output @@ -168,15 +397,38 @@ def process(self, data: T) -> U: """ +class Edge(PipelineModel): + """ + Directed connection between an output slot a node and an input slot in another node + """ + + source_node: Node + source_slot: Optional[str] = None + target_node: Node + target_slot: Optional[str] = None + + class Pipeline(PipelineModel): """ A graph of nodes transforming some input source(s) to some output sink(s) + + The Pipeline model is a container for a set of nodes that are fully instantiated + (e.g. have their "passed" and "fill" keys processed) and connected. + It does not handle running the pipeline -- that is handled by a PipelineRunner. """ nodes: dict[str, Node] = Field(default_factory=dict) """ Dictionary mapping all nodes from their ID to the instantiated node. """ + edges: list[Edge] = Field(default_factory=list) + """ + Edges connecting slots within nodes. + + The nodes within :attr:`.Edge.source_node` and :attr:`.Edge.target_node` must + be the same objects as those in :attr:`.Pipeline.nodes` + (i.e. ``edges[0].source_node is nodes[node_id]`` ). + """ @property def sources(self) -> dict[str, "Source"]: @@ -193,58 +445,228 @@ def sinks(self) -> dict[str, "Sink"]: """All :class:`.Sink` nodes in the processing graph""" return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)} - @abstractmethod - def process(self) -> None: + def graph(self) -> TopologicalSorter: """ - Process one step of data from each of the sources, - passing intermediate data to any subscribed nodes in a chain. - - The `process` method should not return anything except a to-be-implemented - result/status object, as any data intended to be received/processed by - downstream objects should be accessed via a :class:`.Sink` . + Produce a :class:`.TopologicalSorter` based on the graph induced by + :attr:`.Pipeline.nodes` and :attr:`.Pipeline.edges` that yields node ids """ + sorter = TopologicalSorter() + for node_id, node in self.nodes.items(): + in_edges = [e.target_node.id for e in self.edges if e.target_node is node] + sorter.add(node_id, *in_edges) + return sorter - @abstractmethod - def start(self) -> None: + def in_edges(self, node: Union[Node, str]) -> list[Edge]: """ - Start processing data with the pipeline graph + Edges going towards the given node (i.e. the node is the edge's ``target`` ) + + Args: + node (:class:`.Node`, str): Either a node or its id + + Returns: + list[:class:`.Edge`] """ + if isinstance(node, Node): + node = node.id + return [e for e in self.edges if e.target_node.id == node] - @abstractmethod - def stop(self) -> None: + def out_edges(self, node: Union[Node, str]) -> list[Edge]: """ - Stop processing data with the pipeline graph + Edges going away from the given node (i.e. the node is the edge's ``source`` ) + + Args: + node (:class:`.Node`, str): Either a node or its id + + Returns: + list[:class:`.Edge`] """ + if isinstance(node, Node): + node = node.id + return [e for e in self.edges if e.source_node.id == node] @classmethod - def from_config(cls, config: PipelineConfig) -> Self: + def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None) -> Self: """ Instantiate a pipeline model from its configuration + + Args: + config (PipelineConfig): the pipeline config to instantiate + passed (dict[str, Any]): If any nodes in the + """ + cls._validate_passed(config, passed) + + nodes = cls._init_nodes(config, passed) + edges = cls._init_edges(nodes, config.nodes) + + return cls(nodes=nodes, edges=edges) + + @classmethod + def passed_values(cls, config: PipelineConfig) -> dict[str, type]: + """ + Dictionary containing the keys that must be passed as specified by the `passed` field + of a node specification and their types. + + Args: + config (:class:`.PipelineConfig`): Pipeline configuration to get passed values for """ types = Node.node_types() + passed = {} + for node_id, node in config.nodes.items(): + if not node.passed: + continue + + for cfg_key, pass_key in node.passed.items(): + # get type of config key that needs to be passed + config_type = types[node.type_].model_fields["config"].annotation + try: + passed[pass_key] = config_type.__annotations__[cfg_key] + except KeyError as e: + raise ConfigurationMismatchError( + f"Node {node_id} requested {cfg_key} be passed as {pass_key}, " + f"but node type {node.type_}'s config has no field {cfg_key}! " + f"Possible keys: {list(config_type.__annotations__.keys())}" + ) from e + + return passed + + @classmethod + def _init_nodes( + cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None + ) -> dict[str, Node]: + """ + Initialize nodes, filling in any computed values from `fill` and `passed` + """ + if passed is None: + passed = {} - nodes = {k: types[v.type_].from_config(v) for k, v in config.nodes.items()} - nodes = connect_nodes(nodes) - return cls(nodes=nodes) + types = Node.node_types() + graph = config.graph() + graph.prepare() + nodes = {} + while graph.is_active(): + for node in graph.get_ready(): + complete_cfg = cls._complete_node(config.nodes[node], nodes, passed) + nodes[node] = types[complete_cfg.type_].from_specification(complete_cfg) + graph.done(node) -def connect_nodes(nodes: dict[str, Node]) -> dict[str, Node]: + return nodes + + @classmethod + def _init_edges(cls, nodes: dict[str, Node], spec: dict[str, NodeSpecification]) -> list[Edge]: + edges = [] + for node_id, node_spec in spec.items(): + if not node_spec.outputs: + continue + for output in node_spec.outputs: + # FIXME: Ugly and not DRY + target_parts = output["target"].split(".") + target_id, target_slot = ( + (target_parts[0], target_parts[1]) + if len(target_parts) == 2 + else (target_parts[0], None) + ) + edges.append( + Edge( + source_node=nodes[node_id], + target_node=nodes[target_id], + source_slot=output["source"], + target_slot=target_slot, + ) + ) + return edges + + @classmethod + def _complete_node( + cls, node: NodeSpecification, context: dict[str, Node], passed: dict + ) -> NodeSpecification: + """ + Given the context of already-instantiated nodes and passed values, + complete the configuration. + """ + if node.passed: + for cfg_key, passed_key in node.passed.items(): + node.config[cfg_key] = passed[passed_key] + if node.fill: + for cfg_key, fill_key in node.fill.items(): + parts = fill_key.split(".") + val = context[parts[0]] + for part in parts[1:]: + val = getattr(val, part) + node.config[cfg_key] = val + return node + + @classmethod + def _validate_passed(cls, config: PipelineConfig, passed: dict[str, Any]) -> None: + """ + Ensure that the passed values required by the pipeline config are in fact passed + + Raise ConfigurationMismatchError if missing keys, otherwise do nothing + """ + required = cls.passed_values(config) + + for key in required: + if passed is None or key not in passed: + raise ConfigurationMismatchError( + f"Pipeline config requires these values to be passed:\n" + f"{required}\n" + f"But received passed values:\n" + f"{passed}" + ) + + +class _ConfigProtocol(Protocol): """ - Provide references to instantiated nodes + Abstract protocol type to specify that classes consuming the PipelineMixin + must have some config attribute that specifies a pipeline + (without prescribing what that config object must be) """ - for node in nodes.values(): - if node.config.inputs and node.inputs is None: - raise ConfigurationMismatchError( - "inputs found in node configuration, but node type allows no inputs!\n" - f"node: {node.model_dump()}" - ) - if node.config.outputs and not hasattr(node, "outputs"): - raise ConfigurationMismatchError( - "outputs found in node configuration, but node type allows no outputs!\n" - f"node: {node.model_dump()}" - ) + pipeline: Optional[PipelineConfig] = None + + +@dataclass(kw_only=True) +class PipelineMixin: + """Mixin for use with models that have pipelines!""" + + config: _ConfigProtocol + + _pipeline: Optional[Pipeline] = None + _runner: Optional["PipelineRunner"] = None + + @property + def sources(self) -> dict[str, "Source"]: + """Convenience method to access :attr:`.Pipeline.sources`""" + return self.pipeline.sources + + @property + def transforms(self) -> dict[str, "Transform"]: + """Convenience method to access :attr:`.Pipeline.transforms`""" + return self.pipeline.transforms + + @property + def sinks(self) -> dict[str, "Sink"]: + """Convenience method to access :attr:`.Pipeline.sinks`""" + return self.pipeline.sinks + + @property + def pipeline(self) -> Pipeline: + """Instantiated Pipeline from pipeline config""" + if self._pipeline is None: + self._pipeline = Pipeline.from_config(self.config.pipeline) + return self._pipeline + + @property + def runner(self) -> "PipelineRunner": + """ + A :class:`.PipelineRunner` that ... runs the :attr:`~.PipelineMixin.pipeline` ! + + By default, creates a :class:`.SynchronousRunner`, + and this property should be overridden by subclasses that want to specialize + runner instantiation. + """ + from mio.pipeline.runner import SynchronousRunner - node.inputs.update({id: nodes[id] for id in node.config.inputs}) - node.outputs.update({id: nodes[id] for id in node.config.outputs}) - return nodes + if self._runner is None: + self._runner = SynchronousRunner(self.pipeline) + return self._runner diff --git a/mio/models/sdcard.py b/mio/models/sdcard.py index 76638c57..2f40370e 100644 --- a/mio/models/sdcard.py +++ b/mio/models/sdcard.py @@ -4,7 +4,7 @@ for consuming code to use a consistent, introspectable API """ -from typing import Optional +from typing import Literal, Optional from mio.models import MiniscopeConfig from mio.models.buffer import BufferHeader, BufferHeaderFormat @@ -19,7 +19,7 @@ class SectorConfig(MiniscopeConfig): Examples: - >>> sectors = SectorConfig(header=1023, config=1024, data=1025, size=512) + >>> sectors = SectorConfig(header=1023, metadata=1024, data=1025, size=512) >>> sectors.header 1023 >>> # should be 1023 * 512 @@ -33,7 +33,7 @@ class SectorConfig(MiniscopeConfig): """ Holds user settings to configure Miniscope and recording """ - config: int = 1024 + metadata: int = 1024 """ Holds final settings of the actual recording """ @@ -46,22 +46,25 @@ class SectorConfig(MiniscopeConfig): The size of an individual sector """ - def __getattr__(self, item: str) -> int: - """ - Get positions by multiplying by sector size - (__getattr__ is only called if the name can't be found, so we don't need to handle - the base case of the existing attributes) - """ - split = item.split("_") - if len(split) == 2 and split[1] == "pos": - return getattr(self, split[0]) * self.size - else: - raise AttributeError() + @property + def header_pos(self) -> int: + """header * sector size""" + return self.header * self.size + @property + def config_pos(self) -> int: + """metadata * sector size""" + return self.metadata * self.size -class ConfigPositions(MiniscopeConfig): + @property + def data_pos(self) -> int: + """data * sector size""" + return self.data * self.size + + +class MetadataPositions(MiniscopeConfig): """ - Image acquisition configuration positions + Image acquisition metadata positions """ width: int = 0 @@ -94,7 +97,7 @@ class SDBufferHeaderFormat(BufferHeaderFormat): id: str = "sd-buffer-header" - length: int = 0 + length: Literal[0] = 0 linked_list: int = 1 frame_num: int = 2 buffer_count: int = 3 @@ -114,32 +117,30 @@ class SDLayout(MiniscopeConfig, ConfigYAMLMixin): Used by the :class:`.io.WireFreeMiniscope` class to tell it how data on the SD card is laid out. """ - sectors: SectorConfig - write_key0: int = 0x0D7CBA17 - write_key1: int = 0x0D7CBA17 - write_key2: int = 0x0D7CBA17 - write_key3: int = 0x0D7CBA17 - """ - These don't seem to actually be used in the existing reading/writing code, - but we will leave them here for continuity's sake :) - """ word_size: int = 4 """ - I'm actually not sure what this is, but 4 is hardcoded a few times in the - existing notebook and it appears to be used as a word size when - reading from the SD card. + Size of each header word in bytes """ + sectors: SectorConfig header: SDHeaderPositions = SDHeaderPositions() - config: ConfigPositions = ConfigPositions() + metadata: MetadataPositions = MetadataPositions() buffer: SDBufferHeaderFormat = SDBufferHeaderFormat() + header_dtype: str = "uint32" + """ + String form of the numpy dtype that the global and buffer headers are encoded in + """ + buffer_dtype: str = "uint8" + """ + String form of the numpy dtype that each frame buffer is encoded in + """ -class SDConfig(MiniscopeConfig): +class SDMetadata(MiniscopeConfig): """ - The configuration of a recording taken on this SD card. + The metadata of a recording taken on this SD card. - Read from the locations given in :class:`.ConfigPositions` + Read from the locations given in :class:`.MetadataPositions` for an SD card with a given :class:`.SDLayout` """ diff --git a/mio/pipeline/__init__.py b/mio/pipeline/__init__.py new file mode 100644 index 00000000..d2612c56 --- /dev/null +++ b/mio/pipeline/__init__.py @@ -0,0 +1,3 @@ +""" +Runtime classes for pipelines +""" diff --git a/mio/pipeline/runner.py b/mio/pipeline/runner.py new file mode 100644 index 00000000..e66f216c --- /dev/null +++ b/mio/pipeline/runner.py @@ -0,0 +1,263 @@ +""" +Pipeline runners for running pipelines +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator, MutableSequence +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import datetime +from itertools import count +from logging import Logger +from typing import TYPE_CHECKING, Any, Optional, Self +from threading import Event + +from mio import init_logger +from mio.exceptions import PipelineRunningError +from mio.models import Pipeline +from mio.models.pipeline import Edge, Event, Node, Source + +if TYPE_CHECKING: + from mio.sinks import Return + + +@dataclass +class EventStore: + """ + Container class for storing and retrieving events by node and slot + """ + + events: MutableSequence = field(default_factory=list) + counter: count = field(default_factory=count) + + def add(self, values: dict[str, Any], node_id: str) -> None: + """ + Add the result of a :meth:`.Node.process` call to the event store. + + Split the dictionary of values into separate :class:`.Event` s, + store along with current timestamp + + Args: + values (dict): Dict emitted by a :meth:`.Node.process` call + node_id (str): ID of the node that emitted the events + """ + if values is None: + return + timestamp = datetime.now() + for slot, value in values.items(): + self.events.append( + Event( + id=next(self.counter), + timestamp=timestamp, + node_id=node_id, + slot=slot, + value=value, + ) + ) + + def get(self, node_id: str, slot: str) -> Optional[Event]: + """ + Get the event with the matching node_id and slot name + + Returns the most recent matching event, as for now we assume that + each combination of `node_id` and `slot` is emitted only once per processing cycle, + and we assume processing cycles are independent (and thus our events are cleared) + + ``None`` in the case that the event has not been emitted + """ + event = [e for e in self.events if e["node_id"] == node_id and e["slot"] == slot] + return None if len(event) == 0 else event[-1] + + def gather(self, edges: list[Edge]) -> Optional[dict]: + """ + Gather events into a form that can be consumed by a :meth:`.Node.process` method, + given the collection of inbound edges (usually from :meth:`.Pipeline.in_edges` ). + + If none of the requested events have been emitted, return ``None``. + + If all of the requested events have been emitted, return a kwarg-like dict + + If some of the requested events are missing but others are present, + return ``None`` for any missing events. + + .. todo:: + + Add an example + + """ + ret = {} + for edge in edges: + event = self.get(edge.source_node.id, edge.source_slot) + value = None if event is None else event["value"] + ret[edge.target_slot] = value + + return None if not ret or all(val is None for val in ret.values()) else ret + + def clear(self) -> None: + """ + Clear events for this round of processing. + + Does not reset the counter (to continue giving unique ids to the next round's events) + """ + self.events = [] + + +@dataclass +class PipelineRunner(ABC): + """ + Abstract parent class for pipeline runners. + + Pipeline runners handle calling the nodes and passing the + events returned by them to each other. Each runner may do so + however it needs to (synchronously, asynchronously, alone or as part of a cluster, etc.) + as long as it satisfies this abstract interface. + """ + + pipeline: Pipeline + store: EventStore = field(default_factory=EventStore) + + _logger: Logger = field(default_factory=lambda: init_logger("pipeline.runner")) + + @abstractmethod + def process(self) -> Optional[dict[str, Any]]: + """ + Process one step of data from each of the sources, + passing intermediate data to any subscribed nodes in a chain. + + The `process` method normally does not return anything, + except when using the special :class:`.ReturnSink` node - + if there are :class:`.ReturnSink` nodes in a :class:`.Pipeline` graph, + then each call to `process` will return a dictionary with one key + (from the :class:`.ReturnSink`'s `key` config value) and one value for each + :class:`.ReturnSink`. + """ + + @abstractmethod + def start(self) -> None: + """ + Start processing data with the pipeline graph. + + Implementations of this method must raise a :class:`.PipelineRunningError` + if the pipeline has already been started and is running, + (i.e. :meth:`.stop` has not been called, + or the pipeline has not exhausted itself) + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop processing data with the pipeline graph + """ + + @property + @abstractmethod + def running(self) -> bool: + """ + Whether the pipeline is currently running + """ + pass + + def gather_input(self, node: Node) -> Optional[dict[str, Any]]: + """ + Gather input to give to the passed Node from the :attr:`.PipelineRunner.store` + + Returns: + dict: kwargs to pass to :meth:`.Node.process` if matching events are present + dict: empty dict if Node is a :class:`.Source` + None: if no input is available + """ + if isinstance(node, Source): + return {} + + edges = self.pipeline.in_edges(node) + return self.store.gather(edges) + + def gather_return(self) -> Optional[dict]: + """ + If any :class:`.Return` nodes are in the pipeline, + gather their return values to return from :meth:`.PipelineRunner.process` + + Returns: + dict: of the Return sink's key mapped to the returned value, + None: if there are no :class:`.Return` sinks in the pipeline + """ + ret = {} + for sink in self.pipeline.sinks.values(): + if sink.name != "return": + continue + sink: Return + val = sink.get(keep=False) + ret.update(val) + + if not ret: + return None + else: + return ret + + +class SynchronousRunner(PipelineRunner): + """ + Simple, synchronous pipeline runner. + + Just run the nodes in topological order and return from return nodes. + """ + + def __init__(self): + self._running = Event() + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + self.stop() + + def start(self) -> Self: + """ + Start processing data with the pipeline graph. + """ + # TODO: lock for re-entry + if self._running.is_set(): + raise PipelineRunningError("Pipeline is already running!") + + self._running.set() + for node in self.pipeline.nodes.values(): + node.start() + return self + + def stop(self) -> None: + """Stop all nodes processing""" + # TODO: lock to ensure we've been started + for node in self.pipeline.nodes.values(): + node.stop() + self._running.clear() + + @property + def running(self) -> bool: + """Whether the pipeline is currently running""" + return self._running.is_set() + + def process(self) -> Optional[dict[str, Any]]: + """ + Iterate through nodes in topological order, + calling their process method and passing events as they are emitted. + """ + self.store.clear() + + graph = self.pipeline.graph() + graph.prepare() + + while graph.is_active(): + for node_id in graph.get_ready(): + node = self.pipeline.nodes[node_id] + node_input = self.gather_input(node) + if node_input is None: + graph.done(node_id) + self._logger.debug(f"Node {node_id} received no input, skipping") + continue + value = node.process(**node_input) + self.store.add(value, node_id) + graph.done(node_id) + self._logger.debug(f"Node {node_id} emitted %s", value) + + return self.gather_return() diff --git a/mio/sinks/__init__.py b/mio/sinks/__init__.py new file mode 100644 index 00000000..f8d4cd7c --- /dev/null +++ b/mio/sinks/__init__.py @@ -0,0 +1,7 @@ +""" +Sink pipeline nodes that receive but do not emit events +""" + +from mio.sinks.return_ import Return + +__all__ = ["Return"] diff --git a/mio/sinks/return_.py b/mio/sinks/return_.py new file mode 100644 index 00000000..c97a603e --- /dev/null +++ b/mio/sinks/return_.py @@ -0,0 +1,56 @@ +""" +Special Return sink that pipeline runners use to return values from :meth:`.PipelineRunner.process` +""" + +import sys +from typing import Any, Optional + +from mio.models.pipeline import Sink, T + +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + + +class ReturnConfig(TypedDict): + """ + Config for return nodes + """ + + key: str + """The key to use in the returned dictionary""" + + +class Return(Sink): + """ + Special sink node that returns values from a pipeline runner's `process` method + """ + + name = "return" + input_type = Any + + config: ReturnConfig + + _value: Any = None + + def process(self, value: T) -> None: + """ + Store the incoming value to retrieve later with :meth:`.get` + """ + self._value = value + + def get(self, keep: bool = False) -> Optional[dict[str, T]]: + """ + Get the stored value from the process call + + Args: + keep (bool): If ``True``, keep the stored value, otherwise clear it, consume it + """ + if self._value is None: + return None + else: + val = {self.config["key"]: self._value} + if not keep: + self._value = None + return val diff --git a/mio/sources/__init__.py b/mio/sources/__init__.py new file mode 100644 index 00000000..0e84b150 --- /dev/null +++ b/mio/sources/__init__.py @@ -0,0 +1,8 @@ +""" +Source pipeline nodes that emit but do not receive events +""" + +from mio.sources.file import FileSource, BinaryFileSource, SDFileSource +from mio.sources.opalkelly import okDev + +__all__ = ["BinaryFileSource", "FileSource", "SDFileSource", "okDev"] diff --git a/mio/sources/file.py b/mio/sources/file.py index 0011e9ee..854089e2 100644 --- a/mio/sources/file.py +++ b/mio/sources/file.py @@ -2,7 +2,23 @@ File-based data sources """ +import sys +import os +from pathlib import Path +from typing import BinaryIO, ClassVar, Optional +from threading import Lock + +import numpy as np +from pydantic import Field, PrivateAttr + +from mio.exceptions import EndOfRecordingException, ReadHeaderException from mio.models.pipeline import Source +from mio.models.sdcard import SDBufferHeader, SDLayout + +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict class FileSource(Source): @@ -10,18 +26,74 @@ class FileSource(Source): Generic parent class for file sources """ + name = "file-source" -class BinaryLayout: - """Layout for binary files""" - pass +class BinaryFileSource(FileSource): + """ + A FileSource that yields blocks of binary data + """ + name = "binary-file-source" + output_type: ClassVar[bytes] -class BinaryFileSource(FileSource): + path: Path + offset: int = 0 + """ + The offset position from the start of the file from which to consider the "zero point" + """ + block_size: int + """ + Number of bytes to read per processing loop + """ + + _f: Optional[BinaryIO] = None + + def start(self) -> None: + """Open the file, seek to the offset""" + self._f = open(self.path, "rb") # noqa: SIM115 + self._f.seek(self.offset, 0) + + def stop(self) -> None: + """Close the file, remove the reference""" + self._f.close() + self._f = None + + def tell(self) -> int: + """Return the current position in the file""" + if self._f is None: + raise RuntimeError("File has not yet been opened with start") + return self._f.tell() + + @property + def position(self) -> int: + """Property alias for :meth:`.tell` - the current position in the file""" + return self.tell() + + def process(self) -> bytes: + """Return a block of data""" + return self._f.read(self.block_size) + + +class SDFileSourceOutput(TypedDict): + """Output types returned by :meth:`.SDFileSource.process`""" + + header: SDBufferHeader + buffer: np.ndarray + + +class SDFileSourceConfig(TypedDict): + """Config for :class:`.SDFileSource`""" + + path: Path + layout: SDLayout + + +class SDFileSource(FileSource): """ Structured binary file that has - * a global header with config values + * a global header with metadata values * a series of buffers, each containing a * buffer header - with metadata for that buffer and @@ -29,6 +101,191 @@ class BinaryFileSource(FileSource): The source thus has two configurations - * the ``config`` - getter and setter for the actual configuration values of the source + * the ``metadata`` - getter and setter for the actual configuration values of the source * the ``layout`` - how the configuration and data are laid out within the file. + + """ + + name = "sd-file-source" + output_type = SDFileSourceOutput + + config: SDFileSourceConfig + + _f: Optional[BinaryIO] = None + _positions: dict[int, int] = PrivateAttr(default_factory=dict) """ + A mapping between frame number and byte position in the video that makes for + faster seeking :) + + As we read, we store the locations of each frame before reading it. Later, + we can assign to `frame` to seek back to those positions. Assigning to `frame` + works without caching position, but has to manually iterate through each frame. + """ + _last_buffer: int = None + _frame: int = 0 + _lock: Lock = PrivateAttr(default_factory=Lock) + + @property + def offset(self) -> int: + """Start point of the data sector""" + return self.config["layout"].sectors.data_pos + + @property + def header_size(self) -> int: + """ + Number of bytes in a buffer header + + .. note:: + + This isn't guaranteed to be accurate, see: + https://github.com/Aharoni-Lab/Miniscope-v4-Wire-Free/issues/64 + """ + return ( + max([v for v in self.config["layout"].buffer.model_dump().values() if v is not None]) + + 1 + ) * self.config["layout"].word_size + + def start(self) -> None: + """Open the file, seek to the offset""" + self._last_buffer = 0 + + self._f = open(self.config["path"], "rb") # noqa: SIM115 + self._f.seek(self.offset, os.SEEK_SET) + + def stop(self) -> None: + """Close the file, remove the reference""" + self._f.close() + self._f = None + + def tell(self) -> int: + """Return the current position in the file""" + if self._f is None: + raise RuntimeError("File has not yet been opened with start") + with self._lock: + return self._f.tell() + + @property + def size(self) -> int: + """ + Size of the file (in bytes) + """ + current_pos = self.tell() + with self._lock: + self._f.seek(self.offset, os.SEEK_END) + end = self._f.tell() + self._f.seek(current_pos, os.SEEK_SET) + return end + + @property + def position(self) -> int: + """Property alias for :meth:`.tell` - the current position in the file""" + return self.tell() + + def process(self) -> SDFileSourceOutput: + """ + Read a single data buffer, parsing its header and splitting it from the data + """ + start_position = self.tell() + + header = self._read_header(self._f) + self._last_buffer = header.buffer_count + self._frame = header.frame_num + + if header.frame_num not in self._positions: + self._positions[header.frame_num] = start_position + + buffer = self._read_buffer(self._f, header) + buffer = self._trim(buffer, header.data_length) + return {"header": header, "buffer": buffer} + + def _read_header(self, sd: BinaryIO) -> SDBufferHeader: + """ + Given an already open file buffer opened in bytes mode, + seeked to the start of a frame, read the data header + """ + # Get the length of the header from the first word + try: + with self._lock: + dataHeader = np.frombuffer( + sd.read(self.config["layout"].word_size), + dtype=np.dtype(self.config["layout"].header_dtype), + ) + except IndexError as e: + if "index 0 is out of bounds for axis 0 with size 0" in str(e): + # end of file if we are reading from a disk image without any + # additional space on disk + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise e + + # Get the rest of the values in the header + try: + with self._lock: + dataHeader = np.append( + dataHeader, + np.frombuffer( + sd.read(int(dataHeader[0]) * self.config["layout"].word_size), + dtype=np.dtype(self.config["layout"].header_dtype), + ), + ) + except ValueError as e: + if "read length must be non-negative" in str(e): + # end of file! Value error thrown because the dataHeader will be + # blank, and thus have a value of 0 for the header size, and we + # can't read 0 from the card. + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise e + + # use construct because we're already sure these are ints from the numpy casting + # https://docs.pydantic.dev/latest/usage/models/#creating-models-without-validation + try: + return SDBufferHeader.from_format( + dataHeader, self.config["layout"].buffer, construct=True + ) + except IndexError as e: + if self.tell() >= self.size: + raise EndOfRecordingException("Reached the end of the video!") from None + else: + raise ReadHeaderException( + "Could not read header, expected header to have " + f"{len(self.config['layout'].buffer.model_dump().keys())} fields, " + f"got {len(dataHeader)}. Likely mismatch between specified " + "and actual SD Card layout or reached end of data.\n" + f"Header Data: {dataHeader}" + ) from e + + def _read_buffer(self, sd: BinaryIO, header: SDBufferHeader) -> np.ndarray: + with self._lock: + return np.frombuffer(sd.read(self._data_read_size(header)), dtype=np.uint8) + + def _data_read_size(self, header: SDBufferHeader) -> int: + """ + After the header, how many bytes to read for the data in a buffer + """ + # blocks are quantized by sector size, so get min number of blocks that cover the data + n_blocks = np.ceil( + (header.data_length + (header.length * self.config["layout"].word_size)) + / self.config["layout"].sectors.size + ) + # expand back to n bytes + sector_size = n_blocks * self.config["layout"].sectors.size + # subtract length of header + return int(sector_size - (header.length * self.config["layout"].word_size)) + + def _trim(self, data: np.ndarray, expected_size: int) -> np.ndarray: + """ + Trim or pad an array to match an expected size. + + This should be the case most of the time - + number of bytes in a memory sector won't match bytes in a buffer + """ + if data.shape[0] != expected_size: + # trim if too long + if data.shape[0] > expected_size: + data = data[0:expected_size] + # pad if too short + else: + data = np.pad(data, (0, expected_size - data.shape[0])) + + return data diff --git a/mio/transforms/__init__.py b/mio/transforms/__init__.py new file mode 100644 index 00000000..4f14cf6d --- /dev/null +++ b/mio/transforms/__init__.py @@ -0,0 +1,7 @@ +""" +Transform pipeline nodes that both receive and emit events +""" + +from mio.transforms.frame import MergeBuffers + +__all__ = ["MergeBuffers"] diff --git a/mio/transforms/frame.py b/mio/transforms/frame.py new file mode 100644 index 00000000..c72cc8aa --- /dev/null +++ b/mio/transforms/frame.py @@ -0,0 +1,68 @@ +""" +Nodes that receive and emit frames +""" + +import sys +from typing import Optional + +import numpy as np + +from mio.models.buffer import BufferHeader +from mio.models.data import Frame +from mio.models.pipeline import Transform + +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + + +class MergeBuffersConfig(TypedDict): + """Configuration for :class:`.MergeBuffers`""" + + width: int + height: int + + +class MergeBuffersOutput(TypedDict): + """Output returned by :meth:`.MergeBuffers.process`""" + + frame: Frame + + +class MergeBuffers(Transform): + """ + Merge sequential frame buffers into a single frame + """ + + name = "merge-buffers" + input_type = tuple[BufferHeader, np.ndarray] + output_type = MergeBuffersOutput + config: MergeBuffersConfig + + _headers: Optional[list[BufferHeader]] = None + _buffers: Optional[list[np.ndarray]] = None + _last_buffer_n: Optional[int] = None + + def start(self) -> None: + """Init private containers""" + self._headers = [] + self._buffers = [] + self._last_buffer_n = 0 + + def process(self, header: BufferHeader, buffer: np.ndarray) -> Optional[MergeBuffersOutput]: + """ + Receive a header/buffer pair. If the frame buffer count has cycled back to zero, + merge into a completed frame + """ + if header.frame_buffer_count == 0 and self._last_buffer_n >= 0: + frame = np.concat(self._buffers).reshape((self.config["width"], self.config["height"])) + headers = self._headers.copy() + self._headers = [header] + self._buffers = [buffer] + return {"frame": Frame.model_construct(frame=frame, headers=headers)} + else: + self._last_buffer_n = header.frame_buffer_count + self._headers.append(header) + self._buffers.append(buffer) + return None diff --git a/mio/types.py b/mio/types.py index a77af047..871004ba 100644 --- a/mio/types.py +++ b/mio/types.py @@ -17,7 +17,7 @@ CONFIG_ID_PATTERN = r"[\w\-\/#]+" """ -Any alphanumeric string (\w), as well as +Any alphanumeric string (\\w), as well as - ``-`` - ``/`` - ``#`` diff --git a/notebooks/Wire-Free-DAQ.ipynb b/notebooks/Wire-Free-DAQ.ipynb index 11c06957..04211735 100644 --- a/notebooks/Wire-Free-DAQ.ipynb +++ b/notebooks/Wire-Free-DAQ.ipynb @@ -97,7 +97,7 @@ "f = open(driveName, \"rb\") # Open drive\n", "\n", "# Make sure this is the correct drive\n", - "# Read SD Card header and config sectors\n", + "# Read SD Card header and metadata sectors\n", "f.seek(headerSector * sectorSize, 0) # Move to correct sector\n", "headerSectorData = np.fromstring(f.read(sectorSize), dtype=np.uint32)\n", "if ((WRITE_KEY0 == headerSectorData[0]) and (WRITE_KEY1 == headerSectorData[1]) and (WRITE_KEY2 == headerSectorData[2]) and (WRITE_KEY3 == headerSectorData[3])):\n", diff --git a/tests/data/config/reference_a.yaml b/tests/data/config/reference_a.yaml new file mode 100644 index 00000000..0828ee26 --- /dev/null +++ b/tests/data/config/reference_a.yaml @@ -0,0 +1,3 @@ +id: reference-a +value: ["hey", "sup"] +config: reference-b \ No newline at end of file diff --git a/tests/data/config/reference_b.yaml b/tests/data/config/reference_b.yaml new file mode 100644 index 00000000..34d2cf46 --- /dev/null +++ b/tests/data/config/reference_b.yaml @@ -0,0 +1,2 @@ +id: reference-b +value: [1,2,3] \ No newline at end of file diff --git a/tests/fixtures.py b/tests/fixtures.py index 8f162e13..f5e9cc63 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -14,20 +14,20 @@ @pytest.fixture -def wirefree() ->WireFreeMiniscope: +def wirefree() -> WireFreeMiniscope: """ SDCard with wirefree layout pointing to the sample data file """ sd_path = Path(__file__).parent.parent / "data" / "wirefree_example.img" - sdcard = WireFreeMiniscope(drive=sd_path, layout="wirefree-sd-layout") + sdcard = WireFreeMiniscope(drive=sd_path, config="wirefree-default") return sdcard @pytest.fixture def wirefree_battery() -> WireFreeMiniscope: sd_path = Path(__file__).parent.parent / "data" / "wirefree_battery_sample.img" - sdcard = WireFreeMiniscope(drive=sd_path, layout="wirefree-sd-layout-battery") + sdcard = WireFreeMiniscope(drive=sd_path, config="wirefree-sd-layout-battery") return sdcard diff --git a/tests/test_devices/test_wirefree.py b/tests/test_devices/test_wirefree.py index 73074cd8..d1415d35 100644 --- a/tests/test_devices/test_wirefree.py +++ b/tests/test_devices/test_wirefree.py @@ -39,8 +39,8 @@ def test_read(wirefree): # the frame is the right shape assert len(frame.shape) == 2 - assert frame.shape[0] == wirefree.config.height - assert frame.shape[1] == wirefree.config.width + assert frame.shape[0] == wirefree.metadata.height + assert frame.shape[1] == wirefree.metadata.width # assert they're not all zeros - ie. we read some data assert frame.any() @@ -97,10 +97,10 @@ def test_relative_path(): rel_path = abs_child.relative_to(abs_cwd) assert not rel_path.is_absolute() - sdcard = WireFreeMiniscope(drive=rel_path, layout="wirefree-sd-layout") + sdcard = WireFreeMiniscope(drive=rel_path) - # check we can do something basic like read config - assert sdcard.config is not None + # check we can do something basic like read metadata + assert sdcard.metadata is not None # check it remains relative after init assert not sdcard.drive.is_absolute() @@ -108,8 +108,8 @@ def test_relative_path(): # now try with an absolute path abs_path = rel_path.resolve() assert abs_path.is_absolute() - sdcard_abs = WireFreeMiniscope(drive=abs_path, layout="wirefree-sd-layout") - assert sdcard_abs.config is not None + sdcard_abs = WireFreeMiniscope(drive=abs_path) + assert sdcard_abs.metadata is not None assert sdcard_abs.drive.is_absolute() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 7976ecbe..1260a7d4 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -155,3 +155,25 @@ def test_peek_yaml(key, expected, root, first, yaml_config): _ = yaml_peek(key, yaml_file, root=root, first=first) else: assert yaml_peek(key, yaml_file, root=root, first=first) == expected + + +def test_yamlmixin_core_schema(): + """ + The __get_pydantic_core_schema__ method in the ConfigYamlMixin + lets us use ids for keys everywhere + """ + + class B(ConfigYAMLMixin): + value: list[int] + + class A(ConfigYAMLMixin): + value: list[str] + config: B + + class Container(BaseModel): + model: A + + instance = Container(model="reference-a") + assert isinstance(instance.model.config, B) + assert instance.model.value == ["hey", "sup"] + assert instance.model.config.value == [1, 2, 3] diff --git a/tests/test_sdcard.py b/tests/test_sdcard.py index 7a3a8338..2854d3dc 100644 --- a/tests/test_sdcard.py +++ b/tests/test_sdcard.py @@ -8,7 +8,7 @@ def random_sectorconfig(): return SectorConfig( header=np.random.randint(0, 2048), - config=np.random.randint(0, 2048), + metadata=np.random.randint(0, 2048), data=np.random.randint(0, 2048), size=np.random.randint(0, 2048), ) @@ -20,7 +20,7 @@ def test_get_sector_position(random_sectorconfig): """ sectors = random_sectorconfig assert sectors.header_pos == sectors.header * sectors.size - assert sectors.config_pos == sectors.config * sectors.size + assert sectors.config_pos == sectors.metadata * sectors.size assert sectors.data_pos == sectors.data * sectors.size # We should raise an attribute error if we try and get a nonexistent one