Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions src/cala/assets.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import contextlib
import shutil
from copy import deepcopy
from pathlib import Path
from typing import ClassVar, TypeVar
from typing import Any, ClassVar, TypeVar

import xarray as xr
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator

from cala.models.axis import AXIS, Coords, Dims
from cala.models.checks import has_no_nan, is_non_negative
from cala.models.entity import Entity, Group

AssetType = TypeVar("AssetType", xr.DataArray, None)
AssetType = TypeVar("AssetType", xr.DataArray, Path, None)


class Asset(BaseModel):
Expand All @@ -30,6 +32,9 @@ def array(self, value: xr.DataArray) -> None:
def from_array(cls, array: xr.DataArray) -> "Asset":
return cls(array_=array)

def reset(self) -> None:
self.array_ = None

def __eq__(self, other: "Asset") -> bool:
return self.array.equals(other.array)

Expand Down Expand Up @@ -103,13 +108,23 @@ class Traces(Asset):
@property
def array(self) -> xr.DataArray:
if self.zarr_path:
return (
xr.open_zarr(self.zarr_path)
.isel({AXIS.frames_dim: slice(-self.peek_size, None)})
.to_dataarray()
.isel({"variable": 0}) # not sure why it automatically makes this coordinate
.reset_coords("variable", drop=True)
)
try:
da = (
xr.open_zarr(self.zarr_path)
.isel({AXIS.frames_dim: slice(-self.peek_size, None)})
.to_dataarray()
.drop_vars(["variable"])
.isel(variable=0)
)
return da.assign_coords(
{
AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str),
AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str),
}
).compute()

except FileNotFoundError:
return self.array_
else:
return self.array_

Expand All @@ -120,14 +135,32 @@ def array(self, array: xr.DataArray) -> None:
else:
self.array_ = array

def append(self, array: xr.DataArray, dim: str | list[str]) -> None:
array.to_zarr(self.zarr_path, append_dim=dim)
def update(self, array: xr.DataArray, **kwargs: Any) -> None:
self.validate_array_schema(array)
array.to_zarr(self.zarr_path, **kwargs)

def reset(self) -> None:
self.array_ = None
if self.zarr_path:
path = Path(self.zarr_path)
try:
shutil.rmtree(path)
except FileNotFoundError:
contextlib.suppress(FileNotFoundError)

@classmethod
def from_array(
cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None
) -> "Traces":
return cls(array_=array, zarr_path=zarr_path, peek_size=peek_size)
new_cls = cls(zarr_path=zarr_path, peek_size=peek_size)
new_cls.array = array
return new_cls

@model_validator(mode="after")
def check_zarr_setting(self) -> "Traces":
if self.zarr_path:
assert self.peek_size, "peek_size must be set for zarr."
return self

_entity: ClassVar[Entity] = PrivateAttr(
Group(
Expand Down
10 changes: 5 additions & 5 deletions src/cala/nodes/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def _filter_components(
A[Overlaps, Name("overlaps")],
]:
if len(keep_ids) == 0 or footprints.array is None:
footprints.array = None
traces.array = None
pix_stats.array = None
comp_stats.array = None
overlaps.array = None
footprints.reset()
traces.reset()
pix_stats.reset()
comp_stats.reset()
overlaps.reset()

elif footprints.array[AXIS.id_coord].values.tolist() != keep_ids:
footprints.array = (
Expand Down
16 changes: 14 additions & 2 deletions src/cala/nodes/prep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from .background_removal import remove_background
from .denoise import denoise
from .denoise import Restore, blur
from .flatten import butter
from .glow_removal import GlowRemover
from .lines import remove_freq, remove_mean
from .motion import Stabilizer
from .r_estimate import SizeEst

__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst]
__all__ = [
"blur",
"GlowRemover",
"remove_background",
"Stabilizer",
"SizeEst",
"butter",
"remove_mean",
"remove_freq",
"Restore",
]
40 changes: 36 additions & 4 deletions src/cala/nodes/prep/denoise.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,56 @@
from collections.abc import Callable
from functools import partial
from typing import Annotated as A
from typing import Any, Literal

import cv2
import numpy as np
import xarray as xr
from noob import Name
from noob import Name, process_method
from pydantic import BaseModel
from skimage.restoration import calibrate_denoiser

from cala.assets import Frame


def denoise(
frame: Frame, method: Literal["gaussian", "median", "bilateral"], kwargs: dict[str, Any]
def _bilateral(arr: np.ndarray, **kwargs: Any) -> np.ndarray:
arr = arr.astype(np.float32)
return cv2.bilateralFilter(arr, **kwargs)


class Restore(BaseModel):
kwargs: dict[str, Any] | None = None
model: Callable = None

@process_method
def denoise(self, frame: Frame) -> A[Frame, Name("frame")]:
arr = frame.array
if self.model is None:
if not self.kwargs:
param_matrix = {
"d": list(range(1, 20)),
"sigmaColor": [10, 50, 100, 200, 250],
"sigmaSpace": [10, 50, 100, 200, 250],
}
self.model = calibrate_denoiser(arr, _bilateral, param_matrix)
else:
self.model = partial(cv2.bilateralFilter, **self.kwargs)

denoised = self.model(arr)
return Frame.from_array(xr.DataArray(denoised, dims=arr.dims, coords=arr.coords))


def blur(
frame: Frame,
method: Literal["gaussian", "median", "bilateral", "nonlocal"],
kwargs: dict[str, Any],
) -> A[Frame, Name("frame")]:
"""Denoise a single frame."""
methods: dict[str, Callable] = {
"gaussian": cv2.GaussianBlur,
"median": cv2.medianBlur,
"bilateral": cv2.bilateralFilter,
"nonlocal": cv2.fastNlMeansDenoising,
"nonlocal": cv2.fastNlMeansDenoising, # really slow. ~40 ms.
}

_func = methods[method]
Expand Down
32 changes: 32 additions & 0 deletions src/cala/nodes/prep/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Annotated as A
from typing import Any

import xarray as xr
from noob import Name
from skimage.filters import butterworth
from skimage.restoration import rolling_ball

from cala.assets import Frame


def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]:
"""
butterworth filter centers the image to zero. this causes two images with same intensity ratio
across pixels to be indistinguishable.
To recover the absolute brightness, we shift the filtered image by the
mean brightness of the original frame.
"""
arr = butterworth(frame.array, **kwargs) + frame.array.mean().item()

return Frame.from_array(xr.DataArray(arr, dims=frame.array.dims, coords=frame.array.coords))


def ball(frame: Frame, kwargs: dict[str, Any]) -> Frame:
"""
takes a VERY long time. also not as good as butterworth at handling clustered cells (all bright
region)
"""
bg = rolling_ball(frame.array, **kwargs)
frame.array -= bg

return frame
19 changes: 3 additions & 16 deletions src/cala/nodes/prep/glow_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,6 @@ def process(self, frame: Frame) -> A[Frame, Name("frame")]:
self.base_brightness_ = np.minimum(frame.values, self.base_brightness_)
self._learn_count += 1

return Frame.from_array(
xr.DataArray(frame - self.base_brightness_, dims=frame.dims, coords=frame.coords)
)

def get_info(self) -> dict:
"""Get information about the current state.

Returns
-------
dict
Dictionary containing current statistics
"""
return {
"base_brightness_": self.base_brightness_,
"learn_count": self._learn_count,
}
shifted = (frame - self.base_brightness_).values

return Frame.from_array(xr.DataArray(shifted, dims=frame.dims, coords=frame.coords))
40 changes: 35 additions & 5 deletions src/cala/nodes/prep/hlines.py β†’ src/cala/nodes/prep/lines.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
from typing import Annotated as A
from typing import Any, Literal

import numpy as np
from noob import Name
from scipy.ndimage import convolve1d
from scipy.signal import firwin, welch

from cala.assets import Frame
from cala.models import AXIS


def remove(
frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025
def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Frame, Name("frame")]:
arr = frame.array

if orient == "horiz":
denoised = arr - arr.mean(dim=AXIS.width_dim)
elif orient == "vert":
denoised = arr - arr.mean(dim=AXIS.height_dim)
elif orient == "both":
horiz_dn = arr - arr.mean(dim=AXIS.width_dim)
denoised = horiz_dn - horiz_dn.mean(dim=AXIS.height_dim)
else:
raise ValueError(f"Unknown orientation {orient}")

# diff should be frame.mean - denoised.mean, but denoised.mean is always 0 by definition
diff = frame.array.mean()

return Frame.from_array(denoised + diff)


def remove_freq(
frame: Frame,
orient: Literal["horiz", "vert", "both"],
kwargs: dict[str, Any] | None = None,
) -> A[Frame, Name("frame")]:
if kwargs is None:
kwargs = {}

arr = frame.array

if np.all(frame.array == 0):
return frame

denoised = _remove_lines(
arr.values, distortion_freq=distortion_freq, num_taps=num_taps, eps=eps
)
if orient == "horiz":
denoised = _remove_lines(arr.values, **kwargs)
elif orient == "vert":
denoised = _remove_lines(arr.values.T, **kwargs).T
elif orient == "both":
horiz_dn = _remove_lines(arr.values, **kwargs)
denoised = _remove_lines(horiz_dn.T, **kwargs).T

dmin = denoised.min()
if dmin < 0:
Expand Down
18 changes: 5 additions & 13 deletions src/cala/nodes/prep/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import xarray as xr
from noob import Name, process_method
from pydantic import BaseModel, ConfigDict, Field
from skimage.filters import butterworth, difference_of_gaussians, sato, scharr
from skimage.filters import difference_of_gaussians
from skimage.registration import phase_cross_correlation

from cala.assets import Frame
Expand Down Expand Up @@ -124,22 +124,14 @@ def _compute_shift(self, curr_frame: xr.DataArray) -> Shift:
if: abs(sequential_shift - anchor_shift) < drift_speed
then: true_shift = anchor_shift
"""
filters = {
"butterworth": butterworth,
"difference_of_gaussians": difference_of_gaussians,
"sato": sato,
"scharr": scharr,
}
filt_fn = filters[self.pcc_filter]

curr = filt_fn(curr_frame, **self.filter_kwargs)
prev = filt_fn(self.previous_frame_, **self.filter_kwargs)
anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs)
curr = difference_of_gaussians(curr_frame, **self.filter_kwargs)
prev = difference_of_gaussians(self.previous_frame_, **self.filter_kwargs)
anchor = difference_of_gaussians(self.anchor_frame_, **self.filter_kwargs)

anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs)
sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs)

shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift))
shift_diff = np.linalg.norm(anchor_shift - sequent_shift)

frame_idx = curr_frame[AXIS.frame_coord].item()
drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed
Expand Down
Loading
Loading