From 530279bb15f79fe681257058741fab9e2855d3c0 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 7 Aug 2025 16:51:17 -0700 Subject: [PATCH] feat: footprint update diagnosis --- src/cala/assets.py | 28 +++++++++------ src/cala/models/checks.py | 5 +++ src/cala/nodes/footprints.py | 15 +++----- tests/test_iter/test_footprints.py | 57 +++++++++++++++++++++--------- tests/test_pipeline.py | 8 ++--- 5 files changed, 71 insertions(+), 42 deletions(-) diff --git a/src/cala/assets.py b/src/cala/assets.py index d5dfd4b9..c39f8945 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator from cala.models.axis import AXIS, Coords, Dims -from cala.models.checks import is_non_negative +from cala.models.checks import has_no_nan, is_non_negative from cala.models.entity import Entity, Group @@ -49,14 +49,19 @@ class Footprint(Asset): name="footprint", dims=(Dims.width.value, Dims.height.value), dtype=float, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) class Trace(Asset): _entity: ClassVar[Entity] = PrivateAttr( - Entity(name="trace", dims=(Dims.frame.value,), dtype=float, checks=[is_non_negative]) + Entity( + name="trace", + dims=(Dims.frame.value,), + dtype=float, + checks=[is_non_negative, has_no_nan], + ) ) @@ -66,7 +71,7 @@ class Frame(Asset): name="frame", dims=(Dims.width.value, Dims.height.value), dtype=float, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) @@ -77,7 +82,7 @@ class Footprints(Asset): name="footprint-group", member=Footprint.entity(), group_by=Dims.component, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) @@ -125,7 +130,7 @@ def from_array( name="trace-group", member=Trace.entity(), group_by=Dims.component, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) @@ -136,7 +141,7 @@ class Movie(Asset): name="movie", member=Frame.entity(), group_by=Dims.frame.value, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) @@ -154,6 +159,7 @@ class PopSnap(Asset): dims=(Dims.component.value,), dtype=float, coords=[Coords.frame.value, Coords.timestamp.value], + checks=[is_non_negative, has_no_nan], ) ) @@ -170,7 +176,7 @@ class CompStats(Asset): name="comp-stat", dims=comp_dims, dtype=float, - checks=[], + checks=[is_non_negative, has_no_nan], ) ) @@ -181,7 +187,7 @@ class PixStats(Asset): name="pix-stat", dims=(Dims.width.value, Dims.height.value, Dims.component.value), dtype=float, - checks=[], + checks=[is_non_negative, has_no_nan], ) ) @@ -192,7 +198,7 @@ class Overlaps(Asset): name="overlap", dims=comp_dims, dtype=bool, - checks=[], + checks=[has_no_nan], ) ) @@ -214,6 +220,6 @@ class Residual(Asset): name="frame", member=Frame.entity(), group_by=Dims.frame.value, - checks=[is_non_negative], + checks=[is_non_negative, has_no_nan], ) ) diff --git a/src/cala/models/checks.py b/src/cala/models/checks.py index d010dd0d..ea3815e4 100644 --- a/src/cala/models/checks.py +++ b/src/cala/models/checks.py @@ -16,3 +16,8 @@ def is_unique(da: xr.DataArray) -> None: def is_unit_interval(da: xr.DataArray) -> None: if da.min() < 0 or da.max() > 1: raise ValueError("The values in DataArray are not unit interval.") + + +def has_no_nan(da: xr.DataArray) -> None: + if np.isnan(da).any(): + raise ValueError("The DataArray has nan values.") diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index b4b441ff..35000749 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -26,7 +26,7 @@ def ingest_frame( """ Update spatial footprints using sufficient statistics. - Ã[p, i] = max(Ã[p, i] + (W[p, i] - Ã[p, :]M[i, :])/M[i, i], 0) + Ã[p, i] = max(Ã[p, i] + (W[p, i] - Ã[p, :]M[:, i])/M[i, i], 0) where: - Ã is the spatial footprints matrix @@ -67,21 +67,16 @@ def ingest_frame( dask="allowed", ) - # Apply update equation with masking - update = numerator / M_diag - A_new = mask * (A + update) - A_new = xr.where(A_new > 0, A_new, 0) + update = numerator / (M_diag + np.finfo(float).tiny) + A_new = (mask * (A + update)).clip(min=0) if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol: - A = A_new - break + footprints.array = A_new + return footprints else: A = A_new mask = A > 0 - footprints.array = A - return footprints - def _expansion_kernel(self) -> np.ndarray: return cv2.getStructuringElement(cv2.MORPH_CROSS, (self.bep * 2 + 1, self.bep * 2 + 1)) diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index dc07cb6a..aee8245c 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -2,7 +2,7 @@ import pytest import xarray as xr from noob.node import Node, NodeSpecification -from scipy.ndimage import binary_dilation, generate_binary_structure, grey_erosion +from scipy.ndimage import generate_binary_structure, grey_dilation, grey_erosion from cala.assets import Footprints from cala.models.axis import AXIS @@ -32,6 +32,29 @@ def separate_cells() -> Toy: ) +@pytest.fixture +def connected_cells() -> Toy: + n_frames = 50 + + return Toy( + n_frames=n_frames, + frame_dims=FrameDims(width=50, height=50), + cell_radii=8, + cell_positions=[ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ], + cell_traces=[ + np.random.randint(low=0, high=n_frames, size=n_frames).astype(float), + np.abs(np.sin(np.linspace(-np.pi, np.pi, n_frames)) * n_frames).astype(float), + np.array(range(n_frames), dtype=float), + np.array(range(n_frames - 1, -1, -1), dtype=float), + ], + ) + + @pytest.fixture def fpter() -> Node: return Node.from_specification( @@ -43,7 +66,7 @@ def fpter() -> Node: ) -@pytest.mark.parametrize("toy", ["separate_cells"]) +@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"]) def test_ingest_frame(fpter, toy, request): toy = request.getfixturevalue(toy) @@ -60,7 +83,7 @@ def test_ingest_frame(fpter, toy, request): expected = toy.footprints.copy() - assert result == expected + xr.testing.assert_allclose(result.array, expected.array) @pytest.fixture @@ -74,8 +97,9 @@ def xpander() -> Node: ) -@pytest.mark.parametrize("toy", ["separate_cells"]) -def test_expand_boundary(xpander, toy, request): +@pytest.mark.parametrize("defect", [grey_erosion, grey_dilation]) +@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"]) +def test_boundary_morph(xpander, defect, toy, request): """ what would be the circumstances of needing boundary expansion: existing footprint is too small. @@ -93,6 +117,11 @@ def test_expand_boundary(xpander, toy, request): so expansion is almost guaranteed every single loop. this means after the expansion, we need to rely on removal of "overexpanded" pixels. does that occur naturally at W - AM? + Not exactly. + + W: width height comp, A: width height comp M: comp comp + M: avg dot product of traces + AM: footprint x how correlated other cells are """ toy = request.getfixturevalue(toy) @@ -105,8 +134,8 @@ def test_expand_boundary(xpander, toy, request): footprint = generate_binary_structure(2, 1) - eroded_fps = xr.apply_ufunc( - grey_erosion, + modded_fps = xr.apply_ufunc( + defect, toy.footprints.array, kwargs={"footprint": footprint}, vectorize=True, @@ -115,17 +144,11 @@ def test_expand_boundary(xpander, toy, request): ) result = xpander.process( - footprints=Footprints.from_array(eroded_fps), + footprints=Footprints.from_array(modded_fps), pixel_stats=pixstats, component_stats=compstats, ) - # import matplotlib.pyplot as plt - # - # for idx, fps in enumerate(zip(eroded_fps, toy.footprints.array, result.array)): - # ero, exp, res = fps - # plt.imsave(f"eroded_{idx}.png", ero) - # plt.imsave(f"expect_{idx}.png", exp) - # plt.imsave(f"result_{idx}.png", res) - - assert result == toy.footprints + # expansion breaks when a trace is all-zero and overlaps with another component. + # we don't know why. all-zero trace is somewhat unlikely, but we probably need a solution. + xr.testing.assert_allclose(result.array, toy.footprints.array, rtol=1e-3) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c31e0187..64480f14 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,22 +44,22 @@ def test_run(odl_runner) -> None: @pytest.mark.xfail def test_combined_footprint() -> None: """Start with two footprints combined""" - assert NotImplementedError() + raise AssertionError("Not implemented") @pytest.mark.xfail def test_dilating_footprint() -> None: """start with binary-eroded footprints""" - assert NotImplementedError() + raise AssertionError("Not implemented") @pytest.mark.xfail def test_eroding_footprint() -> None: """start with binary-dilated footprints""" - assert NotImplementedError() + raise AssertionError("Not implemented") @pytest.mark.xfail def test_redundant_footprint() -> None: """start with redundant footprints""" - assert NotImplementedError() + raise AssertionError("Not implemented")