Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/cala/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator

from cala.models.axis import AXIS, Coords, Dims
from cala.models.checks import is_non_negative
from cala.models.checks import has_no_nan, is_non_negative
from cala.models.entity import Entity, Group


Expand Down Expand Up @@ -49,14 +49,19 @@ class Footprint(Asset):
name="footprint",
dims=(Dims.width.value, Dims.height.value),
dtype=float,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)


class Trace(Asset):
_entity: ClassVar[Entity] = PrivateAttr(
Entity(name="trace", dims=(Dims.frame.value,), dtype=float, checks=[is_non_negative])
Entity(
name="trace",
dims=(Dims.frame.value,),
dtype=float,
checks=[is_non_negative, has_no_nan],
)
)


Expand All @@ -66,7 +71,7 @@ class Frame(Asset):
name="frame",
dims=(Dims.width.value, Dims.height.value),
dtype=float,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -77,7 +82,7 @@ class Footprints(Asset):
name="footprint-group",
member=Footprint.entity(),
group_by=Dims.component,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)

Expand Down Expand Up @@ -125,7 +130,7 @@ def from_array(
name="trace-group",
member=Trace.entity(),
group_by=Dims.component,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -136,7 +141,7 @@ class Movie(Asset):
name="movie",
member=Frame.entity(),
group_by=Dims.frame.value,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -154,6 +159,7 @@ class PopSnap(Asset):
dims=(Dims.component.value,),
dtype=float,
coords=[Coords.frame.value, Coords.timestamp.value],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -170,7 +176,7 @@ class CompStats(Asset):
name="comp-stat",
dims=comp_dims,
dtype=float,
checks=[],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -181,7 +187,7 @@ class PixStats(Asset):
name="pix-stat",
dims=(Dims.width.value, Dims.height.value, Dims.component.value),
dtype=float,
checks=[],
checks=[is_non_negative, has_no_nan],
)
)

Expand All @@ -192,7 +198,7 @@ class Overlaps(Asset):
name="overlap",
dims=comp_dims,
dtype=bool,
checks=[],
checks=[has_no_nan],
)
)

Expand All @@ -214,6 +220,6 @@ class Residual(Asset):
name="frame",
member=Frame.entity(),
group_by=Dims.frame.value,
checks=[is_non_negative],
checks=[is_non_negative, has_no_nan],
)
)
5 changes: 5 additions & 0 deletions src/cala/models/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ def is_unique(da: xr.DataArray) -> None:
def is_unit_interval(da: xr.DataArray) -> None:
if da.min() < 0 or da.max() > 1:
raise ValueError("The values in DataArray are not unit interval.")


def has_no_nan(da: xr.DataArray) -> None:
if np.isnan(da).any():
raise ValueError("The DataArray has nan values.")
15 changes: 5 additions & 10 deletions src/cala/nodes/footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def ingest_frame(
"""
Update spatial footprints using sufficient statistics.

Γƒ[p, i] = max(Γƒ[p, i] + (W[p, i] - Γƒ[p, :]M[i, :])/M[i, i], 0)
Γƒ[p, i] = max(Γƒ[p, i] + (W[p, i] - Γƒ[p, :]M[:, i])/M[i, i], 0)

where:
- Γƒ is the spatial footprints matrix
Expand Down Expand Up @@ -67,21 +67,16 @@ def ingest_frame(
dask="allowed",
)

# Apply update equation with masking
update = numerator / M_diag
A_new = mask * (A + update)
A_new = xr.where(A_new > 0, A_new, 0)
update = numerator / (M_diag + np.finfo(float).tiny)
A_new = (mask * (A + update)).clip(min=0)

if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol:
A = A_new
break
footprints.array = A_new
return footprints
else:
A = A_new
mask = A > 0

footprints.array = A
return footprints

def _expansion_kernel(self) -> np.ndarray:
return cv2.getStructuringElement(cv2.MORPH_CROSS, (self.bep * 2 + 1, self.bep * 2 + 1))

Expand Down
57 changes: 40 additions & 17 deletions tests/test_iter/test_footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import xarray as xr
from noob.node import Node, NodeSpecification
from scipy.ndimage import binary_dilation, generate_binary_structure, grey_erosion
from scipy.ndimage import generate_binary_structure, grey_dilation, grey_erosion

from cala.assets import Footprints
from cala.models.axis import AXIS
Expand Down Expand Up @@ -32,6 +32,29 @@ def separate_cells() -> Toy:
)


@pytest.fixture
def connected_cells() -> Toy:
n_frames = 50

return Toy(
n_frames=n_frames,
frame_dims=FrameDims(width=50, height=50),
cell_radii=8,
cell_positions=[
Position(width=15, height=15),
Position(width=15, height=35),
Position(width=25, height=25),
Position(width=35, height=35),
],
cell_traces=[
np.random.randint(low=0, high=n_frames, size=n_frames).astype(float),
np.abs(np.sin(np.linspace(-np.pi, np.pi, n_frames)) * n_frames).astype(float),
np.array(range(n_frames), dtype=float),
np.array(range(n_frames - 1, -1, -1), dtype=float),
],
)


@pytest.fixture
def fpter() -> Node:
return Node.from_specification(
Expand All @@ -43,7 +66,7 @@ def fpter() -> Node:
)


@pytest.mark.parametrize("toy", ["separate_cells"])
@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"])
def test_ingest_frame(fpter, toy, request):
toy = request.getfixturevalue(toy)

Expand All @@ -60,7 +83,7 @@ def test_ingest_frame(fpter, toy, request):

expected = toy.footprints.copy()

assert result == expected
xr.testing.assert_allclose(result.array, expected.array)


@pytest.fixture
Expand All @@ -74,8 +97,9 @@ def xpander() -> Node:
)


@pytest.mark.parametrize("toy", ["separate_cells"])
def test_expand_boundary(xpander, toy, request):
@pytest.mark.parametrize("defect", [grey_erosion, grey_dilation])
@pytest.mark.parametrize("toy", ["separate_cells", "connected_cells"])
def test_boundary_morph(xpander, defect, toy, request):
"""
what would be the circumstances of needing boundary expansion:
existing footprint is too small.
Expand All @@ -93,6 +117,11 @@ def test_expand_boundary(xpander, toy, request):
so expansion is almost guaranteed every single loop.
this means after the expansion, we need to rely on removal of "overexpanded" pixels.
does that occur naturally at W - AM?
Not exactly.

W: width height comp, A: width height comp M: comp comp
M: avg dot product of traces
AM: footprint x how correlated other cells are
"""
toy = request.getfixturevalue(toy)

Expand All @@ -105,8 +134,8 @@ def test_expand_boundary(xpander, toy, request):

footprint = generate_binary_structure(2, 1)

eroded_fps = xr.apply_ufunc(
grey_erosion,
modded_fps = xr.apply_ufunc(
defect,
toy.footprints.array,
kwargs={"footprint": footprint},
vectorize=True,
Expand All @@ -115,17 +144,11 @@ def test_expand_boundary(xpander, toy, request):
)

result = xpander.process(
footprints=Footprints.from_array(eroded_fps),
footprints=Footprints.from_array(modded_fps),
pixel_stats=pixstats,
component_stats=compstats,
)

# import matplotlib.pyplot as plt
#
# for idx, fps in enumerate(zip(eroded_fps, toy.footprints.array, result.array)):
# ero, exp, res = fps
# plt.imsave(f"eroded_{idx}.png", ero)
# plt.imsave(f"expect_{idx}.png", exp)
# plt.imsave(f"result_{idx}.png", res)

assert result == toy.footprints
# expansion breaks when a trace is all-zero and overlaps with another component.
# we don't know why. all-zero trace is somewhat unlikely, but we probably need a solution.
xr.testing.assert_allclose(result.array, toy.footprints.array, rtol=1e-3)
8 changes: 4 additions & 4 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ def test_run(odl_runner) -> None:
@pytest.mark.xfail
def test_combined_footprint() -> None:
"""Start with two footprints combined"""
assert NotImplementedError()
raise AssertionError("Not implemented")


@pytest.mark.xfail
def test_dilating_footprint() -> None:
"""start with binary-eroded footprints"""
assert NotImplementedError()
raise AssertionError("Not implemented")


@pytest.mark.xfail
def test_eroding_footprint() -> None:
"""start with binary-dilated footprints"""
assert NotImplementedError()
raise AssertionError("Not implemented")


@pytest.mark.xfail
def test_redundant_footprint() -> None:
"""start with redundant footprints"""
assert NotImplementedError()
raise AssertionError("Not implemented")
Loading