diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 2119c767..b4b441ff 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -9,13 +9,15 @@ class Footprinter: - def __init__(self, boundary_expansion_pixels: int | None = None, tolerance: float = 1e-7): - self.bep = boundary_expansion_pixels + def __init__(self, tol: float, max_iter: int | None = None, bep: int | None = None): + self.bep = bep """ Number of pixels to explore the boundary of the footprint outside of the current footprint. """ - self.tol = tolerance + self.tol = tol + + self.max_iter = max_iter @process_method def ingest_frame( @@ -45,20 +47,14 @@ def ingest_frame( M = component_stats.array W = pixel_stats.array - converged = False - expanded = False - kernel = None - - while not converged: - mask = A > 0 - - if self.bep: - kernel = kernel if kernel else self._expansion_kernel() + mask = A > 0 - if not expanded: - mask = self._expand_boundary(kernel, mask) - expanded = True + if self.bep: + kernel = self._expansion_kernel() + mask = self._expand_boundary(kernel, mask) + # for _ in range(self.max_iter): + while True: AM = A.rename(AXIS.component_rename) @ M numerator = W - AM @@ -76,30 +72,25 @@ def ingest_frame( A_new = mask * (A + update) A_new = xr.where(A_new > 0, A_new, 0) - if abs((A - A_new).sum() / np.prod(A.shape)) < self.tol: + if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol: A = A_new - converged = True + break else: A = A_new + mask = A > 0 footprints.array = A return footprints def _expansion_kernel(self) -> np.ndarray: - return cv2.getStructuringElement( - cv2.MORPH_CROSS, - ( - self.bep * 2 + 1, - self.bep * 2 + 1, - ), - ) # faster than np.ones + return cv2.getStructuringElement(cv2.MORPH_CROSS, (self.bep * 2 + 1, self.bep * 2 + 1)) def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArray: return xr.apply_ufunc( lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1), mask.astype(np.uint8), - input_core_dims=[[*AXIS.spatial_dims]], - output_core_dims=[[*AXIS.spatial_dims]], + input_core_dims=[AXIS.spatial_dims], + output_core_dims=[AXIS.spatial_dims], vectorize=True, dask="parallelized", ) diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index bae08f67..b843add4 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -163,8 +163,8 @@ nodes: footprints_frame: type: cala.nodes.footprints.Footprinter params: - boundary_expansion_pixels: 1 - tolerance: 1e-7 + bep: 1 + tol: 1e-7 depends: - footprints: assets.footprints - pixel_stats: pix_component.value diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index e91339b6..dc07cb6a 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -1,7 +1,11 @@ import numpy as np import pytest +import xarray as xr from noob.node import Node, NodeSpecification +from scipy.ndimage import binary_dilation, generate_binary_structure, grey_erosion +from cala.assets import Footprints +from cala.models.axis import AXIS from cala.testing.toy import FrameDims, Position, Toy @@ -34,7 +38,7 @@ def fpter() -> Node: NodeSpecification( id="test_footprinter", type="cala.nodes.footprints.Footprinter", - params={"boundary_expansion_pixels": None, "tolerance": 1e-7}, + params={"bep": None, "tol": 1e-7}, ) ) @@ -57,3 +61,71 @@ def test_ingest_frame(fpter, toy, request): expected = toy.footprints.copy() assert result == expected + + +@pytest.fixture +def xpander() -> Node: + return Node.from_specification( + NodeSpecification( + id="test_footprinter", + type="cala.nodes.footprints.Footprinter", + params={"bep": 2, "tol": 1e-7}, + ) + ) + + +@pytest.mark.parametrize("toy", ["separate_cells"]) +def test_expand_boundary(xpander, toy, request): + """ + what would be the circumstances of needing boundary expansion: + existing footprint is too small. + does not affect component_stats + pixel_stats may care. if the correlation with a component and a pixel is high, + the pixel_stat would be high. + boundary-expansion would be literally trying to add pixel_stats (normalized) around the + boundary of the current footprint. (basically how many times pixel and trace coincided) + the thing is, pixel_stat never goes below zero. so you're always sort of adding the boundary + pixels. + this means this phenomenon needs to be regulated by another mechanism, i.e. pixel + value going to zero somehow. + it would be pretty hard to ensure the cell boundary does not forever expand, since + the longer the video, the more coincidences with any pixel and any trace will occur, + so expansion is almost guaranteed every single loop. + this means after the expansion, we need to rely on removal of "overexpanded" pixels. + does that occur naturally at W - AM? + """ + toy = request.getfixturevalue(toy) + + pixstats = Node.from_specification( + NodeSpecification(id="test_pixstats", type="cala.nodes.pixel_stats.initialize") + ).process(traces=toy.traces, frames=toy.make_movie()) + compstats = Node.from_specification( + NodeSpecification(id="test_compstats", type="cala.nodes.component_stats.initialize") + ).process(traces=toy.traces) + + footprint = generate_binary_structure(2, 1) + + eroded_fps = xr.apply_ufunc( + grey_erosion, + toy.footprints.array, + kwargs={"footprint": footprint}, + vectorize=True, + input_core_dims=[AXIS.spatial_dims], + output_core_dims=[AXIS.spatial_dims], + ) + + result = xpander.process( + footprints=Footprints.from_array(eroded_fps), + pixel_stats=pixstats, + component_stats=compstats, + ) + + # import matplotlib.pyplot as plt + # + # for idx, fps in enumerate(zip(eroded_fps, toy.footprints.array, result.array)): + # ero, exp, res = fps + # plt.imsave(f"eroded_{idx}.png", ero) + # plt.imsave(f"expect_{idx}.png", exp) + # plt.imsave(f"result_{idx}.png", res) + + assert result == toy.footprints diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8f8a6c12..c31e0187 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,5 @@ import pytest from noob import Cube, SynchronousRunner, Tube -from scipy.ndimage import binary_dilation, binary_erosion @pytest.fixture