From 5371279bb107114fcab98372ed72131654b86af2 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 18 Aug 2025 14:39:44 -0700 Subject: [PATCH 01/43] tests: fixture cleanup --- src/cala/testing/__init__.py | 10 +- src/cala/testing/nodes.py | 205 +++++++++++++++++++++++------------ tests/fixtures/__init__.py | 2 +- tests/fixtures/sims.py | 50 --------- tests/fixtures/toys.py | 16 +++ tests/test_pipeline.py | 41 +++---- 6 files changed, 184 insertions(+), 140 deletions(-) delete mode 100644 tests/fixtures/sims.py create mode 100644 tests/fixtures/toys.py diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index 1bad55fc..f69b35ad 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -1,3 +1,9 @@ -from .nodes import single_cell_source, two_cells_source, two_overlapping_source +from .nodes import ( + SingleCellSource, + TwoCellsSource, + TwoOverlappingSource, + SeparateSource, + ConnectedSource, +) -__all__ = [single_cell_source, two_cells_source, two_overlapping_source] +__all__ = [SingleCellSource, TwoCellsSource, TwoOverlappingSource, SeparateSource, ConnectedSource] diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 12431541..7daf1e8e 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -2,76 +2,145 @@ from typing import Annotated as A import numpy as np -from noob import Name +from noob import Name, process_method +from pydantic import BaseModel, PrivateAttr, model_validator from cala.assets import Frame from cala.testing.toy import FrameDims, Position, Toy -def single_cell_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames))] - if positions is None: - positions = [Position(width=256, height=256)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() - - -def two_cells_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])] - if positions is None: - positions = [Position(width=206, height=206), Position(width=306, height=306)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() - - -def two_overlapping_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])] - if positions is None: - positions = [Position(width=236, height=236), Position(width=276, height=276)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() +class MovieSource(BaseModel): + n_frames: int = 50 + frame_dims: FrameDims | dict[str, int] | None = None + cell_radii: int = 30 + positions: list[dict | Position] | None = None + _toy: Toy = PrivateAttr(None) + _traces: list[np.ndarray] = PrivateAttr(None) + + def _build_toy(self) -> Toy: + return Toy( + n_frames=self.n_frames, + frame_dims=self.frame_dims, + cell_radii=self.cell_radii, + cell_positions=self.positions, + cell_traces=self._traces, + ) + + @process_method + def process(self) -> Generator[A[Frame, Name("frame")]]: + yield from self._toy.movie_gen() + + +class SingleCellSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self._traces = [np.array(range(0, self.n_frames))] + + if self.positions is None: + self.positions = [Position(width=256, height=256)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class TwoCellsSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + + self._traces = [ + np.array(range(0, self.n_frames)), + np.array([0, *range(self.n_frames - 1, 0, -1)]), + ] + if self.positions is None: + self.positions = [Position(width=206, height=206), Position(width=306, height=306)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class TwoOverlappingSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self._traces = [ + np.array(range(0, self.n_frames)), + np.array([0, *range(self.n_frames - 1, 0, -1)]), + ] + + if self.positions is None: + self.positions = [Position(width=236, height=236), Position(width=276, height=276)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class SeparateSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.cell_radii = 3 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ] + self._traces = [ + np.zeros(self.n_frames, dtype=float), + np.ones(self.n_frames, dtype=float), + np.array(range(self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + + self._toy = self._build_toy() + return self + + +class ConnectedSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ] + self._traces = [ + np.zeros(self.n_frames, dtype=float), + np.ones(self.n_frames, dtype=float), + np.array(range(self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + + self._toy = self._build_toy() + return self diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index e08267e8..0bfc51ca 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -9,7 +9,7 @@ yaml_config, ) from .meta import monkeypatch_session -from .sims import connected_cells, separate_cells +from .toys import connected_cells, separate_cells __all__ = [ "monkeypatch_session", diff --git a/tests/fixtures/sims.py b/tests/fixtures/sims.py deleted file mode 100644 index ca6c0731..00000000 --- a/tests/fixtures/sims.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -import pytest - -from cala.testing.toy import FrameDims, Position, Toy - - -@pytest.fixture -def separate_cells() -> Toy: - n_frames = 50 - - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=3, - cell_positions=[ - Position(width=15, height=15), - Position(width=15, height=35), - Position(width=25, height=25), - Position(width=35, height=35), - ], - cell_traces=[ - np.zeros(n_frames, dtype=float), - np.ones(n_frames, dtype=float), - np.array(range(n_frames), dtype=float), - np.array(range(n_frames - 1, -1, -1), dtype=float), - ], - ) - - -@pytest.fixture -def connected_cells() -> Toy: - n_frames = 50 - - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=8, - cell_positions=[ - Position(width=15, height=15), - Position(width=15, height=35), - Position(width=25, height=25), - Position(width=35, height=35), - ], - cell_traces=[ - np.zeros(n_frames, dtype=float), - np.ones(n_frames, dtype=float), - np.array(range(n_frames), dtype=float), - np.array(range(n_frames - 1, -1, -1), dtype=float), - ], - ) diff --git a/tests/fixtures/toys.py b/tests/fixtures/toys.py new file mode 100644 index 00000000..b81c4284 --- /dev/null +++ b/tests/fixtures/toys.py @@ -0,0 +1,16 @@ +import pytest + +from cala.testing.nodes import SeparateSource, ConnectedSource +from cala.testing.toy import Toy + + +@pytest.fixture +def separate_cells() -> Toy: + source = SeparateSource() + return source._toy + + +@pytest.fixture +def connected_cells() -> Toy: + source = ConnectedSource() + return source._toy diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1014d9e2..1337ecca 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,13 +6,19 @@ from cala.models import AXIS -@pytest.fixture(params=["single_cell_source", "two_cells_source", "two_overlapping_source"]) +@pytest.fixture( + params=[ + "SingleCellSource", + "TwoCellsSource", + "TwoOverlappingSource", + "SeparateSource", + "ConnectedSource", + ] +) def tube(request): tube = Tube.from_specification("cala-odl") source = Node.from_specification( - NodeSpecification( - id="source", type=f"cala.testing.{request.param}", params={"n_frames": 50} - ) + NodeSpecification(id="source", type=f"cala.testing.{request.param}") ) tube.nodes["source"] = source @@ -38,7 +44,7 @@ def test_process(runner) -> None: def test_iter(runner) -> None: - gen = runner.iter(n=runner.tube.nodes["source"].spec.params["n_frames"]) + gen = runner.iter(n=runner.tube.nodes["source"].instance.n_frames) movie = [] for _, exp in enumerate(gen): @@ -49,28 +55,25 @@ def test_iter(runner) -> None: expected = xr.concat(movie, dim=AXIS.frames_dim) result = (fps.array @ trs.array).transpose(*expected.dims) - if runner.tube.nodes["source"].fn.__name__ == "two_overlapping_source": + src_node = runner.tube.nodes["source"].spec.type_.split(".")[-1] + + if src_node == "TwoOverlappingSource": diff = expected - result for d_fr, e_fr in zip(diff, expected): assert d_fr.max() <= e_fr.quantile(0.98) * 2e-2 else: xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) + n_discoverable = { + "SingleCellSource": 1, + "TwoCellsSource": 2, + "TwoOverlappingSource": 2, + "SeparateSource": 2, + } + assert fps.array.sizes[AXIS.component_dim] == n_discoverable[src_node] + -@pytest.mark.xfail def test_run(runner) -> None: result = runner.run(n=5) assert result - - -@pytest.mark.xfail -def test_combined_footprint() -> None: - """Start with two footprints combined""" - raise AssertionError("Not implemented") - - -@pytest.mark.xfail -def test_redundant_footprint() -> None: - """start with redundant footprints""" - raise AssertionError("Not implemented") From 5f0ef806f304326ec8c4018966b407c71b3901ac Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 18 Aug 2025 14:39:55 -0700 Subject: [PATCH 02/43] env: update noob --- pdm.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pdm.lock b/pdm.lock index 59f66f42..ba0b54db 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1811,11 +1811,11 @@ files = [ [[package]] name = "noob" -version = "0.1.1.dev119" +version = "0.1.1.dev121" requires_python = ">=3.11" git = "https://github.com/miniscope/noob.git" ref = "37-tube-resources-for-data-shared-between-nodes" -revision = "5d585f6a09f9cf9f1bdbd383ee99f7fdbfa1a1f1" +revision = "fd131c7704c0485baadb7a59d5d714b529f09d58" summary = "Default template for PDM package" dependencies = [ "PyYAML>=6.0.2", From 162ab9a2aa70e1f965a67caa2a9a1674ddc71451 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 18 Aug 2025 14:40:49 -0700 Subject: [PATCH 03/43] debug: small errors --- src/cala/nodes/cleanup.py | 4 ++-- src/cala/nodes/prep/r_estimate.py | 3 +++ src/cala/nodes/prep/rigid_stabilization.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 220e3489..63196499 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -101,13 +101,13 @@ def filter_components( comp_stats.array.set_xindex(AXIS.id_coord) .set_xindex(f"{AXIS.id_coord}'") .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()}) - .reset_index(AXIS.id_coord) + .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) ) overlaps.array = ( overlaps.array.set_xindex(AXIS.id_coord) .set_xindex(f"{AXIS.id_coord}'") .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()}) - .reset_index(AXIS.id_coord) + .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) ) return footprints, traces, pix_stats, comp_stats, overlaps diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index e83335ea..7fb0f4f1 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -32,6 +32,9 @@ def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: return self._est_radius blobs = blob_log(frame.array, **self.log_kwargs) + if blobs.size == 0: + return 0 + self.centers_ = [blobs[:-1] for blobs in blobs] self.sizes_ += [blob[-1].item() for blob in blobs] self._est_radius = int(np.round(np.median(self.sizes_)).item()) diff --git a/src/cala/nodes/prep/rigid_stabilization.py b/src/cala/nodes/prep/rigid_stabilization.py index c1052ab9..0718e608 100644 --- a/src/cala/nodes/prep/rigid_stabilization.py +++ b/src/cala/nodes/prep/rigid_stabilization.py @@ -131,7 +131,7 @@ def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: curr_index = frame[AXIS.frame_coord].item() - return (self.anchor_frame_.array * curr_index + frame) / curr_index + 1 + return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) def process(self, frame: Frame) -> A[Frame, Name("frame")]: if self.is_first_frame(frame): From cfebd1a8d858be4c6a8ff1406b53085bfb122a2a Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 18 Aug 2025 14:41:06 -0700 Subject: [PATCH 04/43] tests: add cell size est --- tests/data/pipelines/odl.yaml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index d84f0301..8895d7cb 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -30,9 +30,7 @@ assets: nodes: source: - type: cala.testing.single_cell_source - params: - n_frames: 30 + type: cala.testing.SingleCellSource denoise: type: cala.nodes.prep.denoise params: @@ -49,13 +47,19 @@ nodes: motion: type: cala.nodes.prep.RigidStabilizer params: - drift_speed: 1.0 + drift_speed: 0.5 depends: - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: - hardset_radius: 10 + # hardset_radius: 10 + log_kwargs: + min_sigma: 1 + max_sigma: 10 + num_sigma: 10 + threshold: 0.2 + overlap: 0.5 depends: - frame: motion.frame cache: From ac24b743ff1bd94f9335885010f9bd431ba4c131 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 18 Aug 2025 17:07:09 -0700 Subject: [PATCH 05/43] tests: try gradual on test --- src/cala/testing/__init__.py | 10 +++++- src/cala/testing/nodes.py | 41 +++++++++++++++++++++++ tests/data/pipelines/odl.yaml | 2 +- tests/fixtures/toys.py | 8 ++++- tests/test_iter/test_detect.py | 60 ++++++++++++++-------------------- tests/test_pipeline.py | 2 +- 6 files changed, 84 insertions(+), 39 deletions(-) diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index f69b35ad..a552a45f 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -4,6 +4,14 @@ TwoOverlappingSource, SeparateSource, ConnectedSource, + GradualOnSource, ) -__all__ = [SingleCellSource, TwoCellsSource, TwoOverlappingSource, SeparateSource, ConnectedSource] +__all__ = [ + "SingleCellSource", + "TwoCellsSource", + "TwoOverlappingSource", + "SeparateSource", + "ConnectedSource", + "GradualOnSource", +] diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 7daf1e8e..5d2a7834 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -144,3 +144,44 @@ def complete_model(self): self._toy = self._build_toy() return self + + +class GradualOnSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): + self.n_frames = 100 + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + Position(width=35, height=15), + ] + gap = 20 + decr = np.array(range(self.n_frames - 1, 0, -1), dtype=float) + sine = np.abs(np.sin(np.linspace(0, 2 * np.pi, self.n_frames - gap)) * self.n_frames) + incr = np.array(range(self.n_frames - gap * 2), dtype=float) + rand = np.random.random(self.n_frames - gap * 3) * self.n_frames + const = np.ones(self.n_frames - gap * 4, dtype=float) * self.n_frames + + self._traces = [ + np.pad(decr, (1, 0), mode="constant", constant_values=0), + np.pad(sine, (gap, 0), mode="constant", constant_values=0), + np.pad(incr, (gap * 2, 0), mode="constant", constant_values=0), + np.pad(rand, (gap * 3, 0), mode="constant", constant_values=0), + np.pad(const, (gap * 4, 0), mode="constant", constant_values=0), + ] + + self._toy = self._build_toy() + return self + + +class SplitOffSource(MovieSource): + @model_validator(mode="after") + def complete_model(self): ... diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 8895d7cb..4934d601 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -55,7 +55,7 @@ nodes: params: # hardset_radius: 10 log_kwargs: - min_sigma: 1 + min_sigma: 3 max_sigma: 10 num_sigma: 10 threshold: 0.2 diff --git a/tests/fixtures/toys.py b/tests/fixtures/toys.py index b81c4284..31141b78 100644 --- a/tests/fixtures/toys.py +++ b/tests/fixtures/toys.py @@ -1,9 +1,15 @@ import pytest -from cala.testing.nodes import SeparateSource, ConnectedSource +from cala.testing import SingleCellSource, SeparateSource, ConnectedSource from cala.testing.toy import Toy +@pytest.fixture +def single_cell() -> Toy: + source = SingleCellSource() + return source._toy + + @pytest.fixture def separate_cells() -> Toy: source = SeparateSource() diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 93a00e7e..6b10cad3 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -5,29 +5,9 @@ from cala.assets import AXIS, Footprints, Residual, Traces from cala.nodes.detect import Cataloger, SliceNMF -from cala.testing.toy import FrameDims, Position, Toy from cala.testing.util import assert_scalar_multiple_arrays -@pytest.fixture(autouse=True, scope="module") -def toy(): - n_frames = 30 - frame_dims = FrameDims(width=512, height=512) - cell_positions = [Position(width=256, height=256)] - cell_radii = 30 - cell_traces = [np.array(range(n_frames), dtype=float)] - confidences = [0.8] - - return Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=cell_positions, - cell_traces=cell_traces, - confidences=confidences, - ) - - @pytest.fixture(scope="class") def slice_nmf(): return SliceNMF.from_specification( @@ -49,20 +29,20 @@ def cataloger(): class TestSliceNMF: - def test_process(self, slice_nmf, toy): + def test_process(self, slice_nmf, single_cell): new_component = slice_nmf.process( - Residual.from_array(toy.make_movie().array), - detect_radius=toy.cell_radii[0] * 2, + Residual.from_array(single_cell.make_movie().array), + detect_radius=single_cell.cell_radii[0] * 2, ) if new_component: new_fp, new_tr = new_component else: raise AssertionError("Failed to detect a new component") - for new, old in zip([new_fp[0], new_tr[0]], [toy.footprints, toy.traces]): + for new, old in zip([new_fp[0], new_tr[0]], [single_cell.footprints, single_cell.traces]): assert_scalar_multiple_arrays(new.array, old.array) - def test_chunks(self, toy): + def test_chunks(self, single_cell): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", @@ -70,26 +50,30 @@ def test_chunks(self, toy): params={"min_frames": 10, "detect_thresh": 1}, ) ) - fpts, trcs = nmf.process(Residual.from_array(toy.make_movie().array), detect_radius=10) + fpts, trcs = nmf.process( + Residual.from_array(single_cell.make_movie().array), detect_radius=10 + ) if not fpts or not trcs: raise AssertionError("Failed to detect a new component") fpt_arr = xr.concat([f.array for f in fpts], dim="component") - expected = toy.footprints.array[0] + expected = single_cell.footprints.array[0] result = (fpt_arr.sum(dim="component") > 0).astype(int) assert np.array_equal(expected, result) for trc in trcs: - assert_scalar_multiple_arrays(trc.array, toy.traces.array[0]) + assert_scalar_multiple_arrays(trc.array, single_cell.traces.array[0]) class TestCataloger: @pytest.fixture(scope="function") - def new_component(self, slice_nmf, toy): - return slice_nmf.process(Residual.from_array(toy.make_movie().array), detect_radius=60) + def new_component(self, slice_nmf, single_cell): + return slice_nmf.process( + Residual.from_array(single_cell.make_movie().array), detect_radius=60 + ) - def test_register(self, cataloger, new_component, toy): + def test_register(self, cataloger, new_component): new_fp, new_tr = new_component fp, tr = cataloger._register( new_fp=new_fp[0].array, @@ -99,14 +83,18 @@ def test_register(self, cataloger, new_component, toy): assert np.array_equal(fp.array, new_fp[0].array) assert np.array_equal(tr.array, new_tr[0].array) - def test_merge_with(self, slice_nmf, cataloger, toy): + def test_merge_with(self, slice_nmf, cataloger, single_cell): new_component = slice_nmf.process( - Residual.from_array(toy.make_movie().array), detect_radius=10 + Residual.from_array(single_cell.make_movie().array), detect_radius=10 ) new_fp, new_tr = new_component fp, tr = cataloger._merge_with( - new_fp[0].array, new_tr[0].array, toy.footprints.array, toy.traces.array, ["cell_0"] + new_fp[0].array, + new_tr[0].array, + single_cell.footprints.array, + single_cell.traces.array, + ["cell_0"], ) movie_result = (fp.array @ tr.array).reset_coords( @@ -114,7 +102,9 @@ def test_merge_with(self, slice_nmf, cataloger, toy): ) movie_new_comp = new_fp[0].array @ new_tr[0].array - movie_expected = (toy.make_movie().array + movie_new_comp).transpose(*movie_result.dims) + movie_expected = (single_cell.make_movie().array + movie_new_comp).transpose( + *movie_result.dims + ) xr.testing.assert_allclose(movie_result, movie_expected) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1337ecca..34b6a037 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,7 +12,7 @@ "TwoCellsSource", "TwoOverlappingSource", "SeparateSource", - "ConnectedSource", + "GradualOnSource", ] ) def tube(request): From 69c0aa6c27d47bebe07fe604a6f7dae0a3d4e8f5 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:16:01 -0700 Subject: [PATCH 06/43] feat: swap confident score with detected_on frame idx --- src/cala/models/axis.py | 10 +++++----- src/cala/nodes/detect/catalog.py | 18 +++++++++++------- src/cala/nodes/overlap.py | 2 +- src/cala/nodes/traces.py | 4 ++-- src/cala/testing/toy.py | 4 ++-- tests/test_iter/test_detect.py | 2 +- 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py index ace438ad..0b5967d5 100644 --- a/src/cala/models/axis.py +++ b/src/cala/models/axis.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -from cala.models.checks import is_unique, is_unit_interval +from cala.models.checks import is_unique, is_unit_interval, has_no_nan class Axis: @@ -17,7 +17,7 @@ class Axis: id_coord: str = "id_" timestamp_coord: str = "timestamp" - confidence_coord: str = "confidence" + detect_coord: str = "detected_on" frame_coord: str = "frame" width_coord: str = "width" height_coord: str = "height" @@ -37,7 +37,7 @@ def component_rename(self) -> dict[str, str]: return { AXIS.component_dim: f"{AXIS.component_dim}'", AXIS.id_coord: f"{AXIS.id_coord}'", - AXIS.confidence_coord: f"{AXIS.confidence_coord}'", + AXIS.detect_coord: f"{AXIS.detect_coord}'", } @@ -62,11 +62,11 @@ class Coords(Enum): width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique]) frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique]) timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique]) - confidence = Coord(name=AXIS.confidence_coord, dtype=float, checks=[is_unit_interval]) + detected = Coord(name=AXIS.detect_coord, dtype=int, checks=[has_no_nan]) class Dims(Enum): width = Dim(name=AXIS.width_dim, coords=[Coords.width.value]) height = Dim(name=AXIS.height_dim, coords=[Coords.height.value]) frame = Dim(name=AXIS.frames_dim, coords=[Coords.frame.value, Coords.timestamp.value]) - component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.confidence.value]) + component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value]) diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index aacc4cb1..ca7d816b 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -72,21 +72,19 @@ def process( footprints = xr.concat( footprints, dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.confidence_coord], + coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) traces = xr.concat( traces, dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.confidence_coord], + coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) return Footprints.from_array(footprints), Traces.from_array(traces) - def _register( - self, new_fp: xr.DataArray, new_tr: xr.DataArray, confidence: float = 0.0 - ) -> tuple[Footprint, Trace]: + def _register(self, new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Trace]: new_id = create_id() @@ -95,7 +93,10 @@ def _register( .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: ( + AXIS.component_dim, + [new_tr[AXIS.frame_coord].max().item()], + ), } ) .isel({AXIS.component_dim: 0}) @@ -105,7 +106,10 @@ def _register( .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: ( + AXIS.component_dim, + [new_tr[AXIS.frame_coord].max().item()], + ), } ) .isel({AXIS.component_dim: 0}) diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/overlap.py index 454d3d97..a04ea94d 100644 --- a/src/cala/nodes/overlap.py +++ b/src/cala/nodes/overlap.py @@ -39,7 +39,7 @@ def ingest_component( V = overlaps.array a_new = new_footprints.array.volumize.dim_with_coords( - dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord] + dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.detect_coord] ) if a_new[AXIS.id_coord].item() in V[AXIS.id_coord].values: diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 4964af6b..a6f0cc74 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -31,7 +31,7 @@ def initialize(self, footprints: Footprints, frames: Movie) -> Traces: trace_coords = [ AXIS.id_coord, - AXIS.confidence_coord, + AXIS.detect_coord, AXIS.frame_coord, AXIS.timestamp_coord, ] @@ -226,7 +226,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: coords=c.isel({AXIS.component_dim: 0}).coords, ) c_new[AXIS.id_coord] = c_det[AXIS.id_coord] - c_new[AXIS.confidence_coord] = c_det[AXIS.confidence_coord] + c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] c_new.loc[{AXIS.frames_dim: c_det[AXIS.frame_coord]}] = c_det else: diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index b5c48d7c..2130c12d 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -141,7 +141,7 @@ def _generate_footprint( return footprint.expand_dims(AXIS.component_dim).assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [confidence]), **{ax: footprint[ax] for ax in AXIS.spatial_dims}, } ) @@ -165,7 +165,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.Da .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [confidence]), AXIS.frames_dim: range(trace.size), } ) diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 6b10cad3..88587789 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -98,7 +98,7 @@ def test_merge_with(self, slice_nmf, cataloger, single_cell): ) movie_result = (fp.array @ tr.array).reset_coords( - [AXIS.id_coord, AXIS.confidence_coord], drop=True + [AXIS.id_coord, AXIS.detect_coord], drop=True ) movie_new_comp = new_fp[0].array @ new_tr[0].array From ab17b390465a8bf121affb9fea2220ac1a72083d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:17:58 -0700 Subject: [PATCH 07/43] feat: max_iter implementation for footprints and traces --- src/cala/nodes/footprints.py | 13 ++++++++++-- src/cala/nodes/traces.py | 40 +++++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index c4ec8228..4cf9d511 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -7,9 +7,11 @@ from cala.assets import CompStats, Footprints, PixStats from cala.models import AXIS +from cala.logging import init_logger class Footprinter: + logger = init_logger(__name__) def __init__(self, tol: float, max_iter: int | None = None, bep: int | None = None): self.bep = bep @@ -55,7 +57,7 @@ def ingest_frame( kernel = self._expansion_kernel() mask = self._expand_boundary(kernel, mask) - # for _ in range(self.max_iter): + cnt = 0 while True: AM = A.rename(AXIS.component_rename) @ M numerator = W - AM @@ -72,8 +74,15 @@ def ingest_frame( update = numerator / (M_diag + np.finfo(float).tiny) A_new = (mask * (A + update)).clip(min=0) - if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol: + step = (np.abs(A - A_new).sum() / np.prod(A.shape)).item() + + cnt += 1 + maxed = self.max_iter and (cnt == self.max_iter) + + if step < self.tol or maxed: footprints.array = A_new + if maxed: + self.logger.debug(msg="max_iter reached before converging.") return footprints else: A = A_new diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index a6f0cc74..b56f3ba7 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -7,6 +7,7 @@ from scipy.sparse.csgraph import connected_components from cala.assets import Footprints, Frame, Movie, Overlaps, PopSnap, Traces +from cala.logging import init_logger from cala.models import AXIS @@ -83,8 +84,11 @@ def _fast_nnls_vector(A: np.ndarray, B: np.ndarray) -> np.ndarray: class FrameUpdate: - def __init__(self, tolerance: float = 1e-3) -> None: - self.tolerance = tolerance + logger = init_logger(__name__) + + def __init__(self, tol: float, max_iter: int | None = None) -> None: + self.tol = tol + self.max_iter = max_iter @process_method def ingest_frame( @@ -131,7 +135,7 @@ def ingest_frame( ) clusters = [np.where(labels == label)[0] for label in np.unique(labels)] - updated_traces = self._update_traces(A, y, c.copy(), clusters, self.tolerance) + updated_traces = self._update_traces(A, y, c.copy(), clusters) traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) @@ -143,14 +147,12 @@ def _update_traces( y: xr.DataArray, c: xr.DataArray, clusters: list[np.ndarray], - eps: float, ) -> xr.DataArray: """ Implementation of the temporal traces update algorithm. - This function implements the core update logic. It uses block coordinate descent - to update temporal traces for overlapping components together while maintaining - non-negativity constraints. + This function uses block coordinate descent to update temporal traces + for overlapping components together while maintaining non-negativity constraints. Args: A (xr.DataArray): Spatial footprints matrix [A, b]. @@ -176,14 +178,13 @@ def _update_traces( # Step 3: Extract diagonal elements for normalization V_diag = np.diag(V) - # Step 4: Initialize previous iteration value - c_old = np.zeros_like(c) + cnt = 0 - # Steps 5-10: Main iteration loop until convergence - while np.linalg.norm(c - c_old) >= eps * np.linalg.norm(c_old): + # Steps 4-9: Main iteration loop until convergence + while True: c_old = c.copy() - # Steps 7-9: Update each group using block coordinate descent + # Steps 6-8: Update each group using block coordinate descent for cluster in clusters: # Update traces for current group (division is pointwise) numerator = u.isel({AXIS.component_dim: cluster}) - ( @@ -191,13 +192,18 @@ def _update_traces( ).rename({f"{AXIS.component_dim}'": AXIS.component_dim}) c.loc[{AXIS.component_dim: cluster}] = np.maximum( - c.isel({AXIS.component_dim: cluster}) + numerator / V_diag[cluster].T, - 0, + c.isel({AXIS.component_dim: cluster}) + numerator / V_diag[cluster].T, 0 ) - return xr.DataArray( - c.values, dims=c.dims, coords=c[AXIS.component_dim].coords - ).assign_coords(y[AXIS.frames_dim].coords) + cnt += 1 + maxed = self.max_iter and (cnt == self.max_iter) + + if np.linalg.norm(c - c_old) >= self.tol * np.linalg.norm(c_old) or maxed: + if maxed: + self.logger.debug(msg="max_iter reached before converging.") + return xr.DataArray( + c.values, dims=c.dims, coords=c[AXIS.component_dim].coords + ).assign_coords(y[AXIS.frames_dim].coords) def ingest_component(traces: Traces, new_traces: Traces) -> Traces: From 57d2a5c7470681819551232a3c9989c0c2ff17e6 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:18:53 -0700 Subject: [PATCH 08/43] feat: catalog merge with "touching" components too --- src/cala/nodes/detect/catalog.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index ca7d816b..684ee16b 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -1,6 +1,7 @@ from collections.abc import Hashable, Iterable from typing import Annotated as A +import cv2 import numpy as np import xarray as xr from noob import Name @@ -215,8 +216,21 @@ def _merge_matrix( fps_base = fps_base.rename(AXIS.component_rename) trs_base = trs_base.rename(AXIS.component_rename) + fps = self._expand_boundary(fps > 0) + overlaps = fps @ fps_base > 0 # this should later reflect confidence corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold return overlaps * corrs + + def _expand_boundary(self, mask: xr.DataArray) -> xr.DataArray: + kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) + return xr.apply_ufunc( + lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1), + mask.astype(np.uint8), + input_core_dims=[AXIS.spatial_dims], + output_core_dims=[AXIS.spatial_dims], + vectorize=True, + dask="parallelized", + ) From 902a9e72b7f9b23d0335a4671e842bbbbcddaf9c Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:19:49 -0700 Subject: [PATCH 09/43] feat: gradual-on source has more realistic, smooth sources --- src/cala/testing/nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 5d2a7834..465afd6d 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -167,15 +167,19 @@ def complete_model(self): decr = np.array(range(self.n_frames - 1, 0, -1), dtype=float) sine = np.abs(np.sin(np.linspace(0, 2 * np.pi, self.n_frames - gap)) * self.n_frames) incr = np.array(range(self.n_frames - gap * 2), dtype=float) - rand = np.random.random(self.n_frames - gap * 3) * self.n_frames - const = np.ones(self.n_frames - gap * 4, dtype=float) * self.n_frames + expo = ( + np.linspace(0, np.exp(3), self.n_frames - gap * 3) + * np.exp(-np.linspace(0, np.exp(2), self.n_frames - gap * 3)) + * self.n_frames + ) + tanh = np.tanh(np.linspace(0, 5, self.n_frames - gap * 4)) * self.n_frames self._traces = [ np.pad(decr, (1, 0), mode="constant", constant_values=0), np.pad(sine, (gap, 0), mode="constant", constant_values=0), np.pad(incr, (gap * 2, 0), mode="constant", constant_values=0), - np.pad(rand, (gap * 3, 0), mode="constant", constant_values=0), - np.pad(const, (gap * 4, 0), mode="constant", constant_values=0), + np.pad(expo, (gap * 3, 0), mode="constant", constant_values=0), + np.pad(tanh, (gap * 4, 0), mode="constant", constant_values=0), ] self._toy = self._build_toy() From 4226a4c38af476c5cc50254617f1e50f1c562d95 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:28:50 -0700 Subject: [PATCH 10/43] debug: energy is frobenuis-normed instead of summing V -> energy reflects actual brightness value --- src/cala/nodes/detect/slice_nmf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index edb11478..303cd450 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -43,7 +43,7 @@ def process( fps = [] trs = [] - while np.sqrt(energy.max()).item() > self.detect_thresh: # or use res directly + while energy.max().item() >= self.detect_thresh: # or use res directly # Find and analyze neighborhood of maximum variance slice_ = self._get_max_energy_slice( arr=res, energy_landscape=energy, radius=detect_radius @@ -65,7 +65,7 @@ def process( trs.append(Trace.from_array(c_new)) res = (res - comp_recon).clip(0) else: - res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 1e-7 return fps, trs @@ -73,7 +73,7 @@ def _get_energy(self, res: xr.DataArray) -> xr.DataArray: pixels_median = res.median(dim=AXIS.frames_dim) V = res - pixels_median - return (V**2).sum(dim=AXIS.frames_dim) + return np.sqrt((V**2).mean(dim=AXIS.frames_dim)) def _get_max_energy_slice( self, From 1aecde581b3e70cc00ffc0b7ab2a5a9f0de2d710 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 01:29:21 -0700 Subject: [PATCH 11/43] test: omit movie stabil temporarily its causing issues --- tests/data/pipelines/odl.yaml | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 4934d601..abc21fc7 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -44,12 +44,12 @@ nodes: type: cala.nodes.prep.GlowRemover depends: - frame: denoise.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 0.5 - depends: - - frame: glow.frame +# motion: +# type: cala.nodes.prep.RigidStabilizer +# params: +# drift_speed: 0.5 +# depends: +# - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: @@ -61,41 +61,42 @@ nodes: threshold: 0.2 overlap: 0.5 depends: - - frame: motion.frame + - frame: glow.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: motion.frame + - frame: glow.frame trace_frame: type: cala.nodes.traces.FrameUpdate params: - tolerance: 0.001 + tol: 0.001 + max_iter: 100 depends: - traces: assets.traces - footprints: assets.footprints - - frame: motion.frame + - frame: glow.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: motion.frame + - frame: glow.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: motion.frame + - frame: glow.frame - new_traces: trace_frame.latest_trace residual: type: cala.nodes.residual.build params: - clip_threshold: 0.001 + clip_threshold: 0.1 depends: - trigger: trace_frame.latest_trace - frames: assets.buffer @@ -161,7 +162,8 @@ nodes: type: cala.nodes.footprints.Footprinter params: bep: 1 - tol: 0.0000001 + tol: 0.0001 + max_iter: 100 depends: - footprints: assets.footprints - pixel_stats: pix_component.value @@ -176,4 +178,4 @@ nodes: return: type: return depends: - - motion.frame \ No newline at end of file + - glow.frame \ No newline at end of file From 6673f9a62f0511936e64cb37c788aed06b6c5b51 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 02:21:06 -0700 Subject: [PATCH 12/43] format: ruff --- src/cala/models/axis.py | 2 +- src/cala/nodes/footprints.py | 2 +- src/cala/testing/__init__.py | 6 +++--- src/cala/testing/nodes.py | 15 ++++++++------- tests/fixtures/toys.py | 2 +- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py index 0b5967d5..e4222955 100644 --- a/src/cala/models/axis.py +++ b/src/cala/models/axis.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -from cala.models.checks import is_unique, is_unit_interval, has_no_nan +from cala.models.checks import has_no_nan, is_unique class Axis: diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 4cf9d511..f04f9481 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -6,8 +6,8 @@ from noob import Name, process_method from cala.assets import CompStats, Footprints, PixStats -from cala.models import AXIS from cala.logging import init_logger +from cala.models import AXIS class Footprinter: diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index a552a45f..b32c0a73 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -1,10 +1,10 @@ from .nodes import ( + ConnectedSource, + GradualOnSource, + SeparateSource, SingleCellSource, TwoCellsSource, TwoOverlappingSource, - SeparateSource, - ConnectedSource, - GradualOnSource, ) __all__ = [ diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 465afd6d..6395028e 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -1,5 +1,6 @@ from collections.abc import Generator from typing import Annotated as A +from typing import Self import numpy as np from noob import Name, process_method @@ -33,7 +34,7 @@ def process(self) -> Generator[A[Frame, Name("frame")]]: class SingleCellSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.frame_dims = ( FrameDims(width=512, height=512) if self.frame_dims is None @@ -52,7 +53,7 @@ def complete_model(self): class TwoCellsSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.frame_dims = ( FrameDims(width=512, height=512) if self.frame_dims is None @@ -74,7 +75,7 @@ def complete_model(self): class TwoOverlappingSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.frame_dims = ( FrameDims(width=512, height=512) if self.frame_dims is None @@ -96,7 +97,7 @@ def complete_model(self): class SeparateSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.cell_radii = 3 self.frame_dims = ( FrameDims(width=50, height=50) @@ -122,7 +123,7 @@ def complete_model(self): class ConnectedSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.cell_radii = 8 self.frame_dims = ( FrameDims(width=50, height=50) @@ -148,7 +149,7 @@ def complete_model(self): class GradualOnSource(MovieSource): @model_validator(mode="after") - def complete_model(self): + def complete_model(self) -> Self: self.n_frames = 100 self.cell_radii = 8 self.frame_dims = ( @@ -188,4 +189,4 @@ def complete_model(self): class SplitOffSource(MovieSource): @model_validator(mode="after") - def complete_model(self): ... + def complete_model(self) -> Self: ... diff --git a/tests/fixtures/toys.py b/tests/fixtures/toys.py index 31141b78..fd9a8507 100644 --- a/tests/fixtures/toys.py +++ b/tests/fixtures/toys.py @@ -1,6 +1,6 @@ import pytest -from cala.testing import SingleCellSource, SeparateSource, ConnectedSource +from cala.testing import ConnectedSource, SeparateSource, SingleCellSource from cala.testing.toy import Toy From bb4a8668943bbb0c3c16b6972e23d654f3f657b8 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 12:06:36 -0700 Subject: [PATCH 13/43] tests: refit tests --- src/cala/testing/nodes.py | 2 +- src/cala/testing/toy.py | 26 +++++++++++++------------- tests/fixtures/__init__.py | 3 ++- tests/test_iter/test_detect.py | 4 ++-- tests/test_iter/test_traces.py | 2 +- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 6395028e..24905dfa 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -40,7 +40,7 @@ def complete_model(self) -> Self: if self.frame_dims is None else FrameDims(**self.frame_dims) ) - self._traces = [np.array(range(0, self.n_frames))] + self._traces = [np.array(range(0, self.n_frames), dtype=float)] if self.positions is None: self.positions = [Position(width=256, height=256)] diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index 2130c12d..4f65339b 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -54,7 +54,7 @@ class Toy(BaseModel): cell_traces: list[np.ndarray] cell_ids: list[str] """If none, auto populated as cell_{idx}.""" - confidences: list[float] + detected_ons: list[int] _footprints: xr.DataArray = PrivateAttr(init=False) _traces: xr.DataArray = PrivateAttr(init=False) @@ -74,9 +74,9 @@ def fill_ids(self) -> Self: return self @model_validator(mode="before") - def fill_confidences(self) -> Self: - if self.get("confidences", None) is None: - self["confidences"] = [0.0] * len(self["cell_positions"]) + def fill_detected_ons(self) -> Self: + if self.get("detected_ons", None) is None: + self["detected_ons"] = [0] * len(self["cell_positions"]) return self @model_validator(mode="before") @@ -124,7 +124,7 @@ def _build_movie_template(self) -> xr.DataArray: ) def _generate_footprint( - self, radius: int, position: Position, id_: str, confidence: float + self, radius: int, position: Position, id_: str, detected_on: int ) -> xr.DataArray: footprint = xr.DataArray( np.zeros((self.frame_dims.height, self.frame_dims.width)), @@ -141,7 +141,7 @@ def _generate_footprint( return footprint.expand_dims(AXIS.component_dim).assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.detect_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [detected_on]), **{ax: footprint[ax] for ax in AXIS.spatial_dims}, } ) @@ -149,13 +149,13 @@ def _generate_footprint( def _build_footprints(self) -> xr.DataArray: footprints = [] for radius, position, id_, confid in zip( - self.cell_radii, self.cell_positions, self.cell_ids, self.confidences + self.cell_radii, self.cell_positions, self.cell_ids, self.detected_ons ): footprints.append(self._generate_footprint(radius, position, id_, confid)) return xr.concat(footprints, dim=AXIS.component_dim) - def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.DataArray: + def _format_trace(self, trace: np.ndarray, id_: str, detected_on: int) -> xr.DataArray: return ( xr.DataArray( trace, @@ -165,7 +165,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.Da .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.detect_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [detected_on]), AXIS.frames_dim: range(trace.size), } ) @@ -173,7 +173,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.Da def _build_traces(self) -> xr.DataArray: traces = [] - for trace, id_, confid in zip(self.cell_traces, self.cell_ids, self.confidences): + for trace, id_, confid in zip(self.cell_traces, self.cell_ids, self.detected_ons): traces.append(self._format_trace(trace, id_, confid)) return xr.concat(traces, dim=AXIS.component_dim).assign_coords( @@ -198,12 +198,12 @@ def make_movie(self) -> Movie: return Movie.from_array(movie) def add_cell( - self, position: Position, radius: int, trace: np.ndarray, id_: str, confidence: float = 0.0 + self, position: Position, radius: int, trace: np.ndarray, id_: str, detected_on: int = 0 ) -> None: - new_footprint = self._generate_footprint(radius, position, id_, confidence) + new_footprint = self._generate_footprint(radius, position, id_, detected_on) self._footprints = xr.concat([self._footprints, new_footprint], dim=AXIS.component_dim) - new_trace = self._format_trace(trace, id_, confidence) + new_trace = self._format_trace(trace, id_, detected_on) self._traces = xr.concat([self._traces, new_trace], dim=AXIS.component_dim) def drop_cell(self, id_: str | Iterable[str]) -> None: diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 0bfc51ca..f4a4c7c3 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -9,7 +9,7 @@ yaml_config, ) from .meta import monkeypatch_session -from .toys import connected_cells, separate_cells +from .toys import connected_cells, separate_cells, single_cell __all__ = [ "monkeypatch_session", @@ -21,6 +21,7 @@ "tmp_config_source", "tmp_cwd", "yaml_config", + "single_cell", "separate_cells", "connected_cells", ] diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 88587789..e8c6c110 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -157,8 +157,8 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): # we're forcing a double-detection in this test new_fps, new_trs = cataloger.process(fps, trs, Footprints(), Traces()) - result = (new_fps.array @ new_trs.array).where(new_fps.array.max(dim=AXIS.component_dim), 0) - expected = movie.where(new_fps.array.max(dim=AXIS.component_dim), 0) + result = new_fps.array @ new_trs.array + expected = movie * (new_fps.array.max(dim=AXIS.component_dim) > 1e-4) assert new_fps.array is not None # 1. the footprints do not overlap diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 4e6122a9..b65105f9 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -26,7 +26,7 @@ def test_init(init, toy, request) -> None: def frame_update() -> Node: return Node.from_specification( spec=NodeSpecification( - id="frame_test", type="cala.nodes.traces.FrameUpdate", params={"tolerance": 1e-3} + id="frame_test", type="cala.nodes.traces.FrameUpdate", params={"tol": 1e-3} ) ) From 0ebbc07bd5311e45293390d65a0d4a5615bba6d7 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 16:02:21 -0700 Subject: [PATCH 14/43] tests: traces are floats!! --- src/cala/testing/nodes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 24905dfa..2e767b6d 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -61,8 +61,8 @@ def complete_model(self) -> Self: ) self._traces = [ - np.array(range(0, self.n_frames)), - np.array([0, *range(self.n_frames - 1, 0, -1)]), + np.array(range(0, self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), ] if self.positions is None: self.positions = [Position(width=206, height=206), Position(width=306, height=306)] @@ -82,8 +82,8 @@ def complete_model(self) -> Self: else FrameDims(**self.frame_dims) ) self._traces = [ - np.array(range(0, self.n_frames)), - np.array([0, *range(self.n_frames - 1, 0, -1)]), + np.array(range(0, self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), ] if self.positions is None: From 910f182055a5d58d909834b2eb80760de56e471d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 16:03:06 -0700 Subject: [PATCH 15/43] BREAKING: footprint frame ingestion before residual --- src/cala/nodes/footprints.py | 4 +++- src/cala/nodes/residual.py | 1 - tests/data/pipelines/odl.yaml | 26 ++++++++++++-------------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index f04f9481..0d3ff079 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -102,7 +102,9 @@ def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArr ) -def ingest_component(footprints: Footprints, new_footprints: Footprints) -> Footprints: +def ingest_component( + footprints: Footprints, new_footprints: Footprints +) -> A[Footprints, Name("footprints")]: if new_footprints.array is None: return footprints diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 48f97514..a6023221 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -12,7 +12,6 @@ def build( frames: Movie, footprints: Footprints, traces: Traces, - trigger: bool, clip_threshold: float | None = None, ) -> A[Residual, Name("movie")]: """ diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index abc21fc7..450b3cee 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -92,15 +92,24 @@ nodes: - component_stats: assets.comp_stats - frame: glow.frame - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 1 + tol: 0.0001 + max_iter: 100 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value residual: type: cala.nodes.residual.build params: clip_threshold: 0.1 depends: - - trigger: trace_frame.latest_trace - frames: assets.buffer - - footprints: assets.footprints + - footprints: footprints_frame.footprints - traces: assets.traces cleanup: type: cala.nodes.cleanup.purge_razed_components @@ -158,22 +167,11 @@ nodes: - new_traces: catalog.new_traces # DETECT ENDS - footprints_frame: - type: cala.nodes.footprints.Footprinter - params: - bep: 1 - tol: 0.0001 - max_iter: 100 - depends: - - footprints: assets.footprints - - pixel_stats: pix_component.value - - component_stats: comp_component.value - overlaps_update: type: cala.nodes.overlap.initialize depends: - overlaps: assets.overlaps - - footprints: footprints_frame.footprints + - footprints: footprint_component.footprints return: type: return From a700d1a7092b36e6f4fb334f2471942662267f88 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 19 Aug 2025 16:03:46 -0700 Subject: [PATCH 16/43] tests: improve pipeline test parametrization --- src/cala/nodes/residual.py | 5 +--- tests/data/pipelines/odl.yaml | 1 - tests/test_iter/test_footprints.py | 2 +- tests/test_iter/test_residual.py | 1 - tests/test_pipeline.py | 48 +++++++++++++++++++----------- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index a6023221..81aabf9b 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -9,10 +9,7 @@ def build( - frames: Movie, - footprints: Footprints, - traces: Traces, - clip_threshold: float | None = None, + frames: Movie, footprints: Footprints, traces: Traces, clip_threshold: float | None = None ) -> A[Residual, Name("movie")]: """ The computation follows the equation: diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 450b3cee..a6e63d23 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -53,7 +53,6 @@ nodes: size_est: type: cala.nodes.prep.SizeEst params: - # hardset_radius: 10 log_kwargs: min_sigma: 3 max_sigma: 10 diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 0fe9e717..14e85359 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -81,7 +81,7 @@ def test_ingest_frame(fpter, toy, request): footprints=toy.footprints, pixel_stats=pixstats, component_stats=compstats ) - expected = toy.footprints.copy() + expected = toy.footprints.model_copy() xr.testing.assert_allclose(result.array, expected.array) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 890bb697..c76fd324 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -20,7 +20,6 @@ def test_init(init, separate_cells) -> None: footprints=separate_cells.footprints, traces=separate_cells.traces, frames=separate_cells.make_movie(), - trigger=True, ) assert np.all(result.array == 0) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 34b6a037..c27cc662 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import xarray as xr from noob import Cube, SynchronousRunner, Tube @@ -10,18 +11,21 @@ params=[ "SingleCellSource", "TwoCellsSource", - "TwoOverlappingSource", "SeparateSource", + "TwoOverlappingSource", "GradualOnSource", ] ) -def tube(request): - tube = Tube.from_specification("cala-odl") - source = Node.from_specification( +def source(request): + return Node.from_specification( NodeSpecification(id="source", type=f"cala.testing.{request.param}") ) - tube.nodes["source"] = source + +@pytest.fixture +def tube(source): + tube = Tube.from_specification("cala-odl") + tube.nodes["source"] = source return tube @@ -43,25 +47,32 @@ def test_process(runner) -> None: assert runner.cube.assets["buffer"].obj.array.size > 0 -def test_iter(runner) -> None: - gen = runner.iter(n=runner.tube.nodes["source"].instance.n_frames) +def test_iter(runner, source) -> None: + gen = runner.iter(n=source.instance.n_frames) + src_name = source.spec.type_.split(".")[-1] + toy = source.instance._toy.model_copy() - movie = [] - for _, exp in enumerate(gen): - movie.append(exp[0].array) + preprocessed_frames = [] + for fr in gen: + preprocessed_frames.append(fr[0].array) fps = runner.cube.assets["footprints"].obj trs = runner.cube.assets["traces"].obj - expected = xr.concat(movie, dim=AXIS.frames_dim) - result = (fps.array @ trs.array).transpose(*expected.dims) + if src_name in ["TwoOverlappingSource", "GradualOnSource"]: + # Correct component count + assert toy.traces.array.sizes[AXIS.component_dim] == trs.array.sizes[AXIS.component_dim] - src_node = runner.tube.nodes["source"].spec.type_.split(".")[-1] + # Traces are reasonably similar + tr_corr = xr.corr( + toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frame_coord + ) + for corr in tr_corr: + assert np.isclose(corr.max(), 1, atol=1e-2) - if src_node == "TwoOverlappingSource": - diff = expected - result - for d_fr, e_fr in zip(diff, expected): - assert d_fr.max() <= e_fr.quantile(0.98) * 2e-2 else: + expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + result = (fps.array @ trs.array).transpose(*expected.dims) + xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) n_discoverable = { @@ -69,8 +80,9 @@ def test_iter(runner) -> None: "TwoCellsSource": 2, "TwoOverlappingSource": 2, "SeparateSource": 2, + "GradualOnSource": 5, } - assert fps.array.sizes[AXIS.component_dim] == n_discoverable[src_node] + assert fps.array.sizes[AXIS.component_dim] == n_discoverable[src_name] def test_run(runner) -> None: From 8603dc11a00b7f2a92f30e27b845167a4dd34748 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 20 Aug 2025 11:23:43 -0700 Subject: [PATCH 17/43] tests: add split off test --- src/cala/testing/__init__.py | 2 ++ src/cala/testing/nodes.py | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index b32c0a73..d3bdc480 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -5,6 +5,7 @@ SingleCellSource, TwoCellsSource, TwoOverlappingSource, + SplitOffSource, ) __all__ = [ @@ -14,4 +15,5 @@ "SeparateSource", "ConnectedSource", "GradualOnSource", + "SplitOffSource", ] diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 2e767b6d..128838a0 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -189,4 +189,23 @@ def complete_model(self) -> Self: class SplitOffSource(MovieSource): @model_validator(mode="after") - def complete_model(self) -> Self: ... + def complete_model(self) -> Self: + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=20, height=20), + Position(width=30, height=30), + ] + self._traces = [ + np.array( + [0, *range(1, int(self.n_frames / 2)), *range(int(self.n_frames / 2), 0, -1)], + dtype=float, + ), + np.array(range(self.n_frames), dtype=float), + ] + self._toy = self._build_toy() + return self From 0469c80d6438919fa858fc9123cd9050cc8afbe4 Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 20 Aug 2025 11:47:39 -0700 Subject: [PATCH 18/43] feat: asset validation - no extra coordinates for certain assets --- src/cala/assets.py | 7 +++++++ src/cala/models/entity.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index 3d1965d8..077b4063 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -86,6 +86,7 @@ class Footprints(Asset): member=Footprint.entity(), group_by=Dims.component, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -134,6 +135,7 @@ def from_array( member=Trace.entity(), group_by=Dims.component, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -145,6 +147,7 @@ class Movie(Asset): member=Frame.entity(), group_by=Dims.frame.value, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -180,6 +183,7 @@ class CompStats(Asset): dims=comp_dims, dtype=float, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -191,6 +195,7 @@ class PixStats(Asset): dims=(Dims.width.value, Dims.height.value, Dims.component.value), dtype=float, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -202,6 +207,7 @@ class Overlaps(Asset): dims=comp_dims, dtype=bool, checks=[has_no_nan], + allow_extra_coords=False, ) ) @@ -224,5 +230,6 @@ class Residual(Asset): member=Frame.entity(), group_by=Dims.frame.value, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) diff --git a/src/cala/models/entity.py b/src/cala/models/entity.py index 1c589257..d3448a1a 100644 --- a/src/cala/models/entity.py +++ b/src/cala/models/entity.py @@ -18,6 +18,7 @@ class Entity(BaseModel): coords: list[Coord] = Field(default_factory=list) dtype: type checks: list[Callable] = Field(default_factory=list) + allow_extra_coords: bool = True _model: DataArraySchema = PrivateAttr(DataArraySchema()) @@ -47,15 +48,14 @@ def to_schema(self) -> DataArraySchema: checks=self.checks, ) - @staticmethod - def _build_coord_schema(coords: list[Coord]) -> CoordsSchema: + def _build_coord_schema(self, coords: list[Coord]) -> CoordsSchema: spec = dict() for c in coords: dim = DimsSchema((c.dim,)) if c.dim else None spec[c.name] = DataArraySchema(dims=dim, dtype=DTypeSchema(c.dtype), checks=c.checks) - return CoordsSchema(spec) + return CoordsSchema(spec, allow_extra_keys=self.allow_extra_coords) class Group(Entity): From a2f913159247bbbffecbd1336ee80e623255a76f Mon Sep 17 00:00:00 2001 From: Raymond Date: Wed, 20 Aug 2025 11:47:58 -0700 Subject: [PATCH 19/43] debug: remove extra coords in R_min --- src/cala/nodes/residual.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 81aabf9b..45eec275 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -80,7 +80,9 @@ def _clear_overestimates(A: xr.DataArray, R: xr.DataArray, clip_val: float) -> x have been removed, and the remaining negative spots are noise level. """ - R_min = R.min(dim=AXIS.frames_dim) + R_min = R.isel({AXIS.frames_dim: -1}).reset_coords( + [AXIS.frame_coord, AXIS.timestamp_coord], drop=True + ) # .min(dim=AXIS.frames_dim) footprints = A.where(R_min > clip_val, 0, drop=False) return footprints From b3d2166b1d6799b696ac326b4cd96f30d5c48a5d Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 08:42:12 -0700 Subject: [PATCH 20/43] feat: reproduction tolerance param for nmf --- src/cala/nodes/detect/slice_nmf.py | 5 ++++- src/cala/testing/__init__.py | 2 +- tests/data/pipelines/odl.yaml | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 303cd450..985d0383 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -19,6 +19,9 @@ class SliceNMF(Node): """Wait until this number of frames to begin detecting.""" detect_thresh: float """Minimum detection threshold for brightness fluctuation.""" + reprod_tol: float + """Mean pixel value error tolerance for reproduced slice from the new component""" + nmf_kwargs: dict[str, Any] = Field(default_factory=dict) error_: float = Field(None) @@ -60,7 +63,7 @@ def process( energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 - if (self.error_ / l1_norm) <= self._model.tol: + if (self.error_ / l1_norm) <= self.reprod_tol: fps.append(Footprint.from_array(a_new)) trs.append(Trace.from_array(c_new)) res = (res - comp_recon).clip(0) diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index d3bdc480..9d866edc 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -3,9 +3,9 @@ GradualOnSource, SeparateSource, SingleCellSource, + SplitOffSource, TwoCellsSource, TwoOverlappingSource, - SplitOffSource, ) __all__ = [ diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index a6e63d23..cf721202 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -128,6 +128,7 @@ nodes: params: min_frames: 10 detect_thresh: 1.0 + reprod_tol: 0.001 depends: - residuals: residual.movie - detect_radius: size_est.radius From 10a7eb55068c01e8efedda4aaf971506d2299115 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 14:52:50 -0700 Subject: [PATCH 21/43] feat: footprint mask tuning post footprint frame ingestion --- src/cala/nodes/cleanup.py | 23 ++++++++++++++++++++++- tests/data/pipelines/odl.yaml | 13 +++++++++++-- tests/test_iter/test_cleanup.py | 16 ++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 tests/test_iter/test_cleanup.py diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 63196499..ffc427f6 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -5,10 +5,31 @@ import xarray as xr from noob import Name -from cala.assets import CompStats, Footprints, Overlaps, PixStats, Traces +from cala.assets import CompStats, Footprints, Overlaps, PixStats, Traces, Residual from cala.models import AXIS +def clear_overestimates( + footprints: Footprints, residuals: Residual, nmf_error: float +) -> A[Footprints, Name("footprints")]: + """ + Remove all sections of the footprints that cause negative residuals. + + This occurs by: + 1. find "significant" negative residual spots that is more than a noise level, and thus + cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) + 2. all footprint values at these spots go to zero. + """ + if residuals.array is None: + return footprints + R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( + [AXIS.frame_coord, AXIS.timestamp_coord], drop=True + ) + tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) + + return tuned_fp + + def purge_razed_components( footprints: Footprints, traces: Traces, diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index cf721202..0218e762 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -25,6 +25,9 @@ assets: overlaps: type: cala.assets.Overlaps scope: session + residuals: + type: cala.assets.Residual + scope: session # Add refiners in nodes @@ -101,15 +104,21 @@ nodes: - footprints: assets.footprints - pixel_stats: pix_frame.value - component_stats: comp_frame.value + facetune: + type: cala.nodes.cleanup.clear_overestimates + params: + nmf_error: 1.0 + depends: + - footprints: footprints_frame.footprints + - residuals: assets.residuals residual: type: cala.nodes.residual.build - params: - clip_threshold: 0.1 depends: - frames: assets.buffer - footprints: footprints_frame.footprints - traces: assets.traces + - residuals: assets.residuals cleanup: type: cala.nodes.cleanup.purge_razed_components params: diff --git a/tests/test_iter/test_cleanup.py b/tests/test_iter/test_cleanup.py new file mode 100644 index 00000000..6281b449 --- /dev/null +++ b/tests/test_iter/test_cleanup.py @@ -0,0 +1,16 @@ +from cala.assets import Residual +from cala.models import AXIS +from cala.nodes.cleanup import clear_overestimates + + +def test_clear_overestimates(single_cell) -> None: + residual = Residual.from_array(single_cell.make_movie().array) + residual.array.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] *= -1 + + result = clear_overestimates( + footprints=single_cell.footprints, residuals=residual, nmf_error=-1.0 + ) + expected = single_cell.footprints.array.copy() + expected.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = 0 + + assert result.equals(expected) From 274e9a8081501b1808967d1548b1b4a626b255fb Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 16:31:23 -0700 Subject: [PATCH 22/43] tests: longer splitoff test --- src/cala/testing/nodes.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 128838a0..a168cede 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -190,6 +190,7 @@ def complete_model(self) -> Self: class SplitOffSource(MovieSource): @model_validator(mode="after") def complete_model(self) -> Self: + self.n_frames = 100 self.cell_radii = 8 self.frame_dims = ( FrameDims(width=50, height=50) @@ -205,7 +206,14 @@ def complete_model(self) -> Self: [0, *range(1, int(self.n_frames / 2)), *range(int(self.n_frames / 2), 0, -1)], dtype=float, ), - np.array(range(self.n_frames), dtype=float), + np.array( + [ + *range(int(self.n_frames / 4)), + *range(int(self.n_frames / 4), 0, -1), + *range(int(self.n_frames / 2)), + ], + dtype=float, + ), ] self._toy = self._build_toy() return self From efa89fddaa7c970bb413f5f30e1288c1f10572ba Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 16:32:31 -0700 Subject: [PATCH 23/43] feat: overlap trace adjustment instead footprint purge --- src/cala/nodes/residual.py | 82 +++++++++++++++++++++----------- tests/test_iter/test_residual.py | 46 ++++++++++-------- 2 files changed, 79 insertions(+), 49 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 45eec275..04cc6432 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -2,14 +2,13 @@ import xarray as xr from noob import Name -from skimage.restoration import estimate_sigma - +import numpy as np from cala.assets import Footprints, Movie, Residual, Traces from cala.models import AXIS def build( - frames: Movie, footprints: Footprints, traces: Traces, clip_threshold: float | None = None + residuals: Residual, frames: Movie, footprints: Footprints, traces: Traces ) -> A[Residual, Name("movie")]: """ The computation follows the equation: @@ -44,45 +43,70 @@ def build( # Reshape footprints to (pixels x components) A = footprints.array + R_latest = Y.isel({AXIS.frames_dim: -1}) - (A @ C.isel({AXIS.frames_dim: -1})) + if R_latest.min() < 0: + shifted_tr = _align_overestimates(A, C.isel({AXIS.frames_dim: -1}), R_latest) + C.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr + traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr + # Compute residual R = Y - [A,b][C;f] R = Y - (A @ C) + residuals.array = R.clip(min=0) # clipping is for the first n frames - clip_val = _estimate_clip_val(Y, clip_threshold) - footprints.array = _clear_overestimates(A, R, clip_val) + return residuals - return Residual.from_array(R.clip(min=0)) - -def _estimate_clip_val(Y: xr.DataArray, clip_threshold: float | None = None) -> float: +def _align_overestimates( + A: xr.DataArray, C_latest: xr.DataArray, R_latest: xr.DataArray +) -> xr.DataArray: """ - Estimate the threshold of "what is a significant negative residual value?" (above noise level) - :param Y: - :param clip_threshold: - :return: - """ - if clip_threshold: - return -Y.max().item() * clip_threshold - else: - return -estimate_sigma(Y) + Gotta be able to do at least ONE OF splitoff or gradualon. + Negative residuals just need to go. There isn't much you can do with the value...? -def _clear_overestimates(A: xr.DataArray, R: xr.DataArray, clip_val: float) -> xr.DataArray: - """ - Remove all sections of the footprints that cause negative residuals. - This occurs by: - 1. find "significant" negative residual spots that is more than a noise level, and thus - cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) - 2. all footprint values at these spots go to zero. + Two cases: (A & B Overlapping) + 1. GradualOn: Know A. B turns ON + -> trace tries to chase (increases) + -> footprint tries to chase + -> residual becomes negative at A-B -> should just decrease, positive at A^B -> actually... should decrease (just more steeply) + + 2. SplitOff: Know AB. B turns OFF + -> trace tries to chase (decreases) + -> footprint tries to chase + -> residual becomes positive at A-B -> should increase, negative at A^B -> this should just decrease, MORE negative at B-A -> this going to zero makes sense + OR + keep B, remove A-B + + R = Y - A @ C + What about the past frame residuals after? + + for GradualOn, nothing should go to zero. + for SplitOff, a chunk needs to go to zero. + + So... how about we do something like (if it's been on for a long time, we become less likely to purge it?) We subsequently clip R minimum to zero, since all significant negative residual spots have been removed, and the remaining negative spots are noise level. + + !!We're assuming there's no completely occluded component. This might be a problem eventually!! """ - R_min = R.isel({AXIS.frames_dim: -1}).reset_coords( - [AXIS.frame_coord, AXIS.timestamp_coord], drop=True - ) # .min(dim=AXIS.frames_dim) - footprints = A.where(R_min > clip_val, 0, drop=False) + unlayered_footprints = _find_unlayered_footprints(A) + # if unlayered_footprints.max(dim=AXIS.spatial_dims).min() == 0: + # raise ValueError("There are at least one completely occluded components.") + + R_rel = R_latest.where((R_latest < 0) * unlayered_footprints.max(dim=AXIS.component_dim)) + dC = ( + (R_rel / A) + .min(dim=AXIS.spatial_dims) + .reset_coords([AXIS.frame_coord, AXIS.timestamp_coord], drop=True) + ) + + return C_latest + xr.apply_ufunc(np.nan_to_num, dC, kwargs={"neginf": 0}).clip(min=0) + - return footprints +def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: + A_layer_mask = (A > 0).sum(dim=AXIS.component_dim) + return A.where(A_layer_mask == 1, 0) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index c76fd324..45b47c73 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -2,10 +2,9 @@ import pytest from noob.node import Node, NodeSpecification -from cala.assets import Residual +import xarray as xr from cala.models.axis import AXIS -from cala.nodes.residual import _clear_overestimates -from cala.testing.toy import FrameDims, Position, Toy +from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints @pytest.fixture(scope="function") @@ -25,25 +24,32 @@ def test_init(init, separate_cells) -> None: assert np.all(result.array == 0) -@pytest.fixture -def one_cell() -> Toy: - n_frames = 50 +def test_align_overestimates(single_cell) -> None: + """ + grab the last frame of the residual. assume part of the footprint masked area is negative + traces needs to proportionally decrease, until the recalculated residual is zero - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=3, - cell_positions=[Position(width=25, height=25)], - cell_traces=[np.array(range(n_frames), dtype=float)], - ) + Eventually, this probably can be absorbed straight into trace frame_ingest as a constraint. + """ + movie = single_cell.make_movie() + last_frame = movie.array.isel({AXIS.frames_dim: -1}) + + last_res = xr.zeros_like(last_frame) + last_res.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = -1 + last_res = last_res.where(single_cell.footprints.array[0].values, 0) + + last_trace = single_cell.traces.array.isel({AXIS.frames_dim: -1}) + footprints = single_cell.footprints.array -def test_clear_overestimates(one_cell) -> None: - residual = Residual.from_array(one_cell.make_movie().array) - residual.array.loc[{AXIS.width_coord: slice(one_cell.cell_positions[0].width, None)}] *= -1 + adjusted_traces = _align_overestimates(A=footprints, R_latest=last_res, C_latest=last_trace) + + np.testing.assert_array_equal( + (footprints @ adjusted_traces).values, movie.array.isel({AXIS.frames_dim: -2}).values + ) - result = _clear_overestimates(A=one_cell.footprints.array, R=residual.array, clip_val=-1.0) - expected = one_cell.footprints.array.copy() - expected.loc[{AXIS.width_coord: slice(one_cell.cell_positions[0].width, None)}] = 0 - assert result.equals(expected) +def test_find_exposed_footprints(connected_cells) -> None: + footprints = connected_cells.footprints + result = _find_unlayered_footprints(footprints.array) + assert result.sum(dim=AXIS.component_dim).max().item() == footprints.max().item() From 5abcc84e1064ae690b5fae2c3855b36461f73396 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 16:32:42 -0700 Subject: [PATCH 24/43] tests: add splitoff test --- tests/test_pipeline.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c27cc662..ada5269e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,6 +14,7 @@ "SeparateSource", "TwoOverlappingSource", "GradualOnSource", + "SplitOffSource", ] ) def source(request): @@ -47,7 +48,7 @@ def test_process(runner) -> None: assert runner.cube.assets["buffer"].obj.array.size > 0 -def test_iter(runner, source) -> None: +def test_odl(runner, source) -> None: gen = runner.iter(n=source.instance.n_frames) src_name = source.spec.type_.split(".")[-1] toy = source.instance._toy.model_copy() @@ -58,10 +59,14 @@ def test_iter(runner, source) -> None: fps = runner.cube.assets["footprints"].obj trs = runner.cube.assets["traces"].obj - if src_name in ["TwoOverlappingSource", "GradualOnSource"]: - # Correct component count + # Correct component count + if src_name != "SeparateSource": assert toy.traces.array.sizes[AXIS.component_dim] == trs.array.sizes[AXIS.component_dim] + else: + # 2 is the # of discoverable cells (non-constant) for SeparateSource + assert trs.array.sizes[AXIS.component_dim] == 2 + if src_name in ["TwoOverlappingSource", "GradualOnSource", "SplitOffSource"]: # Traces are reasonably similar tr_corr = xr.corr( toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frame_coord @@ -74,18 +79,3 @@ def test_iter(runner, source) -> None: result = (fps.array @ trs.array).transpose(*expected.dims) xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) - - n_discoverable = { - "SingleCellSource": 1, - "TwoCellsSource": 2, - "TwoOverlappingSource": 2, - "SeparateSource": 2, - "GradualOnSource": 5, - } - assert fps.array.sizes[AXIS.component_dim] == n_discoverable[src_name] - - -def test_run(runner) -> None: - result = runner.run(n=5) - - assert result From 145758a1c46c8e335a3c937fb131d5d400b35387 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 21 Aug 2025 16:33:25 -0700 Subject: [PATCH 25/43] feat: smaller spotlight size. we can merge more easily --- src/cala/nodes/prep/r_estimate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index 7fb0f4f1..03782489 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -37,6 +37,6 @@ def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: self.centers_ = [blobs[:-1] for blobs in blobs] self.sizes_ += [blob[-1].item() for blob in blobs] - self._est_radius = int(np.round(np.median(self.sizes_)).item()) + self._est_radius = int(np.round(np.median(self.sizes_) / 2).item()) return self._est_radius From efc8d6b3db32b65b90c3f5a7ff56c82b51099d5c Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 22 Aug 2025 12:15:51 -0700 Subject: [PATCH 26/43] debug: stream outputs a generator, not an iterator --- src/cala/nodes/io.py | 17 ++++++++--------- tests/test_io.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 0bc47fda..bd8b21df 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,7 +1,6 @@ from abc import abstractmethod -from collections.abc import Iterator from pathlib import Path -from typing import Protocol +from typing import Protocol, Generator import cv2 from numpy.typing import NDArray @@ -12,7 +11,7 @@ class Stream(Protocol): """Protocol defining the interface for video streams.""" @abstractmethod - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """Iterate over frames.""" ... @@ -35,7 +34,7 @@ def __init__(self, video_path: Path | str) -> None: if not self._cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """ Yields: NDArray: Next frame from the video @@ -66,7 +65,7 @@ def __init__(self, files: list[Path]) -> None: raise ValueError("TIFF files must be grayscale") self._sample_shape = frame.shape - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: for file in self._files: frame = io.imread(file) if len(frame.shape) != 2: @@ -90,7 +89,7 @@ def __init__(self, video_paths: list[Path]) -> None: self._video_paths = video_paths self._current_stream: OpenCVStream | None = None - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """ Iterate over frames from all videos sequentially. @@ -108,7 +107,7 @@ def close(self) -> None: self._current_stream.close() -def stream(files: list[str | Path]) -> Stream: +def stream(files: list[str | Path]) -> Generator[NDArray]: """ Create a video stream from the provided video files. @@ -125,8 +124,8 @@ def stream(files: list[str | Path]) -> Stream: video_format = {".mp4", ".avi", ".webm"} if suffix.issubset(video_format): - return VideoStream(files) + return iter(VideoStream(files)) elif suffix.issubset(image_format): - return ImageStream(files) + return iter(ImageStream(files)) else: raise ValueError(f"Unsupported file format: {suffix}") diff --git a/tests/test_io.py b/tests/test_io.py index 5a9be4ce..55cb5f1c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -50,7 +50,7 @@ def test_tiff_stream(tmp_path): save_tiff(tmp_path / f"{i}.tif", image) media = sorted(glob(str(tmp_path / "*.tif"))) - s = iter(stream(media)) + s = stream(media) for idx, res in enumerate(s): np.testing.assert_array_equal(res, generate_text_image(str(idx))) @@ -65,7 +65,7 @@ def test_video_stream(tmp_path): save_movie(tmp_path / "video.mp4", video) media = sorted(glob(str(tmp_path / "*.mp4"))) - s = iter(stream(media)) + s = stream(media) for idx, res in enumerate(s): np.testing.assert_allclose( From 16bd16912445a43cde956b9d12c243112050fad4 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 22 Aug 2025 12:16:27 -0700 Subject: [PATCH 27/43] feat: add counter for frame idx feat: package_frame outputs a frame instead of dataarray --- src/cala/util.py | 19 ++++++++++++++----- tests/test_prep/test_r_estimate.py | 5 ++--- tests/test_util.py | 5 ++--- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/cala/util.py b/src/cala/util.py index fa6883ca..4e2a7f90 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -1,16 +1,24 @@ -from collections.abc import Sequence +from collections.abc import Sequence, Generator from datetime import datetime +from itertools import count +from typing import Annotated as A from uuid import uuid4 import numpy as np import xarray as xr +from noob import Name +from cala.assets import Frame from cala.models import AXIS -def package_frame( - frame: np.ndarray, index: int, timestamp: datetime | str | None = None -) -> xr.DataArray: +def counter(start: int = 0, limit: int = 1e7) -> A[Generator[int], Name("idx")]: + cnt = count(start=start) + while (val := next(cnt)) < limit: + yield val + + +def package_frame(frame: np.ndarray, index: int, timestamp: datetime | str | None = None) -> Frame: """Transform a 2D numpy frame into an xarray DataArray. Args: @@ -37,12 +45,13 @@ def package_frame( name="frame", ) - return frame.assign_coords( + da = frame.assign_coords( { AXIS.width_dim: range(frame.sizes[AXIS.width_dim]), AXIS.height_dim: range(frame.sizes[AXIS.height_dim]), } ) + return Frame.from_array(da.astype(float)) def create_id() -> str: diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py index b254dfd5..bdd2bc64 100644 --- a/tests/test_prep/test_r_estimate.py +++ b/tests/test_prep/test_r_estimate.py @@ -1,4 +1,3 @@ -from cala.assets import Frame from cala.models import AXIS from cala.nodes.prep.r_estimate import SizeEst from cala.testing.toy import Position @@ -18,7 +17,7 @@ def test_size_estim(separate_cells): max_proj = package_frame( separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=1 ) - result = node.get_median_radius(Frame.from_array(max_proj)) + result = node.get_median_radius(max_proj) expected = separate_cells.cell_radii[0] - 1 assert result == expected @@ -27,7 +26,7 @@ def test_size_estim(separate_cells): max_proj = package_frame( separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=3 ) - result = node.get_median_radius(Frame.from_array(max_proj)) + result = node.get_median_radius(max_proj) assert result == expected assert len(node.sizes_) == 3 diff --git a/tests/test_util.py b/tests/test_util.py index 0ab6a9b9..1c119ec6 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,7 +2,6 @@ import numpy as np -from cala.assets import Frame from cala.util import package_frame @@ -13,6 +12,6 @@ def test_package_frame(): timestamp = datetime(2023, 4, 8, 12, 0, 0) # Transform the frame - dataarray = package_frame(frame, index, timestamp) + frame = package_frame(frame, index, timestamp) - assert Frame.from_array(dataarray) + assert np.array_equal(frame.array.values, frame) From 7d00b746cc860b8f96292608de5cbd17380226a4 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 22 Aug 2025 16:46:42 -0700 Subject: [PATCH 28/43] feat: add nonlocal method to denoise --- src/cala/nodes/prep/denoise.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index c012f066..03f1e1fb 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -11,18 +11,25 @@ def denoise( - frame: Frame, method: Literal["gaussian", "median", "bilateral"] = "gaussian", **kwargs: Any + frame: Frame, method: Literal["gaussian", "median", "bilateral"], kwargs: dict[str, Any] ) -> A[Frame, Name("frame")]: """Denoise a single frame.""" methods: dict[str, Callable] = { "gaussian": cv2.GaussianBlur, "median": cv2.medianBlur, "bilateral": cv2.bilateralFilter, + "nonlocal": cv2.fastNlMeansDenoising, } _func = methods[method] + frame = frame.array - denoised = _func(frame.values.astype(np.float32), **kwargs).astype(np.float64) + if method == "nonlocal": + arr = frame.values.astype(np.uint8) + else: + arr = frame.values.astype(np.float32) + + denoised = _func(arr, **kwargs).astype(float) return Frame.from_array(xr.DataArray(denoised, dims=frame.dims, coords=frame.coords)) From a11dec3f731bc4b0966a78bc6e103295942019b1 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 22 Aug 2025 16:47:21 -0700 Subject: [PATCH 29/43] feat: add source movie using pipe --- tests/data/pipelines/with_src.yaml | 211 +++++++++++++++++++++++++++++ tests/test_pipeline.py | 17 +++ 2 files changed, 228 insertions(+) create mode 100644 tests/data/pipelines/with_src.yaml diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml new file mode 100644 index 00000000..02eabcce --- /dev/null +++ b/tests/data/pipelines/with_src.yaml @@ -0,0 +1,211 @@ +noob_id: cala-io +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +assets: + buffer: + type: cala.assets.Movie + scope: session + depends: + - cache.buffer + footprints: + type: cala.assets.Footprints + scope: session + traces: + type: cala.assets.Traces + scope: session + pix_stats: + type: cala.assets.PixStats + scope: session + comp_stats: + type: cala.assets.CompStats + scope: session + overlaps: + type: cala.assets.Overlaps + scope: session + residuals: + type: cala.assets.Residual + scope: session + + +nodes: + source: + type: cala.nodes.io.stream + params: + files: + - data/movies/msCam1.avi + counter: + type: cala.util.counter + frame: + type: cala.util.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + saltpepper: + type: cala.nodes.prep.denoise + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + denoise: + type: cala.nodes.prep.denoise + params: + method: nonlocal + kwargs: + h: 4 + depends: + - frame: saltpepper.frame + glow: + type: cala.nodes.prep.GlowRemover + depends: + - frame: denoise.frame + filter: + type: + + motion: + type: cala.nodes.prep.RigidStabilizer + params: + drift_speed: 0.5 + depends: + - frame: glow.frame + size_est: + type: cala.nodes.prep.SizeEst + params: + log_kwargs: + min_sigma: 3 + max_sigma: 10 + num_sigma: 10 + threshold: 0.2 + overlap: 0.5 + depends: + - frame: motion.frame + cache: + type: cala.nodes.buffer.fill_buffer + params: + size: 100 + depends: + - buffer: assets.buffer + - frame: motion.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS + trace_frame: + type: cala.nodes.traces.FrameUpdate + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: assets.traces + - footprints: assets.footprints + - frame: motion.frame + - overlaps: assets.overlaps + pix_frame: + type: cala.nodes.pixel_stats.ingest_frame + depends: + - pixel_stats: assets.pix_stats + - frame: motion.frame + - new_traces: trace_frame.latest_trace + comp_frame: + type: cala.nodes.component_stats.ingest_frame + depends: + - component_stats: assets.comp_stats + - frame: motion.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 1 + tol: 0.0001 + max_iter: 100 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + facetune: + type: cala.nodes.cleanup.clear_overestimates + params: + nmf_error: 1.0 + depends: + - footprints: footprints_frame.footprints + - residuals: assets.residuals + + residual: + type: cala.nodes.residual.build + depends: + - frames: assets.buffer + - footprints: footprints_frame.footprints + - traces: assets.traces + - residuals: assets.residuals + cleanup: + type: cala.nodes.cleanup.purge_razed_components + params: + min_thicc: 3 + depends: + - footprints: assets.footprints + - traces: assets.traces + - pix_stats: assets.pix_stats + - comp_stats: assets.comp_stats + - overlaps: assets.overlaps + - trigger: residual.movie + # FRAME UPDATE ENDS + + # DETECT BEGINS + nmf: + type: cala.nodes.detect.SliceNMF + params: + min_frames: 30 + detect_thresh: 2.0 + reprod_tol: 0.001 + depends: + - residuals: residual.movie + - detect_radius: size_est.radius + catalog: + type: cala.nodes.detect.Cataloger + params: + merge_threshold: 0.8 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: assets.footprints + - existing_tr: assets.traces + + trace_component: + type: cala.nodes.traces.ingest_component + depends: + - traces: assets.traces + - new_traces: catalog.new_traces + footprint_component: + type: cala.nodes.footprints.ingest_component + depends: + - footprints: assets.footprints + - new_footprints: catalog.new_footprints + pix_component: + type: cala.nodes.pixel_stats.ingest_component + depends: + - pixel_stats: assets.pix_stats + - frames: assets.buffer + - new_traces: catalog.new_traces + - traces: assets.traces + comp_component: + type: cala.nodes.component_stats.ingest_component + depends: + - component_stats: assets.comp_stats + - traces: assets.traces + - new_traces: catalog.new_traces + + overlaps_update: + type: cala.nodes.overlap.initialize + depends: + - overlaps: assets.overlaps + - footprints: footprint_component.footprints + # DETECT ENDS + + return: + type: return + depends: + - raw: frame.value + - prep: motion.frame \ No newline at end of file diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ada5269e..f2b31631 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -79,3 +79,20 @@ def test_odl(runner, source) -> None: result = (fps.array @ trs.array).transpose(*expected.dims) xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) + + +def test_with_avi() -> None: + cube = Cube.from_specification("cala-io") + tube = Tube.from_specification("cala-io") + + runner = SynchronousRunner(tube=tube, cube=cube) + + import matplotlib.pyplot as plt + + gen = runner.iter() + for i, fr in enumerate(gen): + raw = fr["raw"].array + prep = fr["prep"].array / fr["prep"].array.max() * raw.max() + plt.imsave(f"out{i}.png", xr.concat([raw, prep], dim=AXIS.height_dim)) + + assert runner.cube.assets["buffer"].obj.array.size > 0 From becbedae9aedfad1bbf709c8a850e38e287fb282 Mon Sep 17 00:00:00 2001 From: Raymond Date: Fri, 22 Aug 2025 16:48:00 -0700 Subject: [PATCH 30/43] debug: fix parameters in denoise with odl.yaml --- tests/data/pipelines/odl.yaml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index 0218e762..569a2d51 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -29,18 +29,16 @@ assets: type: cala.assets.Residual scope: session -# Add refiners in nodes - nodes: source: type: cala.testing.SingleCellSource denoise: type: cala.nodes.prep.denoise params: - ksize: - - 3 - - 3 - sigmaX: 1.5 + method: gaussian + kwargs: + ksize: [3, 3] + sigmaX: 1.5 depends: - frame: source.frame glow: From ae0db27702bb4b46f015a17ca29ea87a1ac7d113 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 25 Aug 2025 20:02:02 -0700 Subject: [PATCH 31/43] feat: blob size detection ignores noise --- src/cala/nodes/prep/r_estimate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index 03782489..f77b313e 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -15,6 +15,8 @@ class SizeEst(BaseModel): """if this is set, no learning occurs.""" n_frames: int | None = None """how many first n frames to learn from. if none, keep learning forever""" + noise_threshold: float = 0.0 + log_kwargs: dict[str, Any] = Field(default_factory=dict) sizes_: list[float] = Field(default_factory=list) @@ -31,7 +33,9 @@ def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: if self.n_frames and self.n_frames < frame.array[AXIS.frame_coord]: return self._est_radius - blobs = blob_log(frame.array, **self.log_kwargs) + blobs = blob_log( + frame.array.where(frame.array > self.noise_threshold, 0, drop=False), **self.log_kwargs + ) if blobs.size == 0: return 0 From bb57e18b47422798b530ac939546b1979b62c02b Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 25 Aug 2025 20:03:55 -0700 Subject: [PATCH 32/43] feat: use l0 norm to leave error in residuals. this is to leave room for overlap, while not incrasing the error by too much. not sure if this is a good decision and i should just hardcode some error value --- src/cala/nodes/detect/slice_nmf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 985d0383..ccadd1da 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -58,7 +58,6 @@ def process( ) l1_norm = slice_.sum().item() - # l0_norm = np.prod(slice_.shape) # this fails when the residuals are tiny comp_recon = a_new @ c_new energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 @@ -68,7 +67,8 @@ def process( trs.append(Trace.from_array(c_new)) res = (res - comp_recon).clip(0) else: - res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 1e-7 + l0_norm = np.prod(slice_.shape) + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l0_norm return fps, trs From ed7207539386ff8d2459ec2574847e7f22faf182 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 25 Aug 2025 20:04:29 -0700 Subject: [PATCH 33/43] tests: generate_text_image comes out to util module tests: hline removal test --- src/cala/testing/util.py | 19 +++++++++++++++++++ tests/test_io.py | 15 +-------------- tests/test_prep/test_hlines.py | 23 +++++++++++++++++++++++ 3 files changed, 43 insertions(+), 14 deletions(-) create mode 100644 tests/test_prep/test_hlines.py diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index ac37b24e..8801343c 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -1,3 +1,5 @@ +import cv2 +import numpy as np import xarray as xr @@ -11,3 +13,20 @@ def assert_scalar_multiple_arrays(a: xr.DataArray, b: xr.DataArray, /, rtol: flo aabb = a.dot(a) * b.dot(b) assert abab > aabb * (1 - rtol) + + +def generate_text_image( + text: str, + frame_dims: tuple[int, int] = (256, 256), + org: tuple[int, int] = None, + color: tuple[int, int, int] = (255, 255, 255), + thickness: int = 2, + font_scale: int = 1, +) -> np.ndarray: + image = np.zeros(frame_dims, np.uint8) + font = cv2.FONT_HERSHEY_SIMPLEX + + if org is None: + org = (frame_dims[0] // 2, frame_dims[1] // 2) + + return cv2.putText(image, text, org, font, font_scale, color, thickness, cv2.LINE_AA) diff --git a/tests/test_io.py b/tests/test_io.py index 55cb5f1c..8c042b3f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -7,20 +7,7 @@ from skimage import io from cala.nodes.io import stream - - -def generate_text_image( - text: str, - frame_dims: tuple[int, int] = (256, 256), - org: tuple[int, int] = (50, 50), - color: tuple[int, int, int] = (255, 255, 255), - thickness: int = 2, - font_scale: int = 1, -) -> np.ndarray: - image = np.zeros(frame_dims, np.uint8) - font = cv2.FONT_HERSHEY_SIMPLEX - - return cv2.putText(image, text, org, font, font_scale, color, thickness, cv2.LINE_AA) +from cala.testing.util import generate_text_image def save_tiff(filename: Path, frame: np.ndarray) -> None: diff --git a/tests/test_prep/test_hlines.py b/tests/test_prep/test_hlines.py new file mode 100644 index 00000000..eacb7aa8 --- /dev/null +++ b/tests/test_prep/test_hlines.py @@ -0,0 +1,23 @@ +import numpy as np +from skimage.metrics import structural_similarity + +from cala.nodes.prep.hlines import remove +from cala.testing.util import generate_text_image +from cala.util import package_frame + + +def test_remove_lines(): + img = generate_text_image( + "8", frame_dims=(256, 256), org=(25, 230), thickness=20, font_scale=10 + ) + + noise_amp = 40 + noise = np.tile(np.random.randint(0, noise_amp, img.shape[0]), (img.shape[1], 1)).T + + noisy_img = img // 1.5 + noise + + frame = package_frame(noisy_img, 0) + + result = remove(frame) + + assert structural_similarity(img.astype(int), result.array.values.astype(int)) == 1 From 529640b7833cbc785197a7896bc3333707d548dd Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 25 Aug 2025 20:04:38 -0700 Subject: [PATCH 34/43] feat: hline removal in prep --- src/cala/nodes/prep/hlines.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 src/cala/nodes/prep/hlines.py diff --git a/src/cala/nodes/prep/hlines.py b/src/cala/nodes/prep/hlines.py new file mode 100644 index 00000000..aeb4145d --- /dev/null +++ b/src/cala/nodes/prep/hlines.py @@ -0,0 +1,56 @@ +from typing import Annotated as A +import numpy as np + +from noob import Name +from scipy.ndimage import convolve1d +from scipy.signal import firwin, welch + +from cala.assets import Frame + + +def remove( + frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025 +) -> A[Frame, Name("frame")]: + arr = frame.array + + if np.all(frame.array == 0): + return frame + + denoised = _remove_lines( + arr.values, distortion_freq=distortion_freq, num_taps=num_taps, eps=eps + ) + + dmin = denoised.min() + if dmin < 0: + denoised -= dmin + + arr.values = denoised + + return Frame.from_array(arr) + + +def _remove_lines(image, distortion_freq: float = None, num_taps: int = 65, eps: float = 0.025): + """ + Removes horizontal line artifacts from scanned image. + Args: + image: 2D or 3D array. + distortion_freq: Float, distortion frequency in cycles/pixel, or + `None` to estimate from spectrum. + num_taps: Integer, number of filter taps to use in each dimension. + eps: Small positive param to adjust filters cutoffs (cycles/pixel). + Returns: + Denoised image. + """ + if distortion_freq is None: + distortion_freq = _estimate_distortion_freq(image) + + hpf = firwin(num_taps, distortion_freq - eps, pass_zero="highpass", fs=1) + lpf = firwin(num_taps, eps, pass_zero="lowpass", fs=1) + return image - convolve1d(convolve1d(image, hpf, axis=0), lpf, axis=1) + + +def _estimate_distortion_freq(image, min_frequency=1 / 25): + """Estimates distortion frequency as spectral peak in vertical dim.""" + f, pxx = welch(image.sum(axis=1)) + pxx[f < min_frequency] = 0.0 + return f[pxx.argmax()] From a973d284b0a4832b34bc70baa19bf6ad2558d7d5 Mon Sep 17 00:00:00 2001 From: Raymond Date: Mon, 25 Aug 2025 20:06:36 -0700 Subject: [PATCH 35/43] feat: add hline removal, gaussian smoothing, increase reprod-tolerance to account for general noise --- tests/data/pipelines/with_src.yaml | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index 02eabcce..b2b8404e 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -57,24 +57,38 @@ nodes: method: nonlocal kwargs: h: 4 + templateWindowSize: 7 + searchWindowSize: 21 depends: - frame: saltpepper.frame + lines: + type: cala.nodes.prep.hlines.remove + depends: + - frame: glow.frame glow: type: cala.nodes.prep.GlowRemover depends: - frame: denoise.frame - filter: - type: - + smooth: + type: cala.nodes.prep.denoise + params: + method: gaussian + kwargs: + ksize: [ 7, 7 ] + sigmaX: 1.5 + depends: + - frame: lines.frame motion: type: cala.nodes.prep.RigidStabilizer params: drift_speed: 0.5 depends: - - frame: glow.frame + - frame: smooth.frame size_est: type: cala.nodes.prep.SizeEst params: + noise_threshold: 2.0 + n_frames: 30 log_kwargs: min_sigma: 3 max_sigma: 10 @@ -159,7 +173,7 @@ nodes: params: min_frames: 30 detect_thresh: 2.0 - reprod_tol: 0.001 + reprod_tol: 0.005 depends: - residuals: residual.movie - detect_radius: size_est.radius From 599f561176f8f92cbdc28c5a001a7a5dd1365462 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 16:16:09 -0700 Subject: [PATCH 36/43] debug: clip trace after addition --- src/cala/nodes/residual.py | 2 +- tests/test_iter/test_detect.py | 8 ++++---- tests/test_iter/test_residual.py | 12 ++++++++---- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 04cc6432..d5b520e3 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -104,7 +104,7 @@ def _align_overestimates( .reset_coords([AXIS.frame_coord, AXIS.timestamp_coord], drop=True) ) - return C_latest + xr.apply_ufunc(np.nan_to_num, dC, kwargs={"neginf": 0}).clip(min=0) + return (C_latest + xr.apply_ufunc(np.nan_to_num, dC, kwargs={"neginf": 0})).clip(min=0) def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index e8c6c110..db6ed22d 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -14,7 +14,7 @@ def slice_nmf(): spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"min_frames": 10, "detect_thresh": 1}, + params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.00001}, ) ) @@ -47,7 +47,7 @@ def test_chunks(self, single_cell): spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"min_frames": 10, "detect_thresh": 1}, + params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) fpts, trcs = nmf.process( @@ -158,10 +158,10 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): new_fps, new_trs = cataloger.process(fps, trs, Footprints(), Traces()) result = new_fps.array @ new_trs.array - expected = movie * (new_fps.array.max(dim=AXIS.component_dim) > 1e-4) + expected = movie * (new_fps.array.max(dim=AXIS.component_dim) > 1e-3) assert new_fps.array is not None # 1. the footprints do not overlap assert np.all(np.triu(new_fps.array @ new_fps.array.rename(AXIS.component_rename), 1) == 0) # 2. the trace and footprint values are accurate (where they do exist) - xr.testing.assert_allclose(result, expected.transpose(*result.dims), atol=1e-5) + xr.testing.assert_allclose(result, expected.transpose(*result.dims), atol=1e-3) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 45b47c73..47e6c21d 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -3,6 +3,8 @@ from noob.node import Node, NodeSpecification import xarray as xr + +from cala.assets import Residual from cala.models.axis import AXIS from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints @@ -16,6 +18,7 @@ def init() -> Node: def test_init(init, separate_cells) -> None: result = init.process( + residuals=Residual(), footprints=separate_cells.footprints, traces=separate_cells.traces, frames=separate_cells.make_movie(), @@ -44,12 +47,13 @@ def test_align_overestimates(single_cell) -> None: adjusted_traces = _align_overestimates(A=footprints, R_latest=last_res, C_latest=last_trace) - np.testing.assert_array_equal( - (footprints @ adjusted_traces).values, movie.array.isel({AXIS.frames_dim: -2}).values - ) + result = (footprints @ adjusted_traces).values + expected = movie.array.isel({AXIS.frames_dim: -2}).values + + np.testing.assert_array_equal(result, expected) def test_find_exposed_footprints(connected_cells) -> None: footprints = connected_cells.footprints result = _find_unlayered_footprints(footprints.array) - assert result.sum(dim=AXIS.component_dim).max().item() == footprints.max().item() + assert result.sum(dim=AXIS.component_dim).max().item() == footprints.array.max().item() From ba02168b55a8c8a6e2c869db310deefdd744eece Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 16:27:18 -0700 Subject: [PATCH 37/43] docs: simplify pr template --- .github/pull_request_template.md | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e27c1c51..930ef42e 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -21,10 +21,9 @@ Describe the tests that you ran to verify your changes. Provide instructions so that others can reproduce. --> -- [ ] Ran unit tests -- [ ] Ran integration tests -- [ ] Performed manual testing -- [ ] Updated existing tests +- [ ] Unit tests +- [ ] Integration tests +- [ ] Existing tests update ## 🛠️ Dependencies @@ -36,9 +35,5 @@ List any new dependencies added or existing ones updated. ## ✅ Checklist -- [ ] My code follows the project's style guidelines - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have made corresponding changes to the documentation -- [ ] My changes generate no new warnings -- [ ] I have added tests that prove my fix is effective or that my feature works -- [ ] New and existing unit tests pass locally with my changes From ef03b466beb93c437c020dfd42d5c832be1fec4e Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 17:16:37 -0700 Subject: [PATCH 38/43] docs: reduce r_estimate --- src/cala/nodes/prep/r_estimate.py | 2 +- tests/test_prep/test_r_estimate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index f77b313e..7124c09c 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -41,6 +41,6 @@ def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: self.centers_ = [blobs[:-1] for blobs in blobs] self.sizes_ += [blob[-1].item() for blob in blobs] - self._est_radius = int(np.round(np.median(self.sizes_) / 2).item()) + self._est_radius = (np.median(self.sizes_) // 2 + 1).astype(int) return self._est_radius diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py index bdd2bc64..9461bff1 100644 --- a/tests/test_prep/test_r_estimate.py +++ b/tests/test_prep/test_r_estimate.py @@ -28,7 +28,7 @@ def test_size_estim(separate_cells): ) result = node.get_median_radius(max_proj) - assert result == expected + assert result == expected // 2 + 1 assert len(node.sizes_) == 3 for center in node.centers_: From 023f3bbe9c93162aabc89d379804803b6c1bfa4d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 17:17:17 -0700 Subject: [PATCH 39/43] tests: add cwd_to_pytest_base --- tests/data/pipelines/with_src.yaml | 2 +- tests/fixtures/__init__.py | 2 ++ tests/fixtures/config.py | 8 ++++++++ tests/test_pipeline.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml index b2b8404e..3cdb055a 100644 --- a/tests/data/pipelines/with_src.yaml +++ b/tests/data/pipelines/with_src.yaml @@ -33,7 +33,7 @@ nodes: type: cala.nodes.io.stream params: files: - - data/movies/msCam1.avi + - tests/data/movies/msCam1.avi counter: type: cala.util.counter frame: diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index f4a4c7c3..e54798cc 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -7,6 +7,7 @@ tmp_config_source, tmp_cwd, yaml_config, + cwd_to_pytest_base, ) from .meta import monkeypatch_session from .toys import connected_cells, separate_cells, single_cell @@ -21,6 +22,7 @@ "tmp_config_source", "tmp_cwd", "yaml_config", + "cwd_to_pytest_base", "single_cell", "separate_cells", "connected_cells", diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 792a6be9..d251a3c6 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -1,3 +1,4 @@ +import os from collections.abc import Callable, MutableMapping from pathlib import Path from typing import Any @@ -156,3 +157,10 @@ def _flatten(d: MutableMapping, parent_key: str = "", separator: str = "__") -> else: items.append((new_key, value)) return dict(items) + + +@pytest.fixture +def cwd_to_pytest_base(request): + os.chdir(request.config.rootdir) + yield + os.chdir(request.config.invocation_params.dir) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f2b31631..e2034dcf 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -81,7 +81,7 @@ def test_odl(runner, source) -> None: xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) -def test_with_avi() -> None: +def test_with_avi(cwd_to_pytest_base) -> None: cube = Cube.from_specification("cala-io") tube = Tube.from_specification("cala-io") From 0da80858455cda51f428c2b1b92b6b126db7aa19 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 17:17:30 -0700 Subject: [PATCH 40/43] tests: update denoise test --- tests/test_prep/test_denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_prep/test_denoise.py b/tests/test_prep/test_denoise.py index 116e9718..87a7c503 100644 --- a/tests/test_prep/test_denoise.py +++ b/tests/test_prep/test_denoise.py @@ -46,7 +46,7 @@ def test_denoise( results = [] for frame in iter(gen): - results.append(denoise(frame=frame, method=method, **params)) + results.append(denoise(frame=frame, method=method, kwargs=params)) for exp, res in zip(expected, results): np.testing.assert_allclose(exp.values, res.array.values) From a1937a69bf7bfbacf24f44833c9238791ef57cf5 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 17:17:49 -0700 Subject: [PATCH 41/43] tests: update package_frame --- tests/test_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 1c119ec6..7791e3c8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,6 +12,6 @@ def test_package_frame(): timestamp = datetime(2023, 4, 8, 12, 0, 0) # Transform the frame - frame = package_frame(frame, index, timestamp) + result = package_frame(frame, index, timestamp) - assert np.array_equal(frame.array.values, frame) + assert np.array_equal(result.array.values, frame) From 850f3df32c61c7230c80198db0fd90436fbf354d Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 18:35:01 -0700 Subject: [PATCH 42/43] tests: xfail for splitoff while deprecation is in the works --- tests/test_pipeline.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e2034dcf..cdf6ceca 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -48,6 +48,7 @@ def test_process(runner) -> None: assert runner.cube.assets["buffer"].obj.array.size > 0 +@pytest.mark.xfail(raises=NotImplementedError) def test_odl(runner, source) -> None: gen = runner.iter(n=source.instance.n_frames) src_name = source.spec.type_.split(".")[-1] @@ -60,13 +61,16 @@ def test_odl(runner, source) -> None: trs = runner.cube.assets["traces"].obj # Correct component count - if src_name != "SeparateSource": + if src_name not in ["SeparateSource", "SplitOffSource"]: assert toy.traces.array.sizes[AXIS.component_dim] == trs.array.sizes[AXIS.component_dim] - else: + elif src_name == "SeparateSource": # 2 is the # of discoverable cells (non-constant) for SeparateSource assert trs.array.sizes[AXIS.component_dim] == 2 + elif src_name == "SplitOffSource": + # 3 because one should be deprecated + assert trs.array.sizes[AXIS.component_dim] == 3 - if src_name in ["TwoOverlappingSource", "GradualOnSource", "SplitOffSource"]: + if src_name in ["TwoOverlappingSource", "GradualOnSource"]: # Traces are reasonably similar tr_corr = xr.corr( toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frame_coord @@ -74,12 +78,17 @@ def test_odl(runner, source) -> None: for corr in tr_corr: assert np.isclose(corr.max(), 1, atol=1e-2) - else: + elif src_name in ["SingleCellSource", "TwoCellsSource", "SeparateSource"]: expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) result = (fps.array @ trs.array).transpose(*expected.dims) xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) + elif src_name == "SplitOffSource": + expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + result = (fps.array @ trs.array).transpose(*expected.dims) + raise NotImplementedError("Deprecation not implemented") + def test_with_avi(cwd_to_pytest_base) -> None: cube = Cube.from_specification("cala-io") From dad862bf5e66c20a649d494377461d4696822a10 Mon Sep 17 00:00:00 2001 From: Raymond Date: Tue, 26 Aug 2025 18:39:37 -0700 Subject: [PATCH 43/43] format: ruff --- src/cala/nodes/cleanup.py | 2 +- src/cala/nodes/io.py | 3 ++- src/cala/nodes/prep/denoise.py | 5 +---- src/cala/nodes/prep/hlines.py | 8 +++++--- src/cala/nodes/residual.py | 15 +++++++++++---- src/cala/util.py | 2 +- tests/fixtures/__init__.py | 2 +- tests/fixtures/config.py | 4 ++-- tests/test_iter/test_residual.py | 3 +-- tests/test_pipeline.py | 17 ----------------- 10 files changed, 25 insertions(+), 36 deletions(-) diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index ffc427f6..b1d4c453 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -5,7 +5,7 @@ import xarray as xr from noob import Name -from cala.assets import CompStats, Footprints, Overlaps, PixStats, Traces, Residual +from cala.assets import CompStats, Footprints, Overlaps, PixStats, Residual, Traces from cala.models import AXIS diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index bd8b21df..f7f5cc98 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,6 +1,7 @@ from abc import abstractmethod +from collections.abc import Generator from pathlib import Path -from typing import Protocol, Generator +from typing import Protocol import cv2 from numpy.typing import NDArray diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index 03f1e1fb..2efe2f33 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -25,10 +25,7 @@ def denoise( frame = frame.array - if method == "nonlocal": - arr = frame.values.astype(np.uint8) - else: - arr = frame.values.astype(np.float32) + arr = frame.values.astype(np.uint8) if method == "nonlocal" else frame.values.astype(np.float32) denoised = _func(arr, **kwargs).astype(float) diff --git a/src/cala/nodes/prep/hlines.py b/src/cala/nodes/prep/hlines.py index aeb4145d..2d74dfcc 100644 --- a/src/cala/nodes/prep/hlines.py +++ b/src/cala/nodes/prep/hlines.py @@ -1,6 +1,6 @@ from typing import Annotated as A -import numpy as np +import numpy as np from noob import Name from scipy.ndimage import convolve1d from scipy.signal import firwin, welch @@ -29,7 +29,9 @@ def remove( return Frame.from_array(arr) -def _remove_lines(image, distortion_freq: float = None, num_taps: int = 65, eps: float = 0.025): +def _remove_lines( + image: np.ndarray, distortion_freq: float = None, num_taps: int = 65, eps: float = 0.025 +) -> np.ndarray: """ Removes horizontal line artifacts from scanned image. Args: @@ -49,7 +51,7 @@ def _remove_lines(image, distortion_freq: float = None, num_taps: int = 65, eps: return image - convolve1d(convolve1d(image, hpf, axis=0), lpf, axis=1) -def _estimate_distortion_freq(image, min_frequency=1 / 25): +def _estimate_distortion_freq(image: np.ndarray, min_frequency: float = 1 / 25) -> float: """Estimates distortion frequency as spectral peak in vertical dim.""" f, pxx = welch(image.sum(axis=1)) pxx[f < min_frequency] = 0.0 diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index d5b520e3..f70d6ea7 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -1,8 +1,9 @@ from typing import Annotated as A +import numpy as np import xarray as xr from noob import Name -import numpy as np + from cala.assets import Footprints, Movie, Residual, Traces from cala.models import AXIS @@ -69,12 +70,17 @@ def _align_overestimates( 1. GradualOn: Know A. B turns ON -> trace tries to chase (increases) -> footprint tries to chase - -> residual becomes negative at A-B -> should just decrease, positive at A^B -> actually... should decrease (just more steeply) + -> residual becomes negative at A-B + -> should just decrease, positive at A^B + -> actually... should decrease (just more steeply) 2. SplitOff: Know AB. B turns OFF -> trace tries to chase (decreases) -> footprint tries to chase - -> residual becomes positive at A-B -> should increase, negative at A^B -> this should just decrease, MORE negative at B-A -> this going to zero makes sense + -> residual becomes positive at A-B + -> should increase, negative at A^B + -> this should just decrease, MORE negative at B-A + -> this going to zero makes sense OR keep B, remove A-B @@ -85,7 +91,8 @@ def _align_overestimates( for GradualOn, nothing should go to zero. for SplitOff, a chunk needs to go to zero. - So... how about we do something like (if it's been on for a long time, we become less likely to purge it?) + So... how about we do something like (if it's been on for a long time, + we become less likely to purge it?) We subsequently clip R minimum to zero, since all significant negative residual spots have been removed, and the remaining negative spots are noise level. diff --git a/src/cala/util.py b/src/cala/util.py index 4e2a7f90..32be5074 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence, Generator +from collections.abc import Generator, Sequence from datetime import datetime from itertools import count from typing import Annotated as A diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index e54798cc..ae2fb313 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,4 +1,5 @@ from .config import ( + cwd_to_pytest_base, set_config, set_dotenv, set_env, @@ -7,7 +8,6 @@ tmp_config_source, tmp_cwd, yaml_config, - cwd_to_pytest_base, ) from .meta import monkeypatch_session from .toys import connected_cells, separate_cells, single_cell diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index d251a3c6..47c9361f 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -1,5 +1,5 @@ import os -from collections.abc import Callable, MutableMapping +from collections.abc import Callable, Generator, MutableMapping from pathlib import Path from typing import Any @@ -160,7 +160,7 @@ def _flatten(d: MutableMapping, parent_key: str = "", separator: str = "__") -> @pytest.fixture -def cwd_to_pytest_base(request): +def cwd_to_pytest_base(request: pytest.FixtureRequest) -> Generator[None, Any, None]: os.chdir(request.config.rootdir) yield os.chdir(request.config.invocation_params.dir) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 47e6c21d..a136ce2d 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -1,8 +1,7 @@ import numpy as np import pytest -from noob.node import Node, NodeSpecification - import xarray as xr +from noob.node import Node, NodeSpecification from cala.assets import Residual from cala.models.axis import AXIS diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cdf6ceca..66ebfc9e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -88,20 +88,3 @@ def test_odl(runner, source) -> None: expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) result = (fps.array @ trs.array).transpose(*expected.dims) raise NotImplementedError("Deprecation not implemented") - - -def test_with_avi(cwd_to_pytest_base) -> None: - cube = Cube.from_specification("cala-io") - tube = Tube.from_specification("cala-io") - - runner = SynchronousRunner(tube=tube, cube=cube) - - import matplotlib.pyplot as plt - - gen = runner.iter() - for i, fr in enumerate(gen): - raw = fr["raw"].array - prep = fr["prep"].array / fr["prep"].array.max() * raw.max() - plt.imsave(f"out{i}.png", xr.concat([raw, prep], dim=AXIS.height_dim)) - - assert runner.cube.assets["buffer"].obj.array.size > 0