From 1b6aa8792d80bc95f8fa955a36c7b1768e10f318 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 7 Oct 2025 14:24:13 -0700 Subject: [PATCH 01/47] feat: update test settings --- tests/data/pipelines/gui.yaml | 46 -------------- tests/data/pipelines/odl.yaml | 97 ++++++++++++------------------ tests/data/pipelines/prep.yaml | 18 ++++-- tests/data/pipelines/with_src.yaml | 2 +- 4 files changed, 51 insertions(+), 112 deletions(-) delete mode 100644 tests/data/pipelines/gui.yaml diff --git a/tests/data/pipelines/gui.yaml b/tests/data/pipelines/gui.yaml deleted file mode 100644 index 891650ed..00000000 --- a/tests/data/pipelines/gui.yaml +++ /dev/null @@ -1,46 +0,0 @@ -noob_id: cala-gui -noob_model: noob.tube.TubeSpecification -noob_version: 0.1.1.dev118+g64d81b7 - -assets: - buffer: - type: cala.assets.Movie - scope: runner - footprints: - type: cala.assets.Footprints - scope: runner - traces: - type: cala.assets.Traces - scope: runner - pix_stats: - type: cala.assets.PixStats - scope: runner - comp_stats: - type: cala.assets.CompStats - scope: runner - overlaps: - type: cala.assets.Overlaps - scope: runner - residuals: - type: cala.assets.Residual - scope: runner - - -nodes: - source: - type: cala.nodes.io.stream - params: - files: - - cala/msCam1.avi - - cala/msCam2.avi - - cala/msCam3.avi - - cala/msCam4.avi - - cala/msCam5.avi - - cala/msCam6.avi - counter: - type: cala.nodes.prep.counter - frame: - type: cala.nodes.prep.package_frame - depends: - - frame: source.value - - index: counter.idx diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index fe08371d..3f1ffad9 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -37,7 +37,7 @@ nodes: - index: counter.idx #PREPROCESS BEGINS - hotpix: # needs to happen first + hotpix: type: cala.nodes.prep.blur params: method: median @@ -53,19 +53,8 @@ nodes: type: cala.nodes.prep.SizeEst params: hardset_radius: 6 - noise_threshold: 2.0 - n_frames: 30 - log_kwargs: - min_sigma: 3 - max_sigma: 10 - num_sigma: 10 - threshold: 0.2 - overlap: 0.5 depends: - frame: glow.frame - #PREPROCESS ENDS - - # FRAME UPDATE BEGINS cache: type: cala.nodes.buffer.fill_buffer params: @@ -73,11 +62,14 @@ nodes: depends: - buffer: assets.buffer - frame: glow.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS trace_frame: type: cala.nodes.traces.FrameUpdate params: - tol: 0.0001 - max_iter: 200 + tol: 0.001 + max_iter: 100 depends: - traces: assets.traces - footprints: assets.footprints @@ -98,42 +90,24 @@ nodes: footprints_frame: type: cala.nodes.footprints.Footprinter params: - bep: 1 + bep: 0 tol: 0.0001 - max_iter: 200 + max_iter: 5 depends: - footprints: assets.footprints - pixel_stats: pix_frame.value - component_stats: comp_frame.value - index: counter.idx - facetune: - type: cala.nodes.cleanup.clear_overestimates - params: - nmf_error: 1.0 - depends: - - footprints: footprints_frame.footprints - - residuals: assets.residuals residual: type: cala.nodes.residual.build params: - n_recalc: 5 + n_recalc: 4 depends: - frames: assets.buffer - footprints: footprints_frame.footprints - traces: assets.traces - residuals: assets.residuals - cleanup: - type: cala.nodes.cleanup.purge_razed_components - params: - min_thicc: 3 - depends: - - footprints: assets.footprints - - traces: assets.traces - - pix_stats: assets.pix_stats - - comp_stats: assets.comp_stats - - overlaps: assets.overlaps - - trigger: residual.movie # FRAME UPDATE ENDS # DETECT BEGINS @@ -152,43 +126,48 @@ nodes: age_limit: 100 smooth_kwargs: sigma: 2 - merge_threshold: 0.8 + merge_threshold: 0.95 depends: - new_fps: nmf.new_fps - new_trs: nmf.new_trs - existing_fp: assets.footprints - existing_tr: assets.traces - - trace_component: - type: cala.nodes.traces.ingest_component + detect_update: + type: cala.nodes.detect.update_assets depends: - - traces: assets.traces - - new_traces: catalog.new_traces - footprint_component: - type: cala.nodes.footprints.ingest_component - depends: - - footprints: assets.footprints - new_footprints: catalog.new_footprints - pix_component: - type: cala.nodes.pixel_stats.ingest_component - depends: - - pixel_stats: assets.pix_stats - - frames: assets.buffer - new_traces: catalog.new_traces + - footprints: assets.footprints - traces: assets.traces - comp_component: - type: cala.nodes.component_stats.ingest_component - depends: + - pixel_stats: assets.pix_stats - component_stats: assets.comp_stats - - traces: assets.traces - - new_traces: catalog.new_traces + - overlaps: assets.overlaps + - buffer: assets.buffer + # DETECT ENDS - overlaps_update: - type: cala.nodes.overlap.initialize + merge: + type: cala.nodes.merge.merge_existing + params: + merge_interval: 500 + merge_threshold: 0.95 + smooth_kwargs: + sigma: 2 depends: + - shapes: assets.footprints + - traces: assets.traces - overlaps: assets.overlaps - - footprints: footprint_component.footprints - # DETECT ENDS + - trigger: detect_update.footprints + merge_update: + type: cala.nodes.detect.update_assets + depends: + - new_footprints: merge.footprints + - new_traces: merge.traces + - footprints: assets.footprints + - traces: assets.traces + - pixel_stats: assets.pix_stats + - component_stats: assets.comp_stats + - overlaps: assets.overlaps + - buffer: assets.buffer return: type: return diff --git a/tests/data/pipelines/prep.yaml b/tests/data/pipelines/prep.yaml index 308c6217..d846b61c 100644 --- a/tests/data/pipelines/prep.yaml +++ b/tests/data/pipelines/prep.yaml @@ -25,7 +25,6 @@ nodes: - frame: source.value - index: counter.idx - #PREPROCESS BEGINS hotpix: type: cala.nodes.prep.blur params: @@ -51,11 +50,18 @@ nodes: type: cala.nodes.prep.Anchor depends: - frame: lines.frame - # denoise: - # type: cala.nodes.prep.Restore - # depends: - # - frame: motion.frame + denoise: + type: cala.nodes.prep.blur + params: + method: median + kwargs: + ksize: 7 + depends: + - frame: motion.frame glow: type: cala.nodes.prep.GlowRemover depends: - - frame: motion.frame \ No newline at end of file + - frame: denoise.frame + return: + type: return + depends: glow.frame \ No newline at end of file diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 5330471d..be2fc8e9 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -131,7 +131,7 @@ nodes: footprints_frame: type: cala.nodes.footprints.Footprinter params: - bep: 4 + bep: 0 tol: 0.0001 max_iter: 5 depends: From 7d2feaac854028868897caf2ca45c5f846623bff Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 7 Oct 2025 14:25:25 -0700 Subject: [PATCH 02/47] feat: asset validate / sparsify optional --- src/cala/assets.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index e10b1461..080fba28 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -2,7 +2,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, ClassVar, TypeVar +from typing import Any, ClassVar, TypeVar, Self import numpy as np import xarray as xr @@ -19,6 +19,7 @@ class Asset(BaseModel): + validate_schema: bool = False array_: AssetType = None _entity: ClassVar[Entity] @@ -46,13 +47,12 @@ def __eq__(self, other: "Asset") -> bool: def entity(cls) -> Entity: return cls._entity - @field_validator("array_", mode="after") - @classmethod - def validate_array_schema(cls, value: xr.DataArray) -> AssetType: - if value is not None: - value.validate.against_schema(cls._entity.model) + @model_validator(mode="after") + def validate_array_schema(self) -> Self: + if self.validate_schema and self.array_ is not None: + self.array_.validate.against_schema(self._entity.model) - return value + return self class Footprint(Asset): @@ -66,8 +66,8 @@ class Footprint(Asset): ) @classmethod - def from_array(cls, array: xr.DataArray) -> "Footprint": - if isinstance(array.data, np.ndarray): + def from_array(cls, array: xr.DataArray, sparsify: bool = True) -> "Footprint": + if sparsify and isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) return cls(array_=array) @@ -106,8 +106,8 @@ class Footprints(Asset): ) @classmethod - def from_array(cls, array: xr.DataArray) -> "Footprints": - if isinstance(array.data, np.ndarray): + def from_array(cls, array: xr.DataArray, sparsify: bool = True) -> "Footprints": + if sparsify and isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) return cls(array_=array) From 9640834246a107f897d44d13cca5c22f76452d0e Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 7 Oct 2025 14:26:02 -0700 Subject: [PATCH 03/47] feat: asset validate / sparsify optional --- src/cala/nodes/detect/catalog.py | 16 ++++++++-------- src/cala/nodes/detect/slice_nmf.py | 2 +- tests/test_iter/test_footprints.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index d85ce0a9..09558649 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -34,7 +34,7 @@ def process( if not new_fps or not new_trs: return Footprints(), Traces() - new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim) + new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim).as_numpy() new_trs = xr.concat([tr.array for tr in new_trs], dim=AXIS.component_dim) merge_mat = self._merge_matrix(new_fps, new_trs) new_fps, new_trs = _merge(new_fps, new_trs, merge_mat) @@ -129,7 +129,7 @@ def _register(new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Tr .isel({AXIS.component_dim: 0}) ) - return Footprint.from_array(footprint), Trace.from_array(trace) + return Footprint.from_array(footprint, sparsify=False), Trace.from_array(trace) def _register_batch(new_fps: xr.DataArray, new_trs: xr.DataArray) -> tuple[Footprints, Traces]: @@ -155,7 +155,7 @@ def _register_batch(new_fps: xr.DataArray, new_trs: xr.DataArray) -> tuple[Footp } ) - return Footprints.from_array(footprints), Traces.from_array(traces) + return Footprints.from_array(footprints, sparsify=False), Traces.from_array(traces) def _recompose( @@ -163,7 +163,7 @@ def _recompose( ) -> tuple[Footprint, Trace]: # Reshape neighborhood to 2D matrix (time × space) shape = movie.sum(dim=AXIS.frames_dim) > 0 - slice_ = Movie.from_array(movie.where(shape.as_numpy(), 0, drop=True)) + slice_ = Movie.from_array(movie.where(shape, 0, drop=True)) a, c = _nmf(slice_) @@ -186,7 +186,7 @@ def _nmf(movie: Movie) -> tuple[np.ndarray, np.ndarray]: # Apply NMF (check how long nndsvd takes vs random) model = NMF(n_components=1, init="random", tol=1e-4, max_iter=200) - c = model.fit_transform(stacked.as_numpy()) # temporal component + c = model.fit_transform(stacked) # temporal component a = model.components_ # spatial component return a, c @@ -215,7 +215,7 @@ def _reshape( coords=slice_coords, ) - return Footprint.from_array(a_new), Trace.from_array(c_new) + return Footprint.from_array(a_new, sparsify=False), Trace.from_array(c_new) def _merge_with( @@ -273,7 +273,7 @@ def _merge( res = fps @ trs new_fp, new_tr = _recompose(res, footprints[0].coords, traces[0].coords) else: - new_fp, new_tr = Footprint.from_array(fps[0]), Trace.from_array(trs[0]) + new_fp, new_tr = Footprint.from_array(fps[0], sparsify=False), Trace.from_array(trs[0]) combined_fps.append(new_fp) combined_trs.append(new_tr) @@ -293,7 +293,7 @@ def _absorb( footprints = [] traces = [] - merge_matrix.data = label(merge_matrix.as_numpy(), background=0, connectivity=1) + merge_matrix.data = label(merge_matrix, background=0, connectivity=1) merge_matrix = merge_matrix.assign_coords( {AXIS.component_dim: range(merge_matrix.sizes[AXIS.component_dim])} ).reset_index(AXIS.component_dim) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 4bf80343..7e012eab 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -64,7 +64,7 @@ def process( energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 if (self.error_ / l1_norm) <= self.reprod_tol: - fps.append(Footprint.from_array(a_new)) + fps.append(Footprint.from_array(a_new, sparsify=False)) trs.append(Trace.from_array(c_new)) res = (res - comp_recon).clip(self.error_ / l1_norm) else: diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index a50a3be3..2ac50203 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -62,7 +62,7 @@ def fpter() -> Node: NodeSpecification( id="test_footprinter", type="cala.nodes.footprints.Footprinter", - params={"bep": 3, "tol": 1e-7}, + params={"bep": 0, "tol": 1e-7}, ) ) From 7a2d6924d3ee6aac00db44b307560feb8dd6d843 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 12:33:43 -0700 Subject: [PATCH 04/47] perf: optimize residual to only update the latest, or zero out if recently discovered a component in the area --- src/cala/nodes/residual.py | 59 ++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 03d96dda..b3811370 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -31,7 +31,8 @@ def build( Shape: (frames × height × width) """ if footprints.array is None or traces.array is None: - return Residual.from_array(frames.array) + residuals.array = frames.array + return residuals # Reshape frames to pixels x time Y = frames.array @@ -44,19 +45,22 @@ def build( # Reshape footprints to (pixels x components) A = footprints.array - R_latest = Y.isel({AXIS.frames_dim: -1}) - (A @ C.isel({AXIS.frames_dim: -1})) - if R_latest.min() < 0: - shifted_tr = _align_overestimates(A, C.isel({AXIS.frames_dim: -1}), R_latest) + # Compute residual R = Y - [A,b][C;f] + R_curr = Y.isel({AXIS.frames_dim: -1}) - xr.DataArray( + np.matmul( + A.transpose(*AXIS.spatial_dims, ...).data, + C.transpose(AXIS.component_dim, ...).isel({AXIS.frames_dim: -1}).data, + ), + dims=AXIS.spatial_dims, + ) + if R_curr.min() < 0: + shifted_tr = _align_overestimates(A, C.isel({AXIS.frames_dim: -1}), R_curr) C.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr - # Compute residual R = Y - [A,b][C;f] - # if recently discovered (during the expansion phase), recalculate. otherwise, just append?! - # if residuals.array is None: - R = Y - (A @ C) - # else: - # R = _update(Y, A, C, residuals.array, n_recalc=n_recalc) - residuals.array = R.clip(min=0) # clipping is for the first n frames + # if recently discovered, set to zero (or a small number). otherwise, just append + R = _update(Y, A, C, residuals.array, n_recalc=n_recalc) + residuals.array = R # clipping is for the first n frames return residuals @@ -130,23 +134,16 @@ def _update( targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].max() - n_recalc) if any(targets): - target_ids = targets.where(targets, drop=True)[AXIS.id_coord].values - - A_recent = ( - A.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: target_ids}).reset_index(AXIS.id_coord) - ) - recalc_area = (A_recent.sum(dim=AXIS.component_dim) > 0).as_numpy() - C_recent = ( - C.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: target_ids}).reset_index(AXIS.id_coord) - ) - - recalc = Y.isel({AXIS.frames_dim: slice(None, -1)}).where( - recalc_area - ) - A_recent @ C_recent.isel({AXIS.frames_dim: slice(None, -1)}) - - R = R.sel({AXIS.frame_coord: recalc[AXIS.frame_coord]}).where( - ~recalc_area, recalc, drop=False - ) - R_latest = Y.isel({AXIS.frames_dim: -1}) - (A @ C.isel({AXIS.frames_dim: -1})) - - return xr.concat([R, R_latest], dim=AXIS.frames_dim) + target_ids = targets.where(targets, drop=True)[AXIS.id_coord] + target_area = ~((A.where(target_ids, drop=True)).max(dim=AXIS.component_dim) > 0) + R *= target_area.as_numpy() + + R_curr = Y.isel({AXIS.frames_dim: -1}) - xr.DataArray( + np.matmul( + A.transpose(*AXIS.spatial_dims, ...).data, + C.transpose(AXIS.component_dim, ...).isel({AXIS.frames_dim: -1}).data, + ), + dims=AXIS.spatial_dims, + ) + + return xr.concat([R, R_curr.clip(min=0)], dim=AXIS.frames_dim) From 1ca5763d77e53eb79122f304d90b6531997c5ab9 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 12:33:58 -0700 Subject: [PATCH 05/47] feat: assert_scalar_multiple_arrays error message --- src/cala/testing/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 0870e7bd..aa215a8f 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -9,10 +9,10 @@ def assert_scalar_multiple_arrays(a: xr.DataArray, b: xr.DataArray, /, rtol: flo if not 0 <= rtol <= 1: raise ValueError(f"rtol must be between 0 and 1, got {rtol}.") - abab = (a @ b) ** 2 - aabb = a.dot(a) * b.dot(b) + abab = ((a @ b) ** 2).item() + aabb = (a.dot(a) * b.dot(b)).item() - assert abab > aabb * (1 - rtol) + assert abab > aabb * (1 - rtol), f"Threshold not met: {abab=} > {aabb * (1 - rtol)=}" def generate_text_image( From 8dc78f7e40b191c4a86fb07f0f96411ddb3a72dd Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 13:12:16 -0700 Subject: [PATCH 06/47] test: rank1nmf test --- tests/test_iter/test_detect.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 8a1f6f75..f9ecfc7d 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -2,10 +2,12 @@ import pytest import xarray as xr from noob.node import NodeSpecification +from sklearn.decomposition import NMF from cala.assets import AXIS, Footprints, Residual, Traces from cala.nodes.detect import Cataloger, SliceNMF from cala.nodes.detect.catalog import _merge_with, _register +from cala.nodes.detect.slice_nmf import rank1nmf from cala.testing.util import assert_scalar_multiple_arrays @@ -174,3 +176,20 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): ) # 2. the trace and footprint values are accurate (where they do exist) xr.testing.assert_allclose(result.as_numpy(), expected.as_numpy(), atol=1e-3) + + +def test_rank1nmf(single_cell): + Y = single_cell.make_movie().array + R = Y.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) + R += np.random.randint(0, 2, R.shape) + + shape = np.mean(R.values, axis=1).shape + a_res, c_res, err_res = rank1nmf(R.values, np.random.random(shape), iters=10) + + nmf = NMF(n_components=1, init="random", max_iter=10, tol=1e-3) + a_exp = nmf.fit_transform(R.values) + c_exp = nmf.components_ + err_exp = nmf.reconstruction_err_ + + assert_scalar_multiple_arrays(np.squeeze(a_exp), a_res) + assert_scalar_multiple_arrays(np.squeeze(c_exp), c_res) From 72245e4eb1711a51e936403411805ef43e94f234 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 13:13:26 -0700 Subject: [PATCH 07/47] test: assert_scalar_multiple_arrays for numpy --- src/cala/assets.py | 4 ++-- src/cala/testing/util.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 080fba28..8d0d70e9 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -130,7 +130,7 @@ def array(self) -> xr.DataArray: @array.setter def array(self, array: xr.DataArray) -> None: if self.zarr_path: - self.validate_array_schema(array) + array.validate.against_schema(self._entity.model) array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite else: self.array_ = array @@ -173,7 +173,7 @@ def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Dat ) def update(self, array: xr.DataArray, **kwargs: Any) -> None: - self.validate_array_schema(array) + array.validate.against_schema(self._entity.model) array.to_zarr(self.zarr_path, **kwargs) @classmethod diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index aa215a8f..82327326 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -3,12 +3,21 @@ import xarray as xr -def assert_scalar_multiple_arrays(a: xr.DataArray, b: xr.DataArray, /, rtol: float = 1e-5) -> None: - """Using the Pythagorean Theorem""" +def assert_scalar_multiple_arrays( + a: np.ndarray | xr.DataArray, b: np.ndarray | xr.DataArray, /, rtol: float = 1e-5 +) -> None: + """ + Using the Pythagorean Theorem + Only works with 1-D arrays. (see np.squeeze) + a: (n, ) + b: (n, ) + """ if not 0 <= rtol <= 1: raise ValueError(f"rtol must be between 0 and 1, got {rtol}.") + assert len(a.shape) == len(b.shape) == 1, f"Arrays must be 1-D. Given: {a.shape=}, {b.shape=}" + abab = ((a @ b) ** 2).item() aabb = (a.dot(a) * b.dot(b)).item() From c830f8cfd8fd4dffa39fc2de9d9a470a0264675e Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 15:54:06 -0700 Subject: [PATCH 08/47] feat: fix trace update --- src/cala/nodes/traces.py | 87 ++++++++++++++++++++++++++++------ tests/test_iter/test_traces.py | 19 ++++---- 2 files changed, 83 insertions(+), 23 deletions(-) diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index fd919c28..ffbb65cb 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -1,8 +1,10 @@ +from logging import Logger from typing import Annotated as A import numpy as np import xarray as xr from noob import Name, process_method +from pydantic import BaseModel from scipy.sparse.csgraph import connected_components from cala.assets import Footprints, Frame, Overlaps, PopSnap, Traces @@ -10,12 +12,11 @@ from cala.models import AXIS -class FrameUpdate: - logger = init_logger(__name__) +class FrameUpdate(BaseModel): + tol: float + max_iter: int - def __init__(self, tol: float, max_iter: int | None = None) -> None: - self.tol = tol - self.max_iter = max_iter + _logger: Logger = init_logger(__name__) @process_method def ingest_frame( @@ -53,15 +54,22 @@ def ingest_frame( return PopSnap() # Prepare inputs for the update algorithm - A = footprints.array.stack({"pixels": AXIS.spatial_dims}) + A = footprints.array.stack({"pixels": AXIS.spatial_dims}).transpose("pixels", ...) y = frame.array.stack({"pixels": AXIS.spatial_dims}) - c = traces.array.isel({AXIS.frames_dim: -1}) + c = traces.array.isel({AXIS.frames_dim: -1}).copy() + + AtA = (A @ A.rename(AXIS.component_rename)).to_numpy() _, labels = connected_components( csgraph=overlaps.array.data, directed=False, return_labels=True ) clusters = [np.where(labels == label)[0] for label in np.unique(labels)] - updated_traces = self._update_traces(A, y, c.copy(), clusters) + C, noisyC = _update_traces( + y.values, A.data, c.values, AtA, iters=self.max_iter, tol=self.tol, groups=clusters + ) + updated_traces = xr.DataArray(C, dims=c.dims, coords=c.coords).assign_coords( + {AXIS.frame_coord: y[AXIS.frame_coord], AXIS.timestamp_coord: y[AXIS.timestamp_coord]} + ) if traces.zarr_path: updated_tr = updated_traces.expand_dims(AXIS.frames_dim).assign_coords( @@ -79,11 +87,7 @@ def ingest_frame( return PopSnap.from_array(updated_traces) def _update_traces( - self, - A: xr.DataArray, - y: xr.DataArray, - c: xr.DataArray, - clusters: list[np.ndarray], + self, A: xr.DataArray, y: xr.DataArray, c: xr.DataArray, clusters: list[np.ndarray] ) -> xr.DataArray: """ Implementation of the temporal traces update algorithm. @@ -137,12 +141,67 @@ def _update_traces( if np.linalg.norm(c - c_old) >= self.tol * np.linalg.norm(c_old) or maxed: if maxed: - self.logger.debug(msg="max_iter reached before converging.") + self._logger.debug(msg="max_iter reached before converging.") return xr.DataArray( c.values, dims=c.dims, coords=c[AXIS.component_dim].coords ).assign_coords(y[AXIS.frames_dim].coords) +def _update_traces( + y: np.ndarray, + A: np.ndarray, + noisyC: np.ndarray, + AtA: np.ndarray, + iters: int = 5, + tol: float = 1e-3, + groups: list[list[int]] = None, +): + """ + Solve C = argmin_C ||Yr-AC|| using block-coordinate decent + Parameters + ---------- + y : array of float, shape (pixels,) + flattened array of raw data frame + A : sparse matrix of float, shape (pixels, comps) + neural shapes + noisyC : ndarray of float, shape (comps,) + Initial value of fluorescence intensities. + AtA : ndarray of float (comps, comps) + Overlap matrix of shapes A. + iters : int, optional + Maximal number of iterations. + tol : float, optional + Tolerance. + groups: list of lists + groups of components to update in parallel + """ + AtY = A.T.dot(y) + num_iters = 0 + C_old = np.zeros_like(noisyC) + C = noisyC.copy() + + # faster than np.linalg.norm + def norm(c): + return np.sqrt(c.ravel().dot(c.ravel())) + + while (norm(C_old - C) >= tol * norm(C_old)) and (num_iters < iters): + C_old[:] = C + if groups is None: + for m in range(len(AtY)): + noisyC[m] = C[m] + (AtY[m] - AtA[m].dot(C)) / (AtA[m, m] + np.finfo(C.dtype).eps) + C[m] = max(noisyC[m], 0) + else: + for m in groups: + noisyC[m] = C[m] + (AtY[m] - AtA[m].dot(C)) / ( + AtA.diagonal()[m] + np.finfo(C.dtype).eps + ) + C[m] = np.maximum(noisyC[m], 0) + num_iters += 1 + + # noisyC is just C with negative values (unclipped) + return C, noisyC + + def ingest_component(traces: Traces, new_traces: Traces) -> Traces: """ diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 2c2e625e..2f92ff01 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from noob.node import Node, NodeSpecification from cala.assets import Frame, Overlaps, Traces @@ -10,13 +11,15 @@ def frame_update() -> Node: return Node.from_specification( spec=NodeSpecification( - id="frame_test", type="cala.nodes.traces.FrameUpdate", params={"tol": 1e-3} + id="frame_test", + type="cala.nodes.traces.FrameUpdate", + params={"max_iter": 100, "tol": 1e-4}, ) ) -@pytest.mark.parametrize("toy", ["separate_cells"]) -def test_ingest_frame(frame_update, toy, request) -> None: +@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"]) +def test_update_traces(frame_update, toy, request) -> None: toy = request.getfixturevalue(toy) xray = Node.from_specification( @@ -29,13 +32,11 @@ def test_ingest_frame(frame_update, toy, request) -> None: overlap = xray.process(overlaps=Overlaps(), footprints=toy.footprints) result = frame_update.process( - traces=traces, - footprints=toy.footprints, - frame=frame, - overlaps=overlap, - ) + traces=traces, footprints=toy.footprints, frame=frame, overlaps=overlap + ).array + expected = toy.traces.array.isel({AXIS.frames_dim: -1}) - assert result.array.equals(toy.traces.array.isel({AXIS.frames_dim: -1})) + xr.testing.assert_allclose(result, expected, atol=1e-3) @pytest.fixture From b0deebb540a0157d70c8fca431ea6b6c85dc629b Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 17:47:33 -0700 Subject: [PATCH 09/47] feat: using rank1nmf instead of sklearn nmf --- src/cala/nodes/detect/slice_nmf.py | 48 +++++++++++++++++++++--------- src/cala/nodes/residual.py | 2 +- src/cala/testing/util.py | 13 +++++--- tests/test_iter/test_detect.py | 3 +- 4 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 7e012eab..95747107 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -40,33 +40,32 @@ def process( if residuals.array.sizes[AXIS.frames_dim] < self.min_frames: return [], [] - energy = self._get_energy(residuals.array) + energy = self._get_energy(residuals.array) # 0.008s fps = [] trs = [] res = residuals.array.copy() - while energy.max().item() >= self.detect_thresh: # or use res directly + while np.max(energy) >= self.detect_thresh: # Find and analyze neighborhood of maximum variance slice_ = self._get_max_energy_slice( arr=res, energy_landscape=energy, radius=detect_radius ) - a_new, c_new = self._local_nmf( + a_new, c_new = self._local_nmf( # 0.0019s slice_=slice_, spatial_sizes={k: v for k, v in res.sizes.items() if k in AXIS.spatial_dims}, ) - l1_norm = slice_.sum().item() - comp_recon = a_new @ c_new + l1_norm = np.sum(slice_) energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 if (self.error_ / l1_norm) <= self.reprod_tol: fps.append(Footprint.from_array(a_new, sparsify=False)) trs.append(Trace.from_array(c_new)) - res = (res - comp_recon).clip(self.error_ / l1_norm) + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l1_norm else: l0_norm = np.prod(slice_.shape) res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l0_norm @@ -127,12 +126,9 @@ def _local_nmf( - Temporal component c_new (frames) """ # Reshape neighborhood to 2D matrix (time × space) - R = slice_.stack(space=AXIS.spatial_dims).transpose(AXIS.frames_dim, "space") + R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) - c = self._model.fit_transform(R) # temporal component - a = self._model.components_ # spatial component - - self.error_ = self._model.reconstruction_err_.item() + a, c, self.error_ = rank1nmf(R.values, np.mean(R.values, axis=1)) # Convert back to xarray with proper dimensions and coordinates c_new = xr.DataArray( @@ -156,8 +152,32 @@ def _local_nmf( ) # normalize against the original video (as in whatever the residual used at the time) - factor = slice_.max() / a_new.max() - a_new = a_new * factor - c_new = c_new / factor + factor = slice_.data.max() / c_new.data.max() + a_new = a_new / factor + c_new = c_new * factor return a_new, c_new + + +def rank1nmf( + Ypx: np.ndarray, ain: np.ndarray, iters: int = 10 +) -> tuple[np.ndarray, np.ndarray, float]: + """ + perform a fast rank 1 NMF + + Ypx: (pixels, frames) + ain: (pixels) + """ + eps = np.finfo(np.float32).eps + for t in range(iters): + cin_res = ain.dot(Ypx) + cin = np.maximum(cin_res, 0) + ain = np.maximum(Ypx.dot(cin), 0) + if t in (0, iters - 1): + ain /= np.sqrt(ain.dot(ain)) + eps + elif t % 2 == 0: + ain /= ain.dot(ain) + eps + cin_res = ain.dot(Ypx) + cin = np.maximum(cin_res, 0) + error = np.linalg.norm(Ypx - np.outer(ain, cin), "fro") + return ain, cin, error diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index b3811370..80118ada 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -134,7 +134,7 @@ def _update( targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].max() - n_recalc) if any(targets): - target_ids = targets.where(targets, drop=True)[AXIS.id_coord] + target_ids = targets.where(targets)[AXIS.id_coord] target_area = ~((A.where(target_ids, drop=True)).max(dim=AXIS.component_dim) > 0) R *= target_area.as_numpy() diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index 82327326..76b458b6 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -1,11 +1,13 @@ +from typing import TypeVar + import cv2 import numpy as np import xarray as xr +_TArray = TypeVar("_TArray", xr.DataArray, np.ndarray) + -def assert_scalar_multiple_arrays( - a: np.ndarray | xr.DataArray, b: np.ndarray | xr.DataArray, /, rtol: float = 1e-5 -) -> None: +def assert_scalar_multiple_arrays(a: _TArray, b: _TArray, /, rtol: float = 1e-5) -> None: """ Using the Pythagorean Theorem Only works with 1-D arrays. (see np.squeeze) @@ -16,7 +18,10 @@ def assert_scalar_multiple_arrays( if not 0 <= rtol <= 1: raise ValueError(f"rtol must be between 0 and 1, got {rtol}.") - assert len(a.shape) == len(b.shape) == 1, f"Arrays must be 1-D. Given: {a.shape=}, {b.shape=}" + if isinstance(a, np.ndarray): + assert ( + len(a.shape) == len(b.shape) == 1 + ), f"Arrays must be 1-D. Given: {a.shape=}, {b.shape=}" abab = ((a @ b) ** 2).item() aabb = (a.dot(a) * b.dot(b)).item() diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index f9ecfc7d..9c14fa54 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -61,7 +61,8 @@ def test_chunks(self, single_cell): if not fpts or not trcs: raise AssertionError("Failed to detect a new component") - fpt_arr = xr.concat([f.array for f in fpts], dim="component") + factors = [trc.array.data.max() for trc in trcs] + fpt_arr = xr.concat([f.array * m for f, m in zip(fpts, factors)], dim="component") expected = single_cell.footprints.array[0] result = fpt_arr.sum(dim="component") From 9ba0838042a49a4befc917224600de01085c567a Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 18:47:20 -0700 Subject: [PATCH 10/47] feat: optimize footprint update --- src/cala/nodes/footprints.py | 218 ++++++++++++++++------------- tests/data/pipelines/odl.yaml | 2 +- tests/test_iter/test_footprints.py | 6 +- 3 files changed, 123 insertions(+), 103 deletions(-) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 9b0a43c7..39142245 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -1,39 +1,23 @@ from typing import Annotated as A -import cv2 import numpy as np import xarray as xr from noob import Name, process_method -from skimage.morphology import disk -from sparse import COO +from pydantic import BaseModel +from scipy.sparse import csc_matrix from cala.assets import CompStats, Footprints, PixStats from cala.logging import init_logger from cala.models import AXIS -class Footprinter: - logger = init_logger(__name__) +class Footprinter(BaseModel): + tol: float + max_iter: int | None = None + bep: int | None = None + ratio_lb: float = 0.15 - def __init__( - self, - tol: float, - max_iter: int | None = None, - bep: int | None = None, - ratio_lb: float = 0.15, - ): - self.bep = bep - """ - Number of pixels to explore the boundary of the footprint outside of the current footprint. - """ - - self.ratio_lb = ratio_lb - """ - Ratio of the least bright pixel against the brightest pixel of a given footprint. - """ - - self.tol = tol - self.max_iter = max_iter + _logger = init_logger(__name__) @process_method def ingest_frame( @@ -59,86 +43,22 @@ def ingest_frame( if footprints.array is None: return footprints - A = footprints.array + A = footprints.array.stack(pixel=AXIS.spatial_dims).transpose("pixel", ...) M = component_stats.array - W = pixel_stats.array - - plain_mask = A > 0 - mask = self._build_mask(plain_mask, index=index) - - # Expand M diagonal for broadcasting - M_diag = xr.apply_ufunc( - np.diag, - M, - input_core_dims=[M.dims], - output_core_dims=[[AXIS.component_dim]], - dask="allowed", - ) + W = pixel_stats.array.stack(pixel=AXIS.spatial_dims).transpose(AXIS.component_dim, ...) - cnt = 0 - while True: - AM = A.rename(AXIS.component_rename) @ M - numerator = W - AM - - update = numerator / (M_diag + np.finfo(float).tiny) - A_new = (mask * (A + update)).clip(min=0) - - step = (np.abs(A - A_new).sum() / np.prod(A.shape)).item() - - cnt += 1 - maxed = self.max_iter and (cnt == self.max_iter) - - if step < self.tol or maxed: - A_final = A_new.where( - A_new > A_new.max(AXIS.spatial_dims) * self.ratio_lb, 0, drop=False - ) - if maxed: - self.logger.debug(msg="max_iter reached before converging.") - A_final = A_new.where(plain_mask, 0, drop=False) - - footprints.array = A_final - return footprints - else: - A = A_new - mask = A > 0 - - def _expansion_kernel(self) -> np.ndarray: - return disk(radius=1) - - def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArray: - expanded = xr.apply_ufunc( - lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1), - mask.as_numpy().astype(np.uint8), - input_core_dims=[AXIS.spatial_dims], - output_core_dims=[AXIS.spatial_dims], - vectorize=True, - dask="parallelized", + shapes, mask, _ = update_shapes( + CY=W.values, + CC=M.values, + Ab=A.data.tocsc(), + A_mask=[np.where(Ap.as_numpy() > 0)[0] for Ap in A.transpose(AXIS.component_dim, ...)], ) - expanded.data = COO.from_numpy(expanded.data) - return expanded - - def _build_mask(self, mask: xr.DataArray, index: int) -> xr.DataArray: - expansion_left = (index - mask[AXIS.detect_coord] - self.bep) <= 0 - expand_ids = expansion_left.where(expansion_left, drop=True)[AXIS.id_coord].values - no_expand_ids = expansion_left.where(~expansion_left, drop=True)[AXIS.id_coord].values - - if expand_ids.size > 0: - kernel = self._expansion_kernel() - expanded_mask = self._expand_boundary( - kernel, mask.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: expand_ids}) - ) + footprints.array = xr.DataArray(shapes.toarray(), dims=A.dims, coords=A.coords).unstack( + "pixel" + ) - final_mask = xr.concat( - [ - mask.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: no_expand_ids}), - expanded_mask, - ], - dim=AXIS.component_dim, - ).reset_index(AXIS.id_coord) - else: - final_mask = mask - return final_mask + return footprints def ingest_component( @@ -162,3 +82,103 @@ def ingest_component( footprints.array = xr.concat([a, a_det], dim=AXIS.component_dim, combine_attrs="drop") return footprints + + +def update_shapes( + CY: np.ndarray, + CC: np.ndarray, + Ab: csc_matrix, + A_mask: list[np.ndarray], + Ab_dense: np.ndarray | None = None, + iters: int = 5, +): + """ + :param CY: suff stats (pixel,) + :param CC: suff stats (component), shape (comp, comp) + :param Ab: shape matrix (sparse), shape (pixel, comp) + :param A_mask: list of nonzero coordinates for each footprint list[(pixel,)] + :param Ab_dense: shape matrix (dense) + :param iters: number of iterations + """ + D, M = Ab.shape + + for _ in range(iters): # it's presumably better to run just 1 iter but update more neurons + for m in range(M): + tmp = _update(Ab_dense=Ab_dense, Ab=Ab, CY=CY, CC=CC, m=m, ind_pixels=A_mask[m]) + Ab_dense, Ab, A_mask = _normalize( + m=m, Ab=Ab, Ab_dense=Ab_dense, ind_A=A_mask, ind_pixels=A_mask[m], tmp=tmp + ) + + return Ab, A_mask, Ab_dense + + +def _update( + Ab_dense: np.ndarray, + Ab: csc_matrix, + CY: np.ndarray, + CC: np.ndarray, + m: int, + ind_pixels: int, +) -> np.ndarray: + """ + Update a single footprint + + :param Ab_dense: shape matrix (dense) + :param Ab: shape matrix (sparse) + :param CY: suff stats (pixel) + :param CC: suff stats (component) + :param m: neuron index + :param ind_pixels: index of cell + """ + if Ab_dense is None: + result = np.maximum( + Ab.data[Ab.indptr[m] : Ab.indptr[m + 1]] + + ( + (CY[m, ind_pixels] - Ab.dot(CC[m])[ind_pixels]) + / (CC[m, m] + np.finfo(CC.dtype).eps) + ), + 0, + ) + else: + result = np.maximum( + Ab_dense[ind_pixels, m] + + ( + (CY[m, ind_pixels] - Ab_dense[ind_pixels].dot(CC[m])) + / (CC[m, m] + np.finfo(CC.dtype).eps) + ), + 0, + ) + + return result + + +def _normalize( + m: int, + Ab: csc_matrix, + Ab_dense: np.ndarray | None, + ind_A: list[int], + ind_pixels: int, + tmp: np.ndarray, +) -> tuple[np.ndarray, csc_matrix, list[int]]: + """ + This only exists to prevent footprint values from blowing up / diminishing. + (Hopefully) Irrelevant since we normalize footprint to the actual pixel values. + + :param m: neuron index + :param Ab: shape matrix (sparse) + :param Ab_dense: shape matrix (dense) + :param ind_A: shape matrix of cells + :param ind_pixels: shape array of a cell + :param tmp: updated shape - before normalization + """ + if tmp.dot(tmp) > 0: + # tmp *= 1e-3 / min(1e-3, np.sqrt(tmp.dot(tmp)) + np.finfo(float).eps) + if Ab_dense is not None: + Ab_dense[ind_pixels, m] = tmp # / max(1, np.sqrt(tmp.dot(tmp))) + Ab.data[Ab.indptr[m] : Ab.indptr[m + 1]] = Ab_dense[ind_pixels, m] + else: + # tmp = tmp / max(1, np.sqrt(tmp.dot(tmp))) + Ab.data[Ab.indptr[m] : Ab.indptr[m + 1]] = tmp + ind_A[m] = Ab.indices[slice(Ab.indptr[m], Ab.indptr[m + 1])] + + return Ab_dense, Ab, ind_A diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 3f1ffad9..bb50478e 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -102,7 +102,7 @@ nodes: residual: type: cala.nodes.residual.build params: - n_recalc: 4 + n_recalc: 1 depends: - frames: assets.buffer - footprints: footprints_frame.footprints diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 2ac50203..637ba04c 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -80,11 +80,11 @@ def test_ingest_frame(fpter, toy, request): result = fpter.process( footprints=toy.footprints, pixel_stats=pixstats, component_stats=compstats, index=0 - ) + ).array - expected = toy.footprints.model_copy() + expected = toy.footprints.array.as_numpy() - xr.testing.assert_allclose(result.array.as_numpy(), expected.array.as_numpy()) + xr.testing.assert_allclose(result, expected) @pytest.fixture From 5a61f7e3a7be625b9f0eb36787d02f8947170c2b Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 18:53:59 -0700 Subject: [PATCH 11/47] feat: footprint save as sparse --- src/cala/assets.py | 10 ++++++++-- src/cala/nodes/footprints.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 8d0d70e9..8e7b4847 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -105,9 +105,15 @@ class Footprints(Asset): ) ) + @Asset.array.setter + def array(self, array: xr.DataArray) -> None: + if isinstance(array.data, np.ndarray): + array.data = COO.from_numpy(array.data) + self.array_ = array + @classmethod - def from_array(cls, array: xr.DataArray, sparsify: bool = True) -> "Footprints": - if sparsify and isinstance(array.data, np.ndarray): + def from_array(cls, array: xr.DataArray) -> "Footprints": + if isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) return cls(array_=array) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 39142245..9132438a 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -35,9 +35,9 @@ def ingest_frame( - p are the pixels where component i can be non-zero Args: - pixel_stats (PixelStats): Sufficient statistics W. + pixel_stats (PixStats): Sufficient statistics W. Shape: (pixels × components) - component_stats (ComponentStats): Sufficient statistics M. + component_stats (CompStats): Sufficient statistics M. Shape: (components × components) """ if footprints.array is None: From 55de160a076f1cf7704cac65dd2fd4cc1abc15c4 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 19:22:39 -0700 Subject: [PATCH 12/47] feat: footprint save as sparse --- src/cala/assets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cala/assets.py b/src/cala/assets.py index 8e7b4847..58d99110 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -107,6 +107,7 @@ class Footprints(Asset): @Asset.array.setter def array(self, array: xr.DataArray) -> None: + array.validate.against_schema(self._entity.model) if isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) self.array_ = array From 0e7bbce74cb1cb42723871aec5fea8c0272fd5b1 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 22:27:22 -0700 Subject: [PATCH 13/47] feat: allow making empty assets with from_array --- src/cala/assets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 58d99110..8d8154b2 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -108,13 +108,13 @@ class Footprints(Asset): @Asset.array.setter def array(self, array: xr.DataArray) -> None: array.validate.against_schema(self._entity.model) - if isinstance(array.data, np.ndarray): + if array is not None and isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) self.array_ = array @classmethod def from_array(cls, array: xr.DataArray) -> "Footprints": - if isinstance(array.data, np.ndarray): + if array is not None and isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) return cls(array_=array) From 801605149850fc61ef8b5db22336fd06ed613d11 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 22:29:27 -0700 Subject: [PATCH 14/47] chore: simplify indexing in residual --- src/cala/nodes/residual.py | 46 ++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 80118ada..c2d02fec 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -4,12 +4,12 @@ import xarray as xr from noob import Name -from cala.assets import Footprints, Movie, Residual, Traces +from cala.assets import Footprints, Movie, Residual, Traces, Frame from cala.models import AXIS def build( - residuals: Residual, frames: Movie, footprints: Footprints, traces: Traces, n_recalc: int + residuals: Residual, frame: Frame, footprints: Footprints, traces: Traces, n_recalc: int ) -> A[Residual, Name("movie")]: """ The computation follows the equation: @@ -31,32 +31,27 @@ def build( Shape: (frames × height × width) """ if footprints.array is None or traces.array is None: - residuals.array = frames.array + if residuals.array is None: + residuals.array = frame.array.expand_dims(dim=AXIS.frames_dim) + else: + residuals.array = xr.concat( + [residuals.array, frame.array], + dim=AXIS.frames_dim, + coords=[AXIS.frame_coord, AXIS.timestamp_coord], + ) return residuals - # Reshape frames to pixels x time - Y = frames.array - - # Get temporal components [C; f] - C = traces.array.sel( - {AXIS.frame_coord: Y[AXIS.frame_coord].values.tolist()} - ) # components x time - - # Reshape footprints to (pixels x components) + Y = frame.array + C = traces.array.isel({AXIS.frames_dim: -1}) # (components,) A = footprints.array # Compute residual R = Y - [A,b][C;f] - R_curr = Y.isel({AXIS.frames_dim: -1}) - xr.DataArray( - np.matmul( - A.transpose(*AXIS.spatial_dims, ...).data, - C.transpose(AXIS.component_dim, ...).isel({AXIS.frames_dim: -1}).data, - ), - dims=AXIS.spatial_dims, + R_curr = Y - xr.DataArray( + np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), dims=AXIS.spatial_dims ) if R_curr.min() < 0: - shifted_tr = _align_overestimates(A, C.isel({AXIS.frames_dim: -1}), R_curr) - C.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr - traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr + C = _align_overestimates(A, C, R_curr) + traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = C # if recently discovered, set to zero (or a small number). otherwise, just append R = _update(Y, A, C, residuals.array, n_recalc=n_recalc) @@ -131,18 +126,15 @@ def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: def _update( Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray, R: xr.DataArray, n_recalc: int ) -> xr.DataArray: - targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].max() - n_recalc) + targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].item() - n_recalc) if any(targets): target_ids = targets.where(targets)[AXIS.id_coord] target_area = ~((A.where(target_ids, drop=True)).max(dim=AXIS.component_dim) > 0) R *= target_area.as_numpy() - R_curr = Y.isel({AXIS.frames_dim: -1}) - xr.DataArray( - np.matmul( - A.transpose(*AXIS.spatial_dims, ...).data, - C.transpose(AXIS.component_dim, ...).isel({AXIS.frames_dim: -1}).data, - ), + R_curr = Y - xr.DataArray( + np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), dims=AXIS.spatial_dims, ) From efffcf053b7f0cfeece64ca9231adf0ed5c1dad3 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 22:31:16 -0700 Subject: [PATCH 15/47] perf: no assets within nodes methods feat: normalization with nmf perf: matmul instead of xr @ feat: filter tiny components that never merged --- src/cala/nodes/detect/catalog.py | 231 ++++++++++++++++--------------- 1 file changed, 121 insertions(+), 110 deletions(-) diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index 09558649..c17ffc2d 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -1,4 +1,5 @@ from collections.abc import Hashable, Iterable +from itertools import compress from typing import Annotated as A import cv2 @@ -6,14 +7,15 @@ import xarray as xr from noob import Name from noob.node import Node +from pydantic import Field from scipy.ndimage import gaussian_filter1d from scipy.sparse.csgraph import connected_components from skimage.measure import label -from sklearn.decomposition import NMF from xarray import Coordinates -from cala.assets import Footprint, Footprints, Movie, Trace, Traces +from cala.assets import Footprint, Footprints, Trace, Traces from cala.models import AXIS +from cala.nodes.detect.slice_nmf import rank1nmf from cala.util import combine_attr_replaces, create_id @@ -22,6 +24,8 @@ class Cataloger(Node): age_limit: int """Don't merge with new components if older than this number of frames.""" merge_threshold: float + val_threshold: float = Field(gt=0, lt=1) + cnt_threshold: int = Field(gt=0) def process( self, @@ -34,14 +38,14 @@ def process( if not new_fps or not new_trs: return Footprints(), Traces() - new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim).as_numpy() + new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim) new_trs = xr.concat([tr.array for tr in new_trs], dim=AXIS.component_dim) merge_mat = self._merge_matrix(new_fps, new_trs) new_fps, new_trs = _merge(new_fps, new_trs, merge_mat) known_fp, known_tr = _get_absorption_targets(existing_fp, existing_tr, self.age_limit) merge_mat = self._merge_matrix(new_fps, new_trs, known_fp, known_tr) - footprints, traces = _absorb(new_fps, new_trs, known_fp, known_tr, merge_mat) + footprints, traces = self._absorb(new_fps, new_trs, known_fp, known_tr, merge_mat) return Footprints.from_array(footprints), Traces.from_array(traces) @@ -52,7 +56,7 @@ def _merge_matrix( fps_base: xr.DataArray | None = None, trs_base: xr.DataArray | None = None, ) -> xr.DataArray: - + fps = fps.stack(pixels=AXIS.spatial_dims) trs = xr.DataArray( gaussian_filter1d(trs.transpose(AXIS.component_dim, ...), **self.smooth_kwargs), dims=trs.dims, @@ -63,7 +67,7 @@ def _merge_matrix( fps_base = fps.rename({AXIS.component_dim: f"{AXIS.component_dim}'"}) trs_base = trs.rename({AXIS.component_dim: f"{AXIS.component_dim}'"}) else: - fps_base = fps_base.rename(AXIS.component_rename) + fps_base = fps_base.stack(pixels=AXIS.spatial_dims).rename(AXIS.component_rename) trs_base = xr.DataArray( gaussian_filter1d( trs_base.transpose(AXIS.component_dim, ...), **self.smooth_kwargs @@ -73,14 +77,76 @@ def _merge_matrix( ) trs_base = trs_base.rename(AXIS.component_rename) - # So that "touching" ones can merge - # Can save time by calculating centroid instead - # fps = self._expand_boundary(fps > 0) - - overlaps = fps @ fps_base > 0 - # calculate correlation for overlapping components only? + overlaps = np.matmul(fps.data, fps_base.data.T) > 0 + # corr is fast. (~1ms to 4ms) corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold - return overlaps * corrs + return xr.DataArray(overlaps * corrs.values, dims=corrs.dims, coords=corrs.coords) + + def _absorb( + self, + new_fps: xr.DataArray, + new_trs: xr.DataArray, + known_fps: xr.DataArray, + known_trs: xr.DataArray, + merge_matrix: xr.DataArray, + ) -> tuple[xr.DataArray | None, xr.DataArray | None]: + footprints = [] + traces = [] + + merge_matrix.data = label(merge_matrix.to_numpy(), background=0, connectivity=1) + merge_matrix = merge_matrix.assign_coords( + {AXIS.component_dim: range(merge_matrix.sizes[AXIS.component_dim])} + ).reset_index(AXIS.component_dim) + indep_idxs = ( + merge_matrix.where(merge_matrix.sum(f"{AXIS.component_dim}'") == 0, drop=True)[ + AXIS.component_dim + ].values + if known_fps is not None + else np.array(range(len(merge_matrix))) + ) + if indep_idxs.size > 0: + fps, trs = _register_batch( + new_fps=new_fps.isel({AXIS.component_dim: indep_idxs}), + new_trs=new_trs.isel({AXIS.component_dim: indep_idxs}), + ) + footprints.append(fps) + traces.append(trs) + + num = merge_matrix.max().item() + if num > 0 and known_fps is not None: + for lbl in range(1, num + 1): + new_idxs, _known_idxs = np.where(merge_matrix == lbl) + known_ids = merge_matrix.where(merge_matrix == lbl, drop=True)[ + f"{AXIS.id_coord}'" + ].values + fp = new_fps.sel({AXIS.component_dim: list(set(new_idxs))}) + tr = new_trs.sel({AXIS.component_dim: list(set(new_idxs))}) + footprint, trace = _merge_with(fp, tr, known_fps, known_trs, known_ids) + + footprints.append(footprint) + traces.append(trace) + + mask = [np.sum(fp.data > self.val_threshold) > self.cnt_threshold for fp in footprints] + footprints = list(compress(footprints, mask)) + traces = list(compress(traces, mask)) + + if not footprints: + return None, None + + footprints = xr.concat( + footprints, + dim=AXIS.component_dim, + coords=[AXIS.id_coord, AXIS.detect_coord], + combine_attrs=combine_attr_replaces, + ) + traces = xr.concat( + traces, + dim=AXIS.component_dim, + coords=[AXIS.id_coord, AXIS.detect_coord], + combine_attrs=combine_attr_replaces, + ) + + return footprints, traces def _get_absorption_targets( @@ -98,7 +164,7 @@ def _get_absorption_targets( return known_fp, known_tr -def _register(new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Trace]: +def _register(new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[xr.DataArray, xr.DataArray]: new_id = create_id() @@ -129,10 +195,12 @@ def _register(new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Tr .isel({AXIS.component_dim: 0}) ) - return Footprint.from_array(footprint, sparsify=False), Trace.from_array(trace) + return footprint, trace -def _register_batch(new_fps: xr.DataArray, new_trs: xr.DataArray) -> tuple[Footprints, Traces]: +def _register_batch( + new_fps: xr.DataArray, new_trs: xr.DataArray +) -> tuple[xr.DataArray, xr.DataArray]: count = new_fps.sizes[AXIS.component_dim] new_ids = [create_id() for _ in range(count)] @@ -155,41 +223,35 @@ def _register_batch(new_fps: xr.DataArray, new_trs: xr.DataArray) -> tuple[Footp } ) - return Footprints.from_array(footprints, sparsify=False), Traces.from_array(traces) + return footprints, traces def _recompose( movie: xr.DataArray, fp_coords: Coordinates, tr_coords: Coordinates -) -> tuple[Footprint, Trace]: +) -> tuple[xr.DataArray, xr.DataArray]: # Reshape neighborhood to 2D matrix (time × space) - shape = movie.sum(dim=AXIS.frames_dim) > 0 - slice_ = Movie.from_array(movie.where(shape, 0, drop=True)) - - a, c = _nmf(slice_) + movie = movie.assign_coords({ax: movie[ax] for ax in AXIS.spatial_dims}) + shape = xr.DataArray( + np.sum(movie.transpose(AXIS.frames_dim, ...).data, axis=0) > 0, dims=AXIS.spatial_dims + ) + slice_ = movie.where(shape.as_numpy(), 0, drop=True) + R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) - slice_coords = slice_.array.reset_index(AXIS.frames_dim).reset_coords(drop=True).coords + a, c, error = rank1nmf(R.values, np.mean(R.values, axis=1)) a_new, c_new = _reshape( footprint=a, trace=c, fp_coords=fp_coords, tr_coords=tr_coords, - slice_coords=slice_coords, + slice_coords=slice_.coords, ) - return a_new, c_new - - -def _nmf(movie: Movie) -> tuple[np.ndarray, np.ndarray]: - - stacked = movie.array.stack({"space": AXIS.spatial_dims}).transpose(AXIS.frames_dim, "space") - # Apply NMF (check how long nndsvd takes vs random) - model = NMF(n_components=1, init="random", tol=1e-4, max_iter=200) + factor = slice_.data.max() / c_new.data.max() + a_new = a_new / factor + c_new = c_new * factor - c = model.fit_transform(stacked) # temporal component - a = model.components_ # spatial component - - return a, c + return a_new, c_new def _reshape( @@ -198,7 +260,7 @@ def _reshape( fp_coords: Coordinates, tr_coords: Coordinates, slice_coords: Coordinates, -) -> tuple[Footprint, Trace]: +) -> tuple[xr.DataArray, xr.DataArray]: """Convert back to xarray with proper dimensions and coordinates""" c_new = xr.DataArray(trace.squeeze(), dims=[AXIS.frames_dim], coords=tr_coords) @@ -215,7 +277,7 @@ def _reshape( coords=slice_coords, ) - return Footprint.from_array(a_new, sparsify=False), Trace.from_array(c_new) + return a_new, c_new def _merge_with( @@ -224,15 +286,20 @@ def _merge_with( target_fps: xr.DataArray, target_trs: xr.DataArray, dupe_ids: Iterable[Hashable], -) -> tuple[Footprint, Trace]: +) -> tuple[xr.DataArray, xr.DataArray]: target_fp = target_fps.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: dupe_ids}) target_tr = target_trs.set_xindex(AXIS.id_coord).sel({AXIS.id_coord: dupe_ids}) - recreated_movie = target_fp @ target_tr.dropna(dim=AXIS.frames_dim) - new_movie = new_fp @ new_tr - combined_movie = recreated_movie + new_movie - combined_movie = combined_movie.assign_coords( - {ax: combined_movie[ax] for ax in AXIS.spatial_dims} + recreated_movie = np.matmul( + target_fp.transpose(*AXIS.spatial_dims, ...).data, + target_tr.dropna(dim=AXIS.frames_dim).data, + ) + new_movie = np.matmul( + new_fp.transpose(*AXIS.spatial_dims, ...).data, + new_tr.dropna(dim=AXIS.frames_dim).data, + ) + combined_movie = xr.DataArray( + recreated_movie + new_movie, dims=[*AXIS.spatial_dims, AXIS.frames_dim] ) a_new, c_new = _recompose( @@ -240,10 +307,10 @@ def _merge_with( new_fp.isel({AXIS.component_dim: 0}).coords, new_tr.isel({AXIS.component_dim: 0}).coords, ) - a_new.array.attrs["replaces"] = target_fp[AXIS.id_coord].values.tolist() - c_new.array.attrs["replaces"] = target_tr[AXIS.id_coord].values.tolist() + a_new.attrs["replaces"] = target_fp[AXIS.id_coord].values.tolist() + c_new.attrs["replaces"] = target_tr[AXIS.id_coord].values.tolist() - return _register(a_new.array, c_new.array) + return _register(a_new, c_new) def _expand_boundary(mask: xr.DataArray) -> xr.DataArray: @@ -270,73 +337,17 @@ def _merge( fps = footprints.sel({AXIS.component_dim: group}) trs = traces.sel({AXIS.component_dim: group}) if len(group) > 1: - res = fps @ trs + res = xr.DataArray( + np.matmul(fps.transpose(*AXIS.spatial_dims, ...).data, trs.data), + dims=[*AXIS.spatial_dims, AXIS.frames_dim], + ) new_fp, new_tr = _recompose(res, footprints[0].coords, traces[0].coords) else: - new_fp, new_tr = Footprint.from_array(fps[0], sparsify=False), Trace.from_array(trs[0]) + new_fp, new_tr = fps[0], trs[0] combined_fps.append(new_fp) combined_trs.append(new_tr) - new_fps = xr.concat([fp.array for fp in combined_fps], dim=AXIS.component_dim) - new_trs = xr.concat([tr.array for tr in combined_trs], dim=AXIS.component_dim) + new_fps = xr.concat(combined_fps, dim=AXIS.component_dim) + new_trs = xr.concat(combined_trs, dim=AXIS.component_dim) return new_fps, new_trs - - -def _absorb( - new_fps: xr.DataArray, - new_trs: xr.DataArray, - known_fps: xr.DataArray, - known_trs: xr.DataArray, - merge_matrix: xr.DataArray, -) -> tuple[xr.DataArray, xr.DataArray]: - footprints = [] - traces = [] - - merge_matrix.data = label(merge_matrix, background=0, connectivity=1) - merge_matrix = merge_matrix.assign_coords( - {AXIS.component_dim: range(merge_matrix.sizes[AXIS.component_dim])} - ).reset_index(AXIS.component_dim) - indep_idxs = ( - merge_matrix.where(merge_matrix.sum(f"{AXIS.component_dim}'") == 0, drop=True)[ - AXIS.component_dim - ].values - if known_fps is not None - else np.array(range(len(merge_matrix))) - ) - if indep_idxs.size > 0: - fps, trs = _register_batch( - new_fps=new_fps.isel({AXIS.component_dim: indep_idxs}), - new_trs=new_trs.isel({AXIS.component_dim: indep_idxs}), - ) - footprints.append(fps.array) - traces.append(trs.array) - - num = merge_matrix.max().item() - if num > 0 and known_fps is not None: - for lbl in range(1, num + 1): - new_idxs, _known_idxs = np.where(merge_matrix == lbl) - known_ids = merge_matrix.where(merge_matrix == lbl, drop=True)[ - f"{AXIS.id_coord}'" - ].values - fp = new_fps.sel({AXIS.component_dim: list(set(new_idxs))}) - tr = new_trs.sel({AXIS.component_dim: list(set(new_idxs))}) - footprint, trace = _merge_with(fp, tr, known_fps, known_trs, known_ids) - - footprints.append(footprint.array) - traces.append(trace.array) - - footprints = xr.concat( - footprints, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - traces = xr.concat( - traces, - dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.detect_coord], - combine_attrs=combine_attr_replaces, - ) - - return footprints, traces From 6154c61018d777f670082631749db499bc128a88 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 22:32:00 -0700 Subject: [PATCH 16/47] feat: register batch instead of loop --- src/cala/nodes/merge.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/merge.py index 35e690d0..cd546714 100644 --- a/src/cala/nodes/merge.py +++ b/src/cala/nodes/merge.py @@ -8,7 +8,7 @@ from cala.assets import Footprints, Overlaps, Traces from cala.models import AXIS -from cala.nodes.detect.catalog import _recompose, _register +from cala.nodes.detect.catalog import _recompose, _register_batch from cala.util import combine_attr_replaces @@ -68,12 +68,15 @@ def merge_existing( continue fps = target_fps.isel({AXIS.component_dim: group}) trs = target_trs.isel({AXIS.component_dim: group}) - res = fps @ trs + res = xr.DataArray( + np.matmul(fps.transpose(*AXIS.spatial_dims, ...).data, trs.data), + dims=[*AXIS.spatial_dims, AXIS.frames_dim], + ) a_new, c_new = _recompose(res, target_fps[0].coords, target_trs[0].coords) - a_new, c_new = _register(a_new.array, c_new.array) - a_new.array.attrs["replaces"] = fps[AXIS.id_coord].values.tolist() - c_new.array.attrs["replaces"] = trs[AXIS.id_coord].values.tolist() + a_new.attrs["replaces"] = fps[AXIS.id_coord].values.tolist() + c_new.attrs["replaces"] = trs[AXIS.id_coord].values.tolist() + combined_fps.append(a_new) combined_trs.append(c_new) @@ -81,17 +84,18 @@ def merge_existing( return Footprints(), Traces() new_fps = xr.concat( - [fp.array for fp in combined_fps], + combined_fps, dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) new_trs = xr.concat( - [tr.array for tr in combined_trs], + combined_trs, dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) + new_fps, new_trs = _register_batch(new_fps, new_trs) return Footprints.from_array(new_fps), Traces.from_array(new_trs) From e939cc110025225ddd097b17a89584f2cc6c3b72 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 22:33:06 -0700 Subject: [PATCH 17/47] test: refit tests --- src/cala/assets.py | 3 +- tests/test_iter/test_detect.py | 16 ++++++---- tests/test_iter/test_footprints.py | 2 +- tests/test_iter/test_residual.py | 47 ++++++++++++++++++++++++++---- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 8d8154b2..203dfcb1 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -107,7 +107,8 @@ class Footprints(Asset): @Asset.array.setter def array(self, array: xr.DataArray) -> None: - array.validate.against_schema(self._entity.model) + if self.validate_schema: + array.validate.against_schema(self._entity.model) if array is not None and isinstance(array.data, np.ndarray): array.data = COO.from_numpy(array.data) self.array_ = array diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 9c14fa54..5db6a5f3 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -28,7 +28,13 @@ def cataloger(): spec=NodeSpecification( id="test", type="cala.nodes.detect.Cataloger", - params={"age_limit": 100, "smooth_kwargs": {"sigma": 2}, "merge_threshold": 0.8}, + params={ + "age_limit": 100, + "smooth_kwargs": {"sigma": 2}, + "merge_threshold": 0.8, + "val_threshold": 0.5, + "cnt_threshold": 5, + }, ) ) @@ -83,8 +89,8 @@ def test_register(self, cataloger, new_component): new_fp, new_tr = new_component fp, tr = _register(new_fp=new_fp[0].array, new_tr=new_tr[0].array) - assert np.array_equal(fp.array.as_numpy(), new_fp[0].array.as_numpy()) - assert np.array_equal(tr.array, new_tr[0].array) + assert np.array_equal(fp.as_numpy(), new_fp[0].array.as_numpy()) + assert np.array_equal(tr, new_tr[0].array) def test_merge_with(self, slice_nmf, cataloger, single_cell): new_component = slice_nmf.process( @@ -101,9 +107,7 @@ def test_merge_with(self, slice_nmf, cataloger, single_cell): ) movie_result = ( - (fp.array @ tr.array) - .reset_coords([AXIS.id_coord, AXIS.detect_coord], drop=True) - .as_numpy() + (fp @ tr).reset_coords([AXIS.id_coord, AXIS.detect_coord], drop=True).as_numpy() ) movie_new_comp = new_fp[0].array @ new_tr[0].array diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 637ba04c..58a058c5 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -80,7 +80,7 @@ def test_ingest_frame(fpter, toy, request): result = fpter.process( footprints=toy.footprints, pixel_stats=pixstats, component_stats=compstats, index=0 - ).array + ).array.as_numpy() expected = toy.footprints.array.as_numpy() diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 08a18645..e0804268 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -3,9 +3,34 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Residual +from cala.assets import Residual, Footprints, Traces, Frame from cala.models.axis import AXIS from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints +from cala.testing.toy import Toy, FrameDims, Position + + +@pytest.fixture +def connected_cells() -> Toy: + n_frames = 50 + + return Toy( + n_frames=n_frames, + frame_dims=FrameDims(width=50, height=50), + cell_radii=8, + cell_positions=[ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ], + cell_traces=[ + np.random.randint(low=0, high=n_frames, size=n_frames).astype(float), + np.abs(np.sin(np.linspace(-np.pi, np.pi, n_frames)) * n_frames).astype(float), + np.array(range(n_frames), dtype=float), + np.array(range(n_frames - 1, -1, -1), dtype=float), + ], + detected_ons=[n_frames - 1] * 4, + ) @pytest.fixture(scope="function") @@ -17,12 +42,22 @@ def init() -> Node: ) -def test_init(init, separate_cells) -> None: +def test_init(init, connected_cells) -> None: + residual = Residual() + gen = connected_cells.movie_gen() + + for _ in range(connected_cells.n_frames - 1): + init.process( + residuals=residual, + frame=Frame.from_array(next(gen)), + footprints=Footprints(), + traces=Traces(), + ) result = init.process( - residuals=Residual(), - footprints=separate_cells.footprints, - traces=separate_cells.traces, - frames=separate_cells.make_movie(), + residuals=residual, + footprints=connected_cells.footprints, + traces=connected_cells.traces, + frame=Frame.from_array(next(gen)), ) assert np.all(result.array == 0) From dbbd349ce3204a577e993a470f0cf0c81844e514 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 8 Oct 2025 23:42:10 -0700 Subject: [PATCH 18/47] test: refit test yamls --- tests/data/pipelines/odl.yaml | 4 +++- tests/data/pipelines/with_src.yaml | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index bb50478e..a631c3e2 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -104,7 +104,7 @@ nodes: params: n_recalc: 1 depends: - - frames: assets.buffer + - frame: glow.frame - footprints: footprints_frame.footprints - traces: assets.traces - residuals: assets.residuals @@ -127,6 +127,8 @@ nodes: smooth_kwargs: sigma: 2 merge_threshold: 0.95 + val_threshold: 0.5 + cnt_threshold: 5 depends: - new_fps: nmf.new_fps - new_trs: nmf.new_trs diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index be2fc8e9..08878e65 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -145,7 +145,7 @@ nodes: params: n_recalc: 4 depends: - - frames: assets.buffer + - frame: glow.frame - footprints: footprints_frame.footprints - traces: assets.traces - residuals: assets.residuals @@ -168,6 +168,8 @@ nodes: smooth_kwargs: sigma: 2 merge_threshold: 0.95 + val_threshold: 0.5 + cnt_threshold: 5 depends: - new_fps: nmf.new_fps - new_trs: nmf.new_trs From 832c3c2938921d8d458ee504428ce0d25f590506 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 01:40:29 -0700 Subject: [PATCH 19/47] feat: more optimizations / simplifications feat: merge handling in overlaps --- src/cala/assets.py | 6 +++-- src/cala/nodes/buffer.py | 4 ++-- src/cala/nodes/detect/update.py | 6 +++-- src/cala/nodes/overlap.py | 38 +++++++++++++++++++++++--------- src/cala/nodes/traces.py | 9 ++------ src/cala/util.py | 31 ++++++++++++++++++++++++++ tests/test_iter/test_overlaps.py | 2 +- 7 files changed, 71 insertions(+), 25 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 203dfcb1..02a3dfb0 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -138,7 +138,8 @@ def array(self) -> xr.DataArray: @array.setter def array(self, array: xr.DataArray) -> None: if self.zarr_path: - array.validate.against_schema(self._entity.model) + if self.validate_schema: + array.validate.against_schema(self._entity.model) array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite else: self.array_ = array @@ -181,7 +182,8 @@ def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Dat ) def update(self, array: xr.DataArray, **kwargs: Any) -> None: - array.validate.against_schema(self._entity.model) + if self.validate_schema: + array.validate.against_schema(self._entity.model) array.to_zarr(self.zarr_path, **kwargs) @classmethod diff --git a/src/cala/nodes/buffer.py b/src/cala/nodes/buffer.py index 3d897a62..626dd2ff 100644 --- a/src/cala/nodes/buffer.py +++ b/src/cala/nodes/buffer.py @@ -9,8 +9,8 @@ def fill_buffer(size: int, buffer: Movie, frame: Frame) -> A[Movie, Name("buffer")]: if buffer.array is None: - buffer.array = frame.array.expand_dims(AXIS.frames_dim).assign_coords( - {AXIS.timestamp_coord: (AXIS.frames_dim, [frame.array[AXIS.timestamp_coord].item()])} + buffer.array = frame.array.volumize.dim_with_coords( + dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] ) return buffer diff --git a/src/cala/nodes/detect/update.py b/src/cala/nodes/detect/update.py index a5bd25b2..8233e0ac 100644 --- a/src/cala/nodes/detect/update.py +++ b/src/cala/nodes/detect/update.py @@ -5,7 +5,7 @@ from cala.assets import CompStats, Footprints, Movie, Overlaps, PixStats, Traces from cala.nodes.component_stats import ingest_component as update_component_stats from cala.nodes.footprints import ingest_component as update_footprints -from cala.nodes.overlap import initialize as update_overlap +from cala.nodes.overlap import ingest_component as update_overlap from cala.nodes.pixel_stats import ingest_component as update_pixel_stats from cala.nodes.traces import ingest_component as update_traces @@ -34,7 +34,9 @@ def update_assets( updated_component_stats = update_component_stats( component_stats=component_stats, traces=traces, new_traces=new_traces ) - updated_overlaps = update_overlap(overlaps=overlaps, footprints=updated_shapes) + updated_overlaps = update_overlap( + overlaps=overlaps, footprints=footprints, new_footprints=new_footprints + ) return ( updated_traces, diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/overlap.py index a04ea94d..374b6b63 100644 --- a/src/cala/nodes/overlap.py +++ b/src/cala/nodes/overlap.py @@ -3,6 +3,7 @@ from cala.assets import Footprints, Overlaps from cala.models import AXIS +from cala.util import sp_matmul def initialize(overlaps: Overlaps, footprints: Footprints) -> Overlaps: @@ -11,9 +12,9 @@ def initialize(overlaps: Overlaps, footprints: Footprints) -> Overlaps: if A is None: return overlaps - V = (A @ A.rename(AXIS.component_rename)) > 0 + V = sp_matmul(left=A, dim=AXIS.component_dim, rename_map=AXIS.component_rename) - overlaps.array = V + overlaps.array = V > 0 return overlaps @@ -38,37 +39,52 @@ def ingest_component( V = overlaps.array - a_new = new_footprints.array.volumize.dim_with_coords( - dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.detect_coord] - ) + a_new = new_footprints.array.transpose(AXIS.component_dim, ...) + + merged_ids = a_new.attrs.get("replaces", []) + intact_ids = [id_ for id_ in V[AXIS.id_coord].values if id_ not in merged_ids] + + if merged_ids: + V = ( + V.set_xindex(AXIS.id_coord) + .set_xindex(f"{AXIS.id_coord}'") + .sel({AXIS.id_coord: intact_ids, f"{AXIS.id_coord}'": intact_ids}) + .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) + ) if a_new[AXIS.id_coord].item() in V[AXIS.id_coord].values: # trace REPLACEMENT dim_idx = np.where(V[AXIS.id_coord].values == a_new[AXIS.id_coord].item())[0].tolist() V = V.drop_sel({AXIS.component_dim: dim_idx, f"{AXIS.component_dim}'": dim_idx}) - # think i also have to remove the ID from A, + # Also have to remove the ID from A, # since it's been already added in footprints.component_ingest A = footprints.array id_idx = np.where(A[AXIS.id_coord].values == a_new[AXIS.id_coord].item())[0].tolist() A = A.drop_sel({AXIS.component_dim: id_idx}) # Compute spatial overlaps between new and existing components - bottom_left_overlap = A @ a_new.rename(AXIS.component_rename) - top_right_overlap = A.rename(AXIS.component_rename) @ a_new + bl_overlap = sp_matmul( + left=A, right=a_new, dim=AXIS.component_dim, rename_map=AXIS.component_rename + ) + tr_overlap = xr.DataArray( + bl_overlap.data, + dims=[f"{AXIS.component_dim}'", AXIS.component_dim], + coords=a_new[AXIS.component_dim].coords, + ).assign_coords(A[AXIS.component_dim].rename(AXIS.component_rename).coords) # Compute overlaps between new components themselves - new_overlaps = a_new @ a_new.rename(AXIS.component_rename) + new_overlaps = sp_matmul(left=a_new, dim=AXIS.component_dim, rename_map=AXIS.component_rename) # Construct the new overlap matrix by blocks # [existing_overlaps og_new_overlaps.T] # [og_new_overlaps new_overlaps ] # First concatenate horizontally: [existing_overlaps, old_new_overlaps] - top_block = xr.concat([V.astype(float), top_right_overlap], dim=AXIS.component_dim) + top_block = xr.concat([V.astype(float), tr_overlap], dim=AXIS.component_dim) # Then concatenate vertically with [new_overlaps, new_overlaps] - bottom_block = xr.concat([bottom_left_overlap, new_overlaps], dim=AXIS.component_dim) + bottom_block = xr.concat([bl_overlap, new_overlaps], dim=AXIS.component_dim) # Finally combine top and bottom blocks updated_overlaps = xr.concat([top_block, bottom_block], dim=f"{AXIS.component_dim}'") diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index ffbb65cb..5ceb3132 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -72,13 +72,8 @@ def ingest_frame( ) if traces.zarr_path: - updated_tr = updated_traces.expand_dims(AXIS.frames_dim).assign_coords( - { - AXIS.timestamp_coord: ( - AXIS.frames_dim, - [updated_traces[AXIS.timestamp_coord].values], - ) - } + updated_tr = updated_traces.volumize.dim_with_coords( + dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] ) traces.update(updated_tr, append_dim=AXIS.frames_dim) else: diff --git a/src/cala/util.py b/src/cala/util.py index 0251748e..cedfe255 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -3,6 +3,9 @@ from shutil import rmtree from uuid import uuid4 +import numpy as np +import xarray as xr + def create_id() -> str: return uuid4().hex @@ -19,3 +22,31 @@ def clear_dir(directory: Path | str) -> None: path.unlink() elif path.is_dir(): rmtree(path) + + +def sp_matmul( + left: xr.DataArray, dim: str, rename_map: dict, right: xr.DataArray | None = None +) -> xr.DataArray: + """ + Faster than xarray @ (for sparse arrays). The syntax is complicated enough that I'm making a + utility function + + :param left: + :param dim: + :param rename_map: + :param right: + """ + + left = left.transpose(dim, ...) + + if right is None: + right = left + else: + right = right.transpose(dim, ...) + val = np.matmul( + np.reshape(left.data, (left.sizes[dim], -1)), + np.reshape(right.data, (right.sizes[dim], -1)).T, + ) + return xr.DataArray(val, dims=[dim, f"{dim}'"], coords=left[dim].coords).assign_coords( + right[dim].rename(rename_map).coords + ) diff --git a/tests/test_iter/test_overlaps.py b/tests/test_iter/test_overlaps.py index f891a9da..31c18e96 100644 --- a/tests/test_iter/test_overlaps.py +++ b/tests/test_iter/test_overlaps.py @@ -37,7 +37,7 @@ def comp_update() -> Node: def test_ingest_component(init, comp_update, toy, request) -> None: toy = request.getfixturevalue(toy) base = Footprints.from_array(toy.footprints.array.isel({AXIS.component_dim: slice(None, -1)})) - new = Footprint.from_array(toy.footprints.array.isel({AXIS.component_dim: -1})) + new = Footprint.from_array(toy.footprints.array.isel({AXIS.component_dim: [-1]})) pre_ingest = init.process(overlaps=Overlaps(), footprints=base) From e7b12c32ce91e97e866369e7577eca683ac98842 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 01:45:18 -0700 Subject: [PATCH 20/47] format: ruff --- src/cala/assets.py | 2 +- src/cala/nodes/footprints.py | 4 ++-- src/cala/nodes/residual.py | 2 +- src/cala/nodes/traces.py | 4 ++-- src/cala/util.py | 5 +---- tests/test_iter/test_residual.py | 4 ++-- 6 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 02a3dfb0..a5e99388 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -2,7 +2,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, ClassVar, TypeVar, Self +from typing import Any, ClassVar, Self, TypeVar import numpy as np import xarray as xr diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 9132438a..7e67fff0 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -91,7 +91,7 @@ def update_shapes( A_mask: list[np.ndarray], Ab_dense: np.ndarray | None = None, iters: int = 5, -): +) -> tuple[csc_matrix, list[np.ndarray], np.ndarray]: """ :param CY: suff stats (pixel,) :param CC: suff stats (component), shape (comp, comp) @@ -156,7 +156,7 @@ def _normalize( m: int, Ab: csc_matrix, Ab_dense: np.ndarray | None, - ind_A: list[int], + ind_A: list[np.ndarray], ind_pixels: int, tmp: np.ndarray, ) -> tuple[np.ndarray, csc_matrix, list[int]]: diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index c2d02fec..d2bab858 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -4,7 +4,7 @@ import xarray as xr from noob import Name -from cala.assets import Footprints, Movie, Residual, Traces, Frame +from cala.assets import Footprints, Frame, Residual, Traces from cala.models import AXIS diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 5ceb3132..fb6275b0 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -150,7 +150,7 @@ def _update_traces( iters: int = 5, tol: float = 1e-3, groups: list[list[int]] = None, -): +) -> tuple[np.ndarray, np.ndarray]: """ Solve C = argmin_C ||Yr-AC|| using block-coordinate decent Parameters @@ -176,7 +176,7 @@ def _update_traces( C = noisyC.copy() # faster than np.linalg.norm - def norm(c): + def norm(c: np.ndarray) -> float: return np.sqrt(c.ravel().dot(c.ravel())) while (norm(C_old - C) >= tol * norm(C_old)) and (num_iters < iters): diff --git a/src/cala/util.py b/src/cala/util.py index cedfe255..0b6623ec 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -38,11 +38,8 @@ def sp_matmul( """ left = left.transpose(dim, ...) + right = left if right is None else right.transpose(dim, ...) - if right is None: - right = left - else: - right = right.transpose(dim, ...) val = np.matmul( np.reshape(left.data, (left.sizes[dim], -1)), np.reshape(right.data, (right.sizes[dim], -1)).T, diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index e0804268..41e20018 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -3,10 +3,10 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Residual, Footprints, Traces, Frame +from cala.assets import Footprints, Frame, Residual, Traces from cala.models.axis import AXIS from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints -from cala.testing.toy import Toy, FrameDims, Position +from cala.testing.toy import FrameDims, Position, Toy @pytest.fixture From 961ef203611591bb7e21b22fe2b85a5f6d6221c6 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 01:47:06 -0700 Subject: [PATCH 21/47] test: no more boundary expansion --- tests/test_iter/test_footprints.py | 118 ++++++++++++++--------------- 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 58a058c5..6ec6fdfd 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -2,11 +2,7 @@ import pytest import xarray as xr from noob.node import Node, NodeSpecification -from scipy.ndimage import grey_dilation, grey_erosion -from skimage.morphology import disk -from cala.assets import Footprints -from cala.models.axis import AXIS from cala.testing.toy import FrameDims, Position, Toy @@ -98,59 +94,61 @@ def xpander() -> Node: ) -@pytest.mark.parametrize("defect", [grey_erosion, grey_dilation]) -@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"]) -def test_boundary_morph(xpander, defect, toy, request): - """ - what would be the circumstances of needing boundary expansion: - existing footprint is too small. - does not affect component_stats - pixel_stats may care. if the correlation with a component and a pixel is high, - the pixel_stat would be high. - boundary-expansion would be literally trying to add pixel_stats (normalized) around the - boundary of the current footprint. (basically how many times pixel and trace coincided) - the thing is, pixel_stat never goes below zero. so you're always sort of adding the boundary - pixels. - this means this phenomenon needs to be regulated by another mechanism, i.e. pixel - value going to zero somehow. - it would be pretty hard to ensure the cell boundary does not forever expand, since - the longer the video, the more coincidences with any pixel and any trace will occur, - so expansion is almost guaranteed every single loop. - this means after the expansion, we need to rely on removal of "overexpanded" pixels. - does that occur naturally at W - AM? - Not exactly. - - W: width height comp, A: width height comp M: comp comp - M: avg dot product of traces - AM: footprint x how correlated other cells are - """ - toy = request.getfixturevalue(toy) - - pixstats = Node.from_specification( - NodeSpecification(id="test_pixstats", type="cala.nodes.pixel_stats.initialize") - ).process(traces=toy.traces, frames=toy.make_movie()) - compstats = Node.from_specification( - NodeSpecification(id="test_compstats", type="cala.nodes.component_stats.initialize") - ).process(traces=toy.traces) - - footprint = disk(radius=1) - - modded_fps = xr.apply_ufunc( - defect, - toy.footprints.array.as_numpy(), - kwargs={"footprint": footprint}, - vectorize=True, - input_core_dims=[AXIS.spatial_dims], - output_core_dims=[AXIS.spatial_dims], - ) - - result = xpander.process( - footprints=Footprints.from_array(modded_fps), - pixel_stats=pixstats, - component_stats=compstats, - index=0, - ) - - # expansion breaks when a trace is all-zero and overlaps with another component. - # not sure when an all-zero trace would occur (esp with noise), so probably ok. - xr.testing.assert_allclose(result.array.as_numpy(), toy.footprints.array.as_numpy(), atol=1e-3) +# @pytest.mark.parametrize("defect", [grey_erosion, grey_dilation]) +# @pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"]) +# def test_boundary_morph(xpander, defect, toy, request): +# """ +# what would be the circumstances of needing boundary expansion: +# existing footprint is too small. +# does not affect component_stats +# pixel_stats may care. if the correlation with a component and a pixel is high, +# the pixel_stat would be high. +# boundary-expansion would be literally trying to add pixel_stats (normalized) around the +# boundary of the current footprint. (basically how many times pixel and trace coincided) +# the thing is, pixel_stat never goes below zero. so you're always sort of adding the boundary +# pixels. +# this means this phenomenon needs to be regulated by another mechanism, i.e. pixel +# value going to zero somehow. +# it would be pretty hard to ensure the cell boundary does not forever expand, since +# the longer the video, the more coincidences with any pixel and any trace will occur, +# so expansion is almost guaranteed every single loop. +# this means after the expansion, we need to rely on removal of "overexpanded" pixels. +# does that occur naturally at W - AM? +# Not exactly. +# +# W: width height comp, A: width height comp M: comp comp +# M: avg dot product of traces +# AM: footprint x how correlated other cells are +# """ +# toy = request.getfixturevalue(toy) +# +# pixstats = Node.from_specification( +# NodeSpecification(id="test_pixstats", type="cala.nodes.pixel_stats.initialize") +# ).process(traces=toy.traces, frames=toy.make_movie()) +# compstats = Node.from_specification( +# NodeSpecification(id="test_compstats", type="cala.nodes.component_stats.initialize") +# ).process(traces=toy.traces) +# +# footprint = disk(radius=1) +# +# modded_fps = xr.apply_ufunc( +# defect, +# toy.footprints.array.as_numpy(), +# kwargs={"footprint": footprint}, +# vectorize=True, +# input_core_dims=[AXIS.spatial_dims], +# output_core_dims=[AXIS.spatial_dims], +# ) +# +# result = xpander.process( +# footprints=Footprints.from_array(modded_fps), +# pixel_stats=pixstats, +# component_stats=compstats, +# index=0, +# ) +# +# # expansion breaks when a trace is all-zero and overlaps with another component. +# # not sure when an all-zero trace would occur (esp with noise), so probably ok. +# xr.testing.assert_allclose( +# result.array.as_numpy(), toy.footprints.array.as_numpy(), atol=1e-3 +# ) From 339d79d58fd4baf75720d6d08b29a7a688ab520b Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 03:21:36 -0700 Subject: [PATCH 22/47] debug: residual size limit --- src/cala/nodes/residual.py | 9 +++++++-- tests/data/pipelines/odl.yaml | 1 + tests/data/pipelines/with_src.yaml | 21 +++++++++++---------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index d2bab858..76aa29b9 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -9,7 +9,12 @@ def build( - residuals: Residual, frame: Frame, footprints: Footprints, traces: Traces, n_recalc: int + residuals: Residual, + frame: Frame, + footprints: Footprints, + traces: Traces, + size: int, + n_recalc: int, ) -> A[Residual, Name("movie")]: """ The computation follows the equation: @@ -55,7 +60,7 @@ def build( # if recently discovered, set to zero (or a small number). otherwise, just append R = _update(Y, A, C, residuals.array, n_recalc=n_recalc) - residuals.array = R # clipping is for the first n frames + residuals.array = R.isel({AXIS.frames_dim: slice(-size, None)}) return residuals diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index a631c3e2..842729e5 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -102,6 +102,7 @@ nodes: residual: type: cala.nodes.residual.build params: + size: 100 n_recalc: 1 depends: - frame: glow.frame diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 08878e65..61c8f871 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -34,16 +34,16 @@ nodes: type: cala.nodes.io.stream params: files: - - cala/msCam1.avi - - cala/msCam2.avi - - cala/msCam3.avi - - cala/msCam4.avi - - cala/msCam5.avi - - cala/msCam6.avi - - cala/msCam7.avi - - cala/msCam8.avi - - cala/msCam9.avi - - cala/msCam10.avi + - minian/msCam1.avi + - minian/msCam2.avi + - minian/msCam3.avi + - minian/msCam4.avi + - minian/msCam5.avi + - minian/msCam6.avi + - minian/msCam7.avi + - minian/msCam8.avi + - minian/msCam9.avi + - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: @@ -143,6 +143,7 @@ nodes: residual: type: cala.nodes.residual.build params: + size: 100 n_recalc: 4 depends: - frame: glow.frame From 0fb2841e8dc567116404f1fcc292f6e6d64f6b08 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 03:21:56 -0700 Subject: [PATCH 23/47] debug: overlap update in case of multiple new comps --- src/cala/nodes/overlap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/overlap.py index 374b6b63..7c3a85c8 100644 --- a/src/cala/nodes/overlap.py +++ b/src/cala/nodes/overlap.py @@ -52,15 +52,15 @@ def ingest_component( .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) ) - if a_new[AXIS.id_coord].item() in V[AXIS.id_coord].values: - # trace REPLACEMENT - dim_idx = np.where(V[AXIS.id_coord].values == a_new[AXIS.id_coord].item())[0].tolist() + dupli = a_new[AXIS.id_coord].isin(V[AXIS.id_coord]) + if np.any(dupli): + dim_idx = np.where(dupli)[0].tolist() V = V.drop_sel({AXIS.component_dim: dim_idx, f"{AXIS.component_dim}'": dim_idx}) # Also have to remove the ID from A, # since it's been already added in footprints.component_ingest A = footprints.array - id_idx = np.where(A[AXIS.id_coord].values == a_new[AXIS.id_coord].item())[0].tolist() + id_idx = np.where(A[AXIS.id_coord].isin(a_new[AXIS.id_coord]))[0].tolist() A = A.drop_sel({AXIS.component_dim: id_idx}) # Compute spatial overlaps between new and existing components From 9e91d44237a92616afc8f093c713128df1b04a64 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 9 Oct 2025 10:12:33 -0700 Subject: [PATCH 24/47] test: refit residual test --- tests/test_iter/test_residual.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 41e20018..e2761ac1 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -37,7 +37,9 @@ def connected_cells() -> Toy: def init() -> Node: return Node.from_specification( spec=NodeSpecification( - id="res_init_test", type="cala.nodes.residual.build", params={"n_recalc": 5} + id="res_init_test", + type="cala.nodes.residual.build", + params={"size": 100, "n_recalc": 5}, ) ) From cbd1497f947158a4b11a8b12c0540a3648891242 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 13 Oct 2025 10:10:54 -0700 Subject: [PATCH 25/47] test: performance test setup --- pdm.lock | 45 ++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ src/cala/main.py | 2 +- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pdm.lock b/pdm.lock index 2b2d1c6f..077851f2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "gui", "tests"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:9cee290ad8bb089564778c0a70e9f1bc6397ab0ae54bc5644a0937aadc59371a" +content_hash = "sha256:0f4f59018aa95cc2a3cda96a551bd3b39af2b5c70a9c8863f2242be07808af2d" [[metadata.targets]] requires_python = ">=3.11" @@ -3433,6 +3433,19 @@ files = [ {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, ] +[[package]] +name = "tuna" +version = "0.5.11" +requires_python = ">=3.6" +summary = "Visualize Python performance profiles" +dependencies = [ + "importlib-metadata; python_version < \"3.8\"", +] +files = [ + {file = "tuna-0.5.11-py3-none-any.whl", hash = "sha256:ab352a6d836014ace585ecd882148f1f7c68be9ea4bf9e9298b7127594dab2ef"}, + {file = "tuna-0.5.11.tar.gz", hash = "sha256:d47f3e39e80af961c8df016ac97d1643c3c60b5eb451299da0ab5fe411d8866c"}, +] + [[package]] name = "twine" version = "6.1.0" @@ -3761,6 +3774,36 @@ files = [ {file = "xarray_validate-0.0.2.tar.gz", hash = "sha256:4c8ae78a0cbe0719aad9e47ad5574f94474fc09afc6c197f6c947bc83d791798"}, ] +[[package]] +name = "yappi" +version = "1.6.10" +requires_python = ">=3.6" +summary = "Yet Another Python Profiler" +files = [ + {file = "yappi-1.6.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:20b8289e8cca781e948f72d86c03b308e077abeec53ec60080f77319041e0511"}, + {file = "yappi-1.6.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4bc9a30b162cb0e13d6400476fa05c272992bd359592e9bba1a570878d9e155c"}, + {file = "yappi-1.6.10-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40aa421ea7078795ed2f0e6bae3f8f64f6cd5019c885a12c613b44dd1fc598b4"}, + {file = "yappi-1.6.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0d62741c0ac883067e40481ab89ddd9e004292dbd22ac5992cf45745bf28ccc3"}, + {file = "yappi-1.6.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1cf46ebe43ac95f8736618a5f0ac763c7502a3aa964a1dda083d9e9c1bf07b12"}, + {file = "yappi-1.6.10-cp311-cp311-win32.whl", hash = "sha256:ff3688aa99b08ee10ced478b7255ac03865a8b5c0677482056acfe4d4f56e45f"}, + {file = "yappi-1.6.10-cp311-cp311-win_amd64.whl", hash = "sha256:4bd4f820e84d823724b8de4bf6857025e9e6c953978dd32485e054cf7de0eda7"}, + {file = "yappi-1.6.10-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:32c6d928604d7a236090bc36d324f309fe8344c91123bb84e37c43f6677adddc"}, + {file = "yappi-1.6.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9683c40de7e4ddff225068032cd97a6d928e4beddd9c7cf6515325be8ac28036"}, + {file = "yappi-1.6.10-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:733a212014f2b44673ed62be53b3d4dd458844cd2008ba107f27a3293e42f43a"}, + {file = "yappi-1.6.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:7d80938e566ac6329daa3b036fdf7bd34488010efcf0a65169a44603878daa4e"}, + {file = "yappi-1.6.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:01705971b728a4f95829b723d08883c7623ec275f4066f4048b28dc0151fe0af"}, + {file = "yappi-1.6.10-cp312-cp312-win32.whl", hash = "sha256:8dd13a430b046e2921ddf63d992da97968724b41a03e68292f06a2afa11c9d6e"}, + {file = "yappi-1.6.10-cp312-cp312-win_amd64.whl", hash = "sha256:a50eb3aec893c40554f8f811d3341af266d844e7759f7f7abfcdba2744885ea3"}, + {file = "yappi-1.6.10-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:944df9ebc6b283d6591a6b5f4c586d0eb9c6131c915f1b20fb36127ade83720d"}, + {file = "yappi-1.6.10-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3736ea6458edbabd96918d88e2963594823e4ab4c58d62a52ef81f6b5839ec19"}, + {file = "yappi-1.6.10-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f27bbc3311a3662231cff395d38061683fac5c538f3bab6796ff05511d2cce43"}, + {file = "yappi-1.6.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:354cf94d659302b421b13c03487f2f1bce969b97b85fba88afb11f2ef83c35f3"}, + {file = "yappi-1.6.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1d82839835ae2c291b88fb56d82f80c88c00d76df29f3c1ed050db73b553bef0"}, + {file = "yappi-1.6.10-cp313-cp313-win32.whl", hash = "sha256:fc84074575afcc5a2a712e132c0b51541b7434b3099be99f573964ef3b6064a8"}, + {file = "yappi-1.6.10-cp313-cp313-win_amd64.whl", hash = "sha256:334b31dfefae02bc28b7cd50953aaaae3292e40c15efb613792e4a587281a161"}, + {file = "yappi-1.6.10.tar.gz", hash = "sha256:463b822727658937bd95a7d80ca9758605b8cd0014e004e9e520ec9cb4db0c92"}, +] + [[package]] name = "zarr" version = "3.0.8" diff --git a/pyproject.toml b/pyproject.toml index 704247a7..e4ad07fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,8 @@ tests = [ "snakeviz>=2.2.2", "pytest-profiling>=1.8.1", "pytest-timeout>=2.3.1", + "yappi>=1.6.10", + "tuna>=0.5.11", ] docs = [ "sphinx>=8.2.3", diff --git a/src/cala/main.py b/src/cala/main.py index 4c867efa..cd7004b7 100644 --- a/src/cala/main.py +++ b/src/cala/main.py @@ -21,7 +21,7 @@ @cli.command() def main(spec: str, gui: Annotated[bool, typer.Option()] = False) -> None: if gui: - uvicorn.run("cala.main:app", reload=True, reload_dirs=[Path(__file__).parent]) + uvicorn.run("cala.main:app", reload=False, reload_dirs=[Path(__file__).parent]) else: tube = Tube.from_specification(spec) runner = SynchronousRunner(tube=tube) From ebb433c400b5a4778e39219ac0277beb75954cd1 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 13 Oct 2025 10:13:21 -0700 Subject: [PATCH 26/47] feat: rename residual asset to buffer --- src/cala/assets.py | 13 +------------ src/cala/nodes/cleanup.py | 4 ++-- src/cala/nodes/detect/slice_nmf.py | 4 ++-- src/cala/nodes/residual.py | 15 ++++++++++++--- tests/test_iter/test_cleanup.py | 4 ++-- tests/test_iter/test_detect.py | 16 ++++++++-------- tests/test_iter/test_residual.py | 4 ++-- 7 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index a5e99388..44195262 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -293,18 +293,7 @@ class Overlaps(Asset): ) -class Residual(Asset): - """ - Computes and maintains a buffer of residual signals. - - This method implements the residual computation by subtracting the - reconstructed signal from the original data. It maintains only the - most recent frames as specified by the buffer length. - - The residual buffer contains the recent history of unexplained variance - in the data after accounting for known components. - """ - +class Buffer(Asset): _entity: ClassVar[Entity] = PrivateAttr( Group( name="frame", diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index b3a6d0ab..679a76df 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -5,12 +5,12 @@ import xarray as xr from noob import Name -from cala.assets import CompStats, Footprints, Overlaps, PixStats, Residual, Traces +from cala.assets import CompStats, Footprints, Overlaps, PixStats, Buffer, Traces from cala.models import AXIS def clear_overestimates( - footprints: Footprints, residuals: Residual, nmf_error: float + footprints: Footprints, residuals: Buffer, nmf_error: float ) -> A[Footprints, Name("footprints")]: """ Remove all sections of the footprints that cause negative residuals. diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 95747107..5acecd1c 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -9,7 +9,7 @@ from pydantic import Field, PrivateAttr from sklearn.decomposition import NMF -from cala.assets import Footprint, Residual, Trace +from cala.assets import Footprint, Buffer, Trace from cala.logging import init_logger from cala.models import AXIS @@ -34,7 +34,7 @@ def model_post_init(self, context: Any, /) -> None: self._model = NMF(**self.nmf_kwargs) def process( - self, residuals: Residual, detect_radius: int + self, residuals: Buffer, detect_radius: int ) -> tuple[A[list[Footprint], Name("new_fps")], A[list[Trace], Name("new_trs")]]: if residuals.array.sizes[AXIS.frames_dim] < self.min_frames: diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 76aa29b9..a0d6c9e8 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -4,19 +4,28 @@ import xarray as xr from noob import Name -from cala.assets import Footprints, Frame, Residual, Traces +from cala.assets import Footprints, Frame, Buffer, Traces from cala.models import AXIS def build( - residuals: Residual, + residuals: Buffer, frame: Frame, footprints: Footprints, traces: Traces, size: int, n_recalc: int, -) -> A[Residual, Name("movie")]: +) -> A[Buffer, Name("movie")]: """ + Computes and maintains a buffer of residual signals. + + This method implements the residual computation by subtracting the + reconstructed signal from the original data. It maintains only the + most recent frames as specified by the buffer length. + + The residual buffer contains the recent history of unexplained variance + in the data after accounting for known components. + The computation follows the equation: R_buf = [Y − [A, b][C; f]][:, t′ − l_b + 1 : t′] where: diff --git a/tests/test_iter/test_cleanup.py b/tests/test_iter/test_cleanup.py index ce9e235b..5a2ef226 100644 --- a/tests/test_iter/test_cleanup.py +++ b/tests/test_iter/test_cleanup.py @@ -1,12 +1,12 @@ import xarray as xr -from cala.assets import Residual +from cala.assets import Buffer from cala.models import AXIS from cala.nodes.cleanup import _filter_redundant, clear_overestimates def test_clear_overestimates(single_cell) -> None: - residual = Residual.from_array(single_cell.make_movie().array) + residual = Buffer.from_array(single_cell.make_movie().array) residual.array.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] *= -1 result = clear_overestimates( diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 5db6a5f3..a10d7997 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -4,7 +4,7 @@ from noob.node import NodeSpecification from sklearn.decomposition import NMF -from cala.assets import AXIS, Footprints, Residual, Traces +from cala.assets import AXIS, Footprints, Buffer, Traces from cala.nodes.detect import Cataloger, SliceNMF from cala.nodes.detect.catalog import _merge_with, _register from cala.nodes.detect.slice_nmf import rank1nmf @@ -42,7 +42,7 @@ def cataloger(): class TestSliceNMF: def test_process(self, slice_nmf, single_cell): new_component = slice_nmf.process( - Residual.from_array(single_cell.make_movie().array), + Buffer.from_array(single_cell.make_movie().array), detect_radius=single_cell.cell_radii[0] * 2, ) if new_component: @@ -62,7 +62,7 @@ def test_chunks(self, single_cell): ) ) fpts, trcs = nmf.process( - Residual.from_array(single_cell.make_movie().array), detect_radius=10 + Buffer.from_array(single_cell.make_movie().array), detect_radius=10 ) if not fpts or not trcs: raise AssertionError("Failed to detect a new component") @@ -82,7 +82,7 @@ class TestCataloger: @pytest.fixture(scope="function") def new_component(self, slice_nmf, single_cell): return slice_nmf.process( - Residual.from_array(single_cell.make_movie().array), detect_radius=60 + Buffer.from_array(single_cell.make_movie().array), detect_radius=60 ) def test_register(self, cataloger, new_component): @@ -94,7 +94,7 @@ def test_register(self, cataloger, new_component): def test_merge_with(self, slice_nmf, cataloger, single_cell): new_component = slice_nmf.process( - Residual.from_array(single_cell.make_movie().array), detect_radius=10 + Buffer.from_array(single_cell.make_movie().array), detect_radius=10 ) new_fp, new_tr = new_component @@ -122,7 +122,7 @@ def test_process_ideal(self, slice_nmf, cataloger, separate_cells): test cataloging separate cells. ideal case with cell_radius=5 """ movie = separate_cells.make_movie().array - fps, trs = slice_nmf.process(Residual.from_array(movie), detect_radius=5) + fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=5) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -146,7 +146,7 @@ def test_process_fail(self, slice_nmf, cataloger, separate_cells): test cataloging separate cells. nmf supposed to fail with radius=25 (grabs too many cells) """ movie = separate_cells.make_movie().array - fps, trs = slice_nmf.process(Residual.from_array(movie), detect_radius=25) + fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=25) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -160,7 +160,7 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): trial with connected cells 🙏 """ movie = connected_cells.make_movie().array - fps, trs = slice_nmf.process(Residual.from_array(movie), detect_radius=4) + fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=4) # NOTE: by manually putting in connected_cells, # we're forcing a double-detection in this test diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index e2761ac1..f79f62c5 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -3,7 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Footprints, Frame, Residual, Traces +from cala.assets import Footprints, Frame, Buffer, Traces from cala.models.axis import AXIS from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints from cala.testing.toy import FrameDims, Position, Toy @@ -45,7 +45,7 @@ def init() -> Node: def test_init(init, connected_cells) -> None: - residual = Residual() + residual = Buffer() gen = connected_cells.movie_gen() for _ in range(connected_cells.n_frames - 1): From 6f1665cf4dd4664e42c1b8b2e92859c5931f29e7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 13 Oct 2025 22:01:31 -0700 Subject: [PATCH 27/47] feat: build buffer class for concat / query speed improvement (x50 faster with 100 frames) --- src/cala/assets.py | 63 ++++++++++++++++++++++++++++ src/cala/models/axis.py | 2 +- src/cala/nodes/buffer.py | 13 ++---- src/cala/nodes/component_stats.py | 2 +- src/cala/nodes/detect/slice_nmf.py | 2 +- src/cala/nodes/pixel_stats.py | 2 +- src/cala/nodes/residual.py | 2 +- tests/test_assets.py | 67 +++++++++++++++++++++++++++++- 8 files changed, 137 insertions(+), 16 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 44195262..d7278eb7 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -294,6 +294,13 @@ class Overlaps(Asset): class Buffer(Asset): + """ + Implements a fake ring buffer to avoid expensive copying that occurs with + numpy concat, append, and stack. + + Works by preallocating a space twice the desired size. + """ + _entity: ClassVar[Entity] = PrivateAttr( Group( name="frame", @@ -303,3 +310,59 @@ class Buffer(Asset): allow_extra_coords=False, ) ) + + validate_schema: bool = False + """Validation currently does not play nicely with this class.""" + + size: int + _full: bool = PrivateAttr(False) + _next: int = PrivateAttr(default=0) + + def append(self, array: xr.DataArray) -> None: + self.array_.data[self._next] = array.data + self.array_.data[self._next + self.size] = array.data + for coord in [AXIS.frame_coord, AXIS.timestamp_coord]: + self.array_[coord].data[self._next] = array[coord].item() + self.array_[coord].data[self._next + self.size] = array[coord].item() + + self._next = (self._next + 1) % self.size + if not self._full: + # check if this made the buffer full + self._full = self._next == 0 + + @property + def array(self) -> xr.DataArray: + if self._full: + out = self.array_.isel({AXIS.frames_dim: slice(self._next, self._next + self.size)}) + else: + out = self.array_.isel({AXIS.frames_dim: slice(None, self._next)}) + return out # .assign_coords({AXIS.frame_coord: out[AXIS.frame_coord].astype(int)}) + + @array.setter + def array(self, array: xr.DataArray) -> None: + """ + Build a new buffer array. + """ + array = ( + array.volumize.dim_with_coords( + dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] + ) + if AXIS.frames_dim not in array.dims + else array.isel({AXIS.frames_dim: slice(-self.size, None)}) + ) + fill_sizes = dict(array.sizes) + fill_sizes[AXIS.frames_dim] = self.size - array.sizes[AXIS.frames_dim] + fill = np.zeros(list(fill_sizes.values())) + filler = xr.DataArray( + fill, + dims=array.dims, + coords={ + AXIS.frame_coord: (AXIS.frames_dim, [np.nan] * (fill_sizes[AXIS.frames_dim])), + AXIS.timestamp_coord: (AXIS.frames_dim, [""] * (fill_sizes[AXIS.frames_dim])), + }, + ) + buffer = xr.concat([array, filler] * 2, dim=AXIS.frames_dim) + + self._full = array.sizes[AXIS.frames_dim] >= self.size + self._next = np.min((array.sizes[AXIS.frames_dim], self.size)) % self.size + self.array_ = buffer diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py index e4222955..6430b9da 100644 --- a/src/cala/models/axis.py +++ b/src/cala/models/axis.py @@ -18,7 +18,7 @@ class Axis: id_coord: str = "id_" timestamp_coord: str = "timestamp" detect_coord: str = "detected_on" - frame_coord: str = "frame" + frame_coord: str = "frame_idx" width_coord: str = "width" height_coord: str = "height" diff --git a/src/cala/nodes/buffer.py b/src/cala/nodes/buffer.py index 626dd2ff..24cb3811 100644 --- a/src/cala/nodes/buffer.py +++ b/src/cala/nodes/buffer.py @@ -1,24 +1,17 @@ from typing import Annotated as A -import xarray as xr from noob import Name -from cala.assets import Frame, Movie +from cala.assets import Frame, Buffer from cala.models import AXIS -def fill_buffer(size: int, buffer: Movie, frame: Frame) -> A[Movie, Name("buffer")]: +def fill_buffer(buffer: Buffer, frame: Frame) -> A[Buffer, Name("buffer")]: if buffer.array is None: buffer.array = frame.array.volumize.dim_with_coords( dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] ) return buffer - buffered = ( - buffer.array.transpose(AXIS.frames_dim, ...)[-size + 1 :] - if buffer.array.sizes[AXIS.frames_dim] >= size - else buffer.array - ) - - buffer.array = xr.concat([buffered, frame.array], dim=AXIS.frames_dim) + buffer.append(frame.array) return buffer diff --git a/src/cala/nodes/component_stats.py b/src/cala/nodes/component_stats.py index 0baf16eb..3ebcdec2 100644 --- a/src/cala/nodes/component_stats.py +++ b/src/cala/nodes/component_stats.py @@ -52,7 +52,7 @@ def ingest_frame(component_stats: CompStats, frame: Frame, new_traces: PopSnap) return component_stats # Compute scaling factors - frame_idx = frame.array.coords[AXIS.frame_coord].item() + frame_idx = frame.array[AXIS.frame_coord].item() prev_scale = frame_idx / (frame_idx + 1) new_scale = 1 / (frame_idx + 1) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 5acecd1c..0939acd7 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -9,7 +9,7 @@ from pydantic import Field, PrivateAttr from sklearn.decomposition import NMF -from cala.assets import Footprint, Buffer, Trace +from cala.assets import Buffer, Footprint, Trace from cala.logging import init_logger from cala.models import AXIS diff --git a/src/cala/nodes/pixel_stats.py b/src/cala/nodes/pixel_stats.py index 20df8608..d4b9e9e9 100644 --- a/src/cala/nodes/pixel_stats.py +++ b/src/cala/nodes/pixel_stats.py @@ -68,7 +68,7 @@ def ingest_frame(pixel_stats: PixStats, frame: Frame, new_traces: PopSnap) -> Pi return pixel_stats # Compute scaling factors - frame_idx = frame.array.coords[AXIS.frame_coord].item() + frame_idx = frame.array[AXIS.frame_coord].item() prev_scale = frame_idx / (frame_idx + 1) new_scale = 1 / (frame_idx + 1) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index a0d6c9e8..79f7931c 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -4,7 +4,7 @@ import xarray as xr from noob import Name -from cala.assets import Footprints, Frame, Buffer, Traces +from cala.assets import Buffer, Footprints, Frame, Traces from cala.models import AXIS diff --git a/tests/test_assets.py b/tests/test_assets.py index d8bee4f3..d05da33b 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,9 +1,11 @@ import os +from datetime import datetime from pathlib import Path import pytest +import xarray as xr -from cala.assets import Traces +from cala.assets import Buffer, Traces from cala.models import AXIS @@ -70,3 +72,66 @@ def test_overwrite(connected_cells, separate_cells, path): sep_traces = separate_cells.traces.array zarr_traces.array = sep_traces assert zarr_traces.array.equals(sep_traces) + + +# two cases of init: +# 1. brick by brick +# 2. lump dump + +# two cases of update +# 1. append +# 2. lump update + + +def test_buffer_assign(connected_cells): + movie = connected_cells.make_movie().array + buff = Buffer(size=10) + buff.array = movie.isel({AXIS.frames_dim: -1}) + assert buff.array.equals(movie.isel({AXIS.frames_dim: [-1]})) + + buff.array = movie.isel({AXIS.frames_dim: slice(-5, None)}) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-5, None)})) + + buff.array = movie.isel({AXIS.frames_dim: slice(-10, None)}) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-10, None)})) + + buff.array = movie.isel({AXIS.frames_dim: slice(-15, None)}) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(-10, None)})) + + +def test_buffer_append(connected_cells): + movie = connected_cells.make_movie().array + buff = Buffer(size=10) + buff.array = movie.isel({AXIS.frames_dim: 0}) + buff.append(movie.isel({AXIS.frames_dim: 1})) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(0, 2)})) + + buff.array = movie.isel({AXIS.frames_dim: slice(None, 9)}) + buff.append(movie.isel({AXIS.frames_dim: 9})) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(0, 10)})) + buff.append(movie.isel({AXIS.frames_dim: 10})) + assert buff.array.equals(movie.isel({AXIS.frames_dim: slice(1, 11)})) + + +def test_buffer_speed(single_cell): + movie = single_cell.make_movie().array + movie = xr.concat([movie, movie], dim=AXIS.frames_dim) + buff = Buffer(size=100) + buff.array = movie + + start = datetime.now() + iter = 100 + for _ in range(iter): + buff.append(movie.isel({AXIS.frames_dim: 0})) + buff.array + result = (datetime.now() - start) / iter + + start = datetime.now() + for _ in range(iter): + xr.concat( + [movie.isel({AXIS.frames_dim: slice(1, None)}), movie.isel({AXIS.frames_dim: 0})], + dim=AXIS.frames_dim, + ) + expected = (datetime.now() - start) / iter + + assert result < expected From 3f1498716131104bff7cf72ed5c305835f577a56 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 13 Oct 2025 22:02:18 -0700 Subject: [PATCH 28/47] feat: build buffer class for concat / query speed improvement (x50 faster with 100 frames) --- src/cala/nodes/cleanup.py | 2 +- src/cala/testing/toy.py | 5 +++-- tests/test_iter/test_detect.py | 4 ++-- tests/test_iter/test_residual.py | 2 +- tests/test_pipeline.py | 6 +++--- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 679a76df..7319946d 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -5,7 +5,7 @@ import xarray as xr from noob import Name -from cala.assets import CompStats, Footprints, Overlaps, PixStats, Buffer, Traces +from cala.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces from cala.models import AXIS diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index 024f1c2e..62857df6 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -146,7 +146,8 @@ def _generate_footprint( { AXIS.id_coord: (AXIS.component_dim, [id_]), AXIS.detect_coord: (AXIS.component_dim, [detected_on]), - **{ax: footprint[ax] for ax in AXIS.spatial_dims}, + AXIS.width_coord: (AXIS.width_dim, footprint[AXIS.width_dim].data), + AXIS.height_coord: (AXIS.height_dim, footprint[AXIS.height_dim].data), } ) @@ -170,7 +171,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, detected_on: int) -> xr.Dat { AXIS.id_coord: (AXIS.component_dim, [id_]), AXIS.detect_coord: (AXIS.component_dim, [detected_on]), - AXIS.frames_dim: range(trace.size), + AXIS.frame_coord: (AXIS.frames_dim, range(trace.size)), } ) ) diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index a10d7997..3d488bde 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -4,7 +4,7 @@ from noob.node import NodeSpecification from sklearn.decomposition import NMF -from cala.assets import AXIS, Footprints, Buffer, Traces +from cala.assets import AXIS, Buffer, Footprints, Traces from cala.nodes.detect import Cataloger, SliceNMF from cala.nodes.detect.catalog import _merge_with, _register from cala.nodes.detect.slice_nmf import rank1nmf @@ -82,7 +82,7 @@ class TestCataloger: @pytest.fixture(scope="function") def new_component(self, slice_nmf, single_cell): return slice_nmf.process( - Buffer.from_array(single_cell.make_movie().array), detect_radius=60 + Buffer(size=100).from_array(single_cell.make_movie().array), detect_radius=60 ) def test_register(self, cataloger, new_component): diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index f79f62c5..e5897150 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -3,7 +3,7 @@ import xarray as xr from noob.node import Node, NodeSpecification -from cala.assets import Footprints, Frame, Buffer, Traces +from cala.assets import Buffer, Footprints, Frame, Traces from cala.models.axis import AXIS from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints from cala.testing.toy import FrameDims, Position, Toy diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 94cb63dc..579450c6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -70,19 +70,19 @@ def test_odl(runner, source) -> None: if src_name in ["TwoOverlappingSource", "GradualOnSource"]: # Traces are reasonably similar tr_corr = xr.corr( - toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frame_coord + toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frames_dim ) for corr in tr_corr: assert np.isclose(corr.max(), 1, atol=1e-2) elif src_name in ["SingleCellSource", "TwoCellsSource", "SeparateSource"]: - expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + expected = xr.concat(preprocessed_frames, dim=AXIS.frames_dim) result = (fps.array @ trs.array).transpose(*expected.dims) xr.testing.assert_allclose(expected, result.as_numpy(), atol=1e-5, rtol=1e-5) elif src_name == "SplitOffSource": - expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + expected = xr.concat(preprocessed_frames, dim=AXIS.frames_dim) result = (fps.array @ trs.array).transpose(*expected.dims) raise NotImplementedError("Deprecation not implemented") From 05afc8f76d852a2d1fd7b233c14c380701c9929c Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 12:00:14 -0700 Subject: [PATCH 29/47] feat: cover residual array none case --- src/cala/assets.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index d7278eb7..70002421 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -331,7 +331,9 @@ def append(self, array: xr.DataArray) -> None: self._full = self._next == 0 @property - def array(self) -> xr.DataArray: + def array(self) -> xr.DataArray | None: + if self.array_ is None: + return None if self._full: out = self.array_.isel({AXIS.frames_dim: slice(self._next, self._next + self.size)}) else: @@ -366,3 +368,9 @@ def array(self, array: xr.DataArray) -> None: self._full = array.sizes[AXIS.frames_dim] >= self.size self._next = np.min((array.sizes[AXIS.frames_dim], self.size)) % self.size self.array_ = buffer + + @classmethod + def from_array(cls, array: xr.DataArray, size: int) -> Self: + buffer = cls(size=size) + buffer.array = array + return buffer From 83452264a14da98c73f97cf6be00ee6550c61d02 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 12:00:32 -0700 Subject: [PATCH 30/47] tests: refit tests --- tests/test_iter/test_cleanup.py | 6 ++---- tests/test_iter/test_detect.py | 25 +++++++++++++------------ tests/test_iter/test_residual.py | 4 ++-- tests/test_iter/test_traces.py | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/test_iter/test_cleanup.py b/tests/test_iter/test_cleanup.py index 5a2ef226..2f64e1d0 100644 --- a/tests/test_iter/test_cleanup.py +++ b/tests/test_iter/test_cleanup.py @@ -6,7 +6,7 @@ def test_clear_overestimates(single_cell) -> None: - residual = Buffer.from_array(single_cell.make_movie().array) + residual = Buffer.from_array(single_cell.make_movie().array, size=100) residual.array.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] *= -1 result = clear_overestimates( @@ -29,14 +29,12 @@ def test_erase_redundant(splitoff_cells) -> None: traces = splitoff_cells.traces dead_trace = xr.DataArray( - 0.1, + [0.1] * traces.array.sizes[AXIS.frames_dim], dims=traces.array.isel({AXIS.component_dim: 1}).dims, coords=traces.array.isel({AXIS.component_dim: 1}).coords, ).assign_coords({AXIS.id_coord: "cell_2", AXIS.detect_coord: 77}) traces.array = xr.concat([traces.array, dead_trace], dim=AXIS.component_dim) - # frame = Frame.from_array(footprints.array @ traces.array.isel({AXIS.frames_dim: -1})) - result = _filter_redundant( footprints=footprints, traces=traces, min_life_in_frames=10, quantile=0.9 ) diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 3d488bde..7c9b6c82 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -42,7 +42,7 @@ def cataloger(): class TestSliceNMF: def test_process(self, slice_nmf, single_cell): new_component = slice_nmf.process( - Buffer.from_array(single_cell.make_movie().array), + Buffer.from_array(single_cell.make_movie().array, size=100), detect_radius=single_cell.cell_radii[0] * 2, ) if new_component: @@ -62,7 +62,7 @@ def test_chunks(self, single_cell): ) ) fpts, trcs = nmf.process( - Buffer.from_array(single_cell.make_movie().array), detect_radius=10 + Buffer.from_array(single_cell.make_movie().array, size=100), detect_radius=10 ) if not fpts or not trcs: raise AssertionError("Failed to detect a new component") @@ -81,9 +81,9 @@ def test_chunks(self, single_cell): class TestCataloger: @pytest.fixture(scope="function") def new_component(self, slice_nmf, single_cell): - return slice_nmf.process( - Buffer(size=100).from_array(single_cell.make_movie().array), detect_radius=60 - ) + buff = Buffer(size=100) + buff.array = single_cell.make_movie().array + return slice_nmf.process(buff, detect_radius=60) def test_register(self, cataloger, new_component): new_fp, new_tr = new_component @@ -93,9 +93,9 @@ def test_register(self, cataloger, new_component): assert np.array_equal(tr, new_tr[0].array) def test_merge_with(self, slice_nmf, cataloger, single_cell): - new_component = slice_nmf.process( - Buffer.from_array(single_cell.make_movie().array), detect_radius=10 - ) + buff = Buffer(size=100) + buff.array = single_cell.make_movie().array + new_component = slice_nmf.process(buff, detect_radius=10) new_fp, new_tr = new_component fp, tr = _merge_with( @@ -121,8 +121,9 @@ def test_process_ideal(self, slice_nmf, cataloger, separate_cells): """ test cataloging separate cells. ideal case with cell_radius=5 """ - movie = separate_cells.make_movie().array - fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=5) + buff = Buffer(size=100) + buff.array = separate_cells.make_movie().array + fps, trs = slice_nmf.process(buff, detect_radius=5) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -146,7 +147,7 @@ def test_process_fail(self, slice_nmf, cataloger, separate_cells): test cataloging separate cells. nmf supposed to fail with radius=25 (grabs too many cells) """ movie = separate_cells.make_movie().array - fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=25) + fps, trs = slice_nmf.process(Buffer.from_array(movie, size=100), detect_radius=25) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -160,7 +161,7 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): trial with connected cells 🙏 """ movie = connected_cells.make_movie().array - fps, trs = slice_nmf.process(Buffer.from_array(movie), detect_radius=4) + fps, trs = slice_nmf.process(Buffer.from_array(movie, size=100), detect_radius=4) # NOTE: by manually putting in connected_cells, # we're forcing a double-detection in this test diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index e5897150..8b306567 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -39,13 +39,13 @@ def init() -> Node: spec=NodeSpecification( id="res_init_test", type="cala.nodes.residual.build", - params={"size": 100, "n_recalc": 5}, + params={"n_recalc": 5}, ) ) def test_init(init, connected_cells) -> None: - residual = Buffer() + residual = Buffer(size=100) gen = connected_cells.movie_gen() for _ in range(connected_cells.n_frames - 1): diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 2f92ff01..2f5979cd 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -61,6 +61,6 @@ def test_ingest_component(comp_update, toy, request) -> None: result = comp_update.process(traces, Traces.from_array(new_traces)) expected = toy.traces.array.drop_sel({AXIS.component_dim: 0}) - expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10 - 1)}] = np.nan + expected.loc[{AXIS.component_dim: -1, AXIS.frames_dim: slice(None, 10)}] = np.nan assert result.array.equals(expected) From a24a78fd3ee4ef6d6efa936bb01d7913782d28e7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 12:00:53 -0700 Subject: [PATCH 31/47] feat: residuals node uses buffer --- src/cala/nodes/residual.py | 50 ++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 79f7931c..155c99a6 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -13,7 +13,6 @@ def build( frame: Frame, footprints: Footprints, traces: Traces, - size: int, n_recalc: int, ) -> A[Buffer, Name("movie")]: """ @@ -48,32 +47,40 @@ def build( if residuals.array is None: residuals.array = frame.array.expand_dims(dim=AXIS.frames_dim) else: - residuals.array = xr.concat( - [residuals.array, frame.array], - dim=AXIS.frames_dim, - coords=[AXIS.frame_coord, AXIS.timestamp_coord], - ) + residuals.append(frame.array) return residuals Y = frame.array C = traces.array.isel({AXIS.frames_dim: -1}) # (components,) A = footprints.array - # Compute residual R = Y - [A,b][C;f] - R_curr = Y - xr.DataArray( - np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), dims=AXIS.spatial_dims - ) - if R_curr.min() < 0: + R_curr, flag = _find_overestimates(Y=Y, A=A, C=C) + if flag: C = _align_overestimates(A, C, R_curr) traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = C # if recently discovered, set to zero (or a small number). otherwise, just append - R = _update(Y, A, C, residuals.array, n_recalc=n_recalc) - residuals.array = R.isel({AXIS.frames_dim: slice(-size, None)}) + preserve_area = _get_new_estimators_area(A=A, C=C, n_recalc=n_recalc) + residuals.array_ *= preserve_area.as_numpy() + R_curr = _get_residuals(Y=Y, A=A, C=C) + residuals.append(R_curr.clip(min=0)) return residuals +def _get_residuals(Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray) -> xr.DataArray: + return Y - xr.DataArray( + np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), dims=AXIS.spatial_dims + ) + + +def _find_overestimates( + Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray +) -> tuple[xr.DataArray, bool]: + R_curr = _get_residuals(Y, A, C) + return R_curr, R_curr.min() < 0 + + def _align_overestimates( A: xr.DataArray, C_latest: xr.DataArray, R_latest: xr.DataArray ) -> xr.DataArray: @@ -137,19 +144,14 @@ def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: return A.where(A_layer_mask == 1, 0) -def _update( - Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray, R: xr.DataArray, n_recalc: int -) -> xr.DataArray: +def _get_new_estimators_area( + A: xr.DataArray, C: xr.DataArray, n_recalc: int +) -> xr.DataArray | None: targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].item() - n_recalc) if any(targets): target_ids = targets.where(targets)[AXIS.id_coord] target_area = ~((A.where(target_ids, drop=True)).max(dim=AXIS.component_dim) > 0) - R *= target_area.as_numpy() - - R_curr = Y - xr.DataArray( - np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), - dims=AXIS.spatial_dims, - ) - - return xr.concat([R, R_curr.clip(min=0)], dim=AXIS.frames_dim) + return target_area + else: + return None From ceee0717eef0c3804bedfdad773662ae93e6e19e Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 12:08:28 -0700 Subject: [PATCH 32/47] test: refit yaml --- tests/data/pipelines/odl.yaml | 11 ++++++----- tests/data/pipelines/with_src.yaml | 29 +++++++++++++++-------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 842729e5..c03f7497 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -4,7 +4,9 @@ noob_version: 0.1.1.dev118+g64d81b7 assets: buffer: - type: cala.assets.Movie + type: cala.assets.Buffer + params: + size: 100 scope: runner footprints: type: cala.assets.Footprints @@ -22,7 +24,9 @@ assets: type: cala.assets.Overlaps scope: runner residuals: - type: cala.assets.Residual + type: cala.assets.Buffer + params: + size: 100 scope: runner nodes: @@ -57,8 +61,6 @@ nodes: - frame: glow.frame cache: type: cala.nodes.buffer.fill_buffer - params: - size: 100 depends: - buffer: assets.buffer - frame: glow.frame @@ -102,7 +104,6 @@ nodes: residual: type: cala.nodes.residual.build params: - size: 100 n_recalc: 1 depends: - frame: glow.frame diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 61c8f871..18a80ade 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -4,7 +4,9 @@ noob_version: 0.1.1.dev118+g64d81b7 assets: buffer: - type: cala.assets.Movie + type: cala.assets.Buffer + params: + size: 100 scope: runner footprints: type: cala.assets.Footprints @@ -25,7 +27,9 @@ assets: type: cala.assets.Overlaps scope: runner residuals: - type: cala.assets.Residual + type: cala.assets.Buffer + params: + size: 100 scope: runner @@ -36,14 +40,14 @@ nodes: files: - minian/msCam1.avi - minian/msCam2.avi - - minian/msCam3.avi - - minian/msCam4.avi - - minian/msCam5.avi - - minian/msCam6.avi - - minian/msCam7.avi - - minian/msCam8.avi - - minian/msCam9.avi - - minian/msCam10.avi + # - minian/msCam3.avi + # - minian/msCam4.avi + # - minian/msCam5.avi + # - minian/msCam6.avi + # - minian/msCam7.avi + # - minian/msCam8.avi + # - minian/msCam9.avi + # - minian/msCam10.avi counter: type: cala.nodes.prep.counter frame: @@ -98,8 +102,6 @@ nodes: - frame: denoise.frame cache: type: cala.nodes.buffer.fill_buffer - params: - size: 100 depends: - buffer: assets.buffer - frame: glow.frame @@ -143,8 +145,7 @@ nodes: residual: type: cala.nodes.residual.build params: - size: 100 - n_recalc: 4 + n_recalc: 1 depends: - frame: glow.frame - footprints: footprints_frame.footprints From 5e368bdf23ac17784a1c9ee0ee3a2df43d05d90c Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 12:24:53 -0700 Subject: [PATCH 33/47] feat: frame indexing is easier with separate coord vs. dim names --- src/cala/nodes/buffer.py | 2 +- src/cala/nodes/residual.py | 7 ++++--- src/cala/nodes/traces.py | 7 +++---- tests/test_assets.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/cala/nodes/buffer.py b/src/cala/nodes/buffer.py index 24cb3811..8bb9a9d5 100644 --- a/src/cala/nodes/buffer.py +++ b/src/cala/nodes/buffer.py @@ -2,7 +2,7 @@ from noob import Name -from cala.assets import Frame, Buffer +from cala.assets import Buffer, Frame from cala.models import AXIS diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 155c99a6..8e91c2c4 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -57,11 +57,12 @@ def build( R_curr, flag = _find_overestimates(Y=Y, A=A, C=C) if flag: C = _align_overestimates(A, C, R_curr) - traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = C + traces.array.loc[{AXIS.frames_dim: -1}] = C # if recently discovered, set to zero (or a small number). otherwise, just append preserve_area = _get_new_estimators_area(A=A, C=C, n_recalc=n_recalc) - residuals.array_ *= preserve_area.as_numpy() + if preserve_area is not None: + residuals.array_ *= preserve_area.as_numpy() R_curr = _get_residuals(Y=Y, A=A, C=C) residuals.append(R_curr.clip(min=0)) @@ -78,7 +79,7 @@ def _find_overestimates( Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray ) -> tuple[xr.DataArray, bool]: R_curr = _get_residuals(Y, A, C) - return R_curr, R_curr.min() < 0 + return R_curr, R_curr.min() < -np.finfo(np.float32).eps def _align_overestimates( diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index fb6275b0..c01c0f93 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -73,7 +73,7 @@ def ingest_frame( if traces.zarr_path: updated_tr = updated_traces.volumize.dim_with_coords( - dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord] + dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord] ) traces.update(updated_tr, append_dim=AXIS.frames_dim) else: @@ -202,7 +202,6 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: :param traces: :param new_traces: Can be either a newly registered trace or an updated existing one. - :return: """ c = traces.full_array() @@ -216,7 +215,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: return traces if c.sizes[AXIS.frames_dim] > c_det.sizes[AXIS.frames_dim]: - # if newly detected cells are truncated + # if newly detected cells are truncated, pad with np.nans c_new = xr.DataArray( np.full((c_det.sizes[AXIS.component_dim], c.sizes[AXIS.frames_dim]), np.nan), dims=[AXIS.component_dim, AXIS.frames_dim], @@ -227,7 +226,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: c_new.loc[{AXIS.frames_dim: c_det[AXIS.frame_coord]}] = c_det else: - c_new = c_det.sel({AXIS.frame_coord: c[AXIS.frame_coord]}) + c_new = c_det merged_ids = c_det.attrs.get("replaces") if merged_ids: diff --git a/tests/test_assets.py b/tests/test_assets.py index d05da33b..2a4eccff 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -123,7 +123,7 @@ def test_buffer_speed(single_cell): iter = 100 for _ in range(iter): buff.append(movie.isel({AXIS.frames_dim: 0})) - buff.array + _ = buff.array result = (datetime.now() - start) / iter start = datetime.now() From 293b1ca3e1f64fcb0650dda839fb66037e57fff9 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 15:28:13 -0700 Subject: [PATCH 34/47] perf: optimize residual and overlaps updates --- src/cala/nodes/merge.py | 2 +- src/cala/nodes/residual.py | 9 +++++---- src/cala/util.py | 22 ++++++++++++---------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/cala/nodes/merge.py b/src/cala/nodes/merge.py index cd546714..0b90b122 100644 --- a/src/cala/nodes/merge.py +++ b/src/cala/nodes/merge.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr from noob import Name -from scipy.ndimage.filters import gaussian_filter1d +from scipy.ndimage import gaussian_filter1d from scipy.sparse.csgraph import connected_components from cala.assets import Footprints, Overlaps, Traces diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 8e91c2c4..f3984f38 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -148,11 +148,12 @@ def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: def _get_new_estimators_area( A: xr.DataArray, C: xr.DataArray, n_recalc: int ) -> xr.DataArray | None: - targets = C[AXIS.detect_coord] >= (C[AXIS.frame_coord].item() - n_recalc) + targets = C[AXIS.detect_coord].values >= (C[AXIS.frame_coord].item() - n_recalc) if any(targets): - target_ids = targets.where(targets)[AXIS.id_coord] - target_area = ~((A.where(target_ids, drop=True)).max(dim=AXIS.component_dim) > 0) - return target_area + target_coords = A.data[targets].nonzero()[1:] + target_area = np.ones(A.shape[1:]) + target_area[target_coords] = 0 + return xr.DataArray(target_area, dims=A.dims[1:]) else: return None diff --git a/src/cala/util.py b/src/cala/util.py index 0b6623ec..babeb776 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -3,8 +3,8 @@ from shutil import rmtree from uuid import uuid4 -import numpy as np import xarray as xr +from sparse import COO def create_id() -> str: @@ -37,13 +37,15 @@ def sp_matmul( :param right: """ - left = left.transpose(dim, ...) - right = left if right is None else right.transpose(dim, ...) + l = left.transpose(dim, ...).data.reshape((left.sizes[dim], -1)).tocsr() + if right is None: + right = left + r = l + else: + r = right.transpose(dim, ...).data.reshape((right.sizes[dim], -1)).tocsr() - val = np.matmul( - np.reshape(left.data, (left.sizes[dim], -1)), - np.reshape(right.data, (right.sizes[dim], -1)).T, - ) - return xr.DataArray(val, dims=[dim, f"{dim}'"], coords=left[dim].coords).assign_coords( - right[dim].rename(rename_map).coords - ) + val = l @ r.T + + return xr.DataArray( + COO.from_scipy_sparse(val), dims=[dim, f"{dim}'"], coords=left[dim].coords + ).assign_coords(right[dim].rename(rename_map).coords) From 389f63c5227a0b0264ec4b0704d927aba05d70d8 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 16:23:54 -0700 Subject: [PATCH 35/47] chore: rename to Tracer for easier perf check --- src/cala/nodes/traces.py | 2 +- tests/data/pipelines/odl.yaml | 2 +- tests/data/pipelines/with_src.yaml | 2 +- tests/test_iter/test_traces.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index c01c0f93..4484bd36 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -12,7 +12,7 @@ from cala.models import AXIS -class FrameUpdate(BaseModel): +class Tracer(BaseModel): tol: float max_iter: int diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index c03f7497..573ba8b6 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -68,7 +68,7 @@ nodes: # FRAME UPDATE BEGINS trace_frame: - type: cala.nodes.traces.FrameUpdate + type: cala.nodes.traces.Tracer params: tol: 0.001 max_iter: 100 diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 18a80ade..1879595e 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -109,7 +109,7 @@ nodes: # FRAME UPDATE BEGINS trace_frame: - type: cala.nodes.traces.FrameUpdate + type: cala.nodes.traces.Tracer params: tol: 0.001 max_iter: 100 diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 2f5979cd..ff471db0 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -12,7 +12,7 @@ def frame_update() -> Node: return Node.from_specification( spec=NodeSpecification( id="frame_test", - type="cala.nodes.traces.FrameUpdate", + type="cala.nodes.traces.Tracer", params={"max_iter": 100, "tol": 1e-4}, ) ) From 4bd678b927430c8a8b4c0af14ca7e36866bcfb55 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 16:24:10 -0700 Subject: [PATCH 36/47] perf: optimize residual - avoid numba bootup --- src/cala/nodes/residual.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index f3984f38..4f5a96a1 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -151,8 +151,11 @@ def _get_new_estimators_area( targets = C[AXIS.detect_coord].values >= (C[AXIS.frame_coord].item() - n_recalc) if any(targets): - target_coords = A.data[targets].nonzero()[1:] - target_area = np.ones(A.shape[1:]) + idx = np.where(targets)[0] + nonzeros = A.data.nonzero() + target_mask = np.isin(nonzeros[0], idx) + target_coords = tuple(nonzero[target_mask] for nonzero in nonzeros[1:]) + target_area = np.ones(A.shape[1:], dtype=bool) target_area[target_coords] = 0 return xr.DataArray(target_area, dims=A.dims[1:]) else: From 0aec0b1d4713817a36260539b622bcaa7d91d974 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 15 Oct 2025 17:49:04 -0700 Subject: [PATCH 37/47] perf: optimize footprint - avoid coo indexing --- src/cala/nodes/footprints.py | 18 ++++++++++-------- src/cala/nodes/traces.py | 4 +++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 7e67fff0..986fbfca 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -43,19 +43,21 @@ def ingest_frame( if footprints.array is None: return footprints - A = footprints.array.stack(pixel=AXIS.spatial_dims).transpose("pixel", ...) + A = footprints.array.transpose(AXIS.component_dim, ...) + A_arr = A.data.reshape((A.sizes[AXIS.component_dim], -1)).tocsc() M = component_stats.array - W = pixel_stats.array.stack(pixel=AXIS.spatial_dims).transpose(AXIS.component_dim, ...) + W = pixel_stats.array.transpose(AXIS.component_dim, ...) + W_arr = W.data.reshape((W.sizes[AXIS.component_dim], -1)) shapes, mask, _ = update_shapes( - CY=W.values, + CY=W_arr, CC=M.values, - Ab=A.data.tocsc(), - A_mask=[np.where(Ap.as_numpy() > 0)[0] for Ap in A.transpose(AXIS.component_dim, ...)], + Ab=A_arr.T.tocsc(), + A_mask=[Ap.nonzero()[0] for Ap in A_arr], ) - footprints.array = xr.DataArray(shapes.toarray(), dims=A.dims, coords=A.coords).unstack( - "pixel" + footprints.array = xr.DataArray( + shapes.T.toarray().reshape(A.shape), dims=A.dims, coords=A.coords ) return footprints @@ -93,7 +95,7 @@ def update_shapes( iters: int = 5, ) -> tuple[csc_matrix, list[np.ndarray], np.ndarray]: """ - :param CY: suff stats (pixel,) + :param CY: suff stats (comp, pixel) :param CC: suff stats (component), shape (comp, comp) :param Ab: shape matrix (sparse), shape (pixel, comp) :param A_mask: list of nonzero coordinates for each footprint list[(pixel,)] diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 4484bd36..11d62e2e 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -224,7 +224,9 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: c_new[AXIS.id_coord] = c_det[AXIS.id_coord] c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] - c_new.loc[{AXIS.frames_dim: c_det[AXIS.frame_coord]}] = c_det + c_new.loc[ + {AXIS.frames_dim: slice(c.sizes[AXIS.frames_dim] - c_det.sizes[AXIS.frames_dim], None)} + ] = c_det else: c_new = c_det From 14cc810e75d3763cc080f67b3c7a51bca43edfb7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 16 Oct 2025 16:26:35 -0700 Subject: [PATCH 38/47] perf: optimize unlayered footprint search --- src/cala/nodes/residual.py | 9 ++++++--- tests/test_iter/test_residual.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 4f5a96a1..e3a13812 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -140,9 +140,12 @@ def _align_overestimates( ) -def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: - A_layer_mask = (A > 0).sum(dim=AXIS.component_dim) - return A.where(A_layer_mask == 1, 0) +def _find_unlayered_footprints(A: COO) -> coo_matrix: + coords = A.nonzero()[1] + pixels, counts = np.unique(coords, return_counts=True) + mask = pixels[counts == 1] + vals = A.data[np.isin(coords, mask)] + return coo_matrix((vals, (np.zeros_like(mask), mask)), shape=(1, A.shape[1])) def _get_new_estimators_area( diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 8b306567..89fcdf25 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -92,6 +92,8 @@ def test_align_overestimates(single_cell) -> None: def test_find_exposed_footprints(connected_cells) -> None: - footprints = connected_cells.footprints - result = _find_unlayered_footprints(footprints.array) - assert result.sum(dim=AXIS.component_dim).max().item() == footprints.array.max().item() + footprints = connected_cells.footprints.array + result = _find_unlayered_footprints( + footprints.data.reshape((footprints.sizes[AXIS.component_dim], -1)) + ) + assert result.max().item() == footprints.max().item() From 13a07e49ff23fb295dc0265a1afaa5cfb4f90222 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 11:58:55 -0700 Subject: [PATCH 39/47] feat: footprint remove tiny pixels --- src/cala/nodes/footprints.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 986fbfca..682ecece 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -56,6 +56,9 @@ def ingest_frame( A_mask=[Ap.nonzero()[0] for Ap in A_arr], ) + # maybe this happens before footprint update? + shapes[shapes <= self.ratio_lb] = 0 + footprints.array = xr.DataArray( shapes.T.toarray().reshape(A.shape), dims=A.dims, coords=A.coords ) From 093a58a5465a18d33d11160e08f02b26ef33330b Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 11:59:18 -0700 Subject: [PATCH 40/47] feat: perf update on residual align_overestimate --- src/cala/nodes/residual.py | 71 ++++++++++---------------------------- 1 file changed, 19 insertions(+), 52 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index e3a13812..dacaefb8 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -3,6 +3,7 @@ import numpy as np import xarray as xr from noob import Name +from sparse import COO from cala.assets import Buffer, Footprints, Frame, Traces from cala.models import AXIS @@ -64,6 +65,7 @@ def build( if preserve_area is not None: residuals.array_ *= preserve_area.as_numpy() R_curr = _get_residuals(Y=Y, A=A, C=C) + # we're not fully modifying for the negative minimum, so we need to clip residuals.append(R_curr.clip(min=0)) return residuals @@ -86,66 +88,31 @@ def _align_overestimates( A: xr.DataArray, C_latest: xr.DataArray, R_latest: xr.DataArray ) -> xr.DataArray: """ - Gotta be able to do at least ONE OF splitoff or gradualon. - - Negative residuals just need to go. There isn't much you can do with the value...? - - Two cases: (A & B Overlapping) - 1. GradualOn: Know A. B turns ON - -> trace tries to chase (increases) - -> footprint tries to chase - -> residual becomes negative at A-B - -> should just decrease, positive at A^B - -> actually... should decrease (just more steeply) - - 2. SplitOff: Know AB. B turns OFF - -> trace tries to chase (decreases) - -> footprint tries to chase - -> residual becomes positive at A-B - -> should increase, negative at A^B - -> this should just decrease, MORE negative at B-A - -> this going to zero makes sense - OR - keep B, remove A-B - - R = Y - A @ C - - What about the past frame residuals after? - - for GradualOn, nothing should go to zero. - for SplitOff, a chunk needs to go to zero. - - So... how about we do something like (if it's been on for a long time, - we become less likely to purge it?) - - We subsequently clip R minimum to zero, since all significant negative residual spots - have been removed, and the remaining negative spots are noise level. - !!We're assuming there's no completely occluded component. This might be a problem eventually!! """ + A_pix = A.data.reshape((A.sizes[AXIS.component_dim], -1)).tocsr() + R = R_latest.values + unlayered_stamp = _find_unlayered_footprints(A_pix) # same up to here - unlayered_footprints = _find_unlayered_footprints(A) - # if unlayered_footprints.max(dim=AXIS.spatial_dims).min() == 0: - # raise ValueError("There are at least one completely occluded components.") + R_rel = unlayered_stamp * np.minimum(R, 0).reshape(1, -1) # same up to here to 2e-6 and nan + RA = A_pix.power(-1).multiply(R_rel).tocsr() # same up to here to 2e-6 and neginf + dC = RA.minimum(0).sum(axis=1) # .nanmin(axis=1, explicit=True) + # divide by the number of active pixels to normalize negative (prevents outliers) + dC_norm = np.asarray(dC).squeeze() / np.array([a.nnz for a in RA]) - R_rel = R_latest.where((R_latest < 0) * unlayered_footprints.max(dim=AXIS.component_dim)) - dC = ( - (R_rel / A) - .min(dim=AXIS.spatial_dims) - .reset_coords([AXIS.frame_coord, AXIS.timestamp_coord], drop=True) - ) - - return (C_latest + xr.apply_ufunc(np.nan_to_num, dC.as_numpy(), kwargs={"neginf": 0})).clip( - min=0 - ) + return (C_latest + dC_norm).clip(min=0) -def _find_unlayered_footprints(A: COO) -> coo_matrix: +def _find_unlayered_footprints(A: COO) -> np.ndarray: coords = A.nonzero()[1] pixels, counts = np.unique(coords, return_counts=True) - mask = pixels[counts == 1] - vals = A.data[np.isin(coords, mask)] - return coo_matrix((vals, (np.zeros_like(mask), mask)), shape=(1, A.shape[1])) + mask = np.isin(coords, pixels[counts == 1]) + vals = A.data[mask] + locs = coords[mask] + ret = np.zeros(A.shape[1]) + ret[locs] = vals + return ret + # return coo_matrix((vals, (np.zeros_like(mask), mask)), shape=(1, A.shape[1])).toarray() def _get_new_estimators_area( From 05de590d4fce36a831f62ed3b877b3dbee631975 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 12:02:40 -0700 Subject: [PATCH 41/47] feat: residual learns instead of one-shot flooring --- src/cala/nodes/residual.py | 1 - tests/test_iter/test_residual.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index dacaefb8..795cbb06 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -112,7 +112,6 @@ def _find_unlayered_footprints(A: COO) -> np.ndarray: ret = np.zeros(A.shape[1]) ret[locs] = vals return ret - # return coo_matrix((vals, (np.zeros_like(mask), mask)), shape=(1, A.shape[1])).toarray() def _get_new_estimators_area( diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 89fcdf25..de1a2d6d 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -68,14 +68,15 @@ def test_init(init, connected_cells) -> None: def test_align_overestimates(single_cell) -> None: """ grab the last frame of the residual. assume part of the footprint masked area is negative - traces needs to proportionally decrease, until the recalculated residual is zero + traces needs to proportionally decrease - Eventually, this probably can be absorbed straight into trace frame_ingest as a constraint. + Maybe this can be absorbed straight into trace frame_ingest as a constraint. """ movie = single_cell.make_movie() last_frame = movie.array.isel({AXIS.frames_dim: -1}) last_res = xr.zeros_like(last_frame) + # we have negative residuals last_res.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = -1 last_res = last_res.where(single_cell.footprints.array[0].to_numpy(), 0) @@ -85,10 +86,8 @@ def test_align_overestimates(single_cell) -> None: adjusted_traces = _align_overestimates(A=footprints, R_latest=last_res, C_latest=last_trace) - result = (footprints @ adjusted_traces).as_numpy().values - expected = movie.array.isel({AXIS.frames_dim: -2}).values - - np.testing.assert_array_equal(result, expected) + # adjusted to lower than last_trace + assert single_cell.traces.array.isel({AXIS.frames_dim: -2}) < adjusted_traces < last_trace def test_find_exposed_footprints(connected_cells) -> None: @@ -97,3 +96,8 @@ def test_find_exposed_footprints(connected_cells) -> None: footprints.data.reshape((footprints.sizes[AXIS.component_dim], -1)) ) assert result.max().item() == footprints.max().item() + + +@pytest.mark.xfail +def test_handle_outlier_pixel() -> None: + """a test to make sure an outlier pixel does not mess up the whole trace""" From f24f038deee7c65a4d01884583e86e40faa91710 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 12:55:45 -0700 Subject: [PATCH 42/47] feat: remove symmetry from toy footprints now we're moving away from xarray :( --- src/cala/testing/toy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index 62857df6..102d360a 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -136,6 +136,9 @@ def _generate_footprint( ) shape = disk(radius) + shape[:radius, :radius] *= 2 + shape[radius:, :radius] *= 3 + shape[radius:, radius:] *= 4 width_slice = slice(position.width - radius, position.width + radius + 1) height_slice = slice(position.height - radius, position.height + radius + 1) From 7815b4f1d9893e6bda274f490213ea37202f01ac Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 16:55:49 -0700 Subject: [PATCH 43/47] tests: remove hotpix in tests - its messing up the frame values with uint --- tests/data/pipelines/odl.yaml | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 573ba8b6..ac09e7a9 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -41,22 +41,14 @@ nodes: - index: counter.idx #PREPROCESS BEGINS - hotpix: - type: cala.nodes.prep.blur - params: - method: median - kwargs: - ksize: 3 - depends: - - frame: frame.value glow: type: cala.nodes.prep.GlowRemover depends: - - frame: hotpix.frame + - frame: frame.value size_est: type: cala.nodes.prep.SizeEst params: - hardset_radius: 6 + hardset_radius: 5 depends: - frame: glow.frame cache: From 556195b2ff8645f16182f78f685c91310387fbfc Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 18:27:37 -0700 Subject: [PATCH 44/47] tests: refit tests to work with new toy --- tests/test_iter/test_detect.py | 10 +++---- tests/test_iter/test_pixel_stats.py | 3 ++- tests/test_iter/test_residual.py | 5 +++- tests/test_pipeline.py | 4 +-- tests/test_prep/test_background_removal.py | 31 ---------------------- tests/test_prep/test_glow_removal.py | 2 +- tests/test_prep/test_r_estimate.py | 3 +++ 7 files changed, 17 insertions(+), 41 deletions(-) delete mode 100644 tests/test_prep/test_background_removal.py diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 7c9b6c82..3ffa0a63 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -17,7 +17,7 @@ def slice_nmf(): spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.0001}, + params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) @@ -167,11 +167,11 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): # we're forcing a double-detection in this test new_fps, new_trs = cataloger.process(fps, trs, Footprints(), Traces()) - result = (new_fps.array @ new_trs.array).transpose(AXIS.frames_dim, ...) - expected = movie.transpose(*result.dims) + result = (new_fps.array @ new_trs.array).transpose(AXIS.frames_dim, ...).as_numpy() + expected = movie.transpose(*result.dims).as_numpy() # not sure why we're getting some stray pixels... but we need to remove them - sig_pxls = new_fps.array.max(dim=AXIS.component_dim) > 0.1 + sig_pxls = (new_fps.array.max(dim=AXIS.component_dim) > 0.1).as_numpy() result, expected = result.where(sig_pxls), expected.where(sig_pxls) assert new_fps.array is not None @@ -181,7 +181,7 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): == 0 ) # 2. the trace and footprint values are accurate (where they do exist) - xr.testing.assert_allclose(result.as_numpy(), expected.as_numpy(), atol=1e-3) + xr.testing.assert_allclose(result, expected, atol=1) def test_rank1nmf(single_cell): diff --git a/tests/test_iter/test_pixel_stats.py b/tests/test_iter/test_pixel_stats.py index 92d6eced..8b412f07 100644 --- a/tests/test_iter/test_pixel_stats.py +++ b/tests/test_iter/test_pixel_stats.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import xarray as xr from noob.node import Node, NodeSpecification from cala.assets import Frame, Movie, PopSnap, Traces @@ -52,7 +53,7 @@ def test_ingest_frame(init, frame_update, separate_cells) -> None: ) expected = init.process(traces=separate_cells.traces, frames=separate_cells.make_movie()) - assert expected == result + xr.testing.assert_allclose(expected.array, result.array) @pytest.fixture(scope="function") diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index de1a2d6d..b3cb6576 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -83,8 +83,11 @@ def test_align_overestimates(single_cell) -> None: last_trace = single_cell.traces.array.isel({AXIS.frames_dim: -1}) footprints = single_cell.footprints.array + shapes_sparse = footprints.data.reshape((footprints.sizes[AXIS.component_dim], -1)).tocsr() - adjusted_traces = _align_overestimates(A=footprints, R_latest=last_res, C_latest=last_trace) + adjusted_traces = _align_overestimates( + A_pix=shapes_sparse, R_latest=last_res, C_latest=last_trace + ) # adjusted to lower than last_trace assert single_cell.traces.array.isel({AXIS.frames_dim: -2}) < adjusted_traces < last_trace diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 579450c6..aee597bb 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,7 +14,7 @@ "SeparateSource", "TwoOverlappingSource", "GradualOnSource", - "SplitOffSource", + # "SplitOffSource", ] ) def source(request): @@ -88,7 +88,7 @@ def test_odl(runner, source) -> None: # def test_with_src(): -# tube = Tube.from_specification("cala-with-ca1") +# tube = Tube.from_specification("cala-with-movie") # runner = SynchronousRunner(tube=tube) # runner.run() # diff --git a/tests/test_prep/test_background_removal.py b/tests/test_prep/test_background_removal.py deleted file mode 100644 index ccef05dd..00000000 --- a/tests/test_prep/test_background_removal.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any - -import numpy as np -import pytest - -from cala.nodes.prep import remove_background -from cala.testing.toy import FrameDims, Position, Toy - - -@pytest.mark.parametrize( - "params, sum", - [({"method": "uniform", "kernel_size": 2}, 16), ({"method": "tophat", "kernel_size": 4}, 29)], -) -def test_background_removal(params: dict[str, Any], sum) -> None: - """Test consistency of streaming background removal""" - toy = Toy( - n_frames=10, - frame_dims=FrameDims(width=11, height=11), - cell_radii=3, - cell_positions=[Position(width=5, height=5)], - cell_traces=[np.ones(10)], - emit_frames=True, - ) - - gen = toy.movie_gen() - - frame = next(gen) - - result = remove_background(frame, **params) - - assert result.array.sum() == sum diff --git a/tests/test_prep/test_glow_removal.py b/tests/test_prep/test_glow_removal.py index ec313208..777a0a41 100644 --- a/tests/test_prep/test_glow_removal.py +++ b/tests/test_prep/test_glow_removal.py @@ -23,7 +23,7 @@ def test_glow_removal(): expected_base = movie.array.min(dim=AXIS.frames_dim) res = [] - for frame, br in zip(iter(gen), [5, 4, 3, 2, 1, 1, 1, 1, 1, 1]): + for frame, br in zip(iter(gen), np.array([5, 4, 3, 2, 1, 1, 1, 1, 1, 1]) * 4): res.append(yeah_glo.process(frame)) assert yeah_glo.base_brightness_.max() == br diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py index ac45c6eb..6681ec0b 100644 --- a/tests/test_prep/test_r_estimate.py +++ b/tests/test_prep/test_r_estimate.py @@ -1,9 +1,12 @@ +import pytest + from cala.models import AXIS from cala.nodes.prep import package_frame from cala.nodes.prep.r_estimate import SizeEst from cala.testing.toy import Position +@pytest.mark.xfail def test_size_estim(separate_cells): kwargs = { "min_sigma": 1, From 15c8c2081e79f81f849ac16d49f6ee0138a15e37 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 18:28:47 -0700 Subject: [PATCH 45/47] feat: relax nmf acceptance condition for both l0 and l1 --- src/cala/nodes/detect/slice_nmf.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 0939acd7..91f439de 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -40,7 +40,7 @@ def process( if residuals.array.sizes[AXIS.frames_dim] < self.min_frames: return [], [] - energy = self._get_energy(residuals.array) # 0.008s + energy = self._get_energy(residuals.array) fps = [] trs = [] @@ -58,17 +58,19 @@ def process( spatial_sizes={k: v for k, v in res.sizes.items() if k in AXIS.spatial_dims}, ) - l1_norm = np.sum(slice_) + l1_norm = np.sum(slice_.values) + l1_error = self.error_ / l1_norm + l0_norm = np.prod(slice_.shape).astype(float) + l0_error = self.error_ / l0_norm energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 - if (self.error_ / l1_norm) <= self.reprod_tol: + if min(l1_error, l0_error) <= self.reprod_tol: fps.append(Footprint.from_array(a_new, sparsify=False)) trs.append(Trace.from_array(c_new)) - res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l1_norm + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = l1_error else: - l0_norm = np.prod(slice_.shape) - res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l0_norm + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = l0_error return fps, trs @@ -126,9 +128,16 @@ def _local_nmf( - Temporal component c_new (frames) """ # Reshape neighborhood to 2D matrix (time × space) - R = slice_.stack(space=AXIS.spatial_dims).transpose("space", AXIS.frames_dim) + R = ( + slice_.transpose(AXIS.frames_dim, ...) + .data.reshape((slice_.sizes[AXIS.frames_dim], -1)) + .T + ) + + mean_R = np.mean(R, axis=1) + # nan_mask = np.isnan(mean_R) - a, c, self.error_ = rank1nmf(R.values, np.mean(R.values, axis=1)) + a, c, self.error_ = rank1nmf(R, mean_R) # Convert back to xarray with proper dimensions and coordinates c_new = xr.DataArray( @@ -167,6 +176,9 @@ def rank1nmf( Ypx: (pixels, frames) ain: (pixels) + iters: valid only by period of 4 (seems like i mod 4 = 2 gives good results. + mod 4 = 3 is marginally better.) + """ eps = np.finfo(np.float32).eps for t in range(iters): From 82c46484b1e5718befae42c3551952c3d7744908 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 18:29:20 -0700 Subject: [PATCH 46/47] perf: optimize residual speed --- src/cala/nodes/residual.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 795cbb06..49fa8fd6 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr from noob import Name -from sparse import COO +from scipy.sparse import csr_matrix from cala.assets import Buffer, Footprints, Frame, Traces from cala.models import AXIS @@ -54,27 +54,28 @@ def build( Y = frame.array C = traces.array.isel({AXIS.frames_dim: -1}) # (components,) A = footprints.array + A_pix = ( + A.transpose(AXIS.component_dim, ...).data.reshape((A.sizes[AXIS.component_dim], -1)).tocsr() + ) - R_curr, flag = _find_overestimates(Y=Y, A=A, C=C) + R_curr, flag = _find_overestimates(Y=Y, A=A_pix, C=C) if flag: - C = _align_overestimates(A, C, R_curr) + C = _align_overestimates(A_pix=A_pix, C_latest=C, R_latest=R_curr) traces.array.loc[{AXIS.frames_dim: -1}] = C # if recently discovered, set to zero (or a small number). otherwise, just append preserve_area = _get_new_estimators_area(A=A, C=C, n_recalc=n_recalc) if preserve_area is not None: residuals.array_ *= preserve_area.as_numpy() - R_curr = _get_residuals(Y=Y, A=A, C=C) + R_curr = _get_residuals(Y=Y, A=A_pix, C=C) # we're not fully modifying for the negative minimum, so we need to clip residuals.append(R_curr.clip(min=0)) return residuals -def _get_residuals(Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray) -> xr.DataArray: - return Y - xr.DataArray( - np.matmul(A.transpose(*AXIS.spatial_dims, ...).data, C.data), dims=AXIS.spatial_dims - ) +def _get_residuals(Y: xr.DataArray, A: csr_matrix, C: xr.DataArray) -> xr.DataArray: + return Y - xr.DataArray((C.data @ A).reshape(Y.shape), dims=AXIS.spatial_dims) def _find_overestimates( @@ -85,12 +86,11 @@ def _find_overestimates( def _align_overestimates( - A: xr.DataArray, C_latest: xr.DataArray, R_latest: xr.DataArray + A_pix: csr_matrix, C_latest: xr.DataArray, R_latest: xr.DataArray ) -> xr.DataArray: """ !!We're assuming there's no completely occluded component. This might be a problem eventually!! """ - A_pix = A.data.reshape((A.sizes[AXIS.component_dim], -1)).tocsr() R = R_latest.values unlayered_stamp = _find_unlayered_footprints(A_pix) # same up to here @@ -103,7 +103,7 @@ def _align_overestimates( return (C_latest + dC_norm).clip(min=0) -def _find_unlayered_footprints(A: COO) -> np.ndarray: +def _find_unlayered_footprints(A: csr_matrix) -> np.ndarray: coords = A.nonzero()[1] pixels, counts = np.unique(coords, return_counts=True) mask = np.isin(coords, pixels[counts == 1]) From 1f668f362b413a57ec4243c6165c92658ee36f70 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 20 Oct 2025 18:38:56 -0700 Subject: [PATCH 47/47] format --- src/cala/assets.py | 1 + src/cala/nodes/detect/catalog.py | 4 ++++ src/cala/util.py | 8 ++++---- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 70002421..d648c04a 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -338,6 +338,7 @@ def array(self) -> xr.DataArray | None: out = self.array_.isel({AXIS.frames_dim: slice(self._next, self._next + self.size)}) else: out = self.array_.isel({AXIS.frames_dim: slice(None, self._next)}) + # kinda expensive. maybe float is fine? return out # .assign_coords({AXIS.frame_coord: out[AXIS.frame_coord].astype(int)}) @array.setter diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index c17ffc2d..fc5a12cc 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -26,6 +26,7 @@ class Cataloger(Node): merge_threshold: float val_threshold: float = Field(gt=0, lt=1) cnt_threshold: int = Field(gt=0) + """must have cnt-number of pixels that are above the val-value""" def process( self, @@ -46,9 +47,12 @@ def process( known_fp, known_tr = _get_absorption_targets(existing_fp, existing_tr, self.age_limit) merge_mat = self._merge_matrix(new_fps, new_trs, known_fp, known_tr) footprints, traces = self._absorb(new_fps, new_trs, known_fp, known_tr, merge_mat) + # footprints = self._smooth(shapes) return Footprints.from_array(footprints), Traces.from_array(traces) + def _smooth(self, shapes: xr.DataArray) -> xr.DataArray: ... + def _merge_matrix( self, fps: xr.DataArray, diff --git a/src/cala/util.py b/src/cala/util.py index babeb776..344036de 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -37,14 +37,14 @@ def sp_matmul( :param right: """ - l = left.transpose(dim, ...).data.reshape((left.sizes[dim], -1)).tocsr() + ll = left.transpose(dim, ...).data.reshape((left.sizes[dim], -1)).tocsr() if right is None: right = left - r = l + rr = ll else: - r = right.transpose(dim, ...).data.reshape((right.sizes[dim], -1)).tocsr() + rr = right.transpose(dim, ...).data.reshape((right.sizes[dim], -1)).tocsr() - val = l @ r.T + val = ll @ rr.T return xr.DataArray( COO.from_scipy_sparse(val), dims=[dim, f"{dim}'"], coords=left[dim].coords