diff --git a/src/cala/assets.py b/src/cala/assets.py index a9e751e0..200f8492 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -78,7 +78,7 @@ class Frame(Asset): Entity( name="frame", dims=(Dims.width.value, Dims.height.value), - dtype=float, + dtype=None, # np.number, # gets converted to float64 in xarray-validate checks=[is_non_negative, has_no_nan], ) ) diff --git a/src/cala/gui/components/encoder.py b/src/cala/gui/components/encoder.py index ff0b8eab..af3b06f6 100644 --- a/src/cala/gui/components/encoder.py +++ b/src/cala/gui/components/encoder.py @@ -5,8 +5,7 @@ import av import numpy as np from av.video import VideoStream -from noob import process_method -from pydantic import BaseModel +from noob.node import Node from cala.assets import Frame from cala.config import config @@ -26,14 +25,13 @@ def __str__(self) -> str: return "Encoding failed." -class Encoder(BaseModel): - grid_id: str +class Encoder(Node): frame_rate: int _stream: VideoStream | None = None _container: av.container.OutputContainer | None = None def model_post_init(self, context: Any, /) -> None: - encode_dir = config.runtime_dir / self.grid_id + encode_dir = config.runtime_dir / self.id encode_dir.mkdir(parents=True, exist_ok=True) clear_dir(encode_dir) hls_manifest = encode_dir / "stream.m3u8" @@ -51,8 +49,7 @@ def model_post_init(self, context: Any, /) -> None: self._stream = self._container.add_stream("h264", rate=self.frame_rate) self._stream.pix_fmt = "yuv420p" - @process_method - def save(self, frame: Frame) -> None: + def process(self, frame: Frame) -> None: frame = frame.array.astype(np.uint8) self._stream.width = frame.sizes[AXIS.width_dim] self._stream.height = frame.sizes[AXIS.height_dim] diff --git a/src/cala/gui/plots.py b/src/cala/gui/plots.py index aae58911..cb1a8142 100644 --- a/src/cala/gui/plots.py +++ b/src/cala/gui/plots.py @@ -1,14 +1,9 @@ from pathlib import Path import cv2 -import imageio.v2 as imageio -import matplotlib.pyplot as plt import numpy as np -import seaborn as sns import xarray as xr -sns.set_theme(style="whitegrid", context="notebook", font_scale=1.2, palette="deep") - def write_movie(video: xr.DataArray, path: str | Path) -> None: """Test visualization of stabilized calcium video to verify motion stabilization.""" @@ -23,76 +18,3 @@ def write_movie(video: xr.DataArray, path: str | Path) -> None: out.write(frame_bgr) out.release() - - -def write_gif( - videos: xr.DataArray | list[xr.DataArray], - path: str | Path, - n_cols: int | None = None, -) -> None: - """ - Save video frames with optional processing function. Can handle single or multiple videos. - - Parameters: - ----------- - videos : Union[xr.DataArray, List[Tuple[xr.DataArray, str]]] - Either a single video DataArray or list of (video, title) tuples for comparison - n_cols : Optional[int] - Number of columns when displaying multiple videos. If None, tries to make square grid - """ - # Handle single video case - if isinstance(videos, xr.DataArray): - videos = [videos] - - # Verify all videos have same number of frames - n_frames = len(videos[0][0]) - if not all(len(video) == n_frames for video in videos): - raise ValueError("All videos must have the same number of frames") - - n_videos = len(videos) - if n_cols is None: - n_cols = int(np.ceil(np.sqrt(n_videos))) if n_videos > 1 else 1 - n_rows = int(np.ceil(n_videos / n_cols)) - - # Get global min/max for consistent scaling - vmin = np.min([np.min(video) for video in videos]) - vmax = np.max([np.max(video) for video in videos]) - - for frame_idx in range(n_frames): - if n_videos == 1: - fig, ax = plt.subplots(figsize=(8, 8)) - axes = [[ax]] - else: - fig, axes = plt.subplots( - n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False - ) - - for vid_idx, (video, title) in enumerate(videos): - last_row = vid_idx // n_cols - remn_col = vid_idx % n_cols - ax = axes[last_row][remn_col] - - frame = video[frame_idx] - - ax.imshow(frame, cmap="gray", vmin=vmin, vmax=vmax) - if title: - ax.set_title(f"{title}\nFrame {frame_idx}") - else: - ax.set_title(f"Frame {frame_idx}") - ax.axis("off") - - # Hide empty subplots - if n_videos > 1: - for idx in range(n_videos, n_rows * n_cols): - last_row = idx // n_cols - remn_col = idx % n_cols - axes[last_row][remn_col].set_visible(False) - - plt.tight_layout() - plt.savefig(path / f"{frame_idx:04d}.png", dpi=150, bbox_inches="tight") - - # Create gif - frames = [] - for i in range(n_frames): - frames.append(imageio.imread(path / f"{i:04d}.png")) - imageio.mimsave(path, frames, fps=30) diff --git a/src/cala/gui/spec.yaml b/src/cala/gui/spec.yaml index 171426c8..56eedf4d 100644 --- a/src/cala/gui/spec.yaml +++ b/src/cala/gui/spec.yaml @@ -1,13 +1,18 @@ nodes: - prep_movie: + raw_movie: type: cala.gui.components.Encoder params: - grid_id: prep_movie frame_rate: 30 depends: - frame.value - component_count: - type: cala.gui.components.component_counter + prep_movie: + type: cala.gui.components.Encoder + params: + frame_rate: 30 depends: - - index: counter.idx - - traces: assets.traces \ No newline at end of file + - flatten.frame +# component_count: +# type: cala.gui.components.component_counter +# depends: +# - index: counter.idx +# - traces: assets.traces \ No newline at end of file diff --git a/src/cala/models/entity.py b/src/cala/models/entity.py index d3448a1a..c21cfc21 100644 --- a/src/cala/models/entity.py +++ b/src/cala/models/entity.py @@ -16,7 +16,7 @@ class Entity(BaseModel): name: str dims: tuple[Dim, ...] coords: list[Coord] = Field(default_factory=list) - dtype: type + dtype: type | None checks: list[Callable] = Field(default_factory=list) allow_extra_coords: bool = True @@ -44,7 +44,7 @@ def to_schema(self) -> DataArraySchema: return DataArraySchema( dims=DimsSchema(tuple(dim.name for dim in self.dims), ordered=False), coords=coords_schema, - dtype=DTypeSchema(self.dtype), + dtype=DTypeSchema(self.dtype) if self.dtype else None, checks=self.checks, ) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 6837b571..75930ca7 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -3,14 +3,14 @@ from .flatten import butter from .glow_removal import GlowRemover from .lines import remove_freq, remove_mean -from .motion import Stabilizer +from .motion import Anchor from .r_estimate import SizeEst __all__ = [ "blur", "GlowRemover", "remove_background", - "Stabilizer", + "Anchor", "SizeEst", "butter", "remove_mean", diff --git a/src/cala/nodes/prep/flatten.py b/src/cala/nodes/prep/flatten.py index 5997b73a..8db07ba9 100644 --- a/src/cala/nodes/prep/flatten.py +++ b/src/cala/nodes/prep/flatten.py @@ -11,16 +11,18 @@ def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]: """ - butterworth filter centers the image to zero. this causes two images with same intensity ratio - across pixels to be indistinguishable. - To recover the absolute brightness, we shift the filtered image by the - mean brightness of the original frame. + Butterworth filter centers the image to zero. This is due to the constant term (the mean) + being expressed as the 0th term in the fourier series. + Since the absolute background activity does not matter (all that is left is the high-frequency + signal), we simply add half of the 8-bit pixel max so that the total cannot exceed the + 0-255 range. + + The filter can also be used to reduce the scattering and the glow! (inspired by Marcel Brosche) + This helps remove overlap between cells (with higher cutoff_frequency_ratio) """ - arr = butterworth(frame.array, **kwargs) + frame.array.mean().item() + arr = butterworth(frame.array, **kwargs) + 2**7 - return Frame.from_array( - xr.DataArray(arr.clip(0), dims=frame.array.dims, coords=frame.array.coords) - ) + return Frame.from_array(xr.DataArray(arr, dims=frame.array.dims, coords=frame.array.coords)) def ball(frame: Frame, kwargs: dict[str, Any]) -> Frame: diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 2cdd9d42..b6daf8cb 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -1,164 +1,445 @@ +# adapted from SIMA (https://github.com/losonczylab) and the +# scikit-image (http://scikit-image.org/) package. +# +# +# Unless otherwise specified by LICENSE.txt files in individual +# directories, all code is +# +# Copyright (C) 2011, the scikit-image team +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of skimage nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, +# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import functools from collections.abc import Callable -from logging import Logger from typing import Annotated as A -from typing import Literal import cv2 import numpy as np import xarray as xr from noob import Name, process_method -from pydantic import BaseModel, ConfigDict, Field +from numpy.fft import ifftshift +from numpydantic import NDArray +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator from skimage.filters import difference_of_gaussians -from skimage.registration import phase_cross_correlation from cala.assets import Frame -from cala.logging import init_logger from cala.models import AXIS +from cala.testing.util import shift_by class Shift(BaseModel): - width: float height: float + width: float + @classmethod + def from_arr(cls, array: NDArray) -> "Shift": + assert array.shape == (2,) + return Shift(height=array[0], width=array[1]) -class Stabilizer(BaseModel): - drift_speed: float = 1.0 - pcc_kwargs: dict = Field(default_factory=dict) - - pcc_filter: Literal["butterworth", "difference_of_gaussians", "sato", "scharr"] - filter_kwargs: dict = Field(default_factory=dict) - filt_fn: Callable = None - - _anchor_last_applied_on: int = None - anchor_frame_: xr.DataArray = None - previous_frame_: xr.DataArray = None - motions_: list[Shift] = None - - logger: Logger = init_logger(__name__) - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @process_method - def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]: - if self.is_first_frame(frame): - return frame - - curr_frame = frame.array + def __add__(self, other: "Shift") -> "Shift": + return Shift(height=self.height + other.height, width=self.width + other.width) - shift = self._compute_shift(curr_frame) - shifted_frame = self._apply_shift(curr_frame, shift) - self.previous_frame_ = shifted_frame +class Anchor(BaseModel): + max_shifts: tuple[float, float] = (50, 50) + upsample_factor: int = 10 - if self._anchor_last_applied_on == shifted_frame[AXIS.frame_coord].item(): - self.anchor_frame_ = self._update_anchor(shifted_frame) + dog_kwargs: dict = Field(default_factory=dict, validate_default=True) + gauss_kwargs: dict = Field(default_factory=dict, validate_default=True) - self.motions_.append(shift) + _reg_shift: Callable = PrivateAttr(None) + """A callable used to find the shift""" + _local: xr.DataArray = PrivateAttr(None) + """local anchor - processed and ready for comparison""" + _global: xr.DataArray = PrivateAttr(None) + """global anchor - processed and ready for comparison""" + _history: list[Shift] = PrivateAttr(default_factory=list) - return Frame.from_array( - xr.DataArray(shifted_frame, dims=frame.array.dims, coords=frame.array.coords) - ) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - def is_first_frame(self, frame: Frame) -> bool: - if ( - (self.anchor_frame_ is not None) - and (self.previous_frame_ is not None) - and (self.motions_ is not None) - ): - return False - - elif ( - (self.anchor_frame_ is None) - and (self.previous_frame_ is None) - and (self.motions_ is None) - ): - self._anchor_last_applied_on = 0 - self.anchor_frame_ = frame.array - self.previous_frame_ = frame.array - self.motions_ = [Shift(width=0, height=0)] - return True + @field_validator("dog_kwargs", mode="before") + @classmethod + def default_dog(cls, value: dict) -> dict: + if not value: + return {"low_sigma": 3} + else: + return value + @field_validator("gauss_kwargs", mode="before") + @classmethod + def default_gauss(cls, value: dict) -> dict: + if not value: + return {"ksize": (11, 11), "sigmaX": 20} else: - raise NotImplementedError( - f"Undefined State: Only some of the attributes are initialized: " - f"{self.anchor_frame_ = }, " - f"{self.previous_frame_ = }, " - f"{self.motions_ = }" - ) + return value - def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: + @process_method + def stabilize(self, frame: Frame) -> A[Frame, Name("frame")]: """ - The simplest way to stabilize streaming frames would be to have a single reference frame - (the first frame) and shift all subsequent frames against this reference frame. - However, as different sets of neurons are active at different times, frames that are far - apart in time sometimes have few common objects to lock onto. - - To mitigate this, we could stabilize all frames against the previous frame, domino style. - However, errors stack up and a gradual shift takes place with this strategy. - - This algorithm attempts to solve this issue by mixing the two strategies: - 1. We default to stabilizing against the anchor frame. - 2. If we begin losing features to lock onto, the anchor shift will explode to - an unpredictable value. - 3. In this case, we fall back to the sequential shift. - 4. During the "anchor mismatch period", the sequential shift will slowly drift. - 5. And then, as old features surface again, the anchor will lock in again. - 6. However, the sequential shift will have drifted. - 7. We try to estimate how fast it would drift with drift_speed. - 8. Then, the TRUE shift is within the range of sequential_shift +- drift. - 9. Thus, we assume that if anchor_shift falls within this range, the anchor shift - has returned to the TRUE shift. - - in mathematical notations, this translates to: - if: - sequential_shift - drift_speed < anchor_shift < sequential_shift + drift_speed - - then: - true_shift = anchor_shift - - the inequality is same as: - sequential_shift - anchor_shift < drift_speed - anchor_shift - sequential_shift < drift_speed - - which summarizes to: - if: abs(sequential_shift - anchor_shift) < drift_speed - then: true_shift = anchor_shift + --- image, prepped, local --- + image: original image. only shifted and outputted + prepped: processed image. only used to find the shift and then discarded + local: shifted prepped from the last iteration to be used as a template + + Steps: + 1. raw prepped gets shifted to last local anchor. We save the shift + 2. the shifted prepped gets shifted to global anchor. We add to the shift + 3. we apply total shift to image. We also save the total shifted prepped + as the anchor for the next frame """ - curr = difference_of_gaussians(curr_frame, **self.filter_kwargs) - prev = difference_of_gaussians(self.previous_frame_, **self.filter_kwargs) - anchor = difference_of_gaussians(self.anchor_frame_, **self.filter_kwargs) + arr = frame.array + prepped = prepare(arr, dog_kwargs=self.dog_kwargs, gauss_kwargs=self.gauss_kwargs) + if not self._has_prereqs: + self._init(prepped) + return frame - anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) - sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) + total = Shift(height=0, width=0) + for template in [self._local, self._global]: + shift_arr, _, _ = self._reg_shift(template.values, prepped.values) + shift = Shift.from_arr(shift_arr) + total += shift + prepped = apply_shift(prepped, shift) + self._get_ready_for_next(prepped) + self._history.append(total) + + result = apply_shift(arr, total) + return Frame.from_array(result) + + @property + def _has_prereqs(self) -> bool: + return self._local is not None + + def _init(self, image: xr.DataArray) -> None: + self._local = image + self._global = image + self._reg_shift = functools.partial( + register_shift, upsample_factor=self.upsample_factor, max_shifts=self.max_shifts + ) - shift_diff = np.linalg.norm(anchor_shift - sequent_shift) + def _calculate_shift(self, shift: xr.DataArray) -> xr.DataArray: ... + + def _get_ready_for_next(self, prepped: xr.DataArray) -> None: + self._local = prepped + + # global learns the local + curr_idx = prepped[AXIS.frame_coord].item() + self._global = (self._global * curr_idx + self._local) / (curr_idx + 1) + + +def prepare(image: xr.DataArray, dog_kwargs: dict, gauss_kwargs: dict) -> xr.DataArray: + tmp = difference_of_gaussians(image, **dog_kwargs) + tmp = cv2.normalize(tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) + result = cv2.GaussianBlur(tmp.astype(float), **gauss_kwargs) + return xr.DataArray(result, dims=image.dims, coords=image.coords) + + +def register_shift( + src_image: np.ndarray, + target_image: np.ndarray, + max_shifts: tuple[float, float] = (20, 20), + upsample_factor: int = 1, +) -> tuple[np.ndarray, np.ndarray, float]: + """ + This code gives the same precision as the FFT upsampled cross-correlation + in a fraction of the computation time and with reduced memory requirements. + It obtains an initial estimate of the cross-correlation peak by an FFT and + then refines the shift estimation by upsampling the DFT only in a small + neighborhood of that estimate by means of a matrix-multiply DFT. + + Args: + src_image : ndarray + Reference image. + + target_image : ndarray + Image to register. Must be same dimensionality as ``src_image``. + + Returns: + shifts : ndarray + Shift vector (in pixels) required to register ``target_image`` with + ``src_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X) + + error : float + Translation invariant normalized RMS error between ``src_image`` and + ``target_image``. + + Raises: + NotImplementedError "Error: register_translation only supports " + "subpixel registration for 2D images" + + ValueError "Error: images must really be same size for " + "register_translation" + + ValueError "Error: register_translation only knows the \"real\" " + "and \"fourier\" values for the ``space`` argument." + + References: + [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, + "Efficient subpixel image registration algorithms," + Optics Letters 33, 156-158 (2008). + """ + + src_freq_1 = cv2.dft(src_image, flags=cv2.DFT_COMPLEX_OUTPUT + cv2.DFT_SCALE) + src_freq = src_freq_1[:, :, 0] + 1j * src_freq_1[:, :, 1] + src_freq = np.array(src_freq, dtype=np.complex128, copy=False) + target_freq_1 = cv2.dft(target_image, flags=cv2.DFT_COMPLEX_OUTPUT + cv2.DFT_SCALE) + target_freq = target_freq_1[:, :, 0] + 1j * target_freq_1[:, :, 1] + target_freq = np.array(target_freq, dtype=np.complex128, copy=False) + + # Whole-pixel shift - Compute cross-correlation by an IFFT + shape = src_freq.shape + image_product = src_freq * target_freq.conj() + image_product_cv = np.dstack([np.real(image_product), np.imag(image_product)]) + cross_correlation = cv2.dft(image_product_cv, flags=cv2.DFT_INVERSE + cv2.DFT_SCALE) + cross_correlation = cross_correlation[:, :, 0] + 1j * cross_correlation[:, :, 1] + + # Locate maximum + new_cross_corr = np.abs(cross_correlation) + + new_cross_corr[max_shifts[0] : -max_shifts[0], :] = 0 + new_cross_corr[:, max_shifts[1] : -max_shifts[1]] = 0 + + maxima = np.unravel_index(np.argmax(new_cross_corr), cross_correlation.shape) + midpoints = np.array([np.fix(axis_size // 2) for axis_size in shape]) + + shifts = np.array(maxima, dtype=np.float64) + shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] + + if upsample_factor == 1: + CCmax = cross_correlation.max() + # If upsampling > 1, then refine estimate with matrix multiply DFT + else: + # Initial shift estimate in upsampled grid + shifts = np.round(shifts * upsample_factor) / upsample_factor + upsampled_region_size = np.ceil(upsample_factor * 1.5) + # Center of output array at dftshift + 1 + dftshift = np.fix(upsampled_region_size / 2.0) + upsample_factor = np.array(upsample_factor, dtype=np.float64) + normalization = src_freq.size * upsample_factor**2 + # Matrix multiply DFT around the current shift estimate + sample_region_offset = dftshift - shifts * upsample_factor + + cross_correlation = _upsampled_dft( + image_product.conj(), upsampled_region_size, upsample_factor, sample_region_offset + ).conj() + cross_correlation /= normalization + # Locate maximum and map back to original pixel grid + maxima = np.array( + np.unravel_index(np.argmax(np.abs(cross_correlation)), cross_correlation.shape), + dtype=np.float64, + ) + maxima -= dftshift + shifts = shifts + (maxima / upsample_factor) + CCmax = cross_correlation.max() + src_amp = _upsampled_dft(src_freq * src_freq.conj(), 1, upsample_factor)[0, 0] + src_amp /= normalization + target_amp = _upsampled_dft(target_freq * target_freq.conj(), 1, upsample_factor)[0, 0] + target_amp /= normalization + + # If its only one row or column the shift along that dimension has no + # effect. We set to zero. + for dim in range(src_freq.ndim): + if shape[dim] == 1: + shifts[dim] = 0 + + return shifts, src_freq, _compute_phasediff(CCmax) + + +def _upsampled_dft( + data: np.ndarray, + upsampled_region_size: list[np.float64], + upsample_factor: int = 1, + axis_offsets: np.ndarray = None, +) -> np.ndarray: + """ + Upsampled DFT by matrix multiplication. + + This code is intended to provide the same result as if the following + operations were performed: + - Embed the array "data" in an array that is ``upsample_factor`` times + larger in each dimension. ifftshift to bring the center of the + image to (1,1). + - Take the FFT of the larger array. + - Extract an ``[upsampled_region_size]`` region of the result, starting + with the ``[axis_offsets+1]`` element. + + It achieves this result by computing the DFT in the output array without + the need to zeropad. Much faster and memory efficient than the zero-padded + FFT approach if ``upsampled_region_size`` is much smaller than + ``data.size * upsample_factor``. + + Args: + data : 2D ndarray + The input data array (DFT of original data) to upsample. + + upsampled_region_size : integer or tuple of integers, optional + The size of the region to be sampled. If one integer is provided, it + is duplicated up to the dimensionality of ``data``. + + upsample_factor : integer, optional + The upsampling factor. Defaults to 1. + + axis_offsets : tuple of integers, optional + The offsets of the region to be sampled. Defaults to None (uses + image center) + + Returns: + output : 2D ndarray + The upsampled DFT of the specified region. + """ + # if people pass in an integer, expand it to a list of equal-sized sections + if not hasattr(upsampled_region_size, "__iter__"): + upsampled_region_size = [ + upsampled_region_size, + ] * data.ndim + else: + if len(upsampled_region_size) != data.ndim: + raise ValueError( + "shape of upsampled region sizes must be equal to " + "input data's number of dimensions." + ) - frame_idx = curr_frame[AXIS.frame_coord].item() - drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed + if axis_offsets is None: + axis_offsets = [ + 0, + ] * data.ndim + else: + if len(axis_offsets) != data.ndim: + raise ValueError( + "number of axis offsets must be equal to input data's number of dimensions." + ) - if shift_diff > drift_threshold: - shift = sequent_shift - else: - shift = anchor_shift - self._anchor_last_applied_on = frame_idx - - return Shift(height=shift[0], width=shift[1]) - - def _apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: - # Define the affine transformation matrix for translation - M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) - - shifted_frame = cv2.warpAffine( - frame.values, - M, - (frame.sizes[AXIS.width_dim], frame.sizes[AXIS.height_dim]), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REPLICATE, - # borderValue=0, + col_kernel = np.exp( + (-1j * 2 * np.pi / (data.shape[1] * upsample_factor)) + * (ifftshift(np.arange(data.shape[1]))[:, None] - np.floor(data.shape[1] // 2)).dot( + np.arange(upsampled_region_size[1])[None, :] - axis_offsets[1] ) - return xr.DataArray(shifted_frame, dims=frame.dims, coords=frame.coords) + ) + row_kernel = np.exp( + (-1j * 2 * np.pi / (data.shape[0] * upsample_factor)) + * (np.arange(upsampled_region_size[0])[:, None] - axis_offsets[0]).dot( + ifftshift(np.arange(data.shape[0]))[None, :] - np.floor(data.shape[0] // 2) + ) + ) - def _update_anchor(self, frame: xr.DataArray) -> xr.DataArray: - curr_index = frame[AXIS.frame_coord].item() + if data.ndim > 2: + pln_kernel = np.exp( + (-1j * 2 * np.pi / (data.shape[2] * upsample_factor)) + * (np.arange(upsampled_region_size[2])[:, None] - axis_offsets[2]).dot( + ifftshift(np.arange(data.shape[2]))[None, :] - np.floor(data.shape[2] // 2) + ) + ) - return (self.anchor_frame_ * curr_index + frame) / (curr_index + 1) + # output = np.tensordot(np.tensordot(row_kernel,data,axes=[1,0]),col_kernel,axes=[1,0]) + output = np.tensordot(row_kernel, data, axes=[1, 0]) + output = np.tensordot(output, col_kernel, axes=[1, 0]) + + if data.ndim > 2: + output = np.tensordot(output, pln_kernel, axes=[1, 1]) + # output = row_kernel.dot(data).dot(col_kernel) + return output + + +def _compute_phasediff(cross_correlation_max: np.complex128) -> np.float64: + """ + Compute global phase difference between the two images + (should be zero if images are non-negative). + + Args: + cross_correlation_max : complex + The complex value of the cross correlation at its maximum point. + """ + return np.arctan2(cross_correlation_max.imag, cross_correlation_max.real) + + +def apply_shifts_dft( + src_freq: np.ndarray, shifts: np.ndarray, diffphase: float, is_freq: bool = True +) -> np.ndarray: + """ + Args: + apply shifts using inverse dft + src_freq: ndarray + if is_freq it is fourier transform image else original image + shifts: shifts to apply + diffphase: comes from the register_translation output + """ + + if not is_freq: + + src_freq = np.dstack([np.real(src_freq), np.imag(src_freq)]) + src_freq = cv2.dft(src_freq, flags=cv2.DFT_COMPLEX_OUTPUT + cv2.DFT_SCALE) + src_freq = src_freq[:, :, 0] + 1j * src_freq[:, :, 1] + src_freq = np.array(src_freq, dtype=np.complex128, copy=False) + + nr, nc = src_freq.shape + Nr = ifftshift(np.arange(-np.fix(nr / 2.0), np.ceil(nr / 2.0))) + Nc = ifftshift(np.arange(-np.fix(nc / 2.0), np.ceil(nc / 2.0))) + Nc, Nr = np.meshgrid(Nc, Nr) + Greg = src_freq * np.exp(1j * 2 * np.pi * (-shifts[0] * Nr / nr - shifts[1] * Nc / nc)) + + Greg = Greg.dot(np.exp(1j * diffphase)) + Greg = np.dstack([np.real(Greg), np.imag(Greg)]) + new_img = cv2.idft(Greg)[:, :, 0] + + max_w, max_h, min_w, min_h = 0, 0, 0, 0 + max_h, max_w = np.ceil(np.maximum((max_h, max_w), shifts[:2])).astype(int) + min_h, min_w = np.floor(np.minimum((min_h, min_w), shifts[:2])).astype(int) + + new_img[:max_h] = new_img[max_h] + if min_h < 0: + new_img[min_h:] = new_img[min_h - 1] + if max_w > 0: + new_img[:, :max_w] = new_img[:, max_w, np.newaxis] + if min_w < 0: + new_img[:, min_w:] = new_img[:, min_w - 1, np.newaxis] + + return new_img + + +def apply_shift(image: xr.DataArray, shift: Shift) -> xr.DataArray: + M = np.float32([[1, 0, shift.width], [0, 1, shift.height]]) + + shifted_frame = cv2.warpAffine( + image.values, + M, + (image.sizes[AXIS.width_dim], image.sizes[AXIS.height_dim]), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REPLICATE, + ) + return xr.DataArray(shifted_frame, dims=image.dims, coords=image.coords) + + +def check_shift_validity( + source: xr.DataArray, target: xr.DataArray, shift: np.ndarray, threshold: float +) -> bool: + expected = np.array([5, 5]) + tester = shift_by(source.values, *expected) + total, _, _ = register_shift(tester, target.values) + result = total - shift + error = np.linalg.norm(result - expected) + return error < threshold diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index f1298b77..5b8bf27d 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -55,9 +55,9 @@ class Toy(BaseModel): cell_radii: int | list[int] cell_positions: list[Position] cell_traces: list[np.ndarray] - cell_ids: list[str] + cell_ids: list[str] = None """If none, auto populated as cell_{idx}.""" - detected_ons: list[int] + detected_ons: list[int] = None emit_frames: bool = False _footprints: xr.DataArray = PrivateAttr(init=False) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 8801343c..4a1687a9 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -30,3 +30,8 @@ def generate_text_image( org = (frame_dims[0] // 2, frame_dims[1] // 2) return cv2.putText(image, text, org, font, font_scale, color, thickness, cv2.LINE_AA) + + +def shift_by(img: np.ndarray, right_pix: int, down_pix: int) -> np.ndarray: + M = np.float32([[1, 0, right_pix], [0, 1, down_pix]]) + return cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), borderMode=cv2.BORDER_REPLICATE) diff --git a/src/cala/util.py b/src/cala/util.py index 32be5074..01f673c6 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -51,7 +51,7 @@ def package_frame(frame: np.ndarray, index: int, timestamp: datetime | str | Non AXIS.height_dim: range(frame.sizes[AXIS.height_dim]), } ) - return Frame.from_array(da.astype(float)) + return Frame.from_array(da) def create_id() -> str: diff --git a/tests/data/pipelines/prep.yaml b/tests/data/pipelines/prep.yaml new file mode 100644 index 00000000..17d23394 --- /dev/null +++ b/tests/data/pipelines/prep.yaml @@ -0,0 +1,64 @@ +noob_id: cala-prep +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +nodes: + source: + type: cala.nodes.io.stream + params: + files: + - cala/msCam1.avi + - cala/msCam2.avi + - cala/msCam3.avi + - cala/msCam4.avi + - cala/msCam5.avi + - cala/msCam6.avi + - cala/msCam7.avi + - cala/msCam8.avi + - cala/msCam9.avi + - cala/msCam10.avi + counter: + type: cala.util.counter + frame: + type: cala.util.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + hotpix: + type: cala.nodes.prep.blur + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + flatten: + type: cala.nodes.prep.butter + params: + kwargs: + cutoff_frequency_ratio: 0.005 + depends: + - frame: hotpix.frame + lines: + type: cala.nodes.prep.remove_freq + params: + orient: both + depends: + - frame: flatten.frame + motion: + type: cala.nodes.prep.Stabilizer + params: + drift_speed: 0.5 + depends: + - frame: lines.frame + +# denoise: # needs to happen after lines +# type: cala.nodes.prep.Restore +# depends: +# - frame: lines.frame +# glow: +# type: cala.nodes.prep.GlowRemover +# depends: +# - frame: motion.frame \ No newline at end of file diff --git a/tests/test_prep/test_anchor.py b/tests/test_prep/test_anchor.py new file mode 100644 index 00000000..8acb4695 --- /dev/null +++ b/tests/test_prep/test_anchor.py @@ -0,0 +1,131 @@ +# These tests are applicable when there is real data. +# Naturally, this is harder to implement on a CI/CD environment, +# and thus have been left as comments for possible future uses. + + +# from collections.abc import Generator +# +# import cv2 +# import numpy as np +# import pytest +# import xarray as xr +# from skimage.filters import difference_of_gaussians +# +# from cala.config import config +# from cala.nodes.io import stream +# from cala.nodes.prep import blur, butter, remove_mean +# from cala.nodes.prep.motion import Anchor, Shift, apply_shift, register_shift +# from cala.util import package_frame +# +# +# @pytest.fixture +# def real() -> Generator[np.ndarray]: +# return stream( +# [ +# "cala/msCam1.avi", +# "cala/msCam2.avi", +# "cala/msCam3.avi", +# "cala/msCam4.avi", +# "cala/msCam5.avi", +# "cala/msCam6.avi", +# "cala/msCam7.avi", +# "cala/msCam8.avi", +# "cala/msCam9.avi", +# "cala/msCam10.avi", +# ] +# ) +# +# +# def median(img: xr.DataArray) -> xr.DataArray: +# tmp = difference_of_gaussians(img, low_sigma=3) # nothing: 1.3 min +# tmp = cv2.normalize( +# tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1 +# ) +# res = cv2.medianBlur(tmp, 11) # 2 mins +# +# return xr.DataArray(res, dims=img.dims, coords=img.coords) +# +# +# def nlm(img: xr.DataArray) -> xr.DataArray: +# tmp = difference_of_gaussians(img, low_sigma=3) # nothing: 1.3 min +# tmp = cv2.normalize( +# tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1 +# ) +# res = cv2.fastNlMeansDenoising(tmp, None, 7, 7, 21) # 3 mins +# +# return xr.DataArray(res, dims=img.dims, coords=img.coords) +# +# +# def gauss(img: xr.DataArray) -> xr.DataArray: +# # tmp = img[100:300, 500:700] +# tmp = difference_of_gaussians(img, low_sigma=3) # nothing: 1.3 min +# tmp = cv2.normalize( +# tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1 +# ) +# res = cv2.GaussianBlur(tmp.astype(float), (11, 11), 20) # 1.5 mins +# +# return xr.DataArray(res, dims=img.dims, coords=img.coords) +# +# +# def test_real(real): +# fourcc = cv2.VideoWriter_fourcc(*"mp4v") +# out = cv2.VideoWriter(config.video_dir / "mc_prev.avi", fourcc, 24.0, (1504, 960)) +# prev = None +# +# for i, arr in enumerate(real): +# frame = package_frame(arr, i) +# frame = blur(frame, method="median", kwargs={"ksize": 3}) +# frame = butter(frame, kwargs={}) +# frame = remove_mean(frame, orient="both") # all of these takes 30 ish seconds +# +# subplots = [] +# for func in [median, gauss, nlm]: +# if prev is None: +# prev = func(frame.array) +# tmpl = prev +# break +# +# prepped = func(frame.array) +# drift, _, _ = register_shift( +# prev.values.astype(float), prepped.values.astype(float), upsample_factor=10 +# ) +# drift = Shift(height=drift[0], width=drift[1]) +# +# corrected = apply_shift(frame.array, drift) +# prev = apply_shift(prepped, drift) +# +# if i % 1 == 0: +# slow_drift, _, _ = register_shift( # THIS WAS CORRECTED< NOT PREV +# tmpl.values.astype(float), prev.values.astype(float), upsample_factor=10 +# ) +# slow_drift = Shift(height=slow_drift[0], width=slow_drift[1]) +# corrected = apply_shift(corrected, slow_drift) +# prev = apply_shift(prev, slow_drift) +# +# subplots.append(corrected) +# +# if subplots: +# left_row = np.concat([frame.array, subplots[0]]) +# right_row = np.concat([subplots[1], subplots[2]]) +# view = np.concat([left_row, right_row], axis=1) +# frame_bgr = cv2.cvtColor(view.astype(np.uint8), cv2.COLOR_GRAY2BGR) +# out.write(frame_bgr) +# +# +# def test_motion2(real): +# fourcc = cv2.VideoWriter_fourcc(*"mp4v") +# out = cv2.VideoWriter(config.video_dir / "mc_lockon.avi", fourcc, 24.0, (752, 960)) +# anchor = Anchor() +# +# for i, arr in enumerate(real): +# frame = package_frame(arr, i) +# frame = blur(frame, method="median", kwargs={"ksize": 3}) +# frame = butter(frame, kwargs={}) +# pre_mc = remove_mean(frame, orient="both") +# frame = anchor.stabilize(pre_mc) +# +# frame_bgr = cv2.cvtColor( +# np.concat([pre_mc.array.astype(np.uint8), frame.array.astype(np.uint8)]), +# cv2.COLOR_GRAY2BGR, +# ) +# out.write(frame_bgr) diff --git a/tests/test_prep/test_motion.py b/tests/test_prep/test_motion.py index a9e9d2e9..3dfeaeeb 100644 --- a/tests/test_prep/test_motion.py +++ b/tests/test_prep/test_motion.py @@ -2,7 +2,8 @@ import pytest from cala.models import AXIS -from cala.nodes.prep.motion import Shift, Stabilizer +from cala.nodes.prep import blur +from cala.nodes.prep.motion import Anchor, Shift from cala.testing.toy import FrameDims, Position, Toy @@ -10,25 +11,32 @@ "params", [ { - "drift_speed": 1, - "pcc_kwargs": {"upsample_factor": 100}, - "pcc_filter": "difference_of_gaussians", - "filter_kwargs": {"low_sigma": 1}, + "upsample_factor": 10, + "dog_kwargs": {"low_sigma": 3}, + "gauss_kwargs": {"ksize": (11, 11), "sigmaX": 20}, } ], ) def test_motion_estimation(params) -> None: - stab = Stabilizer(**params) + stab = Anchor(**params) n_frames = 50 toy = Toy( n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=3, - cell_positions=[Position(width=15, height=15), Position(width=35, height=35)], - cell_traces=[np.array(range(n_frames)), np.array(range(n_frames, 0, -1))], + frame_dims=FrameDims(width=100, height=100), + cell_radii=6, + cell_positions=[ + Position(width=30, height=30), + Position(width=50, height=60), + Position(width=70, height=70), + ], + cell_traces=[ + np.array(range(n_frames)), + np.array(range(n_frames, 0, -1)), + np.array([25] * n_frames), + ], emit_frames=True, ) @@ -50,13 +58,14 @@ def test_motion_estimation(params) -> None: width=shifts[0].width - shift.width, height=shifts[0].height - shift.height ) } + frame = blur(frame, method="gaussian", kwargs={"ksize": (5, 5), "sigmaX": 2}) result.append(stab.stabilize(frame)) - estimate = -np.array([(m.width, m.height) for m in stab.motions_]) + estimate = -np.array([(m.width, m.height) for m in stab._history]) expected = np.array([(m.width, m.height) for m in shifts]) - (shifts[0].width, shifts[0].height) # Allow 1 pixel absolute tolerance - np.testing.assert_allclose(estimate, expected, atol=1.0) + np.testing.assert_allclose(estimate, expected[1:], atol=1.0) def test_rigid_translator_preserves_neuron_traces(): ...