From a213c890501d75164264e50bad1d71ecb9965d93 Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 14 Aug 2025 13:21:58 -0700 Subject: [PATCH 1/3] feat: learn cell radius in detection --- src/cala/nodes/detect/slice_nmf.py | 9 +-- src/cala/nodes/prep/__init__.py | 3 +- src/cala/nodes/prep/r_estimate.py | 41 +++++++++++ tests/test_iter/test_detect.py | 109 +++++++++++------------------ tests/test_prep/test_r_estimate.py | 38 ++++++++++ 5 files changed, 128 insertions(+), 72 deletions(-) create mode 100644 src/cala/nodes/prep/r_estimate.py create mode 100644 tests/test_prep/test_r_estimate.py diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index d54e1346..ffda4a25 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -14,7 +14,6 @@ class SliceNMF(Node): - cell_radius: int nmf_kwargs: dict[str, Any] = Field(default_factory=dict) errors_: list[float] = Field(default_factory=list) @@ -28,7 +27,7 @@ def model_post_init(self, context: Any, /) -> None: self._model = NMF(**self.nmf_kwargs) def process( - self, residuals: Residual, energy: xr.DataArray + self, residuals: Residual, energy: xr.DataArray, detect_radius: int ) -> tuple[A[list[Footprint], Name("new_fps")], A[list[Trace], Name("new_trs")]]: residuals = residuals.array.copy() @@ -38,7 +37,9 @@ def process( if energy.size > 1: while np.sqrt(energy.max()).item() > self.nmf_kwargs["tol"]: # Find and analyze neighborhood of maximum variance - slice_ = self._get_max_energy_slice(arr=residuals, energy_landscape=energy) + slice_ = self._get_max_energy_slice( + arr=residuals, energy_landscape=energy, radius=detect_radius + ) a_new, c_new = self._local_nmf( slice_=slice_, @@ -66,13 +67,13 @@ def _get_max_energy_slice( self, arr: xr.DataArray, energy_landscape: xr.DataArray, + radius: int, ) -> xr.DataArray: """Find neighborhood around point of maximum variance.""" # Find maximum point max_coords = energy_landscape.argmax(dim=AXIS.spatial_dims) # Define neighborhood - radius = int(np.round(self.cell_radius)) window = { ax: slice( max(0, pos.values - radius), diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 192ff3da..361cb0cd 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -2,5 +2,6 @@ from .denoise import denoise from .glow_removal import GlowRemover from .rigid_stabilization import RigidStabilizer +from .r_estimate import SizeEst -__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer] +__all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst] diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py new file mode 100644 index 00000000..e6c4a231 --- /dev/null +++ b/src/cala/nodes/prep/r_estimate.py @@ -0,0 +1,41 @@ +from typing import Annotated as A, Any +import numpy as np +from noob import process_method, Name +from pydantic import BaseModel, Field, ConfigDict, PrivateAttr, model_validator + +from skimage.feature import blob_log +from cala.assets import Frame +from cala.models import AXIS + + +class SizeEst(BaseModel): + hardset_radius: int | None = None + """if this is set, no learning occurs.""" + n_frames: int | None = None + """how many first n frames to learn from. if none, keep learning forever""" + log_kwargs: dict[str, Any] = Field(default_factory=dict) + + sizes_: list[float] = Field(default_factory=list) + centers_: list[np.ndarray] = Field(default_factory=list) + _est_radius: int = PrivateAttr(None) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @model_validator(mode="after") + def validity_check(self): + assert self.hardset_radius or self.n_frames + + @process_method + def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: + if self.hardset_radius: + return self.hardset_radius + + if self.n_frames and self.n_frames < frame.array[AXIS.frame_coord]: + return self._est_radius + + blobs = blob_log(frame.array, **self.log_kwargs) + self.centers_ = [blobs[:-1] for blobs in blobs] + self.sizes_ += [blob[-1].item() for blob in blobs] + self._est_radius = int(np.round(np.median(self.sizes_)).item()) + + return self._est_radius diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index bfcb8bc6..3dde5e43 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -28,11 +28,6 @@ def toy(): ) -@pytest.fixture(autouse=True, scope="class") -def single_cell_video(toy): - return toy.make_movie() - - @pytest.fixture(scope="class") def energy(): return Energy.from_specification( @@ -45,17 +40,16 @@ def energy(): @pytest.fixture(scope="function") -def energy_shape(energy, single_cell_video): - return energy.process(Residual.from_array(single_cell_video.array), trigger=True) +def energy_shape(energy, toy): + return energy.process(Residual.from_array(toy.make_movie().array), trigger=True) @pytest.fixture(scope="class") -def slice_nmf(toy): +def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"cell_radius": 2 * toy.cell_radii[0]}, ) ) @@ -70,29 +64,33 @@ def cataloger(): class TestEnergy: - def test_estimate_gaussian_noise(self, energy, single_cell_video): - noise_level = energy._estimate_gaussian_noise(single_cell_video.array) + def test_estimate_gaussian_noise(self, energy, toy): + noise_level = energy._estimate_gaussian_noise(toy.make_movie().array) print(f"\nNoise Level: {noise_level}") - def test_center_to_median(self, energy, single_cell_video): - centered_video = energy._center_to_median(single_cell_video.array) - assert centered_video.max() < single_cell_video.array.max() + def test_center_to_median(self, energy, toy): + centered_video = energy._center_to_median(toy.make_movie().array) + assert centered_video.max() < toy.make_movie().array.max() - def test_process(self, energy, single_cell_video): + def test_process(self, energy, toy): energy_landscape = energy.process( - residuals=Residual.from_array(single_cell_video.array), trigger=True + residuals=Residual.from_array(toy.make_movie().array), trigger=True ) - assert energy_landscape.sizes == single_cell_video.array[0].sizes + assert energy_landscape.sizes == toy.make_movie().array[0].sizes assert np.all(energy_landscape >= 0) class TestSliceNMF: - def test_get_max_energy_slice(self, slice_nmf, single_cell_video, energy_shape): - slice_ = slice_nmf._get_max_energy_slice(single_cell_video.array, energy_shape) + def test_get_max_energy_slice(self, slice_nmf, toy, energy_shape): + slice_ = slice_nmf._get_max_energy_slice( + toy.make_movie().array, energy_shape, radius=toy.cell_radii[0] * 2 + ) return slice_ - def test_local_nmf(self, slice_nmf, single_cell_video, energy_shape, toy): - slice_ = slice_nmf._get_max_energy_slice(single_cell_video.array, energy_shape) + def test_local_nmf(self, slice_nmf, toy, energy_shape): + slice_ = slice_nmf._get_max_energy_slice( + toy.make_movie().array, energy_shape, radius=toy.cell_radii[0] * 2 + ) footprint, trace = slice_nmf._local_nmf( slice_, toy.frame_dims.model_dump(), @@ -100,9 +98,11 @@ def test_local_nmf(self, slice_nmf, single_cell_video, energy_shape, toy): assert_scalar_multiple_arrays(footprint, toy.footprints.array) - def test_process(self, slice_nmf, single_cell_video, energy_shape, toy): + def test_process(self, slice_nmf, toy, energy_shape): new_component = slice_nmf.process( - Residual.from_array(single_cell_video.array), energy_shape + Residual.from_array(toy.make_movie().array), + energy_shape, + detect_radius=toy.cell_radii[0] * 2, ) if new_component: new_fp, new_tr = new_component @@ -112,15 +112,16 @@ def test_process(self, slice_nmf, single_cell_video, energy_shape, toy): for new, old in zip([new_fp[0], new_tr[0]], [toy.footprints, toy.traces]): assert_scalar_multiple_arrays(new.array, old.array) - def test_chunks(self, single_cell_video, energy_shape, toy): + def test_chunks(self, energy_shape, toy): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"cell_radius": 10}, ) ) - fpts, trcs = nmf.process(Residual.from_array(single_cell_video.array), energy_shape) + fpts, trcs = nmf.process( + Residual.from_array(toy.make_movie().array), energy_shape, detect_radius=10 + ) if not fpts or not trcs: raise AssertionError("Failed to detect a new component") @@ -136,14 +137,10 @@ def test_chunks(self, single_cell_video, energy_shape, toy): class TestCataloger: @pytest.fixture(scope="function") - def new_component(self, single_cell_video, energy_shape): - return SliceNMF.from_specification( - spec=NodeSpecification( - id="test_slice_nmf", - type="cala.nodes.detect.slice_nmf.SliceNMF", - params={"cell_radius": 60}, - ) - ).process(Residual.from_array(single_cell_video.array), energy_shape) + def new_component(self, slice_nmf, toy, energy_shape): + return slice_nmf.process( + Residual.from_array(toy.make_movie().array), energy_shape, detect_radius=60 + ) def test_register(self, cataloger, new_component, toy): new_fp, new_tr = new_component @@ -155,14 +152,10 @@ def test_register(self, cataloger, new_component, toy): assert np.array_equal(fp.array, new_fp[0].array) assert np.array_equal(tr.array, new_tr[0].array) - def test_merge_with(self, cataloger, toy, single_cell_video, energy_shape): - new_component = SliceNMF.from_specification( - spec=NodeSpecification( - id="test_slice_nmf", - type="cala.nodes.detect.slice_nmf.SliceNMF", - params={"cell_radius": 10}, - ) - ).process(Residual.from_array(single_cell_video.array), energy_shape) + def test_merge_with(self, slice_nmf, cataloger, toy, energy_shape): + new_component = slice_nmf.process( + Residual.from_array(toy.make_movie().array), energy_shape, detect_radius=10 + ) new_fp, new_tr = new_component fp, tr = cataloger._merge_with( @@ -174,23 +167,17 @@ def test_merge_with(self, cataloger, toy, single_cell_video, energy_shape): ) movie_new_comp = new_fp[0].array @ new_tr[0].array - movie_expected = (single_cell_video.array + movie_new_comp).transpose(*movie_result.dims) + movie_expected = (toy.make_movie().array + movie_new_comp).transpose(*movie_result.dims) xr.testing.assert_allclose(movie_result, movie_expected) - def test_process_ideal(self, cataloger, separate_cells, energy): + def test_process_ideal(self, slice_nmf, cataloger, separate_cells, energy): """ test cataloging separate cells. ideal case with cell_radius=5 """ movie = separate_cells.make_movie().array ener = energy.process(Residual.from_array(movie), trigger=True) - fps, trs = SliceNMF.from_specification( - spec=NodeSpecification( - id="test_slice_nmf", - type="cala.nodes.detect.slice_nmf.SliceNMF", - params={"cell_radius": 5}, - ) - ).process(Residual.from_array(movie), ener) + fps, trs = slice_nmf.process(Residual.from_array(movie), ener, detect_radius=5) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -209,19 +196,13 @@ def test_process_ideal(self, cataloger, separate_cells, energy): assert new_fps.array.attrs.get("replaces") == detected xr.testing.assert_allclose(result, expected) - def test_process_fail(self, cataloger, separate_cells, energy): + def test_process_fail(self, slice_nmf, cataloger, separate_cells, energy): """ test cataloging separate cells. nmf supposed to fail with radius=25 (grabs too many cells) """ movie = separate_cells.make_movie().array ener = energy.process(Residual.from_array(movie), trigger=True) - fps, trs = SliceNMF.from_specification( - spec=NodeSpecification( - id="test_slice_nmf", - type="cala.nodes.detect.slice_nmf.SliceNMF", - params={"cell_radius": 25}, - ) - ).process(Residual.from_array(movie), ener) + fps, trs = slice_nmf.process(Residual.from_array(movie), ener, detect_radius=25) # NOTE: by manually putting in separate_cells, we're forcing a double-detection in this test new_fps, new_trs = cataloger.process( @@ -230,19 +211,13 @@ def test_process_fail(self, cataloger, separate_cells, energy): assert new_fps.array is None and new_trs.array is None - def test_process_connected(self, cataloger, connected_cells, energy): + def test_process_connected(self, slice_nmf, cataloger, connected_cells, energy): """ trial with connected cells 🙏 """ movie = connected_cells.make_movie().array ener = energy.process(Residual.from_array(movie), trigger=True) - fps, trs = SliceNMF.from_specification( - spec=NodeSpecification( - id="test_slice_nmf", - type="cala.nodes.detect.slice_nmf.SliceNMF", - params={"cell_radius": 4}, - ) - ).process(Residual.from_array(movie), ener) + fps, trs = slice_nmf.process(Residual.from_array(movie), ener, detect_radius=4) # NOTE: by manually putting in connected_cells, # we're forcing a double-detection in this test diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py new file mode 100644 index 00000000..b254dfd5 --- /dev/null +++ b/tests/test_prep/test_r_estimate.py @@ -0,0 +1,38 @@ +from cala.assets import Frame +from cala.models import AXIS +from cala.nodes.prep.r_estimate import SizeEst +from cala.testing.toy import Position +from cala.util import package_frame + + +def test_size_estim(separate_cells): + kwargs = { + "min_sigma": 1, + "max_sigma": 10, + "num_sigma": 10, + "threshold": 0.1, + "overlap": 0.5, + } + node = SizeEst(n_frames=1, log_kwargs=kwargs) + + max_proj = package_frame( + separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=1 + ) + result = node.get_median_radius(Frame.from_array(max_proj)) + + expected = separate_cells.cell_radii[0] - 1 + assert result == expected + assert len(node.sizes_) == 3 + + max_proj = package_frame( + separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=3 + ) + result = node.get_median_radius(Frame.from_array(max_proj)) + + assert result == expected + assert len(node.sizes_) == 3 + + for center in node.centers_: + height = center[0].astype(int).item() + width = center[1].astype(int).item() + assert Position(width=width, height=height) in separate_cells.cell_positions From 500c5444b74231a86441c9057a7ac6189af34fbc Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 14 Aug 2025 13:22:39 -0700 Subject: [PATCH 2/3] feat: unify test yaml files --- src/cala/testing/nodes.py | 2 +- .../{two_overlap_cells.yaml => odl.yaml} | 15 +- tests/data/pipelines/single_cell.yaml | 165 ----------------- tests/data/pipelines/two_cells.yaml | 166 ------------------ tests/test_pipeline.py | 15 +- 5 files changed, 23 insertions(+), 340 deletions(-) rename tests/data/pipelines/{two_overlap_cells.yaml => odl.yaml} (94%) delete mode 100644 tests/data/pipelines/single_cell.yaml delete mode 100644 tests/data/pipelines/two_cells.yaml diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 182a3950..12431541 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -38,7 +38,7 @@ def two_cells_source( positions: list[dict] = None, ) -> Generator[A[Frame, Name("frame")]]: frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames)), np.array([0, *range(30 - 1, 0, -1)])] + traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])] if positions is None: positions = [Position(width=206, height=206), Position(width=306, height=306)] else: diff --git a/tests/data/pipelines/two_overlap_cells.yaml b/tests/data/pipelines/odl.yaml similarity index 94% rename from tests/data/pipelines/two_overlap_cells.yaml rename to tests/data/pipelines/odl.yaml index 8576ceac..98419231 100644 --- a/tests/data/pipelines/two_overlap_cells.yaml +++ b/tests/data/pipelines/odl.yaml @@ -1,4 +1,4 @@ -noob_id: cala-two-overlap-cells +noob_id: cala-odl noob_model: noob.cube.CubeSpecification noob_version: 0.1.1.dev118+g64d81b7 @@ -30,9 +30,9 @@ assets: nodes: source: - type: cala.testing.two_overlapping_source + type: cala.testing.single_cell_source params: - n_frames: 50 + n_frames: 30 denoise: type: cala.nodes.prep.denoise params: @@ -52,6 +52,12 @@ nodes: drift_speed: 1.0 depends: - frame: glow.frame + size_est: + type: cala.nodes.prep.SizeEst + params: + hardset_radius: 10 + depends: + - frame: motion.frame cache: type: cala.nodes.buffer.fill_buffer params: @@ -121,11 +127,10 @@ nodes: - trigger: trace_frame.latest_trace nmf: type: cala.nodes.detect.SliceNMF - params: - cell_radius: 10 depends: - residuals: residual.movie - energy: energy.energy + - detect_radius: size_est.radius catalog: type: cala.nodes.detect.Cataloger params: diff --git a/tests/data/pipelines/single_cell.yaml b/tests/data/pipelines/single_cell.yaml deleted file mode 100644 index 9d6dc79f..00000000 --- a/tests/data/pipelines/single_cell.yaml +++ /dev/null @@ -1,165 +0,0 @@ -noob_id: cala-single-cell -noob_model: noob.cube.CubeSpecification -noob_version: 0.1.1.dev118+g64d81b7 - -# Add GUI server in assets - -assets: - buffer: - type: cala.assets.Movie - scope: session - depends: - - cache.buffer - footprints: - type: cala.assets.Footprints - scope: session - traces: - type: cala.assets.Traces - scope: session - pix_stats: - type: cala.assets.PixStats - scope: session - comp_stats: - type: cala.assets.CompStats - scope: session - overlaps: - type: cala.assets.Overlaps - scope: session - -# Add refiners in nodes - -nodes: - source: - type: cala.testing.single_cell_source - params: - n_frames: 30 - denoise: - type: cala.nodes.prep.denoise - params: - ksize: - - 3 - - 3 - sigmaX: 1.5 - depends: - - frame: source.frame - glow: - type: cala.nodes.prep.GlowRemover - depends: - - frame: denoise.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 1.0 - depends: - - frame: glow.frame - cache: - type: cala.nodes.buffer.fill_buffer - params: - size: 100 - depends: - - buffer: assets.buffer - - frame: motion.frame - - trace_frame: - type: cala.nodes.traces.FrameUpdate - params: - tolerance: 0.001 - depends: - - traces: assets.traces - - footprints: assets.footprints - - frame: motion.frame - - overlaps: assets.overlaps - - pix_frame: - type: cala.nodes.pixel_stats.ingest_frame - depends: - - pixel_stats: assets.pix_stats - - frame: motion.frame - - new_traces: trace_frame.latest_trace - comp_frame: - type: cala.nodes.component_stats.ingest_frame - depends: - - component_stats: assets.comp_stats - - frame: motion.frame - - new_traces: trace_frame.latest_trace - - residual: - type: cala.nodes.residual.build - params: - clip_threshold: 0.001 - depends: - - trigger: trace_frame.latest_trace - - frames: assets.buffer - - footprints: assets.footprints - - traces: assets.traces - -# DETECT BEGINS - energy: - type: cala.nodes.detect.Energy - params: - min_frames: 10 - depends: - - residuals: residual.movie - - trigger: trace_frame.latest_trace - nmf: - type: cala.nodes.detect.SliceNMF - params: - cell_radius: 10 - depends: - - residuals: residual.movie - - energy: energy.energy - catalog: - type: cala.nodes.detect.Cataloger - params: - merge_threshold: 0.8 - depends: - - new_fps: nmf.new_fps - - new_trs: nmf.new_trs - - existing_fp: assets.footprints - - existing_tr: assets.traces - - trace_component: - type: cala.nodes.traces.ingest_component - depends: - - traces: assets.traces - - new_traces: catalog.new_traces - footprint_component: - type: cala.nodes.footprints.ingest_component - depends: - - footprints: assets.footprints - - new_footprints: catalog.new_footprints - pix_component: - type: cala.nodes.pixel_stats.ingest_component - depends: - - pixel_stats: assets.pix_stats - - frames: assets.buffer - - new_traces: catalog.new_traces - - traces: assets.traces - comp_component: - type: cala.nodes.component_stats.ingest_component - depends: - - component_stats: assets.comp_stats - - traces: assets.traces - - new_traces: catalog.new_traces -# DETECT ENDS - - footprints_frame: - type: cala.nodes.footprints.Footprinter - params: - bep: 1 - tol: 0.0000001 - depends: - - footprints: assets.footprints - - pixel_stats: pix_component.value - - component_stats: comp_component.value - - overlaps_update: - type: cala.nodes.overlap.ingest_frame - depends: - - overlaps: assets.overlaps - - footprints: footprints_frame.footprints - - return: - type: return - depends: - - motion.frame \ No newline at end of file diff --git a/tests/data/pipelines/two_cells.yaml b/tests/data/pipelines/two_cells.yaml deleted file mode 100644 index 33c5c117..00000000 --- a/tests/data/pipelines/two_cells.yaml +++ /dev/null @@ -1,166 +0,0 @@ -noob_id: cala-two-cells -noob_model: noob.cube.CubeSpecification -noob_version: 0.1.1.dev118+g64d81b7 - -# Add GUI server in assets - -assets: - buffer: - type: cala.assets.Movie - scope: session - depends: - - cache.buffer - footprints: - type: cala.assets.Footprints - scope: session - traces: - type: cala.assets.Traces - scope: session - pix_stats: - type: cala.assets.PixStats - scope: session - comp_stats: - type: cala.assets.CompStats - scope: session - overlaps: - type: cala.assets.Overlaps - scope: session - -# Add refiners in nodes - -nodes: - source: - type: cala.testing.two_cells_source - params: - n_frames: 30 - - denoise: - type: cala.nodes.prep.denoise - params: - ksize: - - 3 - - 3 - sigmaX: 1.5 - depends: - - frame: source.frame - glow: - type: cala.nodes.prep.GlowRemover - depends: - - frame: denoise.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 1.0 - depends: - - frame: glow.frame - cache: - type: cala.nodes.buffer.fill_buffer - params: - size: 100 - depends: - - buffer: assets.buffer - - frame: motion.frame - - trace_frame: - type: cala.nodes.traces.FrameUpdate - params: - tolerance: 0.001 - depends: - - traces: assets.traces - - footprints: assets.footprints - - frame: motion.frame - - overlaps: assets.overlaps - - pix_frame: - type: cala.nodes.pixel_stats.ingest_frame - depends: - - pixel_stats: assets.pix_stats - - frame: motion.frame - - new_traces: trace_frame.latest_trace - comp_frame: - type: cala.nodes.component_stats.ingest_frame - depends: - - component_stats: assets.comp_stats - - frame: motion.frame - - new_traces: trace_frame.latest_trace - - residual: - type: cala.nodes.residual.build - params: - clip_threshold: 0.001 - depends: - - trigger: trace_frame.latest_trace - - frames: assets.buffer - - footprints: assets.footprints - - traces: assets.traces - - # DETECT BEGINS - energy: - type: cala.nodes.detect.Energy - params: - min_frames: 10 - depends: - - residuals: residual.movie - - trigger: trace_frame.latest_trace - nmf: - type: cala.nodes.detect.SliceNMF - params: - cell_radius: 10 - depends: - - residuals: residual.movie - - energy: energy.energy - catalog: - type: cala.nodes.detect.Cataloger - params: - merge_threshold: 0.8 - depends: - - new_fps: nmf.new_fps - - new_trs: nmf.new_trs - - existing_fp: assets.footprints - - existing_tr: assets.traces - - trace_component: - type: cala.nodes.traces.ingest_component - depends: - - traces: assets.traces - - new_traces: catalog.new_traces - footprint_component: - type: cala.nodes.footprints.ingest_component - depends: - - footprints: assets.footprints - - new_footprints: catalog.new_footprints - pix_component: - type: cala.nodes.pixel_stats.ingest_component - depends: - - pixel_stats: assets.pix_stats - - frames: assets.buffer - - new_traces: catalog.new_traces - - traces: assets.traces - comp_component: - type: cala.nodes.component_stats.ingest_component - depends: - - component_stats: assets.comp_stats - - traces: assets.traces - - new_traces: catalog.new_traces - # DETECT ENDS - - footprints_frame: - type: cala.nodes.footprints.Footprinter - params: - bep: 1 - tol: 0.0000001 - depends: - - footprints: assets.footprints - - pixel_stats: pix_component.value - - component_stats: comp_component.value - - overlaps_update: - type: cala.nodes.overlap.ingest_frame - depends: - - overlaps: assets.overlaps - - footprints: footprints_frame.footprints - - return: - type: return - depends: - - motion.frame \ No newline at end of file diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2113cdaa..dfacc16f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,18 +1,27 @@ import pytest import xarray as xr from noob import Cube, SynchronousRunner, Tube +from noob.node import NodeSpecification, Node from cala.models import AXIS -@pytest.fixture(params=["cala-single-cell", "cala-two-cells", "cala-two-overlap-cells"]) +@pytest.fixture(params=["single_cell_source", "two_cells_source", "two_overlapping_source"]) def tube(request): - return Tube.from_specification(request.param) + tube = Tube.from_specification("cala-odl") + source = Node.from_specification( + NodeSpecification( + id="source", type=f"cala.testing.{request.param}", params={"n_frames": 50} + ) + ) + tube.nodes["source"] = source + + return tube @pytest.fixture def cube(): - return Cube.from_specification("cala-single-cell") + return Cube.from_specification("cala-odl") @pytest.fixture From 3b9f53985bdfb16fe54dbffb82262a72aee93cec Mon Sep 17 00:00:00 2001 From: Raymond Date: Thu, 14 Aug 2025 13:26:30 -0700 Subject: [PATCH 3/3] format: ruff --- src/cala/nodes/prep/__init__.py | 2 +- src/cala/nodes/prep/r_estimate.py | 14 ++++++-------- tests/test_pipeline.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/cala/nodes/prep/__init__.py b/src/cala/nodes/prep/__init__.py index 361cb0cd..e100f748 100644 --- a/src/cala/nodes/prep/__init__.py +++ b/src/cala/nodes/prep/__init__.py @@ -1,7 +1,7 @@ from .background_removal import remove_background from .denoise import denoise from .glow_removal import GlowRemover -from .rigid_stabilization import RigidStabilizer from .r_estimate import SizeEst +from .rigid_stabilization import RigidStabilizer __all__ = [denoise, GlowRemover, remove_background, RigidStabilizer, SizeEst] diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index e6c4a231..e83335ea 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -1,9 +1,11 @@ -from typing import Annotated as A, Any -import numpy as np -from noob import process_method, Name -from pydantic import BaseModel, Field, ConfigDict, PrivateAttr, model_validator +from typing import Annotated as A +from typing import Any +import numpy as np +from noob import Name, process_method +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from skimage.feature import blob_log + from cala.assets import Frame from cala.models import AXIS @@ -21,10 +23,6 @@ class SizeEst(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - @model_validator(mode="after") - def validity_check(self): - assert self.hardset_radius or self.n_frames - @process_method def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: if self.hardset_radius: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dfacc16f..ac9cc396 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,7 @@ import pytest import xarray as xr from noob import Cube, SynchronousRunner, Tube -from noob.node import NodeSpecification, Node +from noob.node import Node, NodeSpecification from cala.models import AXIS