From b8a68733b71651dafb2f3639809da3b09f403816 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:07:46 -0700 Subject: [PATCH 01/17] feat: flatten frame with butterworth --- src/cala/nodes/prep/__init__.py | 12 +++++++++++- src/cala/nodes/prep/flatten.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 src/cala/nodes/prep/flatten.py diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index f8f8661d..92a910c8 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,17 @@ from .background_removal import remove_background from .denoise import denoise +from .flatten import butter from .glow_removal import GlowRemover from .motion import Stabilizer from .r_estimate import SizeEst -__all__ = [denoise, GlowRemover, remove_background, Stabilizer, SizeEst] +__all__ = [ + "denoise", + "GlowRemover", + "remove_background", + "Stabilizer", + "SizeEst", + "butter", + "remove_mean", + "Restore", +] diff --git a/src/cala/nodes/prep/flatten.py b/src/cala/nodes/prep/flatten.py new file mode 100644 index 00000000..011f8aab --- /dev/null +++ b/src/cala/nodes/prep/flatten.py @@ -0,0 +1,32 @@ +from typing import Annotated as A +from typing import Any + +import xarray as xr +from noob import Name +from skimage.filters import butterworth +from skimage.restoration import rolling_ball + +from cala.assets import Frame + + +def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]: + """ + butterworth filter centers the image to zero. this causes two images with same intensity ratio + across pixels to be indistinguishable. + To recover the absolute brightness, we shift the filtered image by the + mean brightness of the original frame. + """ + arr = butterworth(frame.array, **kwargs) + frame.array.mean().item() + + return Frame.from_array(xr.DataArray(arr, dims=frame.array.dims, coords=frame.array.coords)) + + +def ball(frame: Frame, kwargs: dict[str, Any]) -> Frame: + """ + takes a VERY long time. also not as good as butterworth at handling clustered cells (all bright + region) + """ + bg = rolling_ball(frame.array, **kwargs) + frame.array -= bg + + return frame From cafefbf036e92fe9be2853abc32a6017c6a5f9cc Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:08:32 -0700 Subject: [PATCH 02/17] feat: remove getinfo from glow --- src/cala/nodes/prep/glow_removal.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/src/cala/nodes/prep/glow_removal.py b/src/cala/nodes/prep/glow_removal.py index 87fadf74..c5cd1d8d 100644 --- a/src/cala/nodes/prep/glow_removal.py +++ b/src/cala/nodes/prep/glow_removal.py @@ -20,19 +20,6 @@ def process(self, frame: Frame) -> A[Frame, Name("frame")]: self.base_brightness_ = np.minimum(frame.values, self.base_brightness_) self._learn_count += 1 - return Frame.from_array( - xr.DataArray(frame - self.base_brightness_, dims=frame.dims, coords=frame.coords) - ) - - def get_info(self) -> dict: - """Get information about the current state. - - Returns - ------- - dict - Dictionary containing current statistics - """ - return { - "base_brightness_": self.base_brightness_, - "learn_count": self._learn_count, - } + shifted = (frame - self.base_brightness_).values + + return Frame.from_array(xr.DataArray(shifted, dims=frame.dims, coords=frame.coords)) From dd2f06b0c75346524347b0791e81edd82a55732f Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:08:52 -0700 Subject: [PATCH 03/17] feat: line removal with mean --- src/cala/nodes/prep/{hlines.py => lines.py} | 24 +++++++++++++++++++-- tests/test_prep/test_hlines.py | 4 ++-- 2 files changed, 24 insertions(+), 4 deletions(-) rename src/cala/nodes/prep/{hlines.py => lines.py} (69%) diff --git a/src/cala/nodes/prep/hlines.py b/src/cala/nodes/prep/lines.py similarity index 69% rename from src/cala/nodes/prep/hlines.py rename to src/cala/nodes/prep/lines.py index 2d74dfcc..02f9c285 100644 --- a/src/cala/nodes/prep/hlines.py +++ b/src/cala/nodes/prep/lines.py @@ -1,4 +1,4 @@ -from typing import Annotated as A +from typing import Annotated as A, Literal import numpy as np from noob import Name @@ -6,9 +6,29 @@ from scipy.signal import firwin, welch from cala.assets import Frame +from cala.models import AXIS -def remove( +def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Frame, Name("frame")]: + arr = frame.array + + if orient == "horiz": + denoised = arr - arr.mean(dim=AXIS.width_dim) + elif orient == "vert": + denoised = arr - arr.mean(dim=AXIS.height_dim) + elif orient == "both": + horiz_dn = arr - arr.mean(dim=AXIS.width_dim) + denoised = horiz_dn - horiz_dn.mean(dim=AXIS.height_dim) + else: + raise ValueError(f"Unknown orientation {orient}") + + # diff should be frame.mean - denoised.mean, but denoised.mean is always 0 by definition + diff = frame.array.mean() + + return Frame.from_array(denoised + diff) + + +def remove_freq( frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025 ) -> A[Frame, Name("frame")]: arr = frame.array diff --git a/tests/test_prep/test_hlines.py b/tests/test_prep/test_hlines.py index eacb7aa8..4ac90284 100644 --- a/tests/test_prep/test_hlines.py +++ b/tests/test_prep/test_hlines.py @@ -1,7 +1,7 @@ import numpy as np from skimage.metrics import structural_similarity -from cala.nodes.prep.hlines import remove +from cala.nodes.prep.lines import remove_freq from cala.testing.util import generate_text_image from cala.util import package_frame @@ -18,6 +18,6 @@ def test_remove_lines(): frame = package_frame(noisy_img, 0) - result = remove(frame) + result = remove_freq(frame) assert structural_similarity(img.astype(int), result.array.values.astype(int)) == 1 From 3f911395a016c33bc0e9c0fc7580944bfdf3a141 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:09:20 -0700 Subject: [PATCH 04/17] feat: auto denoising calibration with first frame --- src/cala/nodes/prep/denoise.py | 39 ++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index 2efe2f33..a6568c13 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -1,24 +1,55 @@ from collections.abc import Callable +from functools import partial from typing import Annotated as A from typing import Any, Literal import cv2 import numpy as np import xarray as xr -from noob import Name - +from noob import Name, process_method +from pydantic import BaseModel +from skimage.restoration import calibrate_denoiser from cala.assets import Frame +def _bilateral(arr, **kwargs) -> np.ndarray: + arr = arr.astype(np.float32) + return cv2.bilateralFilter(arr, **kwargs) + + +class Restore(BaseModel): + kwargs: dict[str, Any] | None = None + model: Callable = None + + @process_method + def denoise(self, frame: Frame) -> A[Frame, Name("frame")]: + arr = frame.array + if self.model is None: + if not self.kwargs: + param_matrix = { + "d": list(range(1, 20)), + "sigmaColor": [10, 50, 100, 200, 250], + "sigmaSpace": [10, 50, 100, 200, 250], + } + self.model = calibrate_denoiser(arr, _bilateral, param_matrix) + else: + self.model = partial(cv2.bilateralFilter, **self.kwargs) + + denoised = self.model(arr) + return Frame.from_array(xr.DataArray(denoised, dims=arr.dims, coords=arr.coords)) + + def denoise( - frame: Frame, method: Literal["gaussian", "median", "bilateral"], kwargs: dict[str, Any] + frame: Frame, + method: Literal["gaussian", "median", "bilateral", "nonlocal"], + kwargs: dict[str, Any], ) -> A[Frame, Name("frame")]: """Denoise a single frame.""" methods: dict[str, Callable] = { "gaussian": cv2.GaussianBlur, "median": cv2.medianBlur, "bilateral": cv2.bilateralFilter, - "nonlocal": cv2.fastNlMeansDenoising, + "nonlocal": cv2.fastNlMeansDenoising, # really slow. ~40 ms. } _func = methods[method] From 289cddc21a3e2050b97812b2ed33446e64e630f9 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:10:22 -0700 Subject: [PATCH 05/17] feat: yaml reflects new modules --- src/cala/nodes/prep/__init__.py | 3 +- tests/data/pipelines/with_src.yaml | 50 +++++++++++++----------------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 92a910c8..601c5bbc 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,8 @@ from .background_removal import remove_background -from .denoise import denoise +from .denoise import denoise, Restore from .flatten import butter from .glow_removal import GlowRemover +from .lines import remove_mean from .motion import Stabilizer from .r_estimate import SizeEst diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index f0082807..5e0f7d17 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -43,7 +43,7 @@ nodes: - index: counter.idx #PREPROCESS BEGINS - saltpepper: + hotpix: # needs to happen first type: cala.nodes.prep.denoise params: method: median @@ -51,42 +51,36 @@ nodes: ksize: 3 depends: - frame: frame.value - denoise: - type: cala.nodes.prep.denoise + flatten: + type: cala.nodes.prep.butter params: - method: nonlocal kwargs: - h: 4 - templateWindowSize: 7 - searchWindowSize: 21 + cutoff_frequency_ratio: 0.005 depends: - - frame: saltpepper.frame - lines: - type: cala.nodes.prep.hlines.remove + - frame: hotpix.frame + lines: # needs to happen after flatten + type: cala.nodes.prep.remove_mean + params: + orient: both depends: - - frame: denoise.frame + - frame: flatten.frame + denoise: # needs to happen after lines + type: cala.nodes.prep.Restore + depends: + - frame: lines.frame motion: # needs to take place before glow, after lines type: cala.nodes.prep.Stabilizer params: drift_speed: 0.5 pcc_filter: difference_of_gaussians filter_kwargs: - low_sigma: 1 + low_sigma: 4 depends: - - frame: lines.frame + - frame: denoise.frame glow: type: cala.nodes.prep.GlowRemover depends: - frame: motion.frame - smooth: - type: cala.nodes.prep.denoise - params: - method: gaussian - kwargs: - ksize: [ 7, 7 ] - sigmaX: 1.5 - depends: - - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: @@ -100,14 +94,14 @@ nodes: threshold: 0.2 overlap: 0.5 depends: - - frame: smooth.frame + - frame: glow.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: smooth.frame + - frame: motion.frame #PREPROCESS ENDS # FRAME UPDATE BEGINS @@ -119,19 +113,19 @@ nodes: depends: - traces: assets.traces - footprints: assets.footprints - - frame: smooth.frame + - frame: motion.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: smooth.frame + - frame: motion.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: smooth.frame + - frame: motion.frame - new_traces: trace_frame.latest_trace footprints_frame: type: cala.nodes.footprints.Footprinter @@ -226,4 +220,4 @@ nodes: type: return depends: - raw: frame.value - - prep: smooth.frame \ No newline at end of file + - prep: motion.frame \ No newline at end of file From b0a0bf2491ad0b44ce6f42bdb17e18ad056a0e56 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:24:29 -0700 Subject: [PATCH 06/17] feat: frobenius norm does not need abs --- src/cala/nodes/prep/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index 8903143f..f32ad813 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -139,7 +139,7 @@ def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) - shift_diff = abs(np.linalg.norm(anchor_shift - sequent_shift)) + shift_diff = np.linalg.norm(anchor_shift - sequent_shift) frame_idx = curr_frame[AXIS.frame_coord].item() drift_threshold = (frame_idx - self._anchor_last_applied_on) * self.drift_speed From 3814569f0d03d81431f647d4e7a7de6853c7ca60 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 29 Aug 2025 16:58:39 -0700 Subject: [PATCH 07/17] feat: rename denoise func to blur --- src/cala/nodes/prep/__init__.py | 7 ++++--- src/cala/nodes/prep/denoise.py | 2 +- tests/data/pipelines/with_src.yaml | 2 +- tests/test_prep/test_denoise.py | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 601c5bbc..6549903c 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,18 +1,19 @@ from .background_removal import remove_background -from .denoise import denoise, Restore +from .denoise import blur, Restore from .flatten import butter from .glow_removal import GlowRemover -from .lines import remove_mean +from .lines import remove_mean, remove_freq from .motion import Stabilizer from .r_estimate import SizeEst __all__ = [ - "denoise", + "blur", "GlowRemover", "remove_background", "Stabilizer", "SizeEst", "butter", "remove_mean", + "remove_freq", "Restore", ] diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index a6568c13..a73a3313 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -39,7 +39,7 @@ def denoise(self, frame: Frame) -> A[Frame, Name("frame")]: return Frame.from_array(xr.DataArray(denoised, dims=arr.dims, coords=arr.coords)) -def denoise( +def blur( frame: Frame, method: Literal["gaussian", "median", "bilateral", "nonlocal"], kwargs: dict[str, Any], diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 5e0f7d17..20f1527c 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -44,7 +44,7 @@ nodes: #PREPROCESS BEGINS hotpix: # needs to happen first - type: cala.nodes.prep.denoise + type: cala.nodes.prep.blur params: method: median kwargs: diff --git a/tests/test_prep/test_denoise.py b/tests/test_prep/test_denoise.py index 87a7c503..8f745190 100644 --- a/tests/test_prep/test_denoise.py +++ b/tests/test_prep/test_denoise.py @@ -7,7 +7,7 @@ import xarray as xr from cala.models import AXIS -from cala.nodes.prep.denoise import denoise +from cala.nodes.prep.denoise import blur from cala.testing.toy import FrameDims, Position, Toy @@ -46,7 +46,7 @@ def test_denoise( results = [] for frame in iter(gen): - results.append(denoise(frame=frame, method=method, kwargs=params)) + results.append(blur(frame=frame, method=method, kwargs=params)) for exp, res in zip(expected, results): np.testing.assert_allclose(exp.values, res.array.values) From b36a444c372ec9aedd7e2207e702766c39ee572a Mon Sep 17 00:00:00 2001 From: Raymond Date: Sat, 30 Aug 2025 04:03:48 -0700 Subject: [PATCH 08/17] feat: remove_mean does not work nearly as well. --- src/cala/nodes/prep/lines.py | 19 ++++++++++++++----- tests/data/pipelines/with_src.yaml | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/cala/nodes/prep/lines.py b/src/cala/nodes/prep/lines.py index 02f9c285..7f0ad795 100644 --- a/src/cala/nodes/prep/lines.py +++ b/src/cala/nodes/prep/lines.py @@ -1,4 +1,4 @@ -from typing import Annotated as A, Literal +from typing import Annotated as A, Literal, Any import numpy as np from noob import Name @@ -29,16 +29,25 @@ def remove_mean(frame: Frame, orient: Literal["horiz", "vert", "both"]) -> A[Fra def remove_freq( - frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025 + frame: Frame, + orient: Literal["horiz", "vert", "both"], + kwargs: dict[str, Any] | None = None, ) -> A[Frame, Name("frame")]: + if kwargs is None: + kwargs = {} + arr = frame.array if np.all(frame.array == 0): return frame - denoised = _remove_lines( - arr.values, distortion_freq=distortion_freq, num_taps=num_taps, eps=eps - ) + if orient == "horiz": + denoised = _remove_lines(arr.values, **kwargs) + elif orient == "vert": + denoised = _remove_lines(arr.values.T, **kwargs).T + elif orient == "both": + horiz_dn = _remove_lines(arr.values, **kwargs) + denoised = _remove_lines(horiz_dn.T, **kwargs).T dmin = denoised.min() if dmin < 0: diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 20f1527c..f440702c 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -59,7 +59,7 @@ nodes: depends: - frame: hotpix.frame lines: # needs to happen after flatten - type: cala.nodes.prep.remove_mean + type: cala.nodes.prep.remove_freq params: orient: both depends: From ba62a8abc4dbb2fb7c4f18f882a5809c4a572540 Mon Sep 17 00:00:00 2001 From: Raymond Date: Sat, 30 Aug 2025 04:04:15 -0700 Subject: [PATCH 09/17] feat: only use dog for motion --- src/cala/nodes/prep/motion.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/cala/nodes/prep/motion.py b/src/cala/nodes/prep/motion.py index f32ad813..2cdd9d42 100644 --- a/src/cala/nodes/prep/motion.py +++ b/src/cala/nodes/prep/motion.py @@ -8,7 +8,7 @@ import xarray as xr from noob import Name, process_method from pydantic import BaseModel, ConfigDict, Field -from skimage.filters import butterworth, difference_of_gaussians, sato, scharr +from skimage.filters import difference_of_gaussians from skimage.registration import phase_cross_correlation from cala.assets import Frame @@ -124,17 +124,9 @@ def _compute_shift(self, curr_frame: xr.DataArray) -> Shift: if: abs(sequential_shift - anchor_shift) < drift_speed then: true_shift = anchor_shift """ - filters = { - "butterworth": butterworth, - "difference_of_gaussians": difference_of_gaussians, - "sato": sato, - "scharr": scharr, - } - filt_fn = filters[self.pcc_filter] - - curr = filt_fn(curr_frame, **self.filter_kwargs) - prev = filt_fn(self.previous_frame_, **self.filter_kwargs) - anchor = filt_fn(self.anchor_frame_, **self.filter_kwargs) + curr = difference_of_gaussians(curr_frame, **self.filter_kwargs) + prev = difference_of_gaussians(self.previous_frame_, **self.filter_kwargs) + anchor = difference_of_gaussians(self.anchor_frame_, **self.filter_kwargs) anchor_shift, _, _ = phase_cross_correlation(anchor, curr, **self.pcc_kwargs) sequent_shift, _, _ = phase_cross_correlation(prev, curr, **self.pcc_kwargs) From 0911bec799dc246d1a82baf21132fdfde6ee4100 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 08:35:11 -0700 Subject: [PATCH 10/17] feat: zarr management for traces --- src/cala/assets.py | 11 +++++-- src/cala/nodes/traces.py | 10 +++++-- tests/test_assets.py | 63 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 tests/test_assets.py diff --git a/src/cala/assets.py b/src/cala/assets.py index 077b4063..e4df904e 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -120,14 +120,19 @@ def array(self, array: xr.DataArray) -> None: else: self.array_ = array - def append(self, array: xr.DataArray, dim: str | list[str]) -> None: - array.to_zarr(self.zarr_path, append_dim=dim) + def update(self, array: xr.DataArray, **kwargs) -> None: + self.validate_array_schema(array) + array.to_zarr(self.zarr_path, **kwargs) @classmethod def from_array( cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None ) -> "Traces": - return cls(array_=array, zarr_path=zarr_path, peek_size=peek_size) + if zarr_path: + assert peek_size, "peek_size must be set for zarr." + new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) + new_cls.array = array + return new_cls _entity: ClassVar[Entity] = PrivateAttr( Group( diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index b56f3ba7..f2e90bf0 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -137,7 +137,10 @@ def ingest_frame( updated_traces = self._update_traces(A, y, c.copy(), clusters) - traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) + if traces.zarr_path: + traces.update(updated_traces, dim=AXIS.frames_dim) + else: + traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) return PopSnap.from_array(updated_traces) @@ -243,6 +246,9 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: intact_ids = [id_ for id_ in c[AXIS.id_coord].values if id_ not in merged_ids] c = c.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: intact_ids}).reset_index(AXIS.id_coord) - traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") + if traces.zarr_path: + traces.update(c_new, append_dim=AXIS.component_dim) + else: + traces.array = xr.concat([c, c_new], dim=AXIS.component_dim, combine_attrs="drop") return traces diff --git a/tests/test_assets.py b/tests/test_assets.py new file mode 100644 index 00000000..72d06f13 --- /dev/null +++ b/tests/test_assets.py @@ -0,0 +1,63 @@ +import os +from pathlib import Path + +import pytest + +from cala.assets import Traces +from cala.models import AXIS + + +@pytest.fixture +def path(tmp_path: str) -> Path: + return Path(tmp_path) / "assets" + + +def test_assign_zarr(path, connected_cells): + zarr_traces = Traces(zarr_path=path, peek_size=100) + traces = connected_cells.traces.array + zarr_traces.array = traces + print(os.listdir(zarr_traces.zarr_path)) + assert zarr_traces.array_ is None # not in memory + assert zarr_traces.array.equals(traces) + + +def test_from_array(connected_cells, path): + traces = connected_cells.traces.array + zarr_traces = Traces.from_array(traces, path, peek_size=connected_cells.n_frames) + assert zarr_traces.array_ is None + assert zarr_traces.array.equals(traces) + + +@pytest.mark.parametrize("peek_shift", [-1, 0, 1]) +def test_peek(connected_cells, path, peek_shift): + traces = connected_cells.traces.array + zarr_traces = Traces.from_array(traces, path, peek_size=connected_cells.n_frames + peek_shift) + if peek_shift >= 0: + assert zarr_traces.array.equals(traces) + else: + with pytest.raises(AssertionError): + assert zarr_traces.array.equals(traces) + + +def test_ingest_frame(path, connected_cells): + traces = connected_cells.traces.array + old_traces = traces.isel({AXIS.frames_dim: slice(None, -1)}) + zarr_traces = Traces.from_array(old_traces, path, peek_size=connected_cells.n_frames) + new_traces = connected_cells.traces.array.isel({AXIS.frames_dim: [-1]}) + + zarr_traces.update(new_traces, append_dim=AXIS.frames_dim) + # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.frames_dim) + + assert zarr_traces.array.equals(traces) + + +def test_ingest_component(connected_cells, path): + traces = connected_cells.traces.array + old_traces = traces.isel({AXIS.component_dim: slice(None, -1)}) + zarr_traces = Traces.from_array(old_traces, path, peek_size=connected_cells.n_frames) + new_traces = connected_cells.traces.array.isel({AXIS.component_dim: [-1]}) + + zarr_traces.update(new_traces, append_dim=AXIS.component_dim) + # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) + + assert zarr_traces.array.equals(traces) From 1f581bf50beb653d9f0d5b27047b6b6472c3b9f0 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 16:58:24 -0700 Subject: [PATCH 11/17] tests: add orient to line --- tests/test_prep/{test_hlines.py => test_lines.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/test_prep/{test_hlines.py => test_lines.py} (92%) diff --git a/tests/test_prep/test_hlines.py b/tests/test_prep/test_lines.py similarity index 92% rename from tests/test_prep/test_hlines.py rename to tests/test_prep/test_lines.py index 4ac90284..32c04d28 100644 --- a/tests/test_prep/test_hlines.py +++ b/tests/test_prep/test_lines.py @@ -18,6 +18,6 @@ def test_remove_lines(): frame = package_frame(noisy_img, 0) - result = remove_freq(frame) + result = remove_freq(frame, orient="horiz") assert structural_similarity(img.astype(int), result.array.values.astype(int)) == 1 From f4570c77ae1e2821dab8304e6604a770746b0738 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 16:58:53 -0700 Subject: [PATCH 12/17] tests: add zarr option to cube traces --- tests/test_pipeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 66ebfc9e..b206ca1f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import xarray as xr @@ -31,8 +33,11 @@ def tube(source): @pytest.fixture -def cube(): - return Cube.from_specification("cala-odl") +def cube(tmp_path: Path): + cube = Cube.from_specification("cala-odl") + cube.assets["traces"].spec.params["zarr_path"] = tmp_path / "traces" + cube.assets["traces"].params["zarr_path"] = tmp_path / "traces" + return cube @pytest.fixture From 43da503e002fffe1d6c2d57a43c8a2d3c680551f Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 17:22:12 -0700 Subject: [PATCH 13/17] feat: traces returns computed array instead of dask feat: traces has reset method, which removes the actual files --- src/cala/assets.py | 54 +++++++++++++++++++++++++-------- src/cala/nodes/cleanup.py | 10 +++--- src/cala/nodes/prep/__init__.py | 4 +-- 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e4df904e..a9e751e0 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -1,15 +1,17 @@ +import contextlib +import shutil from copy import deepcopy from pathlib import Path -from typing import ClassVar, TypeVar +from typing import Any, ClassVar, TypeVar import xarray as xr -from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator +from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from cala.models.axis import AXIS, Coords, Dims from cala.models.checks import has_no_nan, is_non_negative from cala.models.entity import Entity, Group -AssetType = TypeVar("AssetType", xr.DataArray, None) +AssetType = TypeVar("AssetType", xr.DataArray, Path, None) class Asset(BaseModel): @@ -30,6 +32,9 @@ def array(self, value: xr.DataArray) -> None: def from_array(cls, array: xr.DataArray) -> "Asset": return cls(array_=array) + def reset(self) -> None: + self.array_ = None + def __eq__(self, other: "Asset") -> bool: return self.array.equals(other.array) @@ -103,13 +108,23 @@ class Traces(Asset): @property def array(self) -> xr.DataArray: if self.zarr_path: - return ( - xr.open_zarr(self.zarr_path) - .isel({AXIS.frames_dim: slice(-self.peek_size, None)}) - .to_dataarray() - .isel({"variable": 0}) # not sure why it automatically makes this coordinate - .reset_coords("variable", drop=True) - ) + try: + da = ( + xr.open_zarr(self.zarr_path) + .isel({AXIS.frames_dim: slice(-self.peek_size, None)}) + .to_dataarray() + .drop_vars(["variable"]) + .isel(variable=0) + ) + return da.assign_coords( + { + AXIS.id_coord: lambda ds: da[AXIS.id_coord].astype(str), + AXIS.timestamp_coord: lambda ds: da[AXIS.timestamp_coord].astype(str), + } + ).compute() + + except FileNotFoundError: + return self.array_ else: return self.array_ @@ -120,20 +135,33 @@ def array(self, array: xr.DataArray) -> None: else: self.array_ = array - def update(self, array: xr.DataArray, **kwargs) -> None: + def update(self, array: xr.DataArray, **kwargs: Any) -> None: self.validate_array_schema(array) array.to_zarr(self.zarr_path, **kwargs) + def reset(self) -> None: + self.array_ = None + if self.zarr_path: + path = Path(self.zarr_path) + try: + shutil.rmtree(path) + except FileNotFoundError: + contextlib.suppress(FileNotFoundError) + @classmethod def from_array( cls, array: xr.DataArray, zarr_path: Path | str | None = None, peek_size: int | None = None ) -> "Traces": - if zarr_path: - assert peek_size, "peek_size must be set for zarr." new_cls = cls(zarr_path=zarr_path, peek_size=peek_size) new_cls.array = array return new_cls + @model_validator(mode="after") + def check_zarr_setting(self) -> "Traces": + if self.zarr_path: + assert self.peek_size, "peek_size must be set for zarr." + return self + _entity: ClassVar[Entity] = PrivateAttr( Group( name="trace-group", diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 8fca0add..d6f9e5bf 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -96,11 +96,11 @@ def _filter_components( A[Overlaps, Name("overlaps")], ]: if len(keep_ids) == 0 or footprints.array is None: - footprints.array = None - traces.array = None - pix_stats.array = None - comp_stats.array = None - overlaps.array = None + footprints.reset() + traces.reset() + pix_stats.reset() + comp_stats.reset() + overlaps.reset() elif footprints.array[AXIS.id_coord].values.tolist() != keep_ids: footprints.array = ( diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 6549903c..6837b571 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,8 +1,8 @@ from .background_removal import remove_background -from .denoise import blur, Restore +from .denoise import Restore, blur from .flatten import butter from .glow_removal import GlowRemover -from .lines import remove_mean, remove_freq +from .lines import remove_freq, remove_mean from .motion import Stabilizer from .r_estimate import SizeEst From c94eb6efb7ea6824e4a7bf398fe4563558411cb5 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 17:23:23 -0700 Subject: [PATCH 14/17] feat: trace updates for zarr storage --- src/cala/nodes/prep/denoise.py | 2 +- src/cala/nodes/prep/lines.py | 3 ++- src/cala/nodes/traces.py | 27 +++++++++++++++++++++++---- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index a73a3313..8697f17f 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -12,7 +12,7 @@ from cala.assets import Frame -def _bilateral(arr, **kwargs) -> np.ndarray: +def _bilateral(arr: np.ndarray, **kwargs: Any) -> np.ndarray: arr = arr.astype(np.float32) return cv2.bilateralFilter(arr, **kwargs) diff --git a/src/cala/nodes/prep/lines.py b/src/cala/nodes/prep/lines.py index 7f0ad795..232a2115 100644 --- a/src/cala/nodes/prep/lines.py +++ b/src/cala/nodes/prep/lines.py @@ -1,4 +1,5 @@ -from typing import Annotated as A, Literal, Any +from typing import Annotated as A +from typing import Any, Literal import numpy as np from noob import Name diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index f2e90bf0..4f3a8a14 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -134,11 +134,18 @@ def ingest_frame( csgraph=overlaps.array.data, directed=False, return_labels=True ) clusters = [np.where(labels == label)[0] for label in np.unique(labels)] - - updated_traces = self._update_traces(A, y, c.copy(), clusters) + updated_traces = self._update_traces(A, y, c, clusters) if traces.zarr_path: - traces.update(updated_traces, dim=AXIS.frames_dim) + updated_tr = updated_traces.expand_dims(AXIS.frames_dim).assign_coords( + { + AXIS.timestamp_coord: ( + AXIS.frames_dim, + [updated_traces[AXIS.timestamp_coord].values], + ) + } + ) + traces.update(updated_tr, append_dim=AXIS.frames_dim) else: traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) @@ -190,6 +197,7 @@ def _update_traces( # Steps 6-8: Update each group using block coordinate descent for cluster in clusters: # Update traces for current group (division is pointwise) + numerator = u.isel({AXIS.component_dim: cluster}) - ( V.isel({f"{AXIS.component_dim}'": cluster}) @ c ).rename({f"{AXIS.component_dim}'": AXIS.component_dim}) @@ -244,7 +252,18 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: merged_ids = c_det.attrs.get("replaces") if merged_ids: intact_ids = [id_ for id_ in c[AXIS.id_coord].values if id_ not in merged_ids] - c = c.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: intact_ids}).reset_index(AXIS.id_coord) + if traces.zarr_path: + traces.array = ( + traces.array.set_xindex(AXIS.id_coord) + .sel({AXIS.id_coord: intact_ids}) + .reset_index(AXIS.id_coord) + ) + else: + c = ( + c.set_xindex(AXIS.id_coord) + .sel({AXIS.id_coord: intact_ids}) + .reset_index(AXIS.id_coord) + ) if traces.zarr_path: traces.update(c_new, append_dim=AXIS.component_dim) From 9cc58aec2df722e72da00753be26422cab5d2824 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 17:23:53 -0700 Subject: [PATCH 15/17] test: zarr overwrite test --- tests/test_assets.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_assets.py b/tests/test_assets.py index 72d06f13..1ea62415 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -61,3 +61,12 @@ def test_ingest_component(connected_cells, path): # new_traces.to_zarr(zarr_traces.zarr_path, append_dim=AXIS.component_dim) assert zarr_traces.array.equals(traces) + + +def test_overwrite(connected_cells, separate_cells, path): + conn_traces = connected_cells.traces.array + zarr_traces = Traces.from_array(conn_traces, path, peek_size=connected_cells.n_frames) + + sep_traces = separate_cells.traces.array + zarr_traces.array = sep_traces + assert zarr_traces.array.equals(sep_traces) From 0a197bcbf368b4d36c1c042b77fde608720b4564 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 17:24:09 -0700 Subject: [PATCH 16/17] test: zarr storage setting for cube --- tests/data/pipelines/odl.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 569a2d51..d3dc1cb2 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -15,6 +15,9 @@ assets: scope: session traces: type: cala.assets.Traces + params: + zarr_path: data/traces + peek_size: 100 scope: session pix_stats: type: cala.assets.PixStats @@ -33,7 +36,7 @@ nodes: source: type: cala.testing.SingleCellSource denoise: - type: cala.nodes.prep.denoise + type: cala.nodes.prep.blur params: method: gaussian kwargs: From f5580ca702889ed554ac17a7b56298db801f8365 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 2 Sep 2025 17:27:51 -0700 Subject: [PATCH 17/17] test: zarr storage setting for cube --- src/cala/nodes/prep/denoise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index 8697f17f..6c1b985d 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -9,6 +9,7 @@ from noob import Name, process_method from pydantic import BaseModel from skimage.restoration import calibrate_denoiser + from cala.assets import Frame