Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cala/models/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Coords(Enum):
height = Coord(name=AXIS.height_coord, dtype=int, checks=[is_unique])
width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique])
frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique])
timestamp = Coord(name=AXIS.timestamp_coord, dtype=str)
timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique])
confidence = Coord(name=AXIS.confidence_coord, dtype=float, checks=[is_unit_interval])


Expand Down
21 changes: 15 additions & 6 deletions src/cala/nodes/component_stats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import xarray as xr

from cala.assets import CompStats, Frame, PopSnap, Trace, Traces
Expand Down Expand Up @@ -89,16 +90,24 @@ def ingest_component(component_stats: CompStats, traces: Traces, new_trace: Trac
if new_trace.array is None:
return component_stats

if component_stats.array is None:
component_stats.array = initialize(traces).array
return component_stats

# Get current frame index (starting with 1)
t = new_trace.array[AXIS.frame_coord].max().item() + 1

M = component_stats.array
c_new = new_trace.array
c_new = new_trace.array.volumize.dim_with_coords(
dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord]
)
c_buf = traces.array
M = component_stats.array

if M is None or M.size == 1:
component_stats.array = initialize(traces).array
return component_stats

if c_new[AXIS.id_coord].item() in M[AXIS.id_coord].values:
dim_idx = np.where(M[AXIS.id_coord].values == c_new[AXIS.id_coord].item())[0].tolist()
M = M.drop_sel({AXIS.component_dim: dim_idx, f"{AXIS.component_dim}": dim_idx})

# think i also have to remove the ID from c_buf?

# Compute cross-correlation between buffer and new components
# C_buf^T c_new
Expand Down
46 changes: 28 additions & 18 deletions src/cala/nodes/detect/slice_nmf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Hashable, Mapping
from typing import Annotated as A
from typing import Any

import numpy as np
import xarray as xr
from noob import Name
from noob.node import Node
from pydantic import Field
from sklearn.decomposition import NMF

from cala.assets import Footprint, Residual, Trace
Expand All @@ -13,30 +15,34 @@

class SliceNMF(Node):
cell_radius: int
nmf_kwargs: dict[str, Any] = Field(default_factory=dict)
validity_threshold: float

def model_post_init(self, context: Any, /) -> None:
self.nmf_kwargs.update({"n_components": 1, "init": "nndsvd"})
if not self.nmf_kwargs.get("tol", None):
self.nmf_kwargs["tol"] = 1e-4

def process(
self, residuals: Residual, energy: xr.DataArray
) -> tuple[A[Footprint | None, Name("new_fp")], A[Trace | None, Name("new_tr")]]:
if energy.size == 1:
return Footprint(), Trace()

# Find and analyze neighborhood of maximum variance
) -> tuple[A[Footprint, Name("new_fp")], A[Trace, Name("new_tr")]]:
residuals = residuals.array

slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy)
if energy.size > 1 and residuals.max().item() > self.nmf_kwargs["tol"]:
# Find and analyze neighborhood of maximum variance
slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy)

a_new, c_new = self._local_nmf(
slice_=slice_,
spatial_sizes={k: v for k, v in residuals.sizes.items() if k in AXIS.spatial_dims},
)
a_new, c_new = self._local_nmf(
slice_=slice_,
spatial_sizes={k: v for k, v in residuals.sizes.items() if k in AXIS.spatial_dims},
)

# eventually we should just log this value instead of throwing out the component
# otherwise we keep coming back to this energy max point
if self._check_validity(a_new, residuals):
return Footprint.from_array(a_new), Trace.from_array(c_new)
else:
return None, None
# eventually we should just log this value instead of throwing out the component
# otherwise we keep coming back to this energy max point
if self._check_validity(a_new, residuals):
return Footprint.from_array(a_new), Trace.from_array(c_new)

return Footprint(), Trace()

def _get_max_energy_slice(
self,
Expand Down Expand Up @@ -87,10 +93,10 @@ def _local_nmf(
R = slice_.stack(space=AXIS.spatial_dims).transpose(AXIS.frames_dim, "space")

# Apply NMF (check how long nndsvd takes vs random)
model = NMF(n_components=1, init="nndsvd", tol=1e-4, max_iter=200)
model = NMF(**self.nmf_kwargs)

# when residual is negative, the error becomes massive...
c = model.fit_transform(R.clip(0)) # temporal component
c = model.fit_transform(R) # temporal component
a = model.components_ # spatial component

# Convert back to xarray with proper dimensions and coordinates
Expand Down Expand Up @@ -122,6 +128,10 @@ def _local_nmf(
return a_new, c_new

def _check_validity(self, a_new: xr.DataArray, residuals: xr.DataArray) -> bool:
"""
Think this is redundant with NMF.reconstruction_err_
"""

# not sure if this step is necessary or even makes sense
# how would a rank-1 nmf be not similar to the mean, unless the nmf error was massive?
# and if the error is big, maybe it just means it's partially overlapping with another
Expand Down
11 changes: 10 additions & 1 deletion src/cala/nodes/footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,14 @@ def ingest_component(footprints: Footprints, new_footprint: Footprint | Footprin
)
return footprints

footprints.array = xr.concat([footprints.array, new_footprint.array], dim=AXIS.component_dim)
if new_footprint.array[AXIS.id_coord].item() in footprints.array[AXIS.id_coord].values:
# if replacing (post-merging in catalog)
footprints.array.set_xindex(AXIS.id_coord).loc[
{AXIS.id_coord: new_footprint.array[AXIS.id_coord].item()}
] = new_footprint.array
else:
# if new
footprints.array = xr.concat(
[footprints.array, new_footprint.array], dim=AXIS.component_dim
)
return footprints
6 changes: 4 additions & 2 deletions src/cala/nodes/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def ingest_component(
if new_footprints.array is None:
return overlaps

elif overlaps.array is None:
elif overlaps.array is None or overlaps.array.size == 1:
overlaps.array = initialize(footprints).array
return overlaps

A = footprints.array
a_new = new_footprints.array
a_new = new_footprints.array.volumize.dim_with_coords(
dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord]
)

# Compute spatial overlaps between new and existing components
bottom_left_overlap = A @ a_new.rename(AXIS.component_rename)
Expand Down
9 changes: 7 additions & 2 deletions src/cala/nodes/pixel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def ingest_component(
# (1/t)Y_buf c_new^T
new_stats = scale * (frames.array @ new_trace.array)

# Concatenate with existing pixel stats along component axis
pixel_stats.array = xr.concat([pixel_stats.array, new_stats], dim=AXIS.component_dim)
if new_stats[AXIS.id_coord].item() in pixel_stats.array[AXIS.id_coord].values:
pixel_stats.array.set_xindex(AXIS.id_coord).loc[
{AXIS.id_coord: new_stats[AXIS.id_coord].item()}
] = new_stats
else:
# Concatenate with existing pixel stats along component axis
pixel_stats.array = xr.concat([pixel_stats.array, new_stats], dim=AXIS.component_dim)

return pixel_stats
1 change: 0 additions & 1 deletion src/cala/testing/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,5 @@ def single_cell_source(
cell_radii=cell_radii,
cell_positions=positions,
cell_traces=traces,
confidences=[],
)
return toy.movie_gen()
71 changes: 50 additions & 21 deletions src/cala/testing/toy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Generator, Iterable
from datetime import datetime, timedelta
from typing import Self

import numpy as np
import xarray as xr
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator
from skimage.morphology import disk

from cala.assets import Footprints, Frame, Movie, Traces
Expand Down Expand Up @@ -51,40 +52,68 @@ class Toy(BaseModel):
cell_radii: int | list[int]
cell_positions: list[Position]
cell_traces: list[np.ndarray]
cell_ids: list[str] | None = None
cell_ids: list[str]
"""If none, auto populated as cell_{idx}."""
confidences: list[float] = Field(default_factory=list)
confidences: list[float]

_footprints: xr.DataArray = PrivateAttr(init=False)
_traces: xr.DataArray = PrivateAttr(init=False)

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

def model_post_init(self, __context: None = None) -> None:
assert self.n_frames > 0
@field_validator("n_frames", mode="after")
@classmethod
def natural_num(cls, value: int) -> int:
assert value >= 0, "n_frames must be positive."
return value

@model_validator(mode="before")
def fill_ids(self) -> Self:
if self.get("cell_ids", None) is None:
self["cell_ids"] = [f"cell_{idx}" for idx, _ in enumerate(self["cell_positions"])]
return self

@model_validator(mode="before")
def fill_confidences(self) -> Self:
if self.get("confidences", None) is None:
self["confidences"] = [0.0] * len(self["cell_positions"])
return self

@model_validator(mode="before")
def fill_radii(self) -> Self:
self["cell_radii"] = (
[self["cell_radii"]] * len(self["cell_positions"])
if isinstance(self["cell_radii"], int)
else self["cell_radii"]
)
return self

self.cell_radii = (
[self.cell_radii] * len(self.cell_positions)
if isinstance(self.cell_radii, int)
else self.cell_radii
@model_validator(mode="after")
def consistent_n_cells(self) -> Self:
for cell_trace in self.cell_traces:
assert self.n_frames == len(
cell_trace
), "inconsistent n_frames between n_frames and cell_traces"
return self

@model_validator(mode="after")
def consistent_n_frames(self) -> Self:
assert len(self.cell_positions) == len(self.cell_traces), (
f"inconsistent cell counts. "
f"positions: {len(self.cell_positions)}, "
f"traces: {len(self.cell_traces)}"
)
return self

@model_validator(mode="after")
def cells_within_bounds(self) -> Self:
for position, radius in zip(self.cell_positions, self.cell_radii):
assert np.min([position.width, position.height]) - radius > 0
assert position.width + radius < self.frame_dims.width
assert position.height + radius < self.frame_dims.height
return self

assert len(self.cell_positions) == len(self.cell_traces)

for cell_trace in self.cell_traces:
assert self.n_frames == len(cell_trace)

if self.cell_ids is None:
self.cell_ids = [f"cell_{idx}" for idx, _ in enumerate(self.cell_positions)]
assert len(self.cell_ids) == len(self.cell_traces)

if not self.confidences:
self.confidences = [0.0] * len(self.cell_ids)

def model_post_init(self, __context: None = None) -> None:
self._footprints = self._build_footprints()
self._traces = self._build_traces()

Expand Down
7 changes: 6 additions & 1 deletion tests/data/pipelines/odl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ nodes:
type: cala.nodes.footprints.Footprinter
params:
bep: 1
tol: 1e-7
tol: 0.0000001
depends:
- footprints: assets.footprints
- pixel_stats: pix_component.value
- component_stats: comp_component.value

return:
type: return
depends:
- motion.frame
4 changes: 2 additions & 2 deletions tests/test_iter/test_component_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from noob.node import Node, NodeSpecification

from cala.assets import Frame, PopSnap, Traces
from cala.assets import Frame, PopSnap, Trace, Traces
from cala.models import AXIS


Expand Down Expand Up @@ -78,7 +78,7 @@ def test_ingest_component(init, comp_update, separate_cells):
traces=Traces.from_array(
separate_cells.traces.array.isel({AXIS.component_dim: slice(None, -1)})
),
new_trace=Traces.from_array(separate_cells.traces.array.isel({AXIS.component_dim: [-1]})),
new_trace=Trace.from_array(separate_cells.traces.array.isel({AXIS.component_dim: -1})),
)

expected = init.process(separate_cells.traces)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_iter/test_footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def test_boundary_morph(xpander, defect, toy, request):

# expansion breaks when a trace is all-zero and overlaps with another component.
# we don't know why. all-zero trace is somewhat unlikely, but we probably need a solution.
xr.testing.assert_allclose(result.array, toy.footprints.array, rtol=1e-3)
xr.testing.assert_allclose(result.array, toy.footprints.array, atol=1e-3)
4 changes: 2 additions & 2 deletions tests/test_iter/test_overlaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from noob.node import Node, NodeSpecification

from cala.assets import Footprints
from cala.assets import Footprint, Footprints
from cala.models import AXIS


Expand Down Expand Up @@ -37,7 +37,7 @@ def comp_update() -> Node:
def test_ingest_component(init, comp_update, toy, request) -> None:
toy = request.getfixturevalue(toy)
base = Footprints.from_array(toy.footprints.array.isel({AXIS.component_dim: slice(None, -1)}))
new = Footprints.from_array(toy.footprints.array.isel({AXIS.component_dim: [-1]}))
new = Footprint.from_array(toy.footprints.array.isel({AXIS.component_dim: -1}))

pre_ingest = init.process(footprints=base)

Expand Down
17 changes: 13 additions & 4 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
import xarray as xr
from noob import Cube, SynchronousRunner, Tube

from cala.models import AXIS


@pytest.fixture
def odl_tube():
Expand All @@ -25,13 +28,19 @@ def test_process(odl_runner) -> None:
assert odl_runner.cube.assets["buffer"].obj.array.size > 0


@pytest.mark.xfail
def test_iter(odl_runner) -> None:
gen = odl_runner.iter()
gen = odl_runner.iter(n=30)

result = next(gen)
movie = []
for _, exp in enumerate(gen):
movie.append(exp[0].array)
fps = odl_runner.cube.assets["footprints"].obj
trs = odl_runner.cube.assets["traces"].obj

assert result
expected = xr.concat(movie, dim=AXIS.frames_dim)
result = (fps.array @ trs.array).transpose(*expected.dims)

xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5)


@pytest.mark.xfail
Expand Down
Loading