diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index e100f748..f8f8661d 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,7 @@ from .background_removal import remove_background from .denoise import denoise from .glow_removal import GlowRemover +from .motion import Stabilizer from .r_estimate import SizeEst -from .rigid_stabilization import RigidStabilizer -__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst] +__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst] diff --git a/src/cala/nodes/prep/rigid_stabilization.py b/src/cala/nodes/prep/motion.py similarity index 67% rename from src/cala/nodes/prep/rigid_stabilization.py rename to src/cala/nodes/prep/motion.py index 0718e608..8903143f 100644 --- a/src/cala/nodes/prep/rigid_stabilization.py +++ b/src/cala/nodes/prep/motion.py @@ -1,14 +1,18 @@ +from collections.abc import Callable +from logging import Logger from typing import Annotated as A +from typing import Literal import cv2 import numpy as np import xarray as xr -from noob import Name -from noob.node import Node -from pydantic import BaseModel, Field +from noob import Name, process_method +from pydantic import BaseModel, ConfigDict, Field +from skimage.filters import butterworth, difference_of_gaussians, sato, scharr from skimage.registration import phase_cross_correlation from cala.assets import Frame +from cala.logging import init_logger from cala.models import AXIS @@ -17,15 +21,44 @@ class Shift(BaseModel): height: float -class RigidStabilizer(Node): +class Stabilizer(BaseModel): drift_speed: float = 1.0 - kwargs: dict = Field(default_factory=dict) + pcc_kwargs: dict = Field(default_factory=dict) + + pcc_filter: Literal["butterworth", "difference_of_gaussians", "sato", "scharr"] + filter_kwargs: dict = Field(default_factory=dict) + filt_fn: Callable = None _anchor_last_applied_on: int = None - anchor_frame_: Frame = None - previous_frame_: Frame = None + anchor_frame_: xr.DataArray = None + previous_frame_: xr.DataArray = None motions_: list[Shift] = None + logger: Logger = init_logger(__name__) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @process_method + def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]: + if self.is_first_frame(frame): + return frame + + curr_frame = frame.array + + shift = self._compute_shift(curr_frame) + shifted_frame = self._apply_shift(curr_frame, shift) + + self.previous_frame_ = shifted_frame + + if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): + self.anchor_frame_ = self._update_anchor(shifted_frame) + + self.motions_.append(shift) + + return Frame.from_array( + xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) + ) + def is_first_frame(self, frame: Frame) -> bool: if ( (self.anchor_frame_ is not None) @@ -40,8 +73,8 @@ def is_first_frame(self, frame: Frame) -> bool: and (self.motions_ is None) ): self._anchor_last_applied_on = 0 - self.anchor_frame_ = frame - self.previous_frame_ = frame + self.anchor_frame_ = frame.array + self.previous_frame_ = frame.array self.motions_ = [Shift(width=0, height=0)] return True @@ -53,7 +86,7 @@ def is_first_frame(self, frame: Frame) -> bool: f"{self.motions_ = }" ) - def compute_shift(self, curr_frame: xr.DataArray) -> Shift: + def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: """ The simplest way to stabilize streaming frames would be to have a single reference frame (the first frame) and shift all subsequent frames against this reference frame. @@ -91,14 +124,20 @@ def compute_shift(self, curr_frame: xr.DataArray) -> Shift: if: abs(sequential_shift - anchor_shift) < drift_speed then: true_shift = anchor_shift """ + filters = { + "butterworth": butterworth, + "difference_of_gaussians": difference_of_gaussians, + "sato": sato, + "scharr": scharr, + } + filt_fn = filters[self.pcc_filter] - anchor_shift, a_error, _ = phase_cross_correlation( - self.anchor_frame_.array, curr_frame.values, **self.kwargs - ) + curr = filt_fn(curr_frame, **self.filter_kwargs) + prev = filt_fn(self.previous_frame_, **self.filter_kwargs) + anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs) - sequent_shift, s_error, _ = phase_cross_correlation( - self.previous_frame_.array, curr_frame.values, **self.kwargs - ) + anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) + sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) @@ -113,7 +152,7 @@ def compute_shift(self, curr_frame: xr.DataArray) -> Shift: return Shift(height=shift[0], width=shift[1]) - def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: + def _apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: # Define the affine transformation matrix for translation M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) @@ -122,48 +161,12 @@ def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: M, (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=np.nan, + borderMode=cv2.BORDER_REPLICATE, + # borderValue=0, ) - shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) - def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: + def _update_anchor(self, frame: xr.DataArray) -> xr.DataArray: curr_index = frame[AXIS.frame_coord].item() - return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) - - def process(self, frame: Frame) -> A[Frame, Name("frame")]: - if self.is_first_frame(frame): - return frame - - curr_frame = frame.array - - shift = self.compute_shift(curr_frame) - shifted_frame = self.apply_shift(curr_frame, shift) - - self.previous_frame_ = Frame.from_array(shifted_frame) - - if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): - self.anchor_frame_.array = self.update_anchor(shifted_frame) - - self.motions_.append(shift) - - return Frame.from_array( - xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) - ) - - def get_info(self) -> dict: - """Get information about the current state. - - Returns - ------- - dict - Dictionary containing current statistics - """ - return { - "_anchor_last_applied_on": self._anchor_last_applied_on, - "anchor_frame_": self.anchor_frame_, - "previous_frame_": self.previous_frame_, - "motion_": self.motions_, - } + return (self.anchor_frame_ * curr_index + frame) / (curr_index + 1) diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 3cdb055a..f0082807 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -64,11 +64,20 @@ nodes: lines: type: cala.nodes.prep.hlines.remove depends: - - frame: glow.frame + - frame: denoise.frame + motion: # needs to take place before glow, after lines + type: cala.nodes.prep.Stabilizer + params: + drift_speed: 0.5 + pcc_filter: difference_of_gaussians + filter_kwargs: + low_sigma: 1 + depends: + - frame: lines.frame glow: type: cala.nodes.prep.GlowRemover depends: - - frame: denoise.frame + - frame: motion.frame smooth: type: cala.nodes.prep.denoise params: @@ -77,16 +86,11 @@ nodes: ksize: [ 7, 7 ] sigmaX: 1.5 depends: - - frame: lines.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 0.5 - depends: - - frame: smooth.frame + - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: + hardset_radius: 6 noise_threshold: 2.0 n_frames: 30 log_kwargs: @@ -96,14 +100,14 @@ nodes: threshold: 0.2 overlap: 0.5 depends: - - frame: motion.frame + - frame: smooth.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: motion.frame + - frame: smooth.frame #PREPROCESS ENDS # FRAME UPDATE BEGINS @@ -115,19 +119,19 @@ nodes: depends: - traces: assets.traces - footprints: assets.footprints - - frame: motion.frame + - frame: smooth.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: motion.frame + - frame: smooth.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: motion.frame + - frame: smooth.frame - new_traces: trace_frame.latest_trace footprints_frame: type: cala.nodes.footprints.Footprinter @@ -171,7 +175,7 @@ nodes: nmf: type: cala.nodes.detect.SliceNMF params: - min_frames: 30 + min_frames: 100 detect_thresh: 2.0 reprod_tol: 0.005 depends: @@ -222,4 +226,4 @@ nodes: type: return depends: - raw: frame.value - - prep: motion.frame \ No newline at end of file + - prep: smooth.frame \ No newline at end of file diff --git a/tests/test_prep/test_rigid_stabilization.py b/tests/test_prep/test_motion.py similarity index 78% rename from tests/test_prep/test_rigid_stabilization.py rename to tests/test_prep/test_motion.py index 27214169..f36c2930 100644 --- a/tests/test_prep/test_rigid_stabilization.py +++ b/tests/test_prep/test_motion.py @@ -3,16 +3,24 @@ from noob.node import NodeSpecification from cala.models import AXIS -from cala.nodes.prep.rigid_stabilization import RigidStabilizer, Shift +from cala.nodes.prep.motion import Shift, Stabilizer from cala.testing.toy import FrameDims, Position, Toy -@pytest.mark.parametrize("params", [{"drift_speed": 1, "kwargs": {"upsample_factor": 100}}]) +@pytest.mark.parametrize( + "params", + [ + { + "drift_speed": 1, + "pcc_kwargs": {"upsample_factor": 100}, + "pcc_filter": "difference_of_gaussians", + "filter_kwargs": {"low_sigma": 1}, + } + ], +) def test_motion_estimation(params) -> None: - stab = RigidStabilizer.from_specification( - NodeSpecification(id="test", type="cala.nodes.prep.RigidStabilizer", params=params) - ) + stab = Stabilizer(**params) n_frames = 50 @@ -42,7 +50,7 @@ def test_motion_estimation(params) -> None: width=shifts[0].width - shift.width, height=shifts[0].height - shift.height ) } - result.append(stab.process(frame)) + result.append(stab.stabilize(frame)) estimate = -np.array([(m.width, m.height) for m in stab.motions_]) expected = np.array([(m.width, m.height) for m in shifts]) - (shifts[0].width, shifts[0].height)