Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/cala/nodes/remove.py β†’ src/cala/nodes/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,37 @@
from cala.models import AXIS


def get_razed_ids(
footprints: Footprints, min_thicc: int, trigger: bool
) -> A[xr.DataArray, Name("keep_ids")]:
def purge_razed_components(
footprints: Footprints,
traces: Traces,
pix_stats: PixStats,
comp_stats: CompStats,
overlaps: Overlaps,
min_thicc: int,
trigger: bool,
) -> tuple[
A[Footprints, Name("footprints")],
A[Traces, Name("traces")],
A[PixStats, Name("pix_stats")],
A[CompStats, Name("comp_stats")],
A[Overlaps, Name("overlaps")],
]:
keep_ids = _get_razed_ids(footprints=footprints, min_thicc=min_thicc)
return filter_components(
footprints=footprints,
traces=traces,
pix_stats=pix_stats,
comp_stats=comp_stats,
overlaps=overlaps,
keep_ids=keep_ids,
)


def _get_razed_ids(footprints: Footprints, min_thicc: int) -> A[xr.DataArray, Name("keep_ids")]:
"""
:param min_thicc: minimum number of pixel thickness to keep the cell
:return:
"""
A = footprints.array

if A is None:
Expand Down
3 changes: 2 additions & 1 deletion src/cala/nodes/component_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def ingest_component(component_stats: CompStats, traces: Traces, new_traces: Tra

if merged_ids:
M = (
M.set_xindex([AXIS.id_coord, f"{AXIS.id_coord}'"])
M.set_xindex(AXIS.id_coord)
.set_xindex(f"{AXIS.id_coord}'")
.sel({AXIS.id_coord: intact_ids, f"{AXIS.id_coord}'": intact_ids})
.reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"])
)
Expand Down
3 changes: 1 addition & 2 deletions src/cala/nodes/detect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .catalog import Cataloger
from .energy import Energy
from .slice_nmf import SliceNMF

__all__ = [Energy, SliceNMF, Cataloger]
__all__ = [SliceNMF, Cataloger]
10 changes: 5 additions & 5 deletions src/cala/nodes/detect/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def process(
new_fps = xr.concat([fp.array for fp in new_fps], dim=AXIS.component_dim)
new_trs = xr.concat([tr.array for tr in new_trs], dim=AXIS.component_dim)

conn_mat = self._connection_matrix(new_fps, new_trs)
num, label = connected_components(conn_mat)
merge_mat = self._merge_matrix(new_fps, new_trs)
num, label = connected_components(merge_mat)
combined_fps = []
combined_trs = []

Expand All @@ -52,12 +52,12 @@ def process(
new_fps = xr.concat([fp.array for fp in combined_fps], dim=AXIS.component_dim)
new_trs = xr.concat([tr.array for tr in combined_trs], dim=AXIS.component_dim)

conn_mat = self._connection_matrix(new_fps, new_trs, existing_fp, existing_tr)
merge_mat = self._merge_matrix(new_fps, new_trs, existing_fp, existing_tr)
footprints = []
traces = []

# we're not doing connected components because it's not square matrix
for i, dupes in enumerate(conn_mat.transpose(AXIS.component_dim, ...)):
for i, dupes in enumerate(merge_mat.transpose(AXIS.component_dim, ...)):
if not any(dupes) or existing_fp is None or existing_tr is None:
footprint, trace = self._register(new_fps[i], new_trs[i])
else:
Expand Down Expand Up @@ -197,7 +197,7 @@ def _reshape(

return Footprint.from_array(a_new), Trace.from_array(c_new)

def _connection_matrix(
def _merge_matrix(
self,
fps: xr.DataArray,
trs: xr.DataArray,
Expand Down
52 changes: 0 additions & 52 deletions src/cala/nodes/detect/energy.py

This file was deleted.

91 changes: 51 additions & 40 deletions src/cala/nodes/detect/slice_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,70 @@
from sklearn.decomposition import NMF

from cala.assets import Footprint, Residual, Trace
from cala.logging import init_logger
from cala.models import AXIS


class SliceNMF(Node):
min_frames: int
"""Wait until this number of frames to begin detecting."""
detect_thresh: float
"""Minimum detection threshold for brightness fluctuation."""
nmf_kwargs: dict[str, Any] = Field(default_factory=dict)

errors_: list[float] = Field(default_factory=list)
error_: float = Field(None)
_model: NMF = PrivateAttr(None)

_logger = init_logger(__name__)

def model_post_init(self, context: Any, /) -> None:
self.nmf_kwargs.update({"n_components": 1, "init": "nndsvd"})
if not self.nmf_kwargs.get("tol", None):
self.nmf_kwargs["tol"] = 1e-4

self._model = NMF(**self.nmf_kwargs)

def process(
self, residuals: Residual, energy: xr.DataArray, detect_radius: int
self, residuals: Residual, detect_radius: int
) -> tuple[A[list[Footprint], Name("new_fps")], A[list[Trace], Name("new_trs")]]:
residuals = residuals.array.copy()

fpts = []
trcs = []

if energy.size > 1:
while np.sqrt(energy.max()).item() > self.nmf_kwargs["tol"]:
# Find and analyze neighborhood of maximum variance
slice_ = self._get_max_energy_slice(
arr=residuals, energy_landscape=energy, radius=detect_radius
)

a_new, c_new = self._local_nmf(
slice_=slice_,
spatial_sizes={
k: v for k, v in residuals.sizes.items() if k in AXIS.spatial_dims
},
)

l1_norm = slice_.sum().item()
comp_recon = a_new @ c_new
shift = (comp_recon).median(dim=AXIS.frames_dim)
comp_energy = ((comp_recon - shift) ** 2).sum(dim=AXIS.frames_dim)
energy -= comp_energy

if (self.errors_[-1] / l1_norm) <= self.nmf_kwargs["tol"]:
fpts.append(Footprint.from_array(a_new))
trcs.append(Trace.from_array(c_new))
residuals = (residuals - a_new @ c_new).clip(0)
else:
energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0
residuals.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0
return fpts, trcs
res = residuals.array.copy()

if res.sizes[AXIS.frames_dim] < self.min_frames:
return [], []

energy = self._get_energy(res)

fps = []
trs = []

while np.sqrt(energy.max()).item() > self.detect_thresh: # or use res directly
# Find and analyze neighborhood of maximum variance
slice_ = self._get_max_energy_slice(
arr=res, energy_landscape=energy, radius=detect_radius
)

a_new, c_new = self._local_nmf(
slice_=slice_,
spatial_sizes={k: v for k, v in res.sizes.items() if k in AXIS.spatial_dims},
)

l1_norm = slice_.sum().item()
# l0_norm = np.prod(slice_.shape) # this fails when the residuals are tiny
comp_recon = a_new @ c_new

energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0

if (self.error_ / l1_norm) <= self._model.tol:
fps.append(Footprint.from_array(a_new))
trs.append(Trace.from_array(c_new))
res = (res - comp_recon).clip(0)
else:
res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0

return fps, trs

def _get_energy(self, res: xr.DataArray) -> xr.DataArray:
pixels_median = res.median(dim=AXIS.frames_dim)
V = res - pixels_median

return (V**2).sum(dim=AXIS.frames_dim)

def _get_max_energy_slice(
self,
Expand Down Expand Up @@ -114,8 +126,7 @@ def _local_nmf(
c = self._model.fit_transform(R) # temporal component
a = self._model.components_ # spatial component

err = self._model.reconstruction_err_.item()
self.errors_.append(err)
self.error_ = self._model.reconstruction_err_.item()

# Convert back to xarray with proper dimensions and coordinates
c_new = xr.DataArray(
Expand Down
25 changes: 7 additions & 18 deletions src/cala/nodes/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,17 @@
from cala.models import AXIS


def initialize(
footprints: Footprints,
) -> Overlaps:
"""
Sparse matrix of component footprint overlaps.

Args:
footprints (Footprints): Current temporal component c_t.
"""
def initialize(overlaps: Overlaps, footprints: Footprints) -> Overlaps:
A = footprints.array

# Use matrix multiplication with broadcasting to compute overlaps
data = (A @ A.rename(AXIS.component_rename)) > 0
if A is None:
return overlaps

return Overlaps.from_array(data)
V = (A @ A.rename(AXIS.component_rename)) > 0

overlaps.array = V

def ingest_frame(overlaps: Overlaps, footprints: Footprints) -> Overlaps:
if footprints.array is None:
return overlaps
return initialize(footprints)
return overlaps


def ingest_component(
Expand All @@ -44,8 +34,7 @@ def ingest_component(
return overlaps

elif overlaps.array is None or overlaps.array.size == 1:
overlaps.array = initialize(footprints).array
return overlaps
return initialize(overlaps, footprints)

V = overlaps.array

Expand Down
6 changes: 3 additions & 3 deletions src/cala/nodes/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces:
:return:
"""

if not new_traces:
return traces

c = traces.array
c_det = new_traces.array

if c_det is None:
return traces

if c is None:
traces.array = c_det
return traces
Expand Down
27 changes: 7 additions & 20 deletions tests/data/pipelines/odl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ nodes:
- footprints: assets.footprints
- frame: motion.frame
- overlaps: assets.overlaps

pix_frame:
type: cala.nodes.pixel_stats.ingest_frame
depends:
Expand All @@ -98,38 +97,26 @@ nodes:
- frames: assets.buffer
- footprints: assets.footprints
- traces: assets.traces

over_est:
type: cala.nodes.remove.get_razed_ids
cleanup:
type: cala.nodes.cleanup.purge_razed_components
params:
min_thicc: 3
depends:
- footprints: assets.footprints
- trigger: residual.movie

clean:
type: cala.nodes.remove.filter_components
depends:
- footprints: assets.footprints
- traces: assets.traces
- pix_stats: assets.pix_stats
- comp_stats: assets.comp_stats
- overlaps: assets.overlaps
- keep_ids: over_est.keep_ids
- trigger: residual.movie

# DETECT BEGINS
energy:
type: cala.nodes.detect.Energy
params:
min_frames: 10
depends:
- residuals: residual.movie
- trigger: trace_frame.latest_trace
nmf:
type: cala.nodes.detect.SliceNMF
params:
min_frames: 10
detect_thresh: 1.0
depends:
- residuals: residual.movie
- energy: energy.energy
- detect_radius: size_est.radius
catalog:
type: cala.nodes.detect.Cataloger
Expand Down Expand Up @@ -177,7 +164,7 @@ nodes:
- component_stats: comp_component.value

overlaps_update:
type: cala.nodes.overlap.ingest_frame
type: cala.nodes.overlap.initialize
depends:
- overlaps: assets.overlaps
- footprints: footprints_frame.footprints
Expand Down
Loading
Loading