From 642983e26bf48f47af55c3bda36a7d072ea32206 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 27 Aug 2025 18:00:32 -0700 Subject: [PATCH 1/6] feat: rename stab to motion --- src/cala/nodes/prep/__init__.py | 2 +- src/cala/nodes/prep/motion.py | 235 ++++++++++++++++++ src/cala/nodes/prep/rigid_stabilization.py | 169 ------------- ..._rigid_stabilization.py => test_motion.py} | 2 +- 4 files changed, 237 insertions(+), 171 deletions(-) create mode 100644 src/cala/nodes/prep/motion.py delete mode 100644 src/cala/nodes/prep/rigid_stabilization.py rename tests/test_prep/{test_rigid_stabilization.py => test_motion.py} (95%) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index e100f748..5a186263 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -2,6 +2,6 @@ from .denoise import denoise from .glow_removal import GlowRemover from .r_estimate import SizeEst -from .rigid_stabilization import RigidStabilizer +from .motion import RigidStabilizer __all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst] diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py new file mode 100644 index 00000000..cb98c891 --- /dev/null +++ b/src/cala/nodes/prep/motion.py @@ -0,0 +1,235 @@ +from logging import Logger +from typing import Annotated as A + +import cv2 +import numpy as np +from noob import Name, process_method +from pydantic import BaseModel, Field, ConfigDict + +from cala.assets import Frame +from cala.logging import init_logger +from cala.models import AXIS +from cala.util import package_frame + + +class Shift(BaseModel): + x: float # width + y: float # height + a: float # angle + + +class RigidStabilizer(BaseModel): + drift_speed: float = 1.0 + kwargs: dict = Field(default_factory=dict) + + _anchor_last_applied_on: int = None + anchor_frame_: np.ndarray = None + previous_frame_: np.ndarray = None + previous_keypoints_: np.ndarray = None + motions_: list[Shift] = Field(default_factory=list) + + clahe: cv2.CLAHE = Field(cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)), exclude=True) + logger: Logger = init_logger(__name__) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @process_method + def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]: + arr = frame.array.values.astype(np.uint8) + farr_opt = self.clahe.apply(arr) + + if self.previous_keypoints_ is None: + self.previous_keypoints_ = self._find_keypoints(farr_opt) + + else: + shift = self._generate_shift(farr_opt) + self.motions_.append(shift) + self.previous_keypoints_ = self._find_keypoints(farr_opt) + arr = self._apply_shift(arr, shift) + + self.previous_frame_ = farr_opt + frame = package_frame(arr, index=frame.array[AXIS.frame_coord].item()) + + return frame + + @staticmethod + def _find_keypoints(frame: np.ndarray) -> np.ndarray: + """calculate and save GFTT keypoints for current frame""" + return cv2.goodFeaturesToTrack( + frame, + maxCorners=200, + qualityLevel=0.05, + minDistance=30.0, + blockSize=3, + mask=None, + useHarrisDetector=False, + k=0.04, + ) + + def _generate_shift(self, frame: np.ndarray) -> Shift: + # calculate optical flow using Lucas-Kanade differential method + curr_kps, status, error = cv2.calcOpticalFlowPyrLK( + self.previous_frame_, frame, self.previous_keypoints_, None + ) + + # select only valid keypoints + valid_curr_kps = curr_kps[status == 1] # current + valid_prev_kps = self.previous_keypoints_[status == 1] # previous + + # calculate optimal affine transformation between previous_2_current keypoints + shift = cv2.estimateAffinePartial2D(valid_prev_kps, valid_curr_kps)[0] + + if shift is not None: + # translation in x direction + dx = shift[0, 2] + # translation in y direction + dy = shift[1, 2] + # rotation + da = np.arctan2(shift[1, 0], shift[0, 0]) + else: + dx = dy = da = 0 + + return Shift(x=dx, y=dy, a=da) + + def _apply_shift(self, frame: np.ndarray, shift: Shift) -> np.ndarray: + """ + An internal method that applies affine transformation to the given frame + from previously calculated transformations + """ + # building 2x3 transformation matrix from extracted transformations + shift_op = np.zeros((2, 3), np.float32) + shift_op[0, 0] = np.cos(shift.a) + shift_op[0, 1] = -np.sin(shift.a) + shift_op[1, 0] = np.sin(shift.a) + shift_op[1, 1] = np.cos(shift.a) + shift_op[0, 2] = shift.x + shift_op[1, 2] = shift.y + + # Applying an affine transformation to the frame + return cv2.warpAffine(frame, shift_op, frame.shape[::-1]) + + # def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: + # curr_index = frame[AXIS.frame_coord].item() + # + # return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) + + # if self.is_first_frame(frame): + # return frame + # + # curr_frame = frame.array + # + # shift = self.compute_shift(curr_frame) + # shifted_frame = self.apply_shift(curr_frame, shift) + # + # self.previous_frame_ = Frame.from_array(shifted_frame) + # + # if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): + # self.anchor_frame_.array = self.update_anchor(shifted_frame) + # + # self.motions_.append(shift) + # + # return Frame.from_array( + # xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) + # ) + + # def is_first_frame(self, frame: Frame) -> bool: + # if ( + # (self.anchor_frame_ is not None) + # and (self.previous_frame_ is not None) + # and (self.motions_ is not None) + # ): + # return False + # + # elif ( + # (self.anchor_frame_ is None) + # and (self.previous_frame_ is None) + # and (self.motions_ is None) + # ): + # self._anchor_last_applied_on = 0 + # self.anchor_frame_ = frame + # self.previous_frame_ = frame + # self.motions_ = [Shift(width=0, height=0)] + # return True + # + # else: + # raise NotImplementedError( + # f"Undefined State: Only some of the attributes are initialized: " + # f"{self.anchor_frame_ = }, " + # f"{self.previous_frame_ = }, " + # f"{self.motions_ = }" + # ) + + # def compute_shift(self, curr_frame: xr.DataArray) -> Shift: + # """ + # The simplest way to stabilize streaming frames would be to have a single reference frame + # (the first frame) and shift all subsequent frames against this reference frame. + # However, as different sets of neurons are active at different times, frames that are far + # apart in time sometimes have few common objects to lock onto. + # + # To mitigate this, we could stabilize all frames against the previous frame, domino style. + # However, errors stack up and a gradual shift takes place with this strategy. + # + # This algorithm attempts to solve this issue by mixing the two strategies: + # 1. We default to stabilizing against the anchor frame. + # 2. If we begin losing features to lock onto, the anchor shift will explode to + # an unpredictable value. + # 3. In this case, we fall back to the sequential shift. + # 4. During the "anchor mismatch period", the sequential shift will slowly drift. + # 5. And then, as old features surface again, the anchor will lock in again. + # 6. However, the sequential shift will have drifted. + # 7. We try to estimate how fast it would drift with drift_speed. + # 8. Then, the TRUE shift is within the range of sequential_shift +- drift. + # 9. Thus, we assume that if anchor_shift falls within this range, the anchor shift + # has returned to the TRUE shift. + # + # in mathematical notations, this translates to: + # if: + # sequential_shift - drift_speed < anchor_shift < sequential_shift + drift_speed + # + # then: + # true_shift = anchor_shift + # + # the inequality is same as: + # sequential_shift - anchor_shift < drift_speed + # anchor_shift - sequential_shift < drift_speed + # + # which summarizes to: + # if: abs(sequential_shift - anchor_shift) < drift_speed + # then: true_shift = anchor_shift + # """ + # + # anchor_shift, a_error, _ = phase_cross_correlation( + # self.anchor_frame_.array, curr_frame.values, **self.kwargs + # ) + # + # sequent_shift, s_error, _ = phase_cross_correlation( + # self.previous_frame_.array, curr_frame.values, **self.kwargs + # ) + # + # shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) + # + # frame_idx = curr_frame[AXIS.frame_coord].item() + # drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed + # + # if shift_diff > drift_threshold: + # shift = sequent_shift + # else: + # shift = anchor_shift + # self._anchor_last_applied_on = frame_idx + # + # return Shift(height=shift[0], width=shift[1]) + + # def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: + # # Define the affine transformation matrix for translation + # M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) + # + # shifted_frame = cv2.warpAffine( + # frame.values, + # M, + # (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), + # flags=cv2.INTER_LINEAR, + # borderMode=cv2.BORDER_CONSTANT, + # borderValue=np.nan, + # ) + # shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) + # return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) diff --git a/src/cala/nodes/prep/rigid_stabilization.py b/src/cala/nodes/prep/rigid_stabilization.py deleted file mode 100644 index 0718e608..00000000 --- a/src/cala/nodes/prep/rigid_stabilization.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Annotated as A - -import cv2 -import numpy as np -import xarray as xr -from noob import Name -from noob.node import Node -from pydantic import BaseModel, Field -from skimage.registration import phase_cross_correlation - -from cala.assets import Frame -from cala.models import AXIS - - -class Shift(BaseModel): - width: float - height: float - - -class RigidStabilizer(Node): - drift_speed: float = 1.0 - kwargs: dict = Field(default_factory=dict) - - _anchor_last_applied_on: int = None - anchor_frame_: Frame = None - previous_frame_: Frame = None - motions_: list[Shift] = None - - def is_first_frame(self, frame: Frame) -> bool: - if ( - (self.anchor_frame_ is not None) - and (self.previous_frame_ is not None) - and (self.motions_ is not None) - ): - return False - - elif ( - (self.anchor_frame_ is None) - and (self.previous_frame_ is None) - and (self.motions_ is None) - ): - self._anchor_last_applied_on = 0 - self.anchor_frame_ = frame - self.previous_frame_ = frame - self.motions_ = [Shift(width=0, height=0)] - return True - - else: - raise NotImplementedError( - f"Undefined State: Only some of the attributes are initialized: " - f"{self.anchor_frame_ = }, " - f"{self.previous_frame_ = }, " - f"{self.motions_ = }" - ) - - def compute_shift(self, curr_frame: xr.DataArray) -> Shift: - """ - The simplest way to stabilize streaming frames would be to have a single reference frame - (the first frame) and shift all subsequent frames against this reference frame. - However, as different sets of neurons are active at different times, frames that are far - apart in time sometimes have few common objects to lock onto. - - To mitigate this, we could stabilize all frames against the previous frame, domino style. - However, errors stack up and a gradual shift takes place with this strategy. - - This algorithm attempts to solve this issue by mixing the two strategies: - 1. We default to stabilizing against the anchor frame. - 2. If we begin losing features to lock onto, the anchor shift will explode to - an unpredictable value. - 3. In this case, we fall back to the sequential shift. - 4. During the "anchor mismatch period", the sequential shift will slowly drift. - 5. And then, as old features surface again, the anchor will lock in again. - 6. However, the sequential shift will have drifted. - 7. We try to estimate how fast it would drift with drift_speed. - 8. Then, the TRUE shift is within the range of sequential_shift +- drift. - 9. Thus, we assume that if anchor_shift falls within this range, the anchor shift - has returned to the TRUE shift. - - in mathematical notations, this translates to: - if: - sequential_shift - drift_speed < anchor_shift < sequential_shift + drift_speed - - then: - true_shift = anchor_shift - - the inequality is same as: - sequential_shift - anchor_shift < drift_speed - anchor_shift - sequential_shift < drift_speed - - which summarizes to: - if: abs(sequential_shift - anchor_shift) < drift_speed - then: true_shift = anchor_shift - """ - - anchor_shift, a_error, _ = phase_cross_correlation( - self.anchor_frame_.array, curr_frame.values, **self.kwargs - ) - - sequent_shift, s_error, _ = phase_cross_correlation( - self.previous_frame_.array, curr_frame.values, **self.kwargs - ) - - shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) - - frame_idx = curr_frame[AXIS.frame_coord].item() - drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed - - if shift_diff > drift_threshold: - shift = sequent_shift - else: - shift = anchor_shift - self._anchor_last_applied_on = frame_idx - - return Shift(height=shift[0], width=shift[1]) - - def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: - # Define the affine transformation matrix for translation - M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) - - shifted_frame = cv2.warpAffine( - frame.values, - M, - (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=np.nan, - ) - shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) - return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) - - def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: - curr_index = frame[AXIS.frame_coord].item() - - return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) - - def process(self, frame: Frame) -> A[Frame, Name("frame")]: - if self.is_first_frame(frame): - return frame - - curr_frame = frame.array - - shift = self.compute_shift(curr_frame) - shifted_frame = self.apply_shift(curr_frame, shift) - - self.previous_frame_ = Frame.from_array(shifted_frame) - - if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): - self.anchor_frame_.array = self.update_anchor(shifted_frame) - - self.motions_.append(shift) - - return Frame.from_array( - xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) - ) - - def get_info(self) -> dict: - """Get information about the current state. - - Returns - ------- - dict - Dictionary containing current statistics - """ - return { - "_anchor_last_applied_on": self._anchor_last_applied_on, - "anchor_frame_": self.anchor_frame_, - "previous_frame_": self.previous_frame_, - "motion_": self.motions_, - } diff --git a/tests/test_prep/test_rigid_stabilization.py b/tests/test_prep/test_motion.py similarity index 95% rename from tests/test_prep/test_rigid_stabilization.py rename to tests/test_prep/test_motion.py index 27214169..fb6d7247 100644 --- a/tests/test_prep/test_rigid_stabilization.py +++ b/tests/test_prep/test_motion.py @@ -3,7 +3,7 @@ from noob.node import NodeSpecification from cala.models import AXIS -from cala.nodes.prep.rigid_stabilization import RigidStabilizer, Shift +from cala.nodes.prep.motion import RigidStabilizer, Shift from cala.testing.toy import FrameDims, Position, Toy From ba55a80dc7bf785f403072c970afa7c2e3bf9bce Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 27 Aug 2025 18:11:37 -0700 Subject: [PATCH 2/6] feat: rename rigidstab to stab --- src/cala/nodes/prep/__init__.py | 4 +- src/cala/nodes/prep/motion.py | 317 ++++++++++++-------------------- tests/test_prep/test_motion.py | 8 +- 3 files changed, 126 insertions(+), 203 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 5a186263..e068fab8 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -2,6 +2,6 @@ from .denoise import denoise from .glow_removal import GlowRemover from .r_estimate import SizeEst -from .motion import RigidStabilizer +from .motion import Stabilizer -__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst] +__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst] diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index cb98c891..27e6a617 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -3,233 +3,158 @@ import cv2 import numpy as np +import xarray as xr from noob import Name, process_method from pydantic import BaseModel, Field, ConfigDict +from skimage.registration import phase_cross_correlation from cala.assets import Frame from cala.logging import init_logger from cala.models import AXIS -from cala.util import package_frame class Shift(BaseModel): - x: float # width - y: float # height - a: float # angle + width: float + height: float -class RigidStabilizer(BaseModel): +class Stabilizer(BaseModel): drift_speed: float = 1.0 kwargs: dict = Field(default_factory=dict) _anchor_last_applied_on: int = None - anchor_frame_: np.ndarray = None - previous_frame_: np.ndarray = None - previous_keypoints_: np.ndarray = None - motions_: list[Shift] = Field(default_factory=list) + anchor_frame_: xr.DataArray = None + previous_frame_: xr.DataArray = None + motions_: list[Shift] = None - clahe: cv2.CLAHE = Field(cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)), exclude=True) logger: Logger = init_logger(__name__) model_config = ConfigDict(arbitrary_types_allowed=True) @process_method def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]: - arr = frame.array.values.astype(np.uint8) - farr_opt = self.clahe.apply(arr) + if self.is_first_frame(frame): + return frame - if self.previous_keypoints_ is None: - self.previous_keypoints_ = self._find_keypoints(farr_opt) + curr_frame = frame.array + + shift = self._compute_shift(curr_frame) + shifted_frame = self._apply_shift(curr_frame, shift) + + self.previous_frame_ = shifted_frame + + if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): + self.anchor_frame_ = self._update_anchor(shifted_frame) + + self.motions_.append(shift) + + return Frame.from_array( + xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) + ) + + def is_first_frame(self, frame: Frame) -> bool: + if ( + (self.anchor_frame_ is not None) + and (self.previous_frame_ is not None) + and (self.motions_ is not None) + ): + return False + + elif ( + (self.anchor_frame_ is None) + and (self.previous_frame_ is None) + and (self.motions_ is None) + ): + self._anchor_last_applied_on = 0 + self.anchor_frame_ = frame.array + self.previous_frame_ = frame.array + self.motions_ = [Shift(width=0, height=0)] + return True else: - shift = self._generate_shift(farr_opt) - self.motions_.append(shift) - self.previous_keypoints_ = self._find_keypoints(farr_opt) - arr = self._apply_shift(arr, shift) - - self.previous_frame_ = farr_opt - frame = package_frame(arr, index=frame.array[AXIS.frame_coord].item()) - - return frame - - @staticmethod - def _find_keypoints(frame: np.ndarray) -> np.ndarray: - """calculate and save GFTT keypoints for current frame""" - return cv2.goodFeaturesToTrack( - frame, - maxCorners=200, - qualityLevel=0.05, - minDistance=30.0, - blockSize=3, - mask=None, - useHarrisDetector=False, - k=0.04, + raise NotImplementedError( + f"Undefined State: Only some of the attributes are initialized: " + f"{self.anchor_frame_ = }, " + f"{self.previous_frame_ = }, " + f"{self.motions_ = }" + ) + + def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: + """ + The simplest way to stabilize streaming frames would be to have a single reference frame + (the first frame) and shift all subsequent frames against this reference frame. + However, as different sets of neurons are active at different times, frames that are far + apart in time sometimes have few common objects to lock onto. + + To mitigate this, we could stabilize all frames against the previous frame, domino style. + However, errors stack up and a gradual shift takes place with this strategy. + + This algorithm attempts to solve this issue by mixing the two strategies: + 1. We default to stabilizing against the anchor frame. + 2. If we begin losing features to lock onto, the anchor shift will explode to + an unpredictable value. + 3. In this case, we fall back to the sequential shift. + 4. During the "anchor mismatch period", the sequential shift will slowly drift. + 5. And then, as old features surface again, the anchor will lock in again. + 6. However, the sequential shift will have drifted. + 7. We try to estimate how fast it would drift with drift_speed. + 8. Then, the TRUE shift is within the range of sequential_shift +- drift. + 9. Thus, we assume that if anchor_shift falls within this range, the anchor shift + has returned to the TRUE shift. + + in mathematical notations, this translates to: + if: + sequential_shift - drift_speed < anchor_shift < sequential_shift + drift_speed + + then: + true_shift = anchor_shift + + the inequality is same as: + sequential_shift - anchor_shift < drift_speed + anchor_shift - sequential_shift < drift_speed + + which summarizes to: + if: abs(sequential_shift - anchor_shift) < drift_speed + then: true_shift = anchor_shift + """ + + anchor_shift, a_error, _ = phase_cross_correlation( + self.anchor_frame_.values, curr_frame.values, **self.kwargs ) - def _generate_shift(self, frame: np.ndarray) -> Shift: - # calculate optical flow using Lucas-Kanade differential method - curr_kps, status, error = cv2.calcOpticalFlowPyrLK( - self.previous_frame_, frame, self.previous_keypoints_, None + sequent_shift, s_error, _ = phase_cross_correlation( + self.previous_frame_.values, curr_frame.values, **self.kwargs ) - # select only valid keypoints - valid_curr_kps = curr_kps[status == 1] # current - valid_prev_kps = self.previous_keypoints_[status == 1] # previous + shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) - # calculate optimal affine transformation between previous_2_current keypoints - shift = cv2.estimateAffinePartial2D(valid_prev_kps, valid_curr_kps)[0] + frame_idx = curr_frame[AXIS.frame_coord].item() + drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed - if shift is not None: - # translation in x direction - dx = shift[0, 2] - # translation in y direction - dy = shift[1, 2] - # rotation - da = np.arctan2(shift[1, 0], shift[0, 0]) + if shift_diff > drift_threshold: + shift = sequent_shift else: - dx = dy = da = 0 + shift = anchor_shift + self._anchor_last_applied_on = frame_idx + + return Shift(height=shift[0], width=shift[1]) + + def _apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: + # Define the affine transformation matrix for translation + M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) + + shifted_frame = cv2.warpAffine( + frame.values, + M, + (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=np.nan, + ) + shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) + return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) - return Shift(x=dx, y=dy, a=da) + def _update_anchor(self, frame: xr.DataArray) -> xr.DataArray: + curr_index = frame[AXIS.frame_coord].item() - def _apply_shift(self, frame: np.ndarray, shift: Shift) -> np.ndarray: - """ - An internal method that applies affine transformation to the given frame - from previously calculated transformations - """ - # building 2x3 transformation matrix from extracted transformations - shift_op = np.zeros((2, 3), np.float32) - shift_op[0, 0] = np.cos(shift.a) - shift_op[0, 1] = -np.sin(shift.a) - shift_op[1, 0] = np.sin(shift.a) - shift_op[1, 1] = np.cos(shift.a) - shift_op[0, 2] = shift.x - shift_op[1, 2] = shift.y - - # Applying an affine transformation to the frame - return cv2.warpAffine(frame, shift_op, frame.shape[::-1]) - - # def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: - # curr_index = frame[AXIS.frame_coord].item() - # - # return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) - - # if self.is_first_frame(frame): - # return frame - # - # curr_frame = frame.array - # - # shift = self.compute_shift(curr_frame) - # shifted_frame = self.apply_shift(curr_frame, shift) - # - # self.previous_frame_ = Frame.from_array(shifted_frame) - # - # if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): - # self.anchor_frame_.array = self.update_anchor(shifted_frame) - # - # self.motions_.append(shift) - # - # return Frame.from_array( - # xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) - # ) - - # def is_first_frame(self, frame: Frame) -> bool: - # if ( - # (self.anchor_frame_ is not None) - # and (self.previous_frame_ is not None) - # and (self.motions_ is not None) - # ): - # return False - # - # elif ( - # (self.anchor_frame_ is None) - # and (self.previous_frame_ is None) - # and (self.motions_ is None) - # ): - # self._anchor_last_applied_on = 0 - # self.anchor_frame_ = frame - # self.previous_frame_ = frame - # self.motions_ = [Shift(width=0, height=0)] - # return True - # - # else: - # raise NotImplementedError( - # f"Undefined State: Only some of the attributes are initialized: " - # f"{self.anchor_frame_ = }, " - # f"{self.previous_frame_ = }, " - # f"{self.motions_ = }" - # ) - - # def compute_shift(self, curr_frame: xr.DataArray) -> Shift: - # """ - # The simplest way to stabilize streaming frames would be to have a single reference frame - # (the first frame) and shift all subsequent frames against this reference frame. - # However, as different sets of neurons are active at different times, frames that are far - # apart in time sometimes have few common objects to lock onto. - # - # To mitigate this, we could stabilize all frames against the previous frame, domino style. - # However, errors stack up and a gradual shift takes place with this strategy. - # - # This algorithm attempts to solve this issue by mixing the two strategies: - # 1. We default to stabilizing against the anchor frame. - # 2. If we begin losing features to lock onto, the anchor shift will explode to - # an unpredictable value. - # 3. In this case, we fall back to the sequential shift. - # 4. During the "anchor mismatch period", the sequential shift will slowly drift. - # 5. And then, as old features surface again, the anchor will lock in again. - # 6. However, the sequential shift will have drifted. - # 7. We try to estimate how fast it would drift with drift_speed. - # 8. Then, the TRUE shift is within the range of sequential_shift +- drift. - # 9. Thus, we assume that if anchor_shift falls within this range, the anchor shift - # has returned to the TRUE shift. - # - # in mathematical notations, this translates to: - # if: - # sequential_shift - drift_speed < anchor_shift < sequential_shift + drift_speed - # - # then: - # true_shift = anchor_shift - # - # the inequality is same as: - # sequential_shift - anchor_shift < drift_speed - # anchor_shift - sequential_shift < drift_speed - # - # which summarizes to: - # if: abs(sequential_shift - anchor_shift) < drift_speed - # then: true_shift = anchor_shift - # """ - # - # anchor_shift, a_error, _ = phase_cross_correlation( - # self.anchor_frame_.array, curr_frame.values, **self.kwargs - # ) - # - # sequent_shift, s_error, _ = phase_cross_correlation( - # self.previous_frame_.array, curr_frame.values, **self.kwargs - # ) - # - # shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) - # - # frame_idx = curr_frame[AXIS.frame_coord].item() - # drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed - # - # if shift_diff > drift_threshold: - # shift = sequent_shift - # else: - # shift = anchor_shift - # self._anchor_last_applied_on = frame_idx - # - # return Shift(height=shift[0], width=shift[1]) - - # def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: - # # Define the affine transformation matrix for translation - # M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) - # - # shifted_frame = cv2.warpAffine( - # frame.values, - # M, - # (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), - # flags=cv2.INTER_LINEAR, - # borderMode=cv2.BORDER_CONSTANT, - # borderValue=np.nan, - # ) - # shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) - # return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) + return (self.anchor_frame_ * curr_index + frame) / (curr_index + 1) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index fb6d7247..f093d5d1 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -3,16 +3,14 @@ from noob.node import NodeSpecification from cala.models import AXIS -from cala.nodes.prep.motion import RigidStabilizer, Shift +from cala.nodes.prep.motion import Stabilizer, Shift from cala.testing.toy import FrameDims, Position, Toy @pytest.mark.parametrize("params", [{"drift_speed": 1, "kwargs": {"upsample_factor": 100}}]) def test_motion_estimation(params) -> None: - stab = RigidStabilizer.from_specification( - NodeSpecification(id="test", type="cala.nodes.prep.RigidStabilizer", params=params) - ) + stab = Stabilizer(**params) n_frames = 50 @@ -42,7 +40,7 @@ def test_motion_estimation(params) -> None: width=shifts[0].width - shift.width, height=shifts[0].height - shift.height ) } - result.append(stab.process(frame)) + result.append(stab.stabilize(frame)) estimate = -np.array([(m.width, m.height) for m in stab.motions_]) expected = np.array([(m.width, m.height) for m in shifts]) - (shifts[0].width, shifts[0].height) From b9a471ec384995156bdccdb2dfedcb37a4557f5f Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 28 Aug 2025 00:01:08 -0700 Subject: [PATCH 3/6] feat: motion fixed --- src/cala/nodes/prep/motion.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 27e6a617..98a1cc19 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -1,5 +1,5 @@ from logging import Logger -from typing import Annotated as A +from typing import Annotated as A, ClassVar, Literal, Callable, Any import cv2 import numpy as np @@ -7,6 +7,7 @@ from noob import Name, process_method from pydantic import BaseModel, Field, ConfigDict from skimage.registration import phase_cross_correlation +from skimage.filters import butterworth, difference_of_gaussians, sato, scharr from cala.assets import Frame from cala.logging import init_logger @@ -20,7 +21,11 @@ class Shift(BaseModel): class Stabilizer(BaseModel): drift_speed: float = 1.0 - kwargs: dict = Field(default_factory=dict) + pcc_kwargs: dict = Field(default_factory=dict) + + pcc_filter: Literal["butterworth", "difference_of_gaussians", "sato", "scharr"] + filter_kwargs: dict = Field(default_factory=dict) + filt_fn: Callable = None _anchor_last_applied_on: int = None anchor_frame_: xr.DataArray = None @@ -117,14 +122,20 @@ def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: if: abs(sequential_shift - anchor_shift) < drift_speed then: true_shift = anchor_shift """ + filters = { + "butterworth": butterworth, + "difference_of_gaussians": difference_of_gaussians, + "sato": sato, + "scharr": scharr, + } + filt_fn = filters[self.pcc_filter] - anchor_shift, a_error, _ = phase_cross_correlation( - self.anchor_frame_.values, curr_frame.values, **self.kwargs - ) + curr = filt_fn(curr_frame, **self.filter_kwargs) + prev = filt_fn(self.previous_frame_, **self.filter_kwargs) + anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs) - sequent_shift, s_error, _ = phase_cross_correlation( - self.previous_frame_.values, curr_frame.values, **self.kwargs - ) + anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) + sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) @@ -148,10 +159,9 @@ def _apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: M, (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=np.nan, + borderMode=cv2.BORDER_REPLICATE, + # borderValue=0, ) - shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0) return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) def _update_anchor(self, frame: xr.DataArray) -> xr.DataArray: From de1d57bc5285f228735cd6315cb0246c534c98b8 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 28 Aug 2025 00:03:01 -0700 Subject: [PATCH 4/6] tests: motion fixed --- tests/test_prep/test_motion.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index f093d5d1..06024370 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -7,7 +7,17 @@ from cala.testing.toy import FrameDims, Position, Toy -@pytest.mark.parametrize("params", [{"drift_speed": 1, "kwargs": {"upsample_factor": 100}}]) +@pytest.mark.parametrize( + "params", + [ + { + "drift_speed": 1, + "pcc_kwargs": {"upsample_factor": 100}, + "pcc_filter": "difference_of_gaussians", + "filter_kwargs": {"low_sigma": 1}, + } + ], +) def test_motion_estimation(params) -> None: stab = Stabilizer(**params) From edae0169f37f63e99158b87bae832774bc5df089 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 28 Aug 2025 00:03:51 -0700 Subject: [PATCH 5/6] tests: motion fixed --- tests/data/pipelines/with_src.yaml | 36 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 3cdb055a..f0082807 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -64,11 +64,20 @@ nodes: lines: type: cala.nodes.prep.hlines.remove depends: - - frame: glow.frame + - frame: denoise.frame + motion: # needs to take place before glow, after lines + type: cala.nodes.prep.Stabilizer + params: + drift_speed: 0.5 + pcc_filter: difference_of_gaussians + filter_kwargs: + low_sigma: 1 + depends: + - frame: lines.frame glow: type: cala.nodes.prep.GlowRemover depends: - - frame: denoise.frame + - frame: motion.frame smooth: type: cala.nodes.prep.denoise params: @@ -77,16 +86,11 @@ nodes: ksize: [ 7, 7 ] sigmaX: 1.5 depends: - - frame: lines.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 0.5 - depends: - - frame: smooth.frame + - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: + hardset_radius: 6 noise_threshold: 2.0 n_frames: 30 log_kwargs: @@ -96,14 +100,14 @@ nodes: threshold: 0.2 overlap: 0.5 depends: - - frame: motion.frame + - frame: smooth.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: motion.frame + - frame: smooth.frame #PREPROCESS ENDS # FRAME UPDATE BEGINS @@ -115,19 +119,19 @@ nodes: depends: - traces: assets.traces - footprints: assets.footprints - - frame: motion.frame + - frame: smooth.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: motion.frame + - frame: smooth.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: motion.frame + - frame: smooth.frame - new_traces: trace_frame.latest_trace footprints_frame: type: cala.nodes.footprints.Footprinter @@ -171,7 +175,7 @@ nodes: nmf: type: cala.nodes.detect.SliceNMF params: - min_frames: 30 + min_frames: 100 detect_thresh: 2.0 reprod_tol: 0.005 depends: @@ -222,4 +226,4 @@ nodes: type: return depends: - raw: frame.value - - prep: motion.frame \ No newline at end of file + - prep: smooth.frame \ No newline at end of file From 437b502a264411348e668f1ea3d45e676ded89f0 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 28 Aug 2025 00:05:43 -0700 Subject: [PATCH 6/6] format: ruff --- src/cala/nodes/prep/__init__.py | 2 +- src/cala/nodes/prep/motion.py | 8 +++++--- tests/test_prep/test_motion.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index e068fab8..f8f8661d 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,7 @@ from .background_removal import remove_background from .denoise import denoise from .glow_removal import GlowRemover -from .r_estimate import SizeEst from .motion import Stabilizer +from .r_estimate import SizeEst __all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst] diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 98a1cc19..8903143f 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -1,13 +1,15 @@ +from collections.abc import Callable from logging import Logger -from typing import Annotated as A, ClassVar, Literal, Callable, Any +from typing import Annotated as A +from typing import Literal import cv2 import numpy as np import xarray as xr from noob import Name, process_method -from pydantic import BaseModel, Field, ConfigDict -from skimage.registration import phase_cross_correlation +from pydantic import BaseModel, ConfigDict, Field from skimage.filters import butterworth, difference_of_gaussians, sato, scharr +from skimage.registration import phase_cross_correlation from cala.assets import Frame from cala.logging import init_logger diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index 06024370..f36c2930 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -3,7 +3,7 @@ from noob.node import NodeSpecification from cala.models import AXIS -from cala.nodes.prep.motion import Stabilizer, Shift +from cala.nodes.prep.motion import Shift, Stabilizer from cala.testing.toy import FrameDims, Position, Toy