diff --git a/src/cala/assets.py b/src/cala/assets.py index 077b4063..a9e751e0 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -1,15 +1,17 @@ +import contextlib +import shutil from copy import deepcopy from pathlib import Path -from typing import ClassVar, TypeVar +from typing import Any, ClassVar, TypeVar import xarray as xr -from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator +from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from cala.models.axis import AXIS, Coords, Dims from cala.models.checks import has_no_nan, is_non_negative from cala.models.entity import Entity, Group -AssetType = TypeVar("AssetType", xr.DataArray, None) +AssetType = TypeVar("AssetType", xr.DataArray, Path, None) class Asset(BaseModel): @@ -30,6 +32,9 @@ def array(self, value: xr.DataArray) -> None: def from_array(cls, array: xr.DataArray) -> "Asset": return cls(array_=array) + def reset(self) -> None: + self.array_ = None + def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @@ -103,13 +108,23 @@ class Traces(Asset): @property def array(self) -> xr.DataArray: if self.zarr_path: - return ( - xr.open_zarr(self.zarr_path) - .isel({AXIS.frames_dim: slice(-self.peek_size, None)}) - .to_dataarray() - .isel({"variable": 0}) # not sure why it automatically makes this coordinate - .reset_coords("variable", drop=True) - ) + try: + da = ( + xr.open_zarr(self.zarr_path) + .isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + .to_dataarray() + .drop_vars(["variable"]) + .isel(variable=0) + ) + return da.assign_coords( + { + AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), + AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), + } + ).compute() + + except FileNotFoundError: + return self.array_ else: return self.array_ @@ -120,14 +135,32 @@ def array(self, array: xr.DataArray) -> None: else: self.array_ = array - def append(self, array: xr.DataArray, dim: str | list[str]) -> None: - array.to_zarr(self.zarr_path, append_dim=dim) + def update(self, array: xr.DataArray, **kwargs: Any) -> None: + self.validate_array_schema(array) + array.to_zarr(self.zarr_path, **kwargs) + + def reset(self) -> None: + self.array_ = None + if self.zarr_path: + path = Path(self.zarr_path) + try: + shutil.rmtree(path) + except FileNotFoundError: + contextlib.suppress(FileNotFoundError) @classmethod def from_array( cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None ) -> "Traces": - return cls(array_=array, zarr_path=zarr_path, peek_size=peek_size) + new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) + new_cls.array = array + return new_cls + + @model_validator(mode="after") + def check_zarr_setting(self) -> "Traces": + if self.zarr_path: + assert self.peek_size, "peek_size must be set for zarr." + return self _entity: ClassVar[Entity] = PrivateAttr( Group( diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 8fca0add..d6f9e5bf 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -96,11 +96,11 @@ def _filter_components( A[Overlaps, Name("overlaps")], ]: if len(keep_ids) == 0 or footprints.array is None: - footprints.array = None - traces.array = None - pix_stats.array = None - comp_stats.array = None - overlaps.array = None + footprints.reset() + traces.reset() + pix_stats.reset() + comp_stats.reset() + overlaps.reset() elif footprints.array[AXIS.id_coord].values.tolist() != keep_ids: footprints.array = ( diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index f8f8661d..6837b571 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,19 @@ from .background_removal import remove_background -from .denoise import denoise +from .denoise import Restore, blur +from .flatten import butter from .glow_removal import GlowRemover +from .lines import remove_freq, remove_mean from .motion import Stabilizer from .r_estimate import SizeEst -__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst] +__all__ = [ + "blur", + "GlowRemover", + "remove_background", + "Stabilizer", + "SizeEst", + "butter", + "remove_mean", + "remove_freq", + "Restore", +] diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index 2efe2f33..6c1b985d 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -1,24 +1,56 @@ from collections.abc import Callable +from functools import partial from typing import Annotated as A from typing import Any, Literal import cv2 import numpy as np import xarray as xr -from noob import Name +from noob import Name, process_method +from pydantic import BaseModel +from skimage.restoration import calibrate_denoiser from cala.assets import Frame -def denoise( - frame: Frame, method: Literal["gaussian", "median", "bilateral"], kwargs: dict[str, Any] +def _bilateral(arr: np.ndarray, **kwargs: Any) -> np.ndarray: + arr = arr.astype(np.float32) + return cv2.bilateralFilter(arr, **kwargs) + + +class Restore(BaseModel): + kwargs: dict[str, Any] | None = None + model: Callable = None + + @process_method + def denoise(self, frame: Frame) -> A[Frame, Name("frame")]: + arr = frame.array + if self.model is None: + if not self.kwargs: + param_matrix = { + "d": list(range(1, 20)), + "sigmaColor": [10, 50, 100, 200, 250], + "sigmaSpace": [10, 50, 100, 200, 250], + } + self.model = calibrate_denoiser(arr, _bilateral, param_matrix) + else: + self.model = partial(cv2.bilateralFilter, **self.kwargs) + + denoised = self.model(arr) + return Frame.from_array(xr.DataArray(denoised, dims=arr.dims, coords=arr.coords)) + + +def blur( + frame: Frame, + method: Literal["gaussian", "median", "bilateral", "nonlocal"], + kwargs: dict[str, Any], ) -> A[Frame, Name("frame")]: """Denoise a single frame.""" methods: dict[str, Callable] = { "gaussian": cv2.GaussianBlur, "median": cv2.medianBlur, "bilateral": cv2.bilateralFilter, - "nonlocal": cv2.fastNlMeansDenoising, + "nonlocal": cv2.fastNlMeansDenoising, # really slow. ~40 ms. } _func = methods[method] diff --git a/src/cala/nodes/prep/flatten.py b/src/cala/nodes/prep/flatten.py new file mode 100644 index 00000000..011f8aab --- /dev/null +++ b/src/cala/nodes/prep/flatten.py @@ -0,0 +1,32 @@ +from typing import Annotated as A +from typing import Any + +import xarray as xr +from noob import Name +from skimage.filters import butterworth +from skimage.restoration import rolling_ball + +from cala.assets import Frame + + +def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]: + """ + butterworth filter centers the image to zero. this causes two images with same intensity ratio + across pixels to be indistinguishable. + To recover the absolute brightness, we shift the filtered image by the + mean brightness of the original frame. + """ + arr = butterworth(frame.array, **kwargs) + frame.array.mean().item() + + return Frame.from_array(xr.DataArray(arr, dims=frame.array.dims, coords=frame.array.coords)) + + +def ball(frame: Frame, kwargs: dict[str, Any]) -> Frame: + """ + takes a VERY long time. also not as good as butterworth at handling clustered cells (all bright + region) + """ + bg = rolling_ball(frame.array, **kwargs) + frame.array -= bg + + return frame diff --git a/src/cala/nodes/prep/glow_removal.py b/src/cala/nodes/prep/glow_removal.py index 87fadf74..c5cd1d8d 100644 --- a/src/cala/nodes/prep/glow_removal.py +++ b/src/cala/nodes/prep/glow_removal.py @@ -20,19 +20,6 @@ def process(self, frame: Frame) -> A[Frame, Name("frame")]: self.base_brightness_ = np.minimum(frame.values, self.base_brightness_) self._learn_count += 1 - return Frame.from_array( - xr.DataArray(frame - self.base_brightness_, dims=frame.dims, coords=frame.coords) - ) - - def get_info(self) -> dict: - """Get information about the current state. - - Returns - ------- - dict - Dictionary containing current statistics - """ - return { - "base_brightness_": self.base_brightness_, - "learn_count": self._learn_count, - } + shifted = (frame - self.base_brightness_).values + + return Frame.from_array(xr.DataArray(shifted, dims=frame.dims, coords=frame.coords)) diff --git a/src/cala/nodes/prep/hlines.py b/src/cala/nodes/prep/lines.py similarity index 55% rename from src/cala/nodes/prep/hlines.py rename to src/cala/nodes/prep/lines.py index 2d74dfcc..232a2115 100644 --- a/src/cala/nodes/prep/hlines.py +++ b/src/cala/nodes/prep/lines.py @@ -1,4 +1,5 @@ from typing import Annotated as A +from typing import Any, Literal import numpy as np from noob import Name @@ -6,19 +7,48 @@ from scipy.signal import firwin, welch from cala.assets import Frame +from cala.models import AXIS -def remove( - frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025 +def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Frame, Name("frame")]: + arr = frame.array + + if orient == "horiz": + denoised = arr - arr.mean(dim=AXIS.width_dim) + elif orient == "vert": + denoised = arr - arr.mean(dim=AXIS.height_dim) + elif orient == "both": + horiz_dn = arr - arr.mean(dim=AXIS.width_dim) + denoised = horiz_dn - horiz_dn.mean(dim=AXIS.height_dim) + else: + raise ValueError(f"Unknown orientation {orient}") + + # diff should be frame.mean - denoised.mean, but denoised.mean is always 0 by definition + diff = frame.array.mean() + + return Frame.from_array(denoised + diff) + + +def remove_freq( + frame: Frame, + orient: Literal["horiz", "vert", "both"], + kwargs: dict[str, Any] | None = None, ) -> A[Frame, Name("frame")]: + if kwargs is None: + kwargs = {} + arr = frame.array if np.all(frame.array == 0): return frame - denoised = _remove_lines( - arr.values, distortion_freq=distortion_freq, num_taps=num_taps, eps=eps - ) + if orient == "horiz": + denoised = _remove_lines(arr.values, **kwargs) + elif orient == "vert": + denoised = _remove_lines(arr.values.T, **kwargs).T + elif orient == "both": + horiz_dn = _remove_lines(arr.values, **kwargs) + denoised = _remove_lines(horiz_dn.T, **kwargs).T dmin = denoised.min() if dmin < 0: diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 8903143f..2cdd9d42 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -8,7 +8,7 @@ import xarray as xr from noob import Name, process_method from pydantic import BaseModel, ConfigDict, Field -from skimage.filters import butterworth, difference_of_gaussians, sato, scharr +from skimage.filters import difference_of_gaussians from skimage.registration import phase_cross_correlation from cala.assets import Frame @@ -124,22 +124,14 @@ def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: if: abs(sequential_shift - anchor_shift) < drift_speed then: true_shift = anchor_shift """ - filters = { - "butterworth": butterworth, - "difference_of_gaussians": difference_of_gaussians, - "sato": sato, - "scharr": scharr, - } - filt_fn = filters[self.pcc_filter] - - curr = filt_fn(curr_frame, **self.filter_kwargs) - prev = filt_fn(self.previous_frame_, **self.filter_kwargs) - anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs) + curr = difference_of_gaussians(curr_frame, **self.filter_kwargs) + prev = difference_of_gaussians(self.previous_frame_, **self.filter_kwargs) + anchor = difference_of_gaussians(self.anchor_frame_, **self.filter_kwargs) anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) - shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) + shift_diff = np.linalg.norm(anchor_shift - sequent_shift) frame_idx = curr_frame[AXIS.frame_coord].item() drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index b56f3ba7..4f3a8a14 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -134,10 +134,20 @@ def ingest_frame( csgraph=overlaps.array.data, directed=False, return_labels=True ) clusters = [np.where(labels == label)[0] for label in np.unique(labels)] - - updated_traces = self._update_traces(A, y, c.copy(), clusters) - - traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) + updated_traces = self._update_traces(A, y, c, clusters) + + if traces.zarr_path: + updated_tr = updated_traces.expand_dims(AXIS.frames_dim).assign_coords( + { + AXIS.timestamp_coord: ( + AXIS.frames_dim, + [updated_traces[AXIS.timestamp_coord].values], + ) + } + ) + traces.update(updated_tr, append_dim=AXIS.frames_dim) + else: + traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) @@ -187,6 +197,7 @@ def _update_traces( # Steps 6-8: Update each group using block coordinate descent for cluster in clusters: # Update traces for current group (division is pointwise) + numerator = u.isel({AXIS.component_dim: cluster}) - ( V.isel({f"{AXIS.component_dim}'": cluster}) @ c ).rename({f"{AXIS.component_dim}'": AXIS.component_dim}) @@ -241,8 +252,22 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: merged_ids = c_det.attrs.get("replaces") if merged_ids: intact_ids = [id_ for id_ in c[AXIS.id_coord].values if id_ not in merged_ids] - c = c.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: intact_ids}).reset_index(AXIS.id_coord) - - traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") + if traces.zarr_path: + traces.array = ( + traces.array.set_xindex(AXIS.id_coord) + .sel({AXIS.id_coord: intact_ids}) + .reset_index(AXIS.id_coord) + ) + else: + c = ( + c.set_xindex(AXIS.id_coord) + .sel({AXIS.id_coord: intact_ids}) + .reset_index(AXIS.id_coord) + ) + + if traces.zarr_path: + traces.update(c_new, append_dim=AXIS.component_dim) + else: + traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") return traces diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 569a2d51..d3dc1cb2 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -15,6 +15,9 @@ assets: scope: session traces: type: cala.assets.Traces + params: + zarr_path: data/traces + peek_size: 100 scope: session pix_stats: type: cala.assets.PixStats @@ -33,7 +36,7 @@ nodes: source: type: cala.testing.SingleCellSource denoise: - type: cala.nodes.prep.denoise + type: cala.nodes.prep.blur params: method: gaussian kwargs: diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index f0082807..f440702c 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -43,50 +43,44 @@ nodes: - index: counter.idx #PREPROCESS BEGINS - saltpepper: - type: cala.nodes.prep.denoise + hotpix: # needs to happen first + type: cala.nodes.prep.blur params: method: median kwargs: ksize: 3 depends: - frame: frame.value - denoise: - type: cala.nodes.prep.denoise + flatten: + type: cala.nodes.prep.butter params: - method: nonlocal kwargs: - h: 4 - templateWindowSize: 7 - searchWindowSize: 21 + cutoff_frequency_ratio: 0.005 depends: - - frame: saltpepper.frame - lines: - type: cala.nodes.prep.hlines.remove + - frame: hotpix.frame + lines: # needs to happen after flatten + type: cala.nodes.prep.remove_freq + params: + orient: both depends: - - frame: denoise.frame + - frame: flatten.frame + denoise: # needs to happen after lines + type: cala.nodes.prep.Restore + depends: + - frame: lines.frame motion: # needs to take place before glow, after lines type: cala.nodes.prep.Stabilizer params: drift_speed: 0.5 pcc_filter: difference_of_gaussians filter_kwargs: - low_sigma: 1 + low_sigma: 4 depends: - - frame: lines.frame + - frame: denoise.frame glow: type: cala.nodes.prep.GlowRemover depends: - frame: motion.frame - smooth: - type: cala.nodes.prep.denoise - params: - method: gaussian - kwargs: - ksize: [ 7, 7 ] - sigmaX: 1.5 - depends: - - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: @@ -100,14 +94,14 @@ nodes: threshold: 0.2 overlap: 0.5 depends: - - frame: smooth.frame + - frame: glow.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: smooth.frame + - frame: motion.frame #PREPROCESS ENDS # FRAME UPDATE BEGINS @@ -119,19 +113,19 @@ nodes: depends: - traces: assets.traces - footprints: assets.footprints - - frame: smooth.frame + - frame: motion.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: smooth.frame + - frame: motion.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: smooth.frame + - frame: motion.frame - new_traces: trace_frame.latest_trace footprints_frame: type: cala.nodes.footprints.Footprinter @@ -226,4 +220,4 @@ nodes: type: return depends: - raw: frame.value - - prep: smooth.frame \ No newline at end of file + - prep: motion.frame \ No newline at end of file diff --git a/tests/test_assets.py b/tests/test_assets.py new file mode 100644 index 00000000..1ea62415 --- /dev/null +++ b/tests/test_assets.py @@ -0,0 +1,72 @@ +import os +from pathlib import Path + +import pytest + +from cala.assets import Traces +from cala.models import AXIS + + +@pytest.fixture +def path(tmp_path: str) -> Path: + return Path(tmp_path) / "assets" + + +def test_assign_zarr(path, connected_cells): + zarr_traces = Traces(zarr_path=path, peek_size=100) + traces = connected_cells.traces.array + zarr_traces.array = traces + print(os.listdir(zarr_traces.zarr_path)) + assert zarr_traces.array_ is None # not in memory + assert zarr_traces.array.equals(traces) + + +def test_from_array(connected_cells, path): + traces = connected_cells.traces.array + zarr_traces = Traces.from_array(traces, path, peek_size=connected_cells.n_frames) + assert zarr_traces.array_ is None + assert zarr_traces.array.equals(traces) + + +@pytest.mark.parametrize("peek_shift", [-1, 0, 1]) +def test_peek(connected_cells, path, peek_shift): + traces = connected_cells.traces.array + zarr_traces = Traces.from_array(traces, path, peek_size=connected_cells.n_frames + peek_shift) + if peek_shift >= 0: + assert zarr_traces.array.equals(traces) + else: + with pytest.raises(AssertionError): + assert zarr_traces.array.equals(traces) + + +def test_ingest_frame(path, connected_cells): + traces = connected_cells.traces.array + old_traces = traces.isel({AXIS.frames_dim: slice(None, -1)}) + zarr_traces = Traces.from_array(old_traces, path, peek_size=connected_cells.n_frames) + new_traces = connected_cells.traces.array.isel({AXIS.frames_dim: [-1]}) + + zarr_traces.update(new_traces, append_dim=AXIS.frames_dim) + # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.frames_dim) + + assert zarr_traces.array.equals(traces) + + +def test_ingest_component(connected_cells, path): + traces = connected_cells.traces.array + old_traces = traces.isel({AXIS.component_dim: slice(None, -1)}) + zarr_traces = Traces.from_array(old_traces, path, peek_size=connected_cells.n_frames) + new_traces = connected_cells.traces.array.isel({AXIS.component_dim: [-1]}) + + zarr_traces.update(new_traces, append_dim=AXIS.component_dim) + # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) + + assert zarr_traces.array.equals(traces) + + +def test_overwrite(connected_cells, separate_cells, path): + conn_traces = connected_cells.traces.array + zarr_traces = Traces.from_array(conn_traces, path, peek_size=connected_cells.n_frames) + + sep_traces = separate_cells.traces.array + zarr_traces.array = sep_traces + assert zarr_traces.array.equals(sep_traces) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 66ebfc9e..b206ca1f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import xarray as xr @@ -31,8 +33,11 @@ def tube(source): @pytest.fixture -def cube(): - return Cube.from_specification("cala-odl") +def cube(tmp_path: Path): + cube = Cube.from_specification("cala-odl") + cube.assets["traces"].spec.params["zarr_path"] = tmp_path / "traces" + cube.assets["traces"].params["zarr_path"] = tmp_path / "traces" + return cube @pytest.fixture diff --git a/tests/test_prep/test_denoise.py b/tests/test_prep/test_denoise.py index 87a7c503..8f745190 100644 --- a/tests/test_prep/test_denoise.py +++ b/tests/test_prep/test_denoise.py @@ -7,7 +7,7 @@ import xarray as xr from cala.models import AXIS -from cala.nodes.prep.denoise import denoise +from cala.nodes.prep.denoise import blur from cala.testing.toy import FrameDims, Position, Toy @@ -46,7 +46,7 @@ def test_denoise( results = [] for frame in iter(gen): - results.append(denoise(frame=frame, method=method, kwargs=params)) + results.append(blur(frame=frame, method=method, kwargs=params)) for exp, res in zip(expected, results): np.testing.assert_allclose(exp.values, res.array.values) diff --git a/tests/test_prep/test_hlines.py b/tests/test_prep/test_lines.py similarity index 85% rename from tests/test_prep/test_hlines.py rename to tests/test_prep/test_lines.py index eacb7aa8..32c04d28 100644 --- a/tests/test_prep/test_hlines.py +++ b/tests/test_prep/test_lines.py @@ -1,7 +1,7 @@ import numpy as np from skimage.metrics import structural_similarity -from cala.nodes.prep.hlines import remove +from cala.nodes.prep.lines import remove_freq from cala.testing.util import generate_text_image from cala.util import package_frame @@ -18,6 +18,6 @@ def test_remove_lines(): frame = package_frame(noisy_img, 0) - result = remove(frame) + result = remove_freq(frame, orient="horiz") assert structural_similarity(img.astype(int), result.array.values.astype(int)) == 1