diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e27c1c51..930ef42e 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -21,10 +21,9 @@ Describe the tests that you ran to verify your changes. Provide instructions so that others can reproduce. --> -- [ ] Ran unit tests -- [ ] Ran integration tests -- [ ] Performed manual testing -- [ ] Updated existing tests +- [ ] Unit tests +- [ ] Integration tests +- [ ] Existing tests update ## 🛠️ Dependencies @@ -36,9 +35,5 @@ List any new dependencies added or existing ones updated. ## ✅ Checklist -- [ ] My code follows the project's style guidelines - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have made corresponding changes to the documentation -- [ ] My changes generate no new warnings -- [ ] I have added tests that prove my fix is effective or that my feature works -- [ ] New and existing unit tests pass locally with my changes diff --git a/pdm.lock b/pdm.lock index 59f66f42..ba0b54db 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1811,11 +1811,11 @@ files = [ [[package]] name = "noob" -version = "0.1.1.dev119" +version = "0.1.1.dev121" requires_python = ">=3.11" git = "https://github.com/miniscope/noob.git" ref = "37-tube-resources-for-data-shared-between-nodes" -revision = "5d585f6a09f9cf9f1bdbd383ee99f7fdbfa1a1f1" +revision = "fd131c7704c0485baadb7a59d5d714b529f09d58" summary = "Default template for PDM package" dependencies = [ "PyYAML>=6.0.2", diff --git a/src/cala/assets.py b/src/cala/assets.py index 3d1965d8..077b4063 100644 --- a/src/cala/assets.py +++ b/src/cala/assets.py @@ -86,6 +86,7 @@ class Footprints(Asset): member=Footprint.entity(), group_by=Dims.component, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -134,6 +135,7 @@ def from_array( member=Trace.entity(), group_by=Dims.component, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -145,6 +147,7 @@ class Movie(Asset): member=Frame.entity(), group_by=Dims.frame.value, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -180,6 +183,7 @@ class CompStats(Asset): dims=comp_dims, dtype=float, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -191,6 +195,7 @@ class PixStats(Asset): dims=(Dims.width.value, Dims.height.value, Dims.component.value), dtype=float, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) @@ -202,6 +207,7 @@ class Overlaps(Asset): dims=comp_dims, dtype=bool, checks=[has_no_nan], + allow_extra_coords=False, ) ) @@ -224,5 +230,6 @@ class Residual(Asset): member=Frame.entity(), group_by=Dims.frame.value, checks=[is_non_negative, has_no_nan], + allow_extra_coords=False, ) ) diff --git a/src/cala/models/axis.py b/src/cala/models/axis.py index ace438ad..e4222955 100644 --- a/src/cala/models/axis.py +++ b/src/cala/models/axis.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -from cala.models.checks import is_unique, is_unit_interval +from cala.models.checks import has_no_nan, is_unique class Axis: @@ -17,7 +17,7 @@ class Axis: id_coord: str = "id_" timestamp_coord: str = "timestamp" - confidence_coord: str = "confidence" + detect_coord: str = "detected_on" frame_coord: str = "frame" width_coord: str = "width" height_coord: str = "height" @@ -37,7 +37,7 @@ def component_rename(self) -> dict[str, str]: return { AXIS.component_dim: f"{AXIS.component_dim}'", AXIS.id_coord: f"{AXIS.id_coord}'", - AXIS.confidence_coord: f"{AXIS.confidence_coord}'", + AXIS.detect_coord: f"{AXIS.detect_coord}'", } @@ -62,11 +62,11 @@ class Coords(Enum): width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique]) frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique]) timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique]) - confidence = Coord(name=AXIS.confidence_coord, dtype=float, checks=[is_unit_interval]) + detected = Coord(name=AXIS.detect_coord, dtype=int, checks=[has_no_nan]) class Dims(Enum): width = Dim(name=AXIS.width_dim, coords=[Coords.width.value]) height = Dim(name=AXIS.height_dim, coords=[Coords.height.value]) frame = Dim(name=AXIS.frames_dim, coords=[Coords.frame.value, Coords.timestamp.value]) - component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.confidence.value]) + component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value]) diff --git a/src/cala/models/entity.py b/src/cala/models/entity.py index 1c589257..d3448a1a 100644 --- a/src/cala/models/entity.py +++ b/src/cala/models/entity.py @@ -18,6 +18,7 @@ class Entity(BaseModel): coords: list[Coord] = Field(default_factory=list) dtype: type checks: list[Callable] = Field(default_factory=list) + allow_extra_coords: bool = True _model: DataArraySchema = PrivateAttr(DataArraySchema()) @@ -47,15 +48,14 @@ def to_schema(self) -> DataArraySchema: checks=self.checks, ) - @staticmethod - def _build_coord_schema(coords: list[Coord]) -> CoordsSchema: + def _build_coord_schema(self, coords: list[Coord]) -> CoordsSchema: spec = dict() for c in coords: dim = DimsSchema((c.dim,)) if c.dim else None spec[c.name] = DataArraySchema(dims=dim, dtype=DTypeSchema(c.dtype), checks=c.checks) - return CoordsSchema(spec) + return CoordsSchema(spec, allow_extra_keys=self.allow_extra_coords) class Group(Entity): diff --git a/src/cala/nodes/cleanup.py b/src/cala/nodes/cleanup.py index 220e3489..b1d4c453 100644 --- a/src/cala/nodes/cleanup.py +++ b/src/cala/nodes/cleanup.py @@ -5,10 +5,31 @@ import xarray as xr from noob import Name -from cala.assets import CompStats, Footprints, Overlaps, PixStats, Traces +from cala.assets import CompStats, Footprints, Overlaps, PixStats, Residual, Traces from cala.models import AXIS +def clear_overestimates( + footprints: Footprints, residuals: Residual, nmf_error: float +) -> A[Footprints, Name("footprints")]: + """ + Remove all sections of the footprints that cause negative residuals. + + This occurs by: + 1. find "significant" negative residual spots that is more than a noise level, and thus + cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) + 2. all footprint values at these spots go to zero. + """ + if residuals.array is None: + return footprints + R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords( + [AXIS.frame_coord, AXIS.timestamp_coord], drop=True + ) + tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False) + + return tuned_fp + + def purge_razed_components( footprints: Footprints, traces: Traces, @@ -101,13 +122,13 @@ def filter_components( comp_stats.array.set_xindex(AXIS.id_coord) .set_xindex(f"{AXIS.id_coord}'") .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()}) - .reset_index(AXIS.id_coord) + .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) ) overlaps.array = ( overlaps.array.set_xindex(AXIS.id_coord) .set_xindex(f"{AXIS.id_coord}'") .sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()}) - .reset_index(AXIS.id_coord) + .reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"]) ) return footprints, traces, pix_stats, comp_stats, overlaps diff --git a/src/cala/nodes/detect/catalog.py b/src/cala/nodes/detect/catalog.py index aacc4cb1..684ee16b 100644 --- a/src/cala/nodes/detect/catalog.py +++ b/src/cala/nodes/detect/catalog.py @@ -1,6 +1,7 @@ from collections.abc import Hashable, Iterable from typing import Annotated as A +import cv2 import numpy as np import xarray as xr from noob import Name @@ -72,21 +73,19 @@ def process( footprints = xr.concat( footprints, dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.confidence_coord], + coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) traces = xr.concat( traces, dim=AXIS.component_dim, - coords=[AXIS.id_coord, AXIS.confidence_coord], + coords=[AXIS.id_coord, AXIS.detect_coord], combine_attrs=combine_attr_replaces, ) return Footprints.from_array(footprints), Traces.from_array(traces) - def _register( - self, new_fp: xr.DataArray, new_tr: xr.DataArray, confidence: float = 0.0 - ) -> tuple[Footprint, Trace]: + def _register(self, new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Trace]: new_id = create_id() @@ -95,7 +94,10 @@ def _register( .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: ( + AXIS.component_dim, + [new_tr[AXIS.frame_coord].max().item()], + ), } ) .isel({AXIS.component_dim: 0}) @@ -105,7 +107,10 @@ def _register( .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [new_id]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: ( + AXIS.component_dim, + [new_tr[AXIS.frame_coord].max().item()], + ), } ) .isel({AXIS.component_dim: 0}) @@ -211,8 +216,21 @@ def _merge_matrix( fps_base = fps_base.rename(AXIS.component_rename) trs_base = trs_base.rename(AXIS.component_rename) + fps = self._expand_boundary(fps > 0) + overlaps = fps @ fps_base > 0 # this should later reflect confidence corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold return overlaps * corrs + + def _expand_boundary(self, mask: xr.DataArray) -> xr.DataArray: + kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) + return xr.apply_ufunc( + lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1), + mask.astype(np.uint8), + input_core_dims=[AXIS.spatial_dims], + output_core_dims=[AXIS.spatial_dims], + vectorize=True, + dask="parallelized", + ) diff --git a/src/cala/nodes/detect/slice_nmf.py b/src/cala/nodes/detect/slice_nmf.py index edb11478..ccadd1da 100644 --- a/src/cala/nodes/detect/slice_nmf.py +++ b/src/cala/nodes/detect/slice_nmf.py @@ -19,6 +19,9 @@ class SliceNMF(Node): """Wait until this number of frames to begin detecting.""" detect_thresh: float """Minimum detection threshold for brightness fluctuation.""" + reprod_tol: float + """Mean pixel value error tolerance for reproduced slice from the new component""" + nmf_kwargs: dict[str, Any] = Field(default_factory=dict) error_: float = Field(None) @@ -43,7 +46,7 @@ def process( fps = [] trs = [] - while np.sqrt(energy.max()).item() > self.detect_thresh: # or use res directly + while energy.max().item() >= self.detect_thresh: # or use res directly # Find and analyze neighborhood of maximum variance slice_ = self._get_max_energy_slice( arr=res, energy_landscape=energy, radius=detect_radius @@ -55,17 +58,17 @@ def process( ) l1_norm = slice_.sum().item() - # l0_norm = np.prod(slice_.shape) # this fails when the residuals are tiny comp_recon = a_new @ c_new energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 - if (self.error_ / l1_norm) <= self._model.tol: + if (self.error_ / l1_norm) <= self.reprod_tol: fps.append(Footprint.from_array(a_new)) trs.append(Trace.from_array(c_new)) res = (res - comp_recon).clip(0) else: - res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0 + l0_norm = np.prod(slice_.shape) + res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l0_norm return fps, trs @@ -73,7 +76,7 @@ def _get_energy(self, res: xr.DataArray) -> xr.DataArray: pixels_median = res.median(dim=AXIS.frames_dim) V = res - pixels_median - return (V**2).sum(dim=AXIS.frames_dim) + return np.sqrt((V**2).mean(dim=AXIS.frames_dim)) def _get_max_energy_slice( self, diff --git a/src/cala/nodes/footprints.py b/src/cala/nodes/footprints.py index c4ec8228..0d3ff079 100644 --- a/src/cala/nodes/footprints.py +++ b/src/cala/nodes/footprints.py @@ -6,10 +6,12 @@ from noob import Name, process_method from cala.assets import CompStats, Footprints, PixStats +from cala.logging import init_logger from cala.models import AXIS class Footprinter: + logger = init_logger(__name__) def __init__(self, tol: float, max_iter: int | None = None, bep: int | None = None): self.bep = bep @@ -55,7 +57,7 @@ def ingest_frame( kernel = self._expansion_kernel() mask = self._expand_boundary(kernel, mask) - # for _ in range(self.max_iter): + cnt = 0 while True: AM = A.rename(AXIS.component_rename) @ M numerator = W - AM @@ -72,8 +74,15 @@ def ingest_frame( update = numerator / (M_diag + np.finfo(float).tiny) A_new = (mask * (A + update)).clip(min=0) - if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol: + step = (np.abs(A - A_new).sum() / np.prod(A.shape)).item() + + cnt += 1 + maxed = self.max_iter and (cnt == self.max_iter) + + if step < self.tol or maxed: footprints.array = A_new + if maxed: + self.logger.debug(msg="max_iter reached before converging.") return footprints else: A = A_new @@ -93,7 +102,9 @@ def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArr ) -def ingest_component(footprints: Footprints, new_footprints: Footprints) -> Footprints: +def ingest_component( + footprints: Footprints, new_footprints: Footprints +) -> A[Footprints, Name("footprints")]: if new_footprints.array is None: return footprints diff --git a/src/cala/nodes/io.py b/src/cala/nodes/io.py index 0bc47fda..f7f5cc98 100644 --- a/src/cala/nodes/io.py +++ b/src/cala/nodes/io.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from collections.abc import Iterator +from collections.abc import Generator from pathlib import Path from typing import Protocol @@ -12,7 +12,7 @@ class Stream(Protocol): """Protocol defining the interface for video streams.""" @abstractmethod - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """Iterate over frames.""" ... @@ -35,7 +35,7 @@ def __init__(self, video_path: Path | str) -> None: if not self._cap.isOpened(): raise ValueError(f"Failed to open video file: {video_path}") - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """ Yields: NDArray: Next frame from the video @@ -66,7 +66,7 @@ def __init__(self, files: list[Path]) -> None: raise ValueError("TIFF files must be grayscale") self._sample_shape = frame.shape - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: for file in self._files: frame = io.imread(file) if len(frame.shape) != 2: @@ -90,7 +90,7 @@ def __init__(self, video_paths: list[Path]) -> None: self._video_paths = video_paths self._current_stream: OpenCVStream | None = None - def __iter__(self) -> Iterator[NDArray]: + def __iter__(self) -> Generator[NDArray]: """ Iterate over frames from all videos sequentially. @@ -108,7 +108,7 @@ def close(self) -> None: self._current_stream.close() -def stream(files: list[str | Path]) -> Stream: +def stream(files: list[str | Path]) -> Generator[NDArray]: """ Create a video stream from the provided video files. @@ -125,8 +125,8 @@ def stream(files: list[str | Path]) -> Stream: video_format = {".mp4", ".avi", ".webm"} if suffix.issubset(video_format): - return VideoStream(files) + return iter(VideoStream(files)) elif suffix.issubset(image_format): - return ImageStream(files) + return iter(ImageStream(files)) else: raise ValueError(f"Unsupported file format: {suffix}") diff --git a/src/cala/nodes/overlap.py b/src/cala/nodes/overlap.py index 454d3d97..a04ea94d 100644 --- a/src/cala/nodes/overlap.py +++ b/src/cala/nodes/overlap.py @@ -39,7 +39,7 @@ def ingest_component( V = overlaps.array a_new = new_footprints.array.volumize.dim_with_coords( - dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.confidence_coord] + dim=AXIS.component_dim, coords=[AXIS.id_coord, AXIS.detect_coord] ) if a_new[AXIS.id_coord].item() in V[AXIS.id_coord].values: diff --git a/src/cala/nodes/prep/denoise.py b/src/cala/nodes/prep/denoise.py index c012f066..2efe2f33 100644 --- a/src/cala/nodes/prep/denoise.py +++ b/src/cala/nodes/prep/denoise.py @@ -11,18 +11,22 @@ def denoise( - frame: Frame, method: Literal["gaussian", "median", "bilateral"] = "gaussian", **kwargs: Any + frame: Frame, method: Literal["gaussian", "median", "bilateral"], kwargs: dict[str, Any] ) -> A[Frame, Name("frame")]: """Denoise a single frame.""" methods: dict[str, Callable] = { "gaussian": cv2.GaussianBlur, "median": cv2.medianBlur, "bilateral": cv2.bilateralFilter, + "nonlocal": cv2.fastNlMeansDenoising, } _func = methods[method] + frame = frame.array - denoised = _func(frame.values.astype(np.float32), **kwargs).astype(np.float64) + arr = frame.values.astype(np.uint8) if method == "nonlocal" else frame.values.astype(np.float32) + + denoised = _func(arr, **kwargs).astype(float) return Frame.from_array(xr.DataArray(denoised, dims=frame.dims, coords=frame.coords)) diff --git a/src/cala/nodes/prep/hlines.py b/src/cala/nodes/prep/hlines.py new file mode 100644 index 00000000..2d74dfcc --- /dev/null +++ b/src/cala/nodes/prep/hlines.py @@ -0,0 +1,58 @@ +from typing import Annotated as A + +import numpy as np +from noob import Name +from scipy.ndimage import convolve1d +from scipy.signal import firwin, welch + +from cala.assets import Frame + + +def remove( + frame: Frame, distortion_freq: float | None = None, num_taps: int = 65, eps: float = 0.025 +) -> A[Frame, Name("frame")]: + arr = frame.array + + if np.all(frame.array == 0): + return frame + + denoised = _remove_lines( + arr.values, distortion_freq=distortion_freq, num_taps=num_taps, eps=eps + ) + + dmin = denoised.min() + if dmin < 0: + denoised -= dmin + + arr.values = denoised + + return Frame.from_array(arr) + + +def _remove_lines( + image: np.ndarray, distortion_freq: float = None, num_taps: int = 65, eps: float = 0.025 +) -> np.ndarray: + """ + Removes horizontal line artifacts from scanned image. + Args: + image: 2D or 3D array. + distortion_freq: Float, distortion frequency in cycles/pixel, or + `None` to estimate from spectrum. + num_taps: Integer, number of filter taps to use in each dimension. + eps: Small positive param to adjust filters cutoffs (cycles/pixel). + Returns: + Denoised image. + """ + if distortion_freq is None: + distortion_freq = _estimate_distortion_freq(image) + + hpf = firwin(num_taps, distortion_freq - eps, pass_zero="highpass", fs=1) + lpf = firwin(num_taps, eps, pass_zero="lowpass", fs=1) + return image - convolve1d(convolve1d(image, hpf, axis=0), lpf, axis=1) + + +def _estimate_distortion_freq(image: np.ndarray, min_frequency: float = 1 / 25) -> float: + """Estimates distortion frequency as spectral peak in vertical dim.""" + f, pxx = welch(image.sum(axis=1)) + pxx[f < min_frequency] = 0.0 + return f[pxx.argmax()] diff --git a/src/cala/nodes/prep/r_estimate.py b/src/cala/nodes/prep/r_estimate.py index e83335ea..7124c09c 100644 --- a/src/cala/nodes/prep/r_estimate.py +++ b/src/cala/nodes/prep/r_estimate.py @@ -15,6 +15,8 @@ class SizeEst(BaseModel): """if this is set, no learning occurs.""" n_frames: int | None = None """how many first n frames to learn from. if none, keep learning forever""" + noise_threshold: float = 0.0 + log_kwargs: dict[str, Any] = Field(default_factory=dict) sizes_: list[float] = Field(default_factory=list) @@ -31,9 +33,14 @@ def get_median_radius(self, frame: Frame) -> A[int, Name("radius")]: if self.n_frames and self.n_frames < frame.array[AXIS.frame_coord]: return self._est_radius - blobs = blob_log(frame.array, **self.log_kwargs) + blobs = blob_log( + frame.array.where(frame.array > self.noise_threshold, 0, drop=False), **self.log_kwargs + ) + if blobs.size == 0: + return 0 + self.centers_ = [blobs[:-1] for blobs in blobs] self.sizes_ += [blob[-1].item() for blob in blobs] - self._est_radius = int(np.round(np.median(self.sizes_)).item()) + self._est_radius = (np.median(self.sizes_) // 2 + 1).astype(int) return self._est_radius diff --git a/src/cala/nodes/prep/rigid_stabilization.py b/src/cala/nodes/prep/rigid_stabilization.py index c1052ab9..0718e608 100644 --- a/src/cala/nodes/prep/rigid_stabilization.py +++ b/src/cala/nodes/prep/rigid_stabilization.py @@ -131,7 +131,7 @@ def apply_shift(self, frame: xr.DataArray, shift: Shift) -> xr.DataArray: def update_anchor(self, frame: xr.DataArray) -> xr.DataArray: curr_index = frame[AXIS.frame_coord].item() - return (self.anchor_frame_.array * curr_index + frame) / curr_index + 1 + return (self.anchor_frame_.array * curr_index + frame) / (curr_index + 1) def process(self, frame: Frame) -> A[Frame, Name("frame")]: if self.is_first_frame(frame): diff --git a/src/cala/nodes/residual.py b/src/cala/nodes/residual.py index 48f97514..f70d6ea7 100644 --- a/src/cala/nodes/residual.py +++ b/src/cala/nodes/residual.py @@ -1,19 +1,15 @@ from typing import Annotated as A +import numpy as np import xarray as xr from noob import Name -from skimage.restoration import estimate_sigma from cala.assets import Footprints, Movie, Residual, Traces from cala.models import AXIS def build( - frames: Movie, - footprints: Footprints, - traces: Traces, - trigger: bool, - clip_threshold: float | None = None, + residuals: Residual, frames: Movie, footprints: Footprints, traces: Traces ) -> A[Residual, Name("movie")]: """ The computation follows the equation: @@ -48,43 +44,76 @@ def build( # Reshape footprints to (pixels x components) A = footprints.array + R_latest = Y.isel({AXIS.frames_dim: -1}) - (A @ C.isel({AXIS.frames_dim: -1})) + if R_latest.min() < 0: + shifted_tr = _align_overestimates(A, C.isel({AXIS.frames_dim: -1}), R_latest) + C.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr + traces.array.loc[{AXIS.frames_dim: C[AXIS.frame_coord].max()}] = shifted_tr + # Compute residual R = Y - [A,b][C;f] R = Y - (A @ C) + residuals.array = R.clip(min=0) # clipping is for the first n frames - clip_val = _estimate_clip_val(Y, clip_threshold) - footprints.array = _clear_overestimates(A, R, clip_val) - - return Residual.from_array(R.clip(min=0)) + return residuals -def _estimate_clip_val(Y: xr.DataArray, clip_threshold: float | None = None) -> float: +def _align_overestimates( + A: xr.DataArray, C_latest: xr.DataArray, R_latest: xr.DataArray +) -> xr.DataArray: """ - Estimate the threshold of "what is a significant negative residual value?" (above noise level) - :param Y: - :param clip_threshold: - :return: - """ - if clip_threshold: - return -Y.max().item() * clip_threshold - else: - return -estimate_sigma(Y) + Gotta be able to do at least ONE OF splitoff or gradualon. + Negative residuals just need to go. There isn't much you can do with the value...? -def _clear_overestimates(A: xr.DataArray, R: xr.DataArray, clip_val: float) -> xr.DataArray: - """ - Remove all sections of the footprints that cause negative residuals. - This occurs by: - 1. find "significant" negative residual spots that is more than a noise level, and thus - cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?) - 2. all footprint values at these spots go to zero. + Two cases: (A & B Overlapping) + 1. GradualOn: Know A. B turns ON + -> trace tries to chase (increases) + -> footprint tries to chase + -> residual becomes negative at A-B + -> should just decrease, positive at A^B + -> actually... should decrease (just more steeply) + + 2. SplitOff: Know AB. B turns OFF + -> trace tries to chase (decreases) + -> footprint tries to chase + -> residual becomes positive at A-B + -> should increase, negative at A^B + -> this should just decrease, MORE negative at B-A + -> this going to zero makes sense + OR + keep B, remove A-B + + R = Y - A @ C + What about the past frame residuals after? + + for GradualOn, nothing should go to zero. + for SplitOff, a chunk needs to go to zero. + + So... how about we do something like (if it's been on for a long time, + we become less likely to purge it?) We subsequently clip R minimum to zero, since all significant negative residual spots have been removed, and the remaining negative spots are noise level. + + !!We're assuming there's no completely occluded component. This might be a problem eventually!! """ - R_min = R.min(dim=AXIS.frames_dim) - footprints = A.where(R_min > clip_val, 0, drop=False) + unlayered_footprints = _find_unlayered_footprints(A) + # if unlayered_footprints.max(dim=AXIS.spatial_dims).min() == 0: + # raise ValueError("There are at least one completely occluded components.") + + R_rel = R_latest.where((R_latest < 0) * unlayered_footprints.max(dim=AXIS.component_dim)) + dC = ( + (R_rel / A) + .min(dim=AXIS.spatial_dims) + .reset_coords([AXIS.frame_coord, AXIS.timestamp_coord], drop=True) + ) + + return (C_latest + xr.apply_ufunc(np.nan_to_num, dC, kwargs={"neginf": 0})).clip(min=0) + - return footprints +def _find_unlayered_footprints(A: xr.DataArray) -> xr.DataArray: + A_layer_mask = (A > 0).sum(dim=AXIS.component_dim) + return A.where(A_layer_mask == 1, 0) diff --git a/src/cala/nodes/traces.py b/src/cala/nodes/traces.py index 4964af6b..b56f3ba7 100644 --- a/src/cala/nodes/traces.py +++ b/src/cala/nodes/traces.py @@ -7,6 +7,7 @@ from scipy.sparse.csgraph import connected_components from cala.assets import Footprints, Frame, Movie, Overlaps, PopSnap, Traces +from cala.logging import init_logger from cala.models import AXIS @@ -31,7 +32,7 @@ def initialize(self, footprints: Footprints, frames: Movie) -> Traces: trace_coords = [ AXIS.id_coord, - AXIS.confidence_coord, + AXIS.detect_coord, AXIS.frame_coord, AXIS.timestamp_coord, ] @@ -83,8 +84,11 @@ def _fast_nnls_vector(A: np.ndarray, B: np.ndarray) -> np.ndarray: class FrameUpdate: - def __init__(self, tolerance: float = 1e-3) -> None: - self.tolerance = tolerance + logger = init_logger(__name__) + + def __init__(self, tol: float, max_iter: int | None = None) -> None: + self.tol = tol + self.max_iter = max_iter @process_method def ingest_frame( @@ -131,7 +135,7 @@ def ingest_frame( ) clusters = [np.where(labels == label)[0] for label in np.unique(labels)] - updated_traces = self._update_traces(A, y, c.copy(), clusters, self.tolerance) + updated_traces = self._update_traces(A, y, c.copy(), clusters) traces.array = xr.concat([traces.array, updated_traces], dim=AXIS.frames_dim) @@ -143,14 +147,12 @@ def _update_traces( y: xr.DataArray, c: xr.DataArray, clusters: list[np.ndarray], - eps: float, ) -> xr.DataArray: """ Implementation of the temporal traces update algorithm. - This function implements the core update logic. It uses block coordinate descent - to update temporal traces for overlapping components together while maintaining - non-negativity constraints. + This function uses block coordinate descent to update temporal traces + for overlapping components together while maintaining non-negativity constraints. Args: A (xr.DataArray): Spatial footprints matrix [A, b]. @@ -176,14 +178,13 @@ def _update_traces( # Step 3: Extract diagonal elements for normalization V_diag = np.diag(V) - # Step 4: Initialize previous iteration value - c_old = np.zeros_like(c) + cnt = 0 - # Steps 5-10: Main iteration loop until convergence - while np.linalg.norm(c - c_old) >= eps * np.linalg.norm(c_old): + # Steps 4-9: Main iteration loop until convergence + while True: c_old = c.copy() - # Steps 7-9: Update each group using block coordinate descent + # Steps 6-8: Update each group using block coordinate descent for cluster in clusters: # Update traces for current group (division is pointwise) numerator = u.isel({AXIS.component_dim: cluster}) - ( @@ -191,13 +192,18 @@ def _update_traces( ).rename({f"{AXIS.component_dim}'": AXIS.component_dim}) c.loc[{AXIS.component_dim: cluster}] = np.maximum( - c.isel({AXIS.component_dim: cluster}) + numerator / V_diag[cluster].T, - 0, + c.isel({AXIS.component_dim: cluster}) + numerator / V_diag[cluster].T, 0 ) - return xr.DataArray( - c.values, dims=c.dims, coords=c[AXIS.component_dim].coords - ).assign_coords(y[AXIS.frames_dim].coords) + cnt += 1 + maxed = self.max_iter and (cnt == self.max_iter) + + if np.linalg.norm(c - c_old) >= self.tol * np.linalg.norm(c_old) or maxed: + if maxed: + self.logger.debug(msg="max_iter reached before converging.") + return xr.DataArray( + c.values, dims=c.dims, coords=c[AXIS.component_dim].coords + ).assign_coords(y[AXIS.frames_dim].coords) def ingest_component(traces: Traces, new_traces: Traces) -> Traces: @@ -226,7 +232,7 @@ def ingest_component(traces: Traces, new_traces: Traces) -> Traces: coords=c.isel({AXIS.component_dim: 0}).coords, ) c_new[AXIS.id_coord] = c_det[AXIS.id_coord] - c_new[AXIS.confidence_coord] = c_det[AXIS.confidence_coord] + c_new[AXIS.detect_coord] = c_det[AXIS.detect_coord] c_new.loc[{AXIS.frames_dim: c_det[AXIS.frame_coord]}] = c_det else: diff --git a/src/cala/testing/__init__.py b/src/cala/testing/__init__.py index 1bad55fc..9d866edc 100644 --- a/src/cala/testing/__init__.py +++ b/src/cala/testing/__init__.py @@ -1,3 +1,19 @@ -from .nodes import single_cell_source, two_cells_source, two_overlapping_source +from .nodes import ( + ConnectedSource, + GradualOnSource, + SeparateSource, + SingleCellSource, + SplitOffSource, + TwoCellsSource, + TwoOverlappingSource, +) -__all__ = [single_cell_source, two_cells_source, two_overlapping_source] +__all__ = [ + "SingleCellSource", + "TwoCellsSource", + "TwoOverlappingSource", + "SeparateSource", + "ConnectedSource", + "GradualOnSource", + "SplitOffSource", +] diff --git a/src/cala/testing/nodes.py b/src/cala/testing/nodes.py index 12431541..a168cede 100644 --- a/src/cala/testing/nodes.py +++ b/src/cala/testing/nodes.py @@ -1,77 +1,219 @@ from collections.abc import Generator from typing import Annotated as A +from typing import Self import numpy as np -from noob import Name +from noob import Name, process_method +from pydantic import BaseModel, PrivateAttr, model_validator from cala.assets import Frame from cala.testing.toy import FrameDims, Position, Toy -def single_cell_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames))] - if positions is None: - positions = [Position(width=256, height=256)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() - - -def two_cells_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])] - if positions is None: - positions = [Position(width=206, height=206), Position(width=306, height=306)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() - - -def two_overlapping_source( - n_frames: int = 30, - frame_dims: dict = None, - cell_radii: int = 30, - positions: list[dict] = None, -) -> Generator[A[Frame, Name("frame")]]: - frame_dims = FrameDims(width=512, height=512) if frame_dims is None else FrameDims(**frame_dims) - traces = [np.array(range(0, n_frames)), np.array([0, *range(n_frames - 1, 0, -1)])] - if positions is None: - positions = [Position(width=236, height=236), Position(width=276, height=276)] - else: - positions = [Position(**position) for position in positions] - - toy = Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=positions, - cell_traces=traces, - ) - return toy.movie_gen() +class MovieSource(BaseModel): + n_frames: int = 50 + frame_dims: FrameDims | dict[str, int] | None = None + cell_radii: int = 30 + positions: list[dict | Position] | None = None + _toy: Toy = PrivateAttr(None) + _traces: list[np.ndarray] = PrivateAttr(None) + + def _build_toy(self) -> Toy: + return Toy( + n_frames=self.n_frames, + frame_dims=self.frame_dims, + cell_radii=self.cell_radii, + cell_positions=self.positions, + cell_traces=self._traces, + ) + + @process_method + def process(self) -> Generator[A[Frame, Name("frame")]]: + yield from self._toy.movie_gen() + + +class SingleCellSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self._traces = [np.array(range(0, self.n_frames), dtype=float)] + + if self.positions is None: + self.positions = [Position(width=256, height=256)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class TwoCellsSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + + self._traces = [ + np.array(range(0, self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + if self.positions is None: + self.positions = [Position(width=206, height=206), Position(width=306, height=306)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class TwoOverlappingSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.frame_dims = ( + FrameDims(width=512, height=512) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self._traces = [ + np.array(range(0, self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + + if self.positions is None: + self.positions = [Position(width=236, height=236), Position(width=276, height=276)] + else: + self.positions = [Position(**position) for position in self.positions] + + self._toy = self._build_toy() + return self + + +class SeparateSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.cell_radii = 3 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ] + self._traces = [ + np.zeros(self.n_frames, dtype=float), + np.ones(self.n_frames, dtype=float), + np.array(range(self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + + self._toy = self._build_toy() + return self + + +class ConnectedSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + ] + self._traces = [ + np.zeros(self.n_frames, dtype=float), + np.ones(self.n_frames, dtype=float), + np.array(range(self.n_frames), dtype=float), + np.array([0, *range(self.n_frames - 1, 0, -1)], dtype=float), + ] + + self._toy = self._build_toy() + return self + + +class GradualOnSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.n_frames = 100 + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=15, height=15), + Position(width=15, height=35), + Position(width=25, height=25), + Position(width=35, height=35), + Position(width=35, height=15), + ] + gap = 20 + decr = np.array(range(self.n_frames - 1, 0, -1), dtype=float) + sine = np.abs(np.sin(np.linspace(0, 2 * np.pi, self.n_frames - gap)) * self.n_frames) + incr = np.array(range(self.n_frames - gap * 2), dtype=float) + expo = ( + np.linspace(0, np.exp(3), self.n_frames - gap * 3) + * np.exp(-np.linspace(0, np.exp(2), self.n_frames - gap * 3)) + * self.n_frames + ) + tanh = np.tanh(np.linspace(0, 5, self.n_frames - gap * 4)) * self.n_frames + + self._traces = [ + np.pad(decr, (1, 0), mode="constant", constant_values=0), + np.pad(sine, (gap, 0), mode="constant", constant_values=0), + np.pad(incr, (gap * 2, 0), mode="constant", constant_values=0), + np.pad(expo, (gap * 3, 0), mode="constant", constant_values=0), + np.pad(tanh, (gap * 4, 0), mode="constant", constant_values=0), + ] + + self._toy = self._build_toy() + return self + + +class SplitOffSource(MovieSource): + @model_validator(mode="after") + def complete_model(self) -> Self: + self.n_frames = 100 + self.cell_radii = 8 + self.frame_dims = ( + FrameDims(width=50, height=50) + if self.frame_dims is None + else FrameDims(**self.frame_dims) + ) + self.positions = [ + Position(width=20, height=20), + Position(width=30, height=30), + ] + self._traces = [ + np.array( + [0, *range(1, int(self.n_frames / 2)), *range(int(self.n_frames / 2), 0, -1)], + dtype=float, + ), + np.array( + [ + *range(int(self.n_frames / 4)), + *range(int(self.n_frames / 4), 0, -1), + *range(int(self.n_frames / 2)), + ], + dtype=float, + ), + ] + self._toy = self._build_toy() + return self diff --git a/src/cala/testing/toy.py b/src/cala/testing/toy.py index b5c48d7c..4f65339b 100644 --- a/src/cala/testing/toy.py +++ b/src/cala/testing/toy.py @@ -54,7 +54,7 @@ class Toy(BaseModel): cell_traces: list[np.ndarray] cell_ids: list[str] """If none, auto populated as cell_{idx}.""" - confidences: list[float] + detected_ons: list[int] _footprints: xr.DataArray = PrivateAttr(init=False) _traces: xr.DataArray = PrivateAttr(init=False) @@ -74,9 +74,9 @@ def fill_ids(self) -> Self: return self @model_validator(mode="before") - def fill_confidences(self) -> Self: - if self.get("confidences", None) is None: - self["confidences"] = [0.0] * len(self["cell_positions"]) + def fill_detected_ons(self) -> Self: + if self.get("detected_ons", None) is None: + self["detected_ons"] = [0] * len(self["cell_positions"]) return self @model_validator(mode="before") @@ -124,7 +124,7 @@ def _build_movie_template(self) -> xr.DataArray: ) def _generate_footprint( - self, radius: int, position: Position, id_: str, confidence: float + self, radius: int, position: Position, id_: str, detected_on: int ) -> xr.DataArray: footprint = xr.DataArray( np.zeros((self.frame_dims.height, self.frame_dims.width)), @@ -141,7 +141,7 @@ def _generate_footprint( return footprint.expand_dims(AXIS.component_dim).assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [detected_on]), **{ax: footprint[ax] for ax in AXIS.spatial_dims}, } ) @@ -149,13 +149,13 @@ def _generate_footprint( def _build_footprints(self) -> xr.DataArray: footprints = [] for radius, position, id_, confid in zip( - self.cell_radii, self.cell_positions, self.cell_ids, self.confidences + self.cell_radii, self.cell_positions, self.cell_ids, self.detected_ons ): footprints.append(self._generate_footprint(radius, position, id_, confid)) return xr.concat(footprints, dim=AXIS.component_dim) - def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.DataArray: + def _format_trace(self, trace: np.ndarray, id_: str, detected_on: int) -> xr.DataArray: return ( xr.DataArray( trace, @@ -165,7 +165,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.Da .assign_coords( { AXIS.id_coord: (AXIS.component_dim, [id_]), - AXIS.confidence_coord: (AXIS.component_dim, [confidence]), + AXIS.detect_coord: (AXIS.component_dim, [detected_on]), AXIS.frames_dim: range(trace.size), } ) @@ -173,7 +173,7 @@ def _format_trace(self, trace: np.ndarray, id_: str, confidence: float) -> xr.Da def _build_traces(self) -> xr.DataArray: traces = [] - for trace, id_, confid in zip(self.cell_traces, self.cell_ids, self.confidences): + for trace, id_, confid in zip(self.cell_traces, self.cell_ids, self.detected_ons): traces.append(self._format_trace(trace, id_, confid)) return xr.concat(traces, dim=AXIS.component_dim).assign_coords( @@ -198,12 +198,12 @@ def make_movie(self) -> Movie: return Movie.from_array(movie) def add_cell( - self, position: Position, radius: int, trace: np.ndarray, id_: str, confidence: float = 0.0 + self, position: Position, radius: int, trace: np.ndarray, id_: str, detected_on: int = 0 ) -> None: - new_footprint = self._generate_footprint(radius, position, id_, confidence) + new_footprint = self._generate_footprint(radius, position, id_, detected_on) self._footprints = xr.concat([self._footprints, new_footprint], dim=AXIS.component_dim) - new_trace = self._format_trace(trace, id_, confidence) + new_trace = self._format_trace(trace, id_, detected_on) self._traces = xr.concat([self._traces, new_trace], dim=AXIS.component_dim) def drop_cell(self, id_: str | Iterable[str]) -> None: diff --git a/src/cala/testing/util.py b/src/cala/testing/util.py index ac37b24e..8801343c 100644 --- a/src/cala/testing/util.py +++ b/src/cala/testing/util.py @@ -1,3 +1,5 @@ +import cv2 +import numpy as np import xarray as xr @@ -11,3 +13,20 @@ def assert_scalar_multiple_arrays(a: xr.DataArray, b: xr.DataArray, /, rtol: flo aabb = a.dot(a) * b.dot(b) assert abab > aabb * (1 - rtol) + + +def generate_text_image( + text: str, + frame_dims: tuple[int, int] = (256, 256), + org: tuple[int, int] = None, + color: tuple[int, int, int] = (255, 255, 255), + thickness: int = 2, + font_scale: int = 1, +) -> np.ndarray: + image = np.zeros(frame_dims, np.uint8) + font = cv2.FONT_HERSHEY_SIMPLEX + + if org is None: + org = (frame_dims[0] // 2, frame_dims[1] // 2) + + return cv2.putText(image, text, org, font, font_scale, color, thickness, cv2.LINE_AA) diff --git a/src/cala/util.py b/src/cala/util.py index fa6883ca..32be5074 100644 --- a/src/cala/util.py +++ b/src/cala/util.py @@ -1,16 +1,24 @@ -from collections.abc import Sequence +from collections.abc import Generator, Sequence from datetime import datetime +from itertools import count +from typing import Annotated as A from uuid import uuid4 import numpy as np import xarray as xr +from noob import Name +from cala.assets import Frame from cala.models import AXIS -def package_frame( - frame: np.ndarray, index: int, timestamp: datetime | str | None = None -) -> xr.DataArray: +def counter(start: int = 0, limit: int = 1e7) -> A[Generator[int], Name("idx")]: + cnt = count(start=start) + while (val := next(cnt)) < limit: + yield val + + +def package_frame(frame: np.ndarray, index: int, timestamp: datetime | str | None = None) -> Frame: """Transform a 2D numpy frame into an xarray DataArray. Args: @@ -37,12 +45,13 @@ def package_frame( name="frame", ) - return frame.assign_coords( + da = frame.assign_coords( { AXIS.width_dim: range(frame.sizes[AXIS.width_dim]), AXIS.height_dim: range(frame.sizes[AXIS.height_dim]), } ) + return Frame.from_array(da.astype(float)) def create_id() -> str: diff --git a/tests/data/pipelines/odl.yaml b/tests/data/pipelines/odl.yaml index d84f0301..569a2d51 100644 --- a/tests/data/pipelines/odl.yaml +++ b/tests/data/pipelines/odl.yaml @@ -25,78 +25,98 @@ assets: overlaps: type: cala.assets.Overlaps scope: session - -# Add refiners in nodes + residuals: + type: cala.assets.Residual + scope: session nodes: source: - type: cala.testing.single_cell_source - params: - n_frames: 30 + type: cala.testing.SingleCellSource denoise: type: cala.nodes.prep.denoise params: - ksize: - - 3 - - 3 - sigmaX: 1.5 + method: gaussian + kwargs: + ksize: [3, 3] + sigmaX: 1.5 depends: - frame: source.frame glow: type: cala.nodes.prep.GlowRemover depends: - frame: denoise.frame - motion: - type: cala.nodes.prep.RigidStabilizer - params: - drift_speed: 1.0 - depends: - - frame: glow.frame +# motion: +# type: cala.nodes.prep.RigidStabilizer +# params: +# drift_speed: 0.5 +# depends: +# - frame: glow.frame size_est: type: cala.nodes.prep.SizeEst params: - hardset_radius: 10 + log_kwargs: + min_sigma: 3 + max_sigma: 10 + num_sigma: 10 + threshold: 0.2 + overlap: 0.5 depends: - - frame: motion.frame + - frame: glow.frame cache: type: cala.nodes.buffer.fill_buffer params: size: 100 depends: - buffer: assets.buffer - - frame: motion.frame + - frame: glow.frame trace_frame: type: cala.nodes.traces.FrameUpdate params: - tolerance: 0.001 + tol: 0.001 + max_iter: 100 depends: - traces: assets.traces - footprints: assets.footprints - - frame: motion.frame + - frame: glow.frame - overlaps: assets.overlaps pix_frame: type: cala.nodes.pixel_stats.ingest_frame depends: - pixel_stats: assets.pix_stats - - frame: motion.frame + - frame: glow.frame - new_traces: trace_frame.latest_trace comp_frame: type: cala.nodes.component_stats.ingest_frame depends: - component_stats: assets.comp_stats - - frame: motion.frame + - frame: glow.frame - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 1 + tol: 0.0001 + max_iter: 100 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + facetune: + type: cala.nodes.cleanup.clear_overestimates + params: + nmf_error: 1.0 + depends: + - footprints: footprints_frame.footprints + - residuals: assets.residuals residual: type: cala.nodes.residual.build - params: - clip_threshold: 0.001 depends: - - trigger: trace_frame.latest_trace - frames: assets.buffer - - footprints: assets.footprints + - footprints: footprints_frame.footprints - traces: assets.traces + - residuals: assets.residuals cleanup: type: cala.nodes.cleanup.purge_razed_components params: @@ -115,6 +135,7 @@ nodes: params: min_frames: 10 detect_thresh: 1.0 + reprod_tol: 0.001 depends: - residuals: residual.movie - detect_radius: size_est.radius @@ -153,23 +174,13 @@ nodes: - new_traces: catalog.new_traces # DETECT ENDS - footprints_frame: - type: cala.nodes.footprints.Footprinter - params: - bep: 1 - tol: 0.0000001 - depends: - - footprints: assets.footprints - - pixel_stats: pix_component.value - - component_stats: comp_component.value - overlaps_update: type: cala.nodes.overlap.initialize depends: - overlaps: assets.overlaps - - footprints: footprints_frame.footprints + - footprints: footprint_component.footprints return: type: return depends: - - motion.frame \ No newline at end of file + - glow.frame \ No newline at end of file diff --git a/tests/data/pipelines/with_src.yaml b/tests/data/pipelines/with_src.yaml new file mode 100644 index 00000000..3cdb055a --- /dev/null +++ b/tests/data/pipelines/with_src.yaml @@ -0,0 +1,225 @@ +noob_id: cala-io +noob_model: noob.tube.TubeSpecification +noob_version: 0.1.1.dev118+g64d81b7 + +assets: + buffer: + type: cala.assets.Movie + scope: session + depends: + - cache.buffer + footprints: + type: cala.assets.Footprints + scope: session + traces: + type: cala.assets.Traces + scope: session + pix_stats: + type: cala.assets.PixStats + scope: session + comp_stats: + type: cala.assets.CompStats + scope: session + overlaps: + type: cala.assets.Overlaps + scope: session + residuals: + type: cala.assets.Residual + scope: session + + +nodes: + source: + type: cala.nodes.io.stream + params: + files: + - tests/data/movies/msCam1.avi + counter: + type: cala.util.counter + frame: + type: cala.util.package_frame + depends: + - frame: source.value + - index: counter.idx + + #PREPROCESS BEGINS + saltpepper: + type: cala.nodes.prep.denoise + params: + method: median + kwargs: + ksize: 3 + depends: + - frame: frame.value + denoise: + type: cala.nodes.prep.denoise + params: + method: nonlocal + kwargs: + h: 4 + templateWindowSize: 7 + searchWindowSize: 21 + depends: + - frame: saltpepper.frame + lines: + type: cala.nodes.prep.hlines.remove + depends: + - frame: glow.frame + glow: + type: cala.nodes.prep.GlowRemover + depends: + - frame: denoise.frame + smooth: + type: cala.nodes.prep.denoise + params: + method: gaussian + kwargs: + ksize: [ 7, 7 ] + sigmaX: 1.5 + depends: + - frame: lines.frame + motion: + type: cala.nodes.prep.RigidStabilizer + params: + drift_speed: 0.5 + depends: + - frame: smooth.frame + size_est: + type: cala.nodes.prep.SizeEst + params: + noise_threshold: 2.0 + n_frames: 30 + log_kwargs: + min_sigma: 3 + max_sigma: 10 + num_sigma: 10 + threshold: 0.2 + overlap: 0.5 + depends: + - frame: motion.frame + cache: + type: cala.nodes.buffer.fill_buffer + params: + size: 100 + depends: + - buffer: assets.buffer + - frame: motion.frame + #PREPROCESS ENDS + + # FRAME UPDATE BEGINS + trace_frame: + type: cala.nodes.traces.FrameUpdate + params: + tol: 0.001 + max_iter: 100 + depends: + - traces: assets.traces + - footprints: assets.footprints + - frame: motion.frame + - overlaps: assets.overlaps + pix_frame: + type: cala.nodes.pixel_stats.ingest_frame + depends: + - pixel_stats: assets.pix_stats + - frame: motion.frame + - new_traces: trace_frame.latest_trace + comp_frame: + type: cala.nodes.component_stats.ingest_frame + depends: + - component_stats: assets.comp_stats + - frame: motion.frame + - new_traces: trace_frame.latest_trace + footprints_frame: + type: cala.nodes.footprints.Footprinter + params: + bep: 1 + tol: 0.0001 + max_iter: 100 + depends: + - footprints: assets.footprints + - pixel_stats: pix_frame.value + - component_stats: comp_frame.value + facetune: + type: cala.nodes.cleanup.clear_overestimates + params: + nmf_error: 1.0 + depends: + - footprints: footprints_frame.footprints + - residuals: assets.residuals + + residual: + type: cala.nodes.residual.build + depends: + - frames: assets.buffer + - footprints: footprints_frame.footprints + - traces: assets.traces + - residuals: assets.residuals + cleanup: + type: cala.nodes.cleanup.purge_razed_components + params: + min_thicc: 3 + depends: + - footprints: assets.footprints + - traces: assets.traces + - pix_stats: assets.pix_stats + - comp_stats: assets.comp_stats + - overlaps: assets.overlaps + - trigger: residual.movie + # FRAME UPDATE ENDS + + # DETECT BEGINS + nmf: + type: cala.nodes.detect.SliceNMF + params: + min_frames: 30 + detect_thresh: 2.0 + reprod_tol: 0.005 + depends: + - residuals: residual.movie + - detect_radius: size_est.radius + catalog: + type: cala.nodes.detect.Cataloger + params: + merge_threshold: 0.8 + depends: + - new_fps: nmf.new_fps + - new_trs: nmf.new_trs + - existing_fp: assets.footprints + - existing_tr: assets.traces + + trace_component: + type: cala.nodes.traces.ingest_component + depends: + - traces: assets.traces + - new_traces: catalog.new_traces + footprint_component: + type: cala.nodes.footprints.ingest_component + depends: + - footprints: assets.footprints + - new_footprints: catalog.new_footprints + pix_component: + type: cala.nodes.pixel_stats.ingest_component + depends: + - pixel_stats: assets.pix_stats + - frames: assets.buffer + - new_traces: catalog.new_traces + - traces: assets.traces + comp_component: + type: cala.nodes.component_stats.ingest_component + depends: + - component_stats: assets.comp_stats + - traces: assets.traces + - new_traces: catalog.new_traces + + overlaps_update: + type: cala.nodes.overlap.initialize + depends: + - overlaps: assets.overlaps + - footprints: footprint_component.footprints + # DETECT ENDS + + return: + type: return + depends: + - raw: frame.value + - prep: motion.frame \ No newline at end of file diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index e08267e8..ae2fb313 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,4 +1,5 @@ from .config import ( + cwd_to_pytest_base, set_config, set_dotenv, set_env, @@ -9,7 +10,7 @@ yaml_config, ) from .meta import monkeypatch_session -from .sims import connected_cells, separate_cells +from .toys import connected_cells, separate_cells, single_cell __all__ = [ "monkeypatch_session", @@ -21,6 +22,8 @@ "tmp_config_source", "tmp_cwd", "yaml_config", + "cwd_to_pytest_base", + "single_cell", "separate_cells", "connected_cells", ] diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 792a6be9..47c9361f 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -1,4 +1,5 @@ -from collections.abc import Callable, MutableMapping +import os +from collections.abc import Callable, Generator, MutableMapping from pathlib import Path from typing import Any @@ -156,3 +157,10 @@ def _flatten(d: MutableMapping, parent_key: str = "", separator: str = "__") -> else: items.append((new_key, value)) return dict(items) + + +@pytest.fixture +def cwd_to_pytest_base(request: pytest.FixtureRequest) -> Generator[None, Any, None]: + os.chdir(request.config.rootdir) + yield + os.chdir(request.config.invocation_params.dir) diff --git a/tests/fixtures/sims.py b/tests/fixtures/sims.py deleted file mode 100644 index ca6c0731..00000000 --- a/tests/fixtures/sims.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -import pytest - -from cala.testing.toy import FrameDims, Position, Toy - - -@pytest.fixture -def separate_cells() -> Toy: - n_frames = 50 - - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=3, - cell_positions=[ - Position(width=15, height=15), - Position(width=15, height=35), - Position(width=25, height=25), - Position(width=35, height=35), - ], - cell_traces=[ - np.zeros(n_frames, dtype=float), - np.ones(n_frames, dtype=float), - np.array(range(n_frames), dtype=float), - np.array(range(n_frames - 1, -1, -1), dtype=float), - ], - ) - - -@pytest.fixture -def connected_cells() -> Toy: - n_frames = 50 - - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=8, - cell_positions=[ - Position(width=15, height=15), - Position(width=15, height=35), - Position(width=25, height=25), - Position(width=35, height=35), - ], - cell_traces=[ - np.zeros(n_frames, dtype=float), - np.ones(n_frames, dtype=float), - np.array(range(n_frames), dtype=float), - np.array(range(n_frames - 1, -1, -1), dtype=float), - ], - ) diff --git a/tests/fixtures/toys.py b/tests/fixtures/toys.py new file mode 100644 index 00000000..fd9a8507 --- /dev/null +++ b/tests/fixtures/toys.py @@ -0,0 +1,22 @@ +import pytest + +from cala.testing import ConnectedSource, SeparateSource, SingleCellSource +from cala.testing.toy import Toy + + +@pytest.fixture +def single_cell() -> Toy: + source = SingleCellSource() + return source._toy + + +@pytest.fixture +def separate_cells() -> Toy: + source = SeparateSource() + return source._toy + + +@pytest.fixture +def connected_cells() -> Toy: + source = ConnectedSource() + return source._toy diff --git a/tests/test_io.py b/tests/test_io.py index 5a9be4ce..8c042b3f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -7,20 +7,7 @@ from skimage import io from cala.nodes.io import stream - - -def generate_text_image( - text: str, - frame_dims: tuple[int, int] = (256, 256), - org: tuple[int, int] = (50, 50), - color: tuple[int, int, int] = (255, 255, 255), - thickness: int = 2, - font_scale: int = 1, -) -> np.ndarray: - image = np.zeros(frame_dims, np.uint8) - font = cv2.FONT_HERSHEY_SIMPLEX - - return cv2.putText(image, text, org, font, font_scale, color, thickness, cv2.LINE_AA) +from cala.testing.util import generate_text_image def save_tiff(filename: Path, frame: np.ndarray) -> None: @@ -50,7 +37,7 @@ def test_tiff_stream(tmp_path): save_tiff(tmp_path / f"{i}.tif", image) media = sorted(glob(str(tmp_path / "*.tif"))) - s = iter(stream(media)) + s = stream(media) for idx, res in enumerate(s): np.testing.assert_array_equal(res, generate_text_image(str(idx))) @@ -65,7 +52,7 @@ def test_video_stream(tmp_path): save_movie(tmp_path / "video.mp4", video) media = sorted(glob(str(tmp_path / "*.mp4"))) - s = iter(stream(media)) + s = stream(media) for idx, res in enumerate(s): np.testing.assert_allclose( diff --git a/tests/test_iter/test_cleanup.py b/tests/test_iter/test_cleanup.py new file mode 100644 index 00000000..6281b449 --- /dev/null +++ b/tests/test_iter/test_cleanup.py @@ -0,0 +1,16 @@ +from cala.assets import Residual +from cala.models import AXIS +from cala.nodes.cleanup import clear_overestimates + + +def test_clear_overestimates(single_cell) -> None: + residual = Residual.from_array(single_cell.make_movie().array) + residual.array.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] *= -1 + + result = clear_overestimates( + footprints=single_cell.footprints, residuals=residual, nmf_error=-1.0 + ) + expected = single_cell.footprints.array.copy() + expected.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = 0 + + assert result.equals(expected) diff --git a/tests/test_iter/test_detect.py b/tests/test_iter/test_detect.py index 93a00e7e..db6ed22d 100644 --- a/tests/test_iter/test_detect.py +++ b/tests/test_iter/test_detect.py @@ -5,36 +5,16 @@ from cala.assets import AXIS, Footprints, Residual, Traces from cala.nodes.detect import Cataloger, SliceNMF -from cala.testing.toy import FrameDims, Position, Toy from cala.testing.util import assert_scalar_multiple_arrays -@pytest.fixture(autouse=True, scope="module") -def toy(): - n_frames = 30 - frame_dims = FrameDims(width=512, height=512) - cell_positions = [Position(width=256, height=256)] - cell_radii = 30 - cell_traces = [np.array(range(n_frames), dtype=float)] - confidences = [0.8] - - return Toy( - n_frames=n_frames, - frame_dims=frame_dims, - cell_radii=cell_radii, - cell_positions=cell_positions, - cell_traces=cell_traces, - confidences=confidences, - ) - - @pytest.fixture(scope="class") def slice_nmf(): return SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"min_frames": 10, "detect_thresh": 1}, + params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.00001}, ) ) @@ -49,47 +29,51 @@ def cataloger(): class TestSliceNMF: - def test_process(self, slice_nmf, toy): + def test_process(self, slice_nmf, single_cell): new_component = slice_nmf.process( - Residual.from_array(toy.make_movie().array), - detect_radius=toy.cell_radii[0] * 2, + Residual.from_array(single_cell.make_movie().array), + detect_radius=single_cell.cell_radii[0] * 2, ) if new_component: new_fp, new_tr = new_component else: raise AssertionError("Failed to detect a new component") - for new, old in zip([new_fp[0], new_tr[0]], [toy.footprints, toy.traces]): + for new, old in zip([new_fp[0], new_tr[0]], [single_cell.footprints, single_cell.traces]): assert_scalar_multiple_arrays(new.array, old.array) - def test_chunks(self, toy): + def test_chunks(self, single_cell): nmf = SliceNMF.from_specification( spec=NodeSpecification( id="test_slice_nmf", type="cala.nodes.detect.SliceNMF", - params={"min_frames": 10, "detect_thresh": 1}, + params={"min_frames": 10, "detect_thresh": 1, "reprod_tol": 0.001}, ) ) - fpts, trcs = nmf.process(Residual.from_array(toy.make_movie().array), detect_radius=10) + fpts, trcs = nmf.process( + Residual.from_array(single_cell.make_movie().array), detect_radius=10 + ) if not fpts or not trcs: raise AssertionError("Failed to detect a new component") fpt_arr = xr.concat([f.array for f in fpts], dim="component") - expected = toy.footprints.array[0] + expected = single_cell.footprints.array[0] result = (fpt_arr.sum(dim="component") > 0).astype(int) assert np.array_equal(expected, result) for trc in trcs: - assert_scalar_multiple_arrays(trc.array, toy.traces.array[0]) + assert_scalar_multiple_arrays(trc.array, single_cell.traces.array[0]) class TestCataloger: @pytest.fixture(scope="function") - def new_component(self, slice_nmf, toy): - return slice_nmf.process(Residual.from_array(toy.make_movie().array), detect_radius=60) + def new_component(self, slice_nmf, single_cell): + return slice_nmf.process( + Residual.from_array(single_cell.make_movie().array), detect_radius=60 + ) - def test_register(self, cataloger, new_component, toy): + def test_register(self, cataloger, new_component): new_fp, new_tr = new_component fp, tr = cataloger._register( new_fp=new_fp[0].array, @@ -99,22 +83,28 @@ def test_register(self, cataloger, new_component, toy): assert np.array_equal(fp.array, new_fp[0].array) assert np.array_equal(tr.array, new_tr[0].array) - def test_merge_with(self, slice_nmf, cataloger, toy): + def test_merge_with(self, slice_nmf, cataloger, single_cell): new_component = slice_nmf.process( - Residual.from_array(toy.make_movie().array), detect_radius=10 + Residual.from_array(single_cell.make_movie().array), detect_radius=10 ) new_fp, new_tr = new_component fp, tr = cataloger._merge_with( - new_fp[0].array, new_tr[0].array, toy.footprints.array, toy.traces.array, ["cell_0"] + new_fp[0].array, + new_tr[0].array, + single_cell.footprints.array, + single_cell.traces.array, + ["cell_0"], ) movie_result = (fp.array @ tr.array).reset_coords( - [AXIS.id_coord, AXIS.confidence_coord], drop=True + [AXIS.id_coord, AXIS.detect_coord], drop=True ) movie_new_comp = new_fp[0].array @ new_tr[0].array - movie_expected = (toy.make_movie().array + movie_new_comp).transpose(*movie_result.dims) + movie_expected = (single_cell.make_movie().array + movie_new_comp).transpose( + *movie_result.dims + ) xr.testing.assert_allclose(movie_result, movie_expected) @@ -167,11 +157,11 @@ def test_process_connected(self, slice_nmf, cataloger, connected_cells): # we're forcing a double-detection in this test new_fps, new_trs = cataloger.process(fps, trs, Footprints(), Traces()) - result = (new_fps.array @ new_trs.array).where(new_fps.array.max(dim=AXIS.component_dim), 0) - expected = movie.where(new_fps.array.max(dim=AXIS.component_dim), 0) + result = new_fps.array @ new_trs.array + expected = movie * (new_fps.array.max(dim=AXIS.component_dim) > 1e-3) assert new_fps.array is not None # 1. the footprints do not overlap assert np.all(np.triu(new_fps.array @ new_fps.array.rename(AXIS.component_rename), 1) == 0) # 2. the trace and footprint values are accurate (where they do exist) - xr.testing.assert_allclose(result, expected.transpose(*result.dims), atol=1e-5) + xr.testing.assert_allclose(result, expected.transpose(*result.dims), atol=1e-3) diff --git a/tests/test_iter/test_footprints.py b/tests/test_iter/test_footprints.py index 0fe9e717..14e85359 100644 --- a/tests/test_iter/test_footprints.py +++ b/tests/test_iter/test_footprints.py @@ -81,7 +81,7 @@ def test_ingest_frame(fpter, toy, request): footprints=toy.footprints, pixel_stats=pixstats, component_stats=compstats ) - expected = toy.footprints.copy() + expected = toy.footprints.model_copy() xr.testing.assert_allclose(result.array, expected.array) diff --git a/tests/test_iter/test_residual.py b/tests/test_iter/test_residual.py index 890bb697..a136ce2d 100644 --- a/tests/test_iter/test_residual.py +++ b/tests/test_iter/test_residual.py @@ -1,11 +1,11 @@ import numpy as np import pytest +import xarray as xr from noob.node import Node, NodeSpecification from cala.assets import Residual from cala.models.axis import AXIS -from cala.nodes.residual import _clear_overestimates -from cala.testing.toy import FrameDims, Position, Toy +from cala.nodes.residual import _align_overestimates, _find_unlayered_footprints @pytest.fixture(scope="function") @@ -17,34 +17,42 @@ def init() -> Node: def test_init(init, separate_cells) -> None: result = init.process( + residuals=Residual(), footprints=separate_cells.footprints, traces=separate_cells.traces, frames=separate_cells.make_movie(), - trigger=True, ) assert np.all(result.array == 0) -@pytest.fixture -def one_cell() -> Toy: - n_frames = 50 +def test_align_overestimates(single_cell) -> None: + """ + grab the last frame of the residual. assume part of the footprint masked area is negative + traces needs to proportionally decrease, until the recalculated residual is zero - return Toy( - n_frames=n_frames, - frame_dims=FrameDims(width=50, height=50), - cell_radii=3, - cell_positions=[Position(width=25, height=25)], - cell_traces=[np.array(range(n_frames), dtype=float)], - ) + Eventually, this probably can be absorbed straight into trace frame_ingest as a constraint. + """ + movie = single_cell.make_movie() + last_frame = movie.array.isel({AXIS.frames_dim: -1}) + + last_res = xr.zeros_like(last_frame) + last_res.loc[{AXIS.width_coord: slice(single_cell.cell_positions[0].width, None)}] = -1 + last_res = last_res.where(single_cell.footprints.array[0].values, 0) + + last_trace = single_cell.traces.array.isel({AXIS.frames_dim: -1}) + + footprints = single_cell.footprints.array + + adjusted_traces = _align_overestimates(A=footprints, R_latest=last_res, C_latest=last_trace) + result = (footprints @ adjusted_traces).values + expected = movie.array.isel({AXIS.frames_dim: -2}).values -def test_clear_overestimates(one_cell) -> None: - residual = Residual.from_array(one_cell.make_movie().array) - residual.array.loc[{AXIS.width_coord: slice(one_cell.cell_positions[0].width, None)}] *= -1 + np.testing.assert_array_equal(result, expected) - result = _clear_overestimates(A=one_cell.footprints.array, R=residual.array, clip_val=-1.0) - expected = one_cell.footprints.array.copy() - expected.loc[{AXIS.width_coord: slice(one_cell.cell_positions[0].width, None)}] = 0 - assert result.equals(expected) +def test_find_exposed_footprints(connected_cells) -> None: + footprints = connected_cells.footprints + result = _find_unlayered_footprints(footprints.array) + assert result.sum(dim=AXIS.component_dim).max().item() == footprints.array.max().item() diff --git a/tests/test_iter/test_traces.py b/tests/test_iter/test_traces.py index 4e6122a9..b65105f9 100644 --- a/tests/test_iter/test_traces.py +++ b/tests/test_iter/test_traces.py @@ -26,7 +26,7 @@ def test_init(init, toy, request) -> None: def frame_update() -> Node: return Node.from_specification( spec=NodeSpecification( - id="frame_test", type="cala.nodes.traces.FrameUpdate", params={"tolerance": 1e-3} + id="frame_test", type="cala.nodes.traces.FrameUpdate", params={"tol": 1e-3} ) ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1014d9e2..66ebfc9e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import xarray as xr from noob import Cube, SynchronousRunner, Tube @@ -6,16 +7,26 @@ from cala.models import AXIS -@pytest.fixture(params=["single_cell_source", "two_cells_source", "two_overlapping_source"]) -def tube(request): - tube = Tube.from_specification("cala-odl") - source = Node.from_specification( - NodeSpecification( - id="source", type=f"cala.testing.{request.param}", params={"n_frames": 50} - ) +@pytest.fixture( + params=[ + "SingleCellSource", + "TwoCellsSource", + "SeparateSource", + "TwoOverlappingSource", + "GradualOnSource", + "SplitOffSource", + ] +) +def source(request): + return Node.from_specification( + NodeSpecification(id="source", type=f"cala.testing.{request.param}") ) - tube.nodes["source"] = source + +@pytest.fixture +def tube(source): + tube = Tube.from_specification("cala-odl") + tube.nodes["source"] = source return tube @@ -37,40 +48,43 @@ def test_process(runner) -> None: assert runner.cube.assets["buffer"].obj.array.size > 0 -def test_iter(runner) -> None: - gen = runner.iter(n=runner.tube.nodes["source"].spec.params["n_frames"]) +@pytest.mark.xfail(raises=NotImplementedError) +def test_odl(runner, source) -> None: + gen = runner.iter(n=source.instance.n_frames) + src_name = source.spec.type_.split(".")[-1] + toy = source.instance._toy.model_copy() - movie = [] - for _, exp in enumerate(gen): - movie.append(exp[0].array) + preprocessed_frames = [] + for fr in gen: + preprocessed_frames.append(fr[0].array) fps = runner.cube.assets["footprints"].obj trs = runner.cube.assets["traces"].obj - expected = xr.concat(movie, dim=AXIS.frames_dim) - result = (fps.array @ trs.array).transpose(*expected.dims) - - if runner.tube.nodes["source"].fn.__name__ == "two_overlapping_source": - diff = expected - result - for d_fr, e_fr in zip(diff, expected): - assert d_fr.max() <= e_fr.quantile(0.98) * 2e-2 - else: - xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) - - -@pytest.mark.xfail -def test_run(runner) -> None: - result = runner.run(n=5) - - assert result - + # Correct component count + if src_name not in ["SeparateSource", "SplitOffSource"]: + assert toy.traces.array.sizes[AXIS.component_dim] == trs.array.sizes[AXIS.component_dim] + elif src_name == "SeparateSource": + # 2 is the # of discoverable cells (non-constant) for SeparateSource + assert trs.array.sizes[AXIS.component_dim] == 2 + elif src_name == "SplitOffSource": + # 3 because one should be deprecated + assert trs.array.sizes[AXIS.component_dim] == 3 + + if src_name in ["TwoOverlappingSource", "GradualOnSource"]: + # Traces are reasonably similar + tr_corr = xr.corr( + toy.traces.array, trs.array.rename(AXIS.component_rename), dim=AXIS.frame_coord + ) + for corr in tr_corr: + assert np.isclose(corr.max(), 1, atol=1e-2) -@pytest.mark.xfail -def test_combined_footprint() -> None: - """Start with two footprints combined""" - raise AssertionError("Not implemented") + elif src_name in ["SingleCellSource", "TwoCellsSource", "SeparateSource"]: + expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + result = (fps.array @ trs.array).transpose(*expected.dims) + xr.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) -@pytest.mark.xfail -def test_redundant_footprint() -> None: - """start with redundant footprints""" - raise AssertionError("Not implemented") + elif src_name == "SplitOffSource": + expected = xr.concat(preprocessed_frames, dim=AXIS.frame_coord) + result = (fps.array @ trs.array).transpose(*expected.dims) + raise NotImplementedError("Deprecation not implemented") diff --git a/tests/test_prep/test_denoise.py b/tests/test_prep/test_denoise.py index 116e9718..87a7c503 100644 --- a/tests/test_prep/test_denoise.py +++ b/tests/test_prep/test_denoise.py @@ -46,7 +46,7 @@ def test_denoise( results = [] for frame in iter(gen): - results.append(denoise(frame=frame, method=method, **params)) + results.append(denoise(frame=frame, method=method, kwargs=params)) for exp, res in zip(expected, results): np.testing.assert_allclose(exp.values, res.array.values) diff --git a/tests/test_prep/test_hlines.py b/tests/test_prep/test_hlines.py new file mode 100644 index 00000000..eacb7aa8 --- /dev/null +++ b/tests/test_prep/test_hlines.py @@ -0,0 +1,23 @@ +import numpy as np +from skimage.metrics import structural_similarity + +from cala.nodes.prep.hlines import remove +from cala.testing.util import generate_text_image +from cala.util import package_frame + + +def test_remove_lines(): + img = generate_text_image( + "8", frame_dims=(256, 256), org=(25, 230), thickness=20, font_scale=10 + ) + + noise_amp = 40 + noise = np.tile(np.random.randint(0, noise_amp, img.shape[0]), (img.shape[1], 1)).T + + noisy_img = img // 1.5 + noise + + frame = package_frame(noisy_img, 0) + + result = remove(frame) + + assert structural_similarity(img.astype(int), result.array.values.astype(int)) == 1 diff --git a/tests/test_prep/test_r_estimate.py b/tests/test_prep/test_r_estimate.py index b254dfd5..9461bff1 100644 --- a/tests/test_prep/test_r_estimate.py +++ b/tests/test_prep/test_r_estimate.py @@ -1,4 +1,3 @@ -from cala.assets import Frame from cala.models import AXIS from cala.nodes.prep.r_estimate import SizeEst from cala.testing.toy import Position @@ -18,7 +17,7 @@ def test_size_estim(separate_cells): max_proj = package_frame( separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=1 ) - result = node.get_median_radius(Frame.from_array(max_proj)) + result = node.get_median_radius(max_proj) expected = separate_cells.cell_radii[0] - 1 assert result == expected @@ -27,9 +26,9 @@ def test_size_estim(separate_cells): max_proj = package_frame( separate_cells.make_movie().array.max(dim=AXIS.frames_dim).values, index=3 ) - result = node.get_median_radius(Frame.from_array(max_proj)) + result = node.get_median_radius(max_proj) - assert result == expected + assert result == expected // 2 + 1 assert len(node.sizes_) == 3 for center in node.centers_: diff --git a/tests/test_util.py b/tests/test_util.py index 0ab6a9b9..7791e3c8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,7 +2,6 @@ import numpy as np -from cala.assets import Frame from cala.util import package_frame @@ -13,6 +12,6 @@ def test_package_frame(): timestamp = datetime(2023, 4, 8, 12, 0, 0) # Transform the frame - dataarray = package_frame(frame, index, timestamp) + result = package_frame(frame, index, timestamp) - assert Frame.from_array(dataarray) + assert np.array_equal(result.array.values, frame)