diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py index 2c753096..ace438ad 100644 --- a/src/cala/models/axis.py +++ b/src/cala/models/axis.py @@ -61,7 +61,7 @@ class Coords(Enum): height = Coord(name=AXIS.height_coord, dtype=int, checks=[is_unique]) width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique]) frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique]) - timestamp = Coord(name=AXIS.timestamp_coord, dtype=str) + timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique]) confidence = Coord(name=AXIS.confidence_coord, dtype=float, checks=[is_unit_interval]) diff --git a/src/cala/nodes/component_stats.py b/src/cala/nodes/component_stats.py index faa00208..72f9f0bf 100644 --- a/src/cala/nodes/component_stats.py +++ b/src/cala/nodes/component_stats.py @@ -1,3 +1,4 @@ +import numpy as np import xarray as xr from cala.assets import CompStats, Frame, PopSnap, Trace, Traces @@ -89,16 +90,24 @@ def ingest_component(component_stats: CompStats, traces: Traces, new_trace: Trac if new_trace.array is None: return component_stats - if component_stats.array is None: - component_stats.array = initialize(traces).array - return component_stats - # Get current frame index (starting with 1) t = new_trace.array[AXIS.frame_coord].max().item() + 1 - M = component_stats.array - c_new = new_trace.array + c_new = new_trace.array.volumize.dim_with_coords( + dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord] + ) c_buf = traces.array + M = component_stats.array + + if M is None or M.size == 1: + component_stats.array = initialize(traces).array + return component_stats + + if c_new[AXIS.id_coord].item() in M[AXIS.id_coord].values: + dim_idx = np.where(M[AXIS.id_coord].values == c_new[AXIS.id_coord].item())[0].tolist() + M = M.drop_sel({AXIS.component_dim: dim_idx, f"{AXIS.component_dim}": dim_idx}) + + # think i also have to remove the ID from c_buf? # Compute cross-correlation between buffer and new components # C_buf^T c_new diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index 954a6130..c3e45aab 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -1,10 +1,12 @@ from collections.abc import Hashable, Mapping from typing import Annotated as A +from typing import Any import numpy as np import xarray as xr from noob import Name from noob.node import Node +from pydantic import Field from sklearn.decomposition import NMF from cala.assets import Footprint, Residual, Trace @@ -13,30 +15,34 @@ class SliceNMF(Node): cell_radius: int + nmf_kwargs: dict[str, Any] = Field(default_factory=dict) validity_threshold: float + def model_post_init(self, context: Any, /) -> None: + self.nmf_kwargs.update({"n_components": 1, "init": "nndsvd"}) + if not self.nmf_kwargs.get("tol", None): + self.nmf_kwargs["tol"] = 1e-4 + def process( self, residuals: Residual, energy: xr.DataArray - ) -> tuple[A[Footprint | None, Name("new_fp")], A[Trace | None, Name("new_tr")]]: - if energy.size == 1: - return Footprint(), Trace() - - # Find and analyze neighborhood of maximum variance + ) -> tuple[A[Footprint, Name("new_fp")], A[Trace, Name("new_tr")]]: residuals = residuals.array - slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy) + if energy.size > 1 and residuals.max().item() > self.nmf_kwargs["tol"]: + # Find and analyze neighborhood of maximum variance + slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy) - a_new, c_new = self._local_nmf( - slice_=slice_, - spatial_sizes={k: v for k, v in residuals.sizes.items() if k in AXIS.spatial_dims}, - ) + a_new, c_new = self._local_nmf( + slice_=slice_, + spatial_sizes={k: v for k, v in residuals.sizes.items() if k in AXIS.spatial_dims}, + ) - # eventually we should just log this value instead of throwing out the component - # otherwise we keep coming back to this energy max point - if self._check_validity(a_new, residuals): - return Footprint.from_array(a_new), Trace.from_array(c_new) - else: - return None, None + # eventually we should just log this value instead of throwing out the component + # otherwise we keep coming back to this energy max point + if self._check_validity(a_new, residuals): + return Footprint.from_array(a_new), Trace.from_array(c_new) + + return Footprint(), Trace() def _get_max_energy_slice( self, @@ -87,10 +93,10 @@ def _local_nmf( R = slice_.stack(space=AXIS.spatial_dims).transpose(AXIS.frames_dim, "space") # Apply NMF (check how long nndsvd takes vs random) - model = NMF(n_components=1, init="nndsvd", tol=1e-4, max_iter=200) + model = NMF(**self.nmf_kwargs) # when residual is negative, the error becomes massive... - c = model.fit_transform(R.clip(0)) # temporal component + c = model.fit_transform(R) # temporal component a = model.components_ # spatial component # Convert back to xarray with proper dimensions and coordinates @@ -122,6 +128,10 @@ def _local_nmf( return a_new, c_new def _check_validity(self, a_new: xr.DataArray, residuals: xr.DataArray) -> bool: + """ + Think this is redundant with NMF.reconstruction_err_ + """ + # not sure if this step is necessary or even makes sense # how would a rank-1 nmf be not similar to the mean, unless the nmf error was massive? # and if the error is big, maybe it just means it's partially overlapping with another diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index 35000749..e249f4a4 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -101,5 +101,14 @@ def ingest_component(footprints: Footprints, new_footprint: Footprint | Footprin ) return footprints - footprints.array = xr.concat([footprints.array, new_footprint.array], dim=AXIS.component_dim) + if new_footprint.array[AXIS.id_coord].item() in footprints.array[AXIS.id_coord].values: + # if replacing (post-merging in catalog) + footprints.array.set_xindex(AXIS.id_coord).loc[ + {AXIS.id_coord: new_footprint.array[AXIS.id_coord].item()} + ] = new_footprint.array + else: + # if new + footprints.array = xr.concat( + [footprints.array, new_footprint.array], dim=AXIS.component_dim + ) return footprints diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/overlap.py index 9bf24aca..37453ad8 100644 --- a/src/cala/nodes/overlap.py +++ b/src/cala/nodes/overlap.py @@ -40,12 +40,14 @@ def ingest_component( if new_footprints.array is None: return overlaps - elif overlaps.array is None: + elif overlaps.array is None or overlaps.array.size == 1: overlaps.array = initialize(footprints).array return overlaps A = footprints.array - a_new = new_footprints.array + a_new = new_footprints.array.volumize.dim_with_coords( + dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord] + ) # Compute spatial overlaps between new and existing components bottom_left_overlap = A @ a_new.rename(AXIS.component_rename) diff --git a/src/cala/nodes/pixel_stats.py b/src/cala/nodes/pixel_stats.py index f615d187..25c9cc61 100644 --- a/src/cala/nodes/pixel_stats.py +++ b/src/cala/nodes/pixel_stats.py @@ -119,7 +119,12 @@ def ingest_component( # (1/t)Y_buf c_new^T new_stats = scale * (frames.array @ new_trace.array) - # Concatenate with existing pixel stats along component axis - pixel_stats.array = xr.concat([pixel_stats.array, new_stats], dim=AXIS.component_dim) + if new_stats[AXIS.id_coord].item() in pixel_stats.array[AXIS.id_coord].values: + pixel_stats.array.set_xindex(AXIS.id_coord).loc[ + {AXIS.id_coord: new_stats[AXIS.id_coord].item()} + ] = new_stats + else: + # Concatenate with existing pixel stats along component axis + pixel_stats.array = xr.concat([pixel_stats.array, new_stats], dim=AXIS.component_dim) return pixel_stats diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 4df4f25d..97a3b5c7 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -46,6 +46,5 @@ def single_cell_source( cell_radii=cell_radii, cell_positions=positions, cell_traces=traces, - confidences=[], ) return toy.movie_gen() diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index be089b62..b5c48d7c 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -1,9 +1,10 @@ from collections.abc import Generator, Iterable from datetime import datetime, timedelta +from typing import Self import numpy as np import xarray as xr -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator from skimage.morphology import disk from cala.assets import Footprints, Frame, Movie, Traces @@ -51,40 +52,68 @@ class Toy(BaseModel): cell_radii: int | list[int] cell_positions: list[Position] cell_traces: list[np.ndarray] - cell_ids: list[str] | None = None + cell_ids: list[str] """If none, auto populated as cell_{idx}.""" - confidences: list[float] = Field(default_factory=list) + confidences: list[float] _footprints: xr.DataArray = PrivateAttr(init=False) _traces: xr.DataArray = PrivateAttr(init=False) model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) - def model_post_init(self, __context: None = None) -> None: - assert self.n_frames > 0 + @field_validator("n_frames", mode="after") + @classmethod + def natural_num(cls, value: int) -> int: + assert value >= 0, "n_frames must be positive." + return value + + @model_validator(mode="before") + def fill_ids(self) -> Self: + if self.get("cell_ids", None) is None: + self["cell_ids"] = [f"cell_{idx}" for idx, _ in enumerate(self["cell_positions"])] + return self + + @model_validator(mode="before") + def fill_confidences(self) -> Self: + if self.get("confidences", None) is None: + self["confidences"] = [0.0] * len(self["cell_positions"]) + return self + + @model_validator(mode="before") + def fill_radii(self) -> Self: + self["cell_radii"] = ( + [self["cell_radii"]] * len(self["cell_positions"]) + if isinstance(self["cell_radii"], int) + else self["cell_radii"] + ) + return self - self.cell_radii = ( - [self.cell_radii] * len(self.cell_positions) - if isinstance(self.cell_radii, int) - else self.cell_radii + @model_validator(mode="after") + def consistent_n_cells(self) -> Self: + for cell_trace in self.cell_traces: + assert self.n_frames == len( + cell_trace + ), "inconsistent n_frames between n_frames and cell_traces" + return self + + @model_validator(mode="after") + def consistent_n_frames(self) -> Self: + assert len(self.cell_positions) == len(self.cell_traces), ( + f"inconsistent cell counts. " + f"positions: {len(self.cell_positions)}, " + f"traces: {len(self.cell_traces)}" ) + return self + + @model_validator(mode="after") + def cells_within_bounds(self) -> Self: for position, radius in zip(self.cell_positions, self.cell_radii): assert np.min([position.width, position.height]) - radius > 0 assert position.width + radius < self.frame_dims.width assert position.height + radius < self.frame_dims.height + return self - assert len(self.cell_positions) == len(self.cell_traces) - - for cell_trace in self.cell_traces: - assert self.n_frames == len(cell_trace) - - if self.cell_ids is None: - self.cell_ids = [f"cell_{idx}" for idx, _ in enumerate(self.cell_positions)] - assert len(self.cell_ids) == len(self.cell_traces) - - if not self.confidences: - self.confidences = [0.0] * len(self.cell_ids) - + def model_post_init(self, __context: None = None) -> None: self._footprints = self._build_footprints() self._traces = self._build_traces() diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index b843add4..4890d6d7 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -164,8 +164,13 @@ nodes: type: cala.nodes.footprints.Footprinter params: bep: 1 - tol: 1e-7 + tol: 0.0000001 depends: - footprints: assets.footprints - pixel_stats: pix_component.value - component_stats: comp_component.value + + return: + type: return + depends: + - motion.frame \ No newline at end of file diff --git a/tests/test_iter/test_component_stats.py b/tests/test_iter/test_component_stats.py index 54ccfb50..bc8d5a86 100644 --- a/tests/test_iter/test_component_stats.py +++ b/tests/test_iter/test_component_stats.py @@ -2,7 +2,7 @@ import pytest from noob.node import Node, NodeSpecification -from cala.assets import Frame, PopSnap, Traces +from cala.assets import Frame, PopSnap, Trace, Traces from cala.models import AXIS @@ -78,7 +78,7 @@ def test_ingest_component(init, comp_update, separate_cells): traces=Traces.from_array( separate_cells.traces.array.isel({AXIS.component_dim: slice(None, -1)}) ), - new_trace=Traces.from_array(separate_cells.traces.array.isel({AXIS.component_dim: [-1]})), + new_trace=Trace.from_array(separate_cells.traces.array.isel({AXIS.component_dim: -1})), ) expected = init.process(separate_cells.traces) diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index aee8245c..0fe9e717 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -151,4 +151,4 @@ def test_boundary_morph(xpander, defect, toy, request): # expansion breaks when a trace is all-zero and overlaps with another component. # we don't know why. all-zero trace is somewhat unlikely, but we probably need a solution. - xr.testing.assert_allclose(result.array, toy.footprints.array, rtol=1e-3) + xr.testing.assert_allclose(result.array, toy.footprints.array, atol=1e-3) diff --git a/tests/test_iter/test_overlaps.py b/tests/test_iter/test_overlaps.py index e5f7a23b..7a77c11b 100644 --- a/tests/test_iter/test_overlaps.py +++ b/tests/test_iter/test_overlaps.py @@ -2,7 +2,7 @@ import pytest from noob.node import Node, NodeSpecification -from cala.assets import Footprints +from cala.assets import Footprint, Footprints from cala.models import AXIS @@ -37,7 +37,7 @@ def comp_update() -> Node: def test_ingest_component(init, comp_update, toy, request) -> None: toy = request.getfixturevalue(toy) base = Footprints.from_array(toy.footprints.array.isel({AXIS.component_dim: slice(None, -1)})) - new = Footprints.from_array(toy.footprints.array.isel({AXIS.component_dim: [-1]})) + new = Footprint.from_array(toy.footprints.array.isel({AXIS.component_dim: -1})) pre_ingest = init.process(footprints=base) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 64480f14..3df2c6b1 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,9 @@ import pytest +import xarray as xr from noob import Cube, SynchronousRunner, Tube +from cala.models import AXIS + @pytest.fixture def odl_tube(): @@ -25,13 +28,19 @@ def test_process(odl_runner) -> None: assert odl_runner.cube.assets["buffer"].obj.array.size > 0 -@pytest.mark.xfail def test_iter(odl_runner) -> None: - gen = odl_runner.iter() + gen = odl_runner.iter(n=30) - result = next(gen) + movie = [] + for _, exp in enumerate(gen): + movie.append(exp[0].array) + fps = odl_runner.cube.assets["footprints"].obj + trs = odl_runner.cube.assets["traces"].obj - assert result + expected = xr.concat(movie, dim=AXIS.frames_dim) + result = (fps.array @ trs.array).transpose(*expected.dims) + + xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) @pytest.mark.xfail