Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/cala/nodes/detect/slice_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


class SliceNMF(Node):
cell_radius: int
nmf_kwargs: dict[str, Any] = Field(default_factory=dict)

errors_: list[float] = Field(default_factory=list)
Expand All @@ -28,7 +27,7 @@ def model_post_init(self, context: Any, /) -> None:
self._model = NMF(**self.nmf_kwargs)

def process(
self, residuals: Residual, energy: xr.DataArray
self, residuals: Residual, energy: xr.DataArray, detect_radius: int
) -> tuple[A[list[Footprint], Name("new_fps")], A[list[Trace], Name("new_trs")]]:
residuals = residuals.array.copy()

Expand All @@ -38,7 +37,9 @@ def process(
if energy.size > 1:
while np.sqrt(energy.max()).item() > self.nmf_kwargs["tol"]:
# Find and analyze neighborhood of maximum variance
slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy)
slice_ = self._get_max_energy_slice(
arr=residuals, energy_landscape=energy, radius=detect_radius
)

a_new, c_new = self._local_nmf(
slice_=slice_,
Expand Down Expand Up @@ -66,13 +67,13 @@ def _get_max_energy_slice(
self,
arr: xr.DataArray,
energy_landscape: xr.DataArray,
radius: int,
) -> xr.DataArray:
"""Find neighborhood around point of maximum variance."""
# Find maximum point
max_coords = energy_landscape.argmax(dim=AXIS.spatial_dims)

# Define neighborhood
radius = int(np.round(self.cell_radius))
window = {
ax: slice(
max(0, pos.values - radius),
Expand Down
3 changes: 2 additions & 1 deletion src/cala/nodes/prep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .background_removal import remove_background
from .denoise import denoise
from .glow_removal import GlowRemover
from .r_estimate import SizeEst
from .rigid_stabilization import RigidStabilizer

__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer]
__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst]
39 changes: 39 additions & 0 deletions src/cala/nodes/prep/r_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Annotated as A
from typing import Any

import numpy as np
from noob import Name, process_method
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from skimage.feature import blob_log

from cala.assets import Frame
from cala.models import AXIS


class SizeEst(BaseModel):
hardset_radius: int | None = None
"""if this is set, no learning occurs."""
n_frames: int | None = None
"""how many first n frames to learn from. if none, keep learning forever"""
log_kwargs: dict[str, Any] = Field(default_factory=dict)

sizes_: list[float] = Field(default_factory=list)
centers_: list[np.ndarray] = Field(default_factory=list)
_est_radius: int = PrivateAttr(None)

model_config = ConfigDict(arbitrary_types_allowed=True)

@process_method
def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]:
if self.hardset_radius:
return self.hardset_radius

if self.n_frames and self.n_frames < frame.array[AXIS.frame_coord]:
return self._est_radius

blobs = blob_log(frame.array, **self.log_kwargs)
self.centers_ = [blobs[:-1] for blobs in blobs]
self.sizes_ += [blob[-1].item() for blob in blobs]
self._est_radius = int(np.round(np.median(self.sizes_)).item())

return self._est_radius
2 changes: 1 addition & 1 deletion src/cala/testing/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def two_cells_source(
positions: list[dict] = None,
) -> Generator[A[Frame, Name("frame")]]:
frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims)
traces = [np.array(range(0, n_frames)), np.array([0, *range(30 - 1, 0, -1)])]
traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])]
if positions is None:
positions = [Position(width=206, height=206), Position(width=306, height=306)]
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
noob_id: cala-two-overlap-cells
noob_id: cala-odl
noob_model: noob.cube.CubeSpecification
noob_version: 0.1.1.dev118+g64d81b7

Expand Down Expand Up @@ -30,9 +30,9 @@ assets:

nodes:
source:
type: cala.testing.two_overlapping_source
type: cala.testing.single_cell_source
params:
n_frames: 50
n_frames: 30
denoise:
type: cala.nodes.prep.denoise
params:
Expand All @@ -52,6 +52,12 @@ nodes:
drift_speed: 1.0
depends:
- frame: glow.frame
size_est:
type: cala.nodes.prep.SizeEst
params:
hardset_radius: 10
depends:
- frame: motion.frame
cache:
type: cala.nodes.buffer.fill_buffer
params:
Expand Down Expand Up @@ -121,11 +127,10 @@ nodes:
- trigger: trace_frame.latest_trace
nmf:
type: cala.nodes.detect.SliceNMF
params:
cell_radius: 10
depends:
- residuals: residual.movie
- energy: energy.energy
- detect_radius: size_est.radius
catalog:
type: cala.nodes.detect.Cataloger
params:
Expand Down
165 changes: 0 additions & 165 deletions tests/data/pipelines/single_cell.yaml

This file was deleted.

Loading
Loading