Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cala/nodes/prep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .background_removal import remove_background
from .denoise import denoise
from .glow_removal import GlowRemover
from .motion import Stabilizer
from .r_estimate import SizeEst
from .rigid_stabilization import RigidStabilizer

__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst]
__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst]
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from collections.abc import Callable
from logging import Logger
from typing import Annotated as A
from typing import Literal

import cv2
import numpy as np
import xarray as xr
from noob import Name
from noob.node import Node
from pydantic import BaseModel, Field
from noob import Name, process_method
from pydantic import BaseModel, ConfigDict, Field
from skimage.filters import butterworth, difference_of_gaussians, sato, scharr
from skimage.registration import phase_cross_correlation

from cala.assets import Frame
from cala.logging import init_logger
from cala.models import AXIS


Expand All @@ -17,15 +21,44 @@ class Shift(BaseModel):
height: float


class RigidStabilizer(Node):
class Stabilizer(BaseModel):
drift_speed: float = 1.0
kwargs: dict = Field(default_factory=dict)
pcc_kwargs: dict = Field(default_factory=dict)

pcc_filter: Literal["butterworth", "difference_of_gaussians", "sato", "scharr"]
filter_kwargs: dict = Field(default_factory=dict)
filt_fn: Callable = None

_anchor_last_applied_on: int = None
anchor_frame_: Frame = None
previous_frame_: Frame = None
anchor_frame_: xr.DataArray = None
previous_frame_: xr.DataArray = None
motions_: list[Shift] = None

logger: Logger = init_logger(__name__)

model_config = ConfigDict(arbitrary_types_allowed=True)

@process_method
def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]:
if self.is_first_frame(frame):
return frame

curr_frame = frame.array

shift = self._compute_shift(curr_frame)
shifted_frame = self._apply_shift(curr_frame, shift)

self.previous_frame_ = shifted_frame

if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item():
self.anchor_frame_ = self._update_anchor(shifted_frame)

self.motions_.append(shift)

return Frame.from_array(
xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords)
)

def is_first_frame(self, frame: Frame) -> bool:
if (
(self.anchor_frame_ is not None)
Expand All @@ -40,8 +73,8 @@ def is_first_frame(self, frame: Frame) -> bool:
and (self.motions_ is None)
):
self._anchor_last_applied_on = 0
self.anchor_frame_ = frame
self.previous_frame_ = frame
self.anchor_frame_ = frame.array
self.previous_frame_ = frame.array
self.motions_ = [Shift(width=0, height=0)]
return True

Expand All @@ -53,7 +86,7 @@ def is_first_frame(self, frame: Frame) -> bool:
f"{self.motions_ = }"
)

def compute_shift(self, curr_frame: xr.DataArray) -> Shift:
def _compute_shift(self, curr_frame: xr.DataArray) -> Shift:
"""
The simplest way to stabilize streaming frames would be to have a single reference frame
(the first frame) and shift all subsequent frames against this reference frame.
Expand Down Expand Up @@ -91,14 +124,20 @@ def compute_shift(self, curr_frame: xr.DataArray) -> Shift:
if: abs(sequential_shift - anchor_shift) < drift_speed
then: true_shift = anchor_shift
"""
filters = {
"butterworth": butterworth,
"difference_of_gaussians": difference_of_gaussians,
"sato": sato,
"scharr": scharr,
}
filt_fn = filters[self.pcc_filter]

anchor_shift, a_error, _ = phase_cross_correlation(
self.anchor_frame_.array, curr_frame.values, **self.kwargs
)
curr = filt_fn(curr_frame, **self.filter_kwargs)
prev = filt_fn(self.previous_frame_, **self.filter_kwargs)
anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs)

sequent_shift, s_error, _ = phase_cross_correlation(
self.previous_frame_.array, curr_frame.values, **self.kwargs
)
anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs)
sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs)

shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift))

Expand All @@ -113,7 +152,7 @@ def compute_shift(self, curr_frame: xr.DataArray) -> Shift:

return Shift(height=shift[0], width=shift[1])

def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray:
def _apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray:
# Define the affine transformation matrix for translation
M = np.float32([[1, 0, shift.width], [0, 1, shift.height]])

Expand All @@ -122,48 +161,12 @@ def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray:
M,
(frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=np.nan,
borderMode=cv2.BORDER_REPLICATE,
# borderValue=0,
)
shifted_frame = np.nan_to_num(shifted_frame, copy=True, nan=0)
return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords)

def update_anchor(self, frame: xr.DataArray) -> xr.DataArray:
def _update_anchor(self, frame: xr.DataArray) -> xr.DataArray:
curr_index = frame[AXIS.frame_coord].item()

return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1)

def process(self, frame: Frame) -> A[Frame, Name("frame")]:
if self.is_first_frame(frame):
return frame

curr_frame = frame.array

shift = self.compute_shift(curr_frame)
shifted_frame = self.apply_shift(curr_frame, shift)

self.previous_frame_ = Frame.from_array(shifted_frame)

if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item():
self.anchor_frame_.array = self.update_anchor(shifted_frame)

self.motions_.append(shift)

return Frame.from_array(
xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords)
)

def get_info(self) -> dict:
"""Get information about the current state.

Returns
-------
dict
Dictionary containing current statistics
"""
return {
"_anchor_last_applied_on": self._anchor_last_applied_on,
"anchor_frame_": self.anchor_frame_,
"previous_frame_": self.previous_frame_,
"motion_": self.motions_,
}
return (self.anchor_frame_ * curr_index + frame) / (curr_index + 1)
36 changes: 20 additions & 16 deletions tests/data/pipelines/with_src.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,20 @@ nodes:
lines:
type: cala.nodes.prep.hlines.remove
depends:
- frame: glow.frame
- frame: denoise.frame
motion: # needs to take place before glow, after lines
type: cala.nodes.prep.Stabilizer
params:
drift_speed: 0.5
pcc_filter: difference_of_gaussians
filter_kwargs:
low_sigma: 1
depends:
- frame: lines.frame
glow:
type: cala.nodes.prep.GlowRemover
depends:
- frame: denoise.frame
- frame: motion.frame
smooth:
type: cala.nodes.prep.denoise
params:
Expand All @@ -77,16 +86,11 @@ nodes:
ksize: [ 7, 7 ]
sigmaX: 1.5
depends:
- frame: lines.frame
motion:
type: cala.nodes.prep.RigidStabilizer
params:
drift_speed: 0.5
depends:
- frame: smooth.frame
- frame: glow.frame
size_est:
type: cala.nodes.prep.SizeEst
params:
hardset_radius: 6
noise_threshold: 2.0
n_frames: 30
log_kwargs:
Expand All @@ -96,14 +100,14 @@ nodes:
threshold: 0.2
overlap: 0.5
depends:
- frame: motion.frame
- frame: smooth.frame
cache:
type: cala.nodes.buffer.fill_buffer
params:
size: 100
depends:
- buffer: assets.buffer
- frame: motion.frame
- frame: smooth.frame
#PREPROCESS ENDS

# FRAME UPDATE BEGINS
Expand All @@ -115,19 +119,19 @@ nodes:
depends:
- traces: assets.traces
- footprints: assets.footprints
- frame: motion.frame
- frame: smooth.frame
- overlaps: assets.overlaps
pix_frame:
type: cala.nodes.pixel_stats.ingest_frame
depends:
- pixel_stats: assets.pix_stats
- frame: motion.frame
- frame: smooth.frame
- new_traces: trace_frame.latest_trace
comp_frame:
type: cala.nodes.component_stats.ingest_frame
depends:
- component_stats: assets.comp_stats
- frame: motion.frame
- frame: smooth.frame
- new_traces: trace_frame.latest_trace
footprints_frame:
type: cala.nodes.footprints.Footprinter
Expand Down Expand Up @@ -171,7 +175,7 @@ nodes:
nmf:
type: cala.nodes.detect.SliceNMF
params:
min_frames: 30
min_frames: 100
detect_thresh: 2.0
reprod_tol: 0.005
depends:
Expand Down Expand Up @@ -222,4 +226,4 @@ nodes:
type: return
depends:
- raw: frame.value
- prep: motion.frame
- prep: smooth.frame
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
from noob.node import NodeSpecification

from cala.models import AXIS
from cala.nodes.prep.rigid_stabilization import RigidStabilizer, Shift
from cala.nodes.prep.motion import Shift, Stabilizer
from cala.testing.toy import FrameDims, Position, Toy


@pytest.mark.parametrize("params", [{"drift_speed": 1, "kwargs": {"upsample_factor": 100}}])
@pytest.mark.parametrize(
"params",
[
{
"drift_speed": 1,
"pcc_kwargs": {"upsample_factor": 100},
"pcc_filter": "difference_of_gaussians",
"filter_kwargs": {"low_sigma": 1},
}
],
)
def test_motion_estimation(params) -> None:

stab = RigidStabilizer.from_specification(
NodeSpecification(id="test", type="cala.nodes.prep.RigidStabilizer", params=params)
)
stab = Stabilizer(**params)

n_frames = 50

Expand Down Expand Up @@ -42,7 +50,7 @@ def test_motion_estimation(params) -> None:
width=shifts[0].width - shift.width, height=shifts[0].height - shift.height
)
}
result.append(stab.process(frame))
result.append(stab.stabilize(frame))

estimate = -np.array([(m.width, m.height) for m in stab.motions_])
expected = np.array([(m.width, m.height) for m in shifts]) - (shifts[0].width, shifts[0].height)
Expand Down
Loading