Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 17 additions & 26 deletions src/cala/nodes/footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

class Footprinter:

def __init__(self, boundary_expansion_pixels: int | None = None, tolerance: float = 1e-7):
self.bep = boundary_expansion_pixels
def __init__(self, tol: float, max_iter: int | None = None, bep: int | None = None):
self.bep = bep
"""
Number of pixels to explore the boundary of the footprint outside of the current footprint.
"""

self.tol = tolerance
self.tol = tol

self.max_iter = max_iter

@process_method
def ingest_frame(
Expand Down Expand Up @@ -45,20 +47,14 @@ def ingest_frame(
M = component_stats.array
W = pixel_stats.array

converged = False
expanded = False
kernel = None

while not converged:
mask = A > 0

if self.bep:
kernel = kernel if kernel else self._expansion_kernel()
mask = A > 0

if not expanded:
mask = self._expand_boundary(kernel, mask)
expanded = True
if self.bep:
kernel = self._expansion_kernel()
mask = self._expand_boundary(kernel, mask)

# for _ in range(self.max_iter):
while True:
AM = A.rename(AXIS.component_rename) @ M
numerator = W - AM

Expand All @@ -76,30 +72,25 @@ def ingest_frame(
A_new = mask * (A + update)
A_new = xr.where(A_new > 0, A_new, 0)

if abs((A - A_new).sum() / np.prod(A.shape)) < self.tol:
if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol:
A = A_new
converged = True
break
else:
A = A_new
mask = A > 0

footprints.array = A
return footprints

def _expansion_kernel(self) -> np.ndarray:
return cv2.getStructuringElement(
cv2.MORPH_CROSS,
(
self.bep * 2 + 1,
self.bep * 2 + 1,
),
) # faster than np.ones
return cv2.getStructuringElement(cv2.MORPH_CROSS, (self.bep * 2 + 1, self.bep * 2 + 1))

def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArray:
return xr.apply_ufunc(
lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1),
mask.astype(np.uint8),
input_core_dims=[[*AXIS.spatial_dims]],
output_core_dims=[[*AXIS.spatial_dims]],
input_core_dims=[AXIS.spatial_dims],
output_core_dims=[AXIS.spatial_dims],
vectorize=True,
dask="parallelized",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/data/pipelines/odl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ nodes:
footprints_frame:
type: cala.nodes.footprints.Footprinter
params:
boundary_expansion_pixels: 1
tolerance: 1e-7
bep: 1
tol: 1e-7
depends:
- footprints: assets.footprints
- pixel_stats: pix_component.value
Expand Down
74 changes: 73 additions & 1 deletion tests/test_iter/test_footprints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pytest
import xarray as xr
from noob.node import Node, NodeSpecification
from scipy.ndimage import binary_dilation, generate_binary_structure, grey_erosion

from cala.assets import Footprints
from cala.models.axis import AXIS
from cala.testing.toy import FrameDims, Position, Toy


Expand Down Expand Up @@ -34,7 +38,7 @@ def fpter() -> Node:
NodeSpecification(
id="test_footprinter",
type="cala.nodes.footprints.Footprinter",
params={"boundary_expansion_pixels": None, "tolerance": 1e-7},
params={"bep": None, "tol": 1e-7},
)
)

Expand All @@ -57,3 +61,71 @@ def test_ingest_frame(fpter, toy, request):
expected = toy.footprints.copy()

assert result == expected


@pytest.fixture
def xpander() -> Node:
return Node.from_specification(
NodeSpecification(
id="test_footprinter",
type="cala.nodes.footprints.Footprinter",
params={"bep": 2, "tol": 1e-7},
)
)


@pytest.mark.parametrize("toy", ["separate_cells"])
def test_expand_boundary(xpander, toy, request):
"""
what would be the circumstances of needing boundary expansion:
existing footprint is too small.
does not affect component_stats
pixel_stats may care. if the correlation with a component and a pixel is high,
the pixel_stat would be high.
boundary-expansion would be literally trying to add pixel_stats (normalized) around the
boundary of the current footprint. (basically how many times pixel and trace coincided)
the thing is, pixel_stat never goes below zero. so you're always sort of adding the boundary
pixels.
this means this phenomenon needs to be regulated by another mechanism, i.e. pixel
value going to zero somehow.
it would be pretty hard to ensure the cell boundary does not forever expand, since
the longer the video, the more coincidences with any pixel and any trace will occur,
so expansion is almost guaranteed every single loop.
this means after the expansion, we need to rely on removal of "overexpanded" pixels.
does that occur naturally at W - AM?
"""
toy = request.getfixturevalue(toy)

pixstats = Node.from_specification(
NodeSpecification(id="test_pixstats", type="cala.nodes.pixel_stats.initialize")
).process(traces=toy.traces, frames=toy.make_movie())
compstats = Node.from_specification(
NodeSpecification(id="test_compstats", type="cala.nodes.component_stats.initialize")
).process(traces=toy.traces)

footprint = generate_binary_structure(2, 1)

eroded_fps = xr.apply_ufunc(
grey_erosion,
toy.footprints.array,
kwargs={"footprint": footprint},
vectorize=True,
input_core_dims=[AXIS.spatial_dims],
output_core_dims=[AXIS.spatial_dims],
)

result = xpander.process(
footprints=Footprints.from_array(eroded_fps),
pixel_stats=pixstats,
component_stats=compstats,
)

# import matplotlib.pyplot as plt
#
# for idx, fps in enumerate(zip(eroded_fps, toy.footprints.array, result.array)):
# ero, exp, res = fps
# plt.imsave(f"eroded_{idx}.png", ero)
# plt.imsave(f"expect_{idx}.png", exp)
# plt.imsave(f"result_{idx}.png", res)

assert result == toy.footprints
1 change: 0 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from noob import Cube, SynchronousRunner, Tube
from scipy.ndimage import binary_dilation, binary_erosion


@pytest.fixture
Expand Down
Loading