Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5371279
tests: fixture cleanup
raymondwjang Aug 18, 2025
5f0ef80
env: update noob
raymondwjang Aug 18, 2025
162ab9a
debug: small errors
raymondwjang Aug 18, 2025
cfebd1a
tests: add cell size est
raymondwjang Aug 18, 2025
ac24b74
tests: try gradual on test
raymondwjang Aug 19, 2025
69c0aa6
feat: swap confident score with detected_on frame idx
raymondwjang Aug 19, 2025
ab17b39
feat: max_iter implementation for footprints and traces
raymondwjang Aug 19, 2025
57d2a5c
feat: catalog merge with "touching" components too
raymondwjang Aug 19, 2025
902a9e7
feat: gradual-on source has more realistic, smooth sources
raymondwjang Aug 19, 2025
4226a4c
debug: energy is frobenuis-normed instead of summing V -> energy refl…
raymondwjang Aug 19, 2025
1aecde5
test: omit movie stabil temporarily its causing issues
raymondwjang Aug 19, 2025
6673f9a
format: ruff
raymondwjang Aug 19, 2025
bb4a866
tests: refit tests
raymondwjang Aug 19, 2025
0ebbc07
tests: traces are floats!!
raymondwjang Aug 19, 2025
910f182
BREAKING: footprint frame ingestion before residual
raymondwjang Aug 19, 2025
a700d1a
tests: improve pipeline test parametrization
raymondwjang Aug 19, 2025
8603dc1
tests: add split off test
raymondwjang Aug 20, 2025
0469c80
feat: asset validation - no extra coordinates for certain assets
raymondwjang Aug 20, 2025
a2f9131
debug: remove extra coords in R_min
raymondwjang Aug 20, 2025
b3d2166
feat: reproduction tolerance param for nmf
raymondwjang Aug 21, 2025
10a7eb5
feat: footprint mask tuning post footprint frame ingestion
raymondwjang Aug 21, 2025
274e9a8
tests: longer splitoff test
raymondwjang Aug 21, 2025
efa89fd
feat: overlap trace adjustment instead footprint purge
raymondwjang Aug 21, 2025
5abcc84
tests: add splitoff test
raymondwjang Aug 21, 2025
145758a
feat: smaller spotlight size. we can merge more easily
raymondwjang Aug 21, 2025
efc8d6b
debug: stream outputs a generator, not an iterator
raymondwjang Aug 22, 2025
16bd169
feat: add counter for frame idx
raymondwjang Aug 22, 2025
7d00b74
feat: add nonlocal method to denoise
raymondwjang Aug 22, 2025
a11dec3
feat: add source movie using pipe
raymondwjang Aug 22, 2025
becbeda
debug: fix parameters in denoise with odl.yaml
raymondwjang Aug 22, 2025
ae0db27
feat: blob size detection ignores noise
raymondwjang Aug 26, 2025
bb57e18
feat: use l0 norm to leave error in residuals. this is to leave room …
raymondwjang Aug 26, 2025
ed72075
tests: generate_text_image comes out to util module
raymondwjang Aug 26, 2025
529640b
feat: hline removal in prep
raymondwjang Aug 26, 2025
a973d28
feat: add hline removal, gaussian smoothing, increase reprod-toleranc…
raymondwjang Aug 26, 2025
599f561
debug: clip trace after addition
raymondwjang Aug 26, 2025
ba02168
docs: simplify pr template
raymondwjang Aug 26, 2025
ef03b46
docs: reduce r_estimate
raymondwjang Aug 27, 2025
023f3bb
tests: add cwd_to_pytest_base
raymondwjang Aug 27, 2025
0da8085
tests: update denoise test
raymondwjang Aug 27, 2025
a1937a6
tests: update package_frame
raymondwjang Aug 27, 2025
850f3df
tests: xfail for splitoff while deprecation is in the works
raymondwjang Aug 27, 2025
dad862b
format: ruff
raymondwjang Aug 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ Describe the tests that you ran to verify your changes.
Provide instructions so that others can reproduce.
-->

- [ ] Ran unit tests
- [ ] Ran integration tests
- [ ] Performed manual testing
- [ ] Updated existing tests
- [ ] Unit tests
- [ ] Integration tests
- [ ] Existing tests update


## πŸ› οΈ Dependencies
Expand All @@ -36,9 +35,5 @@ List any new dependencies added or existing ones updated.

## βœ… Checklist

- [ ] My code follows the project's style guidelines
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
4 changes: 2 additions & 2 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions src/cala/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Footprints(Asset):
member=Footprint.entity(),
group_by=Dims.component,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)

Expand Down Expand Up @@ -134,6 +135,7 @@ def from_array(
member=Trace.entity(),
group_by=Dims.component,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)

Expand All @@ -145,6 +147,7 @@ class Movie(Asset):
member=Frame.entity(),
group_by=Dims.frame.value,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)

Expand Down Expand Up @@ -180,6 +183,7 @@ class CompStats(Asset):
dims=comp_dims,
dtype=float,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)

Expand All @@ -191,6 +195,7 @@ class PixStats(Asset):
dims=(Dims.width.value, Dims.height.value, Dims.component.value),
dtype=float,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)

Expand All @@ -202,6 +207,7 @@ class Overlaps(Asset):
dims=comp_dims,
dtype=bool,
checks=[has_no_nan],
allow_extra_coords=False,
)
)

Expand All @@ -224,5 +230,6 @@ class Residual(Asset):
member=Frame.entity(),
group_by=Dims.frame.value,
checks=[is_non_negative, has_no_nan],
allow_extra_coords=False,
)
)
10 changes: 5 additions & 5 deletions src/cala/models/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, Field

from cala.models.checks import is_unique, is_unit_interval
from cala.models.checks import has_no_nan, is_unique


class Axis:
Expand All @@ -17,7 +17,7 @@ class Axis:

id_coord: str = "id_"
timestamp_coord: str = "timestamp"
confidence_coord: str = "confidence"
detect_coord: str = "detected_on"
frame_coord: str = "frame"
width_coord: str = "width"
height_coord: str = "height"
Expand All @@ -37,7 +37,7 @@ def component_rename(self) -> dict[str, str]:
return {
AXIS.component_dim: f"{AXIS.component_dim}'",
AXIS.id_coord: f"{AXIS.id_coord}'",
AXIS.confidence_coord: f"{AXIS.confidence_coord}'",
AXIS.detect_coord: f"{AXIS.detect_coord}'",
}


Expand All @@ -62,11 +62,11 @@ class Coords(Enum):
width = Coord(name=AXIS.width_coord, dtype=int, checks=[is_unique])
frame = Coord(name=AXIS.frame_coord, dtype=int, checks=[is_unique])
timestamp = Coord(name=AXIS.timestamp_coord, dtype=str, checks=[is_unique])
confidence = Coord(name=AXIS.confidence_coord, dtype=float, checks=[is_unit_interval])
detected = Coord(name=AXIS.detect_coord, dtype=int, checks=[has_no_nan])


class Dims(Enum):
width = Dim(name=AXIS.width_dim, coords=[Coords.width.value])
height = Dim(name=AXIS.height_dim, coords=[Coords.height.value])
frame = Dim(name=AXIS.frames_dim, coords=[Coords.frame.value, Coords.timestamp.value])
component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.confidence.value])
component = Dim(name=AXIS.component_dim, coords=[Coords.id.value, Coords.detected.value])
6 changes: 3 additions & 3 deletions src/cala/models/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Entity(BaseModel):
coords: list[Coord] = Field(default_factory=list)
dtype: type
checks: list[Callable] = Field(default_factory=list)
allow_extra_coords: bool = True

_model: DataArraySchema = PrivateAttr(DataArraySchema())

Expand Down Expand Up @@ -47,15 +48,14 @@ def to_schema(self) -> DataArraySchema:
checks=self.checks,
)

@staticmethod
def _build_coord_schema(coords: list[Coord]) -> CoordsSchema:
def _build_coord_schema(self, coords: list[Coord]) -> CoordsSchema:
spec = dict()

for c in coords:
dim = DimsSchema((c.dim,)) if c.dim else None
spec[c.name] = DataArraySchema(dims=dim, dtype=DTypeSchema(c.dtype), checks=c.checks)

return CoordsSchema(spec)
return CoordsSchema(spec, allow_extra_keys=self.allow_extra_coords)


class Group(Entity):
Expand Down
27 changes: 24 additions & 3 deletions src/cala/nodes/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,31 @@
import xarray as xr
from noob import Name

from cala.assets import CompStats, Footprints, Overlaps, PixStats, Traces
from cala.assets import CompStats, Footprints, Overlaps, PixStats, Residual, Traces
from cala.models import AXIS


def clear_overestimates(
footprints: Footprints, residuals: Residual, nmf_error: float
) -> A[Footprints, Name("footprints")]:
"""
Remove all sections of the footprints that cause negative residuals.

This occurs by:
1. find "significant" negative residual spots that is more than a noise level, and thus
cannot be clipped to zero. !!!! (only of the latest frame, and then go back to trace update..?)
2. all footprint values at these spots go to zero.
"""
if residuals.array is None:
return footprints
R_min = residuals.array.isel({AXIS.frames_dim: -1}).reset_coords(
[AXIS.frame_coord, AXIS.timestamp_coord], drop=True
)
tuned_fp = footprints.array.where(R_min > -nmf_error, 0, drop=False)

return tuned_fp


def purge_razed_components(
footprints: Footprints,
traces: Traces,
Expand Down Expand Up @@ -101,13 +122,13 @@ def filter_components(
comp_stats.array.set_xindex(AXIS.id_coord)
.set_xindex(f"{AXIS.id_coord}'")
.sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()})
.reset_index(AXIS.id_coord)
.reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"])
)
overlaps.array = (
overlaps.array.set_xindex(AXIS.id_coord)
.set_xindex(f"{AXIS.id_coord}'")
.sel({AXIS.id_coord: keep_ids, f"{AXIS.id_coord}'": keep_ids.values.tolist()})
.reset_index(AXIS.id_coord)
.reset_index([AXIS.id_coord, f"{AXIS.id_coord}'"])
)

return footprints, traces, pix_stats, comp_stats, overlaps
32 changes: 25 additions & 7 deletions src/cala/nodes/detect/catalog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Hashable, Iterable
from typing import Annotated as A

import cv2
import numpy as np
import xarray as xr
from noob import Name
Expand Down Expand Up @@ -72,21 +73,19 @@ def process(
footprints = xr.concat(
footprints,
dim=AXIS.component_dim,
coords=[AXIS.id_coord, AXIS.confidence_coord],
coords=[AXIS.id_coord, AXIS.detect_coord],
combine_attrs=combine_attr_replaces,
)
traces = xr.concat(
traces,
dim=AXIS.component_dim,
coords=[AXIS.id_coord, AXIS.confidence_coord],
coords=[AXIS.id_coord, AXIS.detect_coord],
combine_attrs=combine_attr_replaces,
)

return Footprints.from_array(footprints), Traces.from_array(traces)

def _register(
self, new_fp: xr.DataArray, new_tr: xr.DataArray, confidence: float = 0.0
) -> tuple[Footprint, Trace]:
def _register(self, new_fp: xr.DataArray, new_tr: xr.DataArray) -> tuple[Footprint, Trace]:

new_id = create_id()

Expand All @@ -95,7 +94,10 @@ def _register(
.assign_coords(
{
AXIS.id_coord: (AXIS.component_dim, [new_id]),
AXIS.confidence_coord: (AXIS.component_dim, [confidence]),
AXIS.detect_coord: (
AXIS.component_dim,
[new_tr[AXIS.frame_coord].max().item()],
),
}
)
.isel({AXIS.component_dim: 0})
Expand All @@ -105,7 +107,10 @@ def _register(
.assign_coords(
{
AXIS.id_coord: (AXIS.component_dim, [new_id]),
AXIS.confidence_coord: (AXIS.component_dim, [confidence]),
AXIS.detect_coord: (
AXIS.component_dim,
[new_tr[AXIS.frame_coord].max().item()],
),
}
)
.isel({AXIS.component_dim: 0})
Expand Down Expand Up @@ -211,8 +216,21 @@ def _merge_matrix(
fps_base = fps_base.rename(AXIS.component_rename)
trs_base = trs_base.rename(AXIS.component_rename)

fps = self._expand_boundary(fps > 0)

overlaps = fps @ fps_base > 0
# this should later reflect confidence
corrs = xr.corr(trs, trs_base, dim=AXIS.frames_dim) > self.merge_threshold

return overlaps * corrs

def _expand_boundary(self, mask: xr.DataArray) -> xr.DataArray:
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
return xr.apply_ufunc(
lambda x: cv2.morphologyEx(x, cv2.MORPH_DILATE, kernel, iterations=1),
mask.astype(np.uint8),
input_core_dims=[AXIS.spatial_dims],
output_core_dims=[AXIS.spatial_dims],
vectorize=True,
dask="parallelized",
)
13 changes: 8 additions & 5 deletions src/cala/nodes/detect/slice_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class SliceNMF(Node):
"""Wait until this number of frames to begin detecting."""
detect_thresh: float
"""Minimum detection threshold for brightness fluctuation."""
reprod_tol: float
"""Mean pixel value error tolerance for reproduced slice from the new component"""

nmf_kwargs: dict[str, Any] = Field(default_factory=dict)

error_: float = Field(None)
Expand All @@ -43,7 +46,7 @@ def process(
fps = []
trs = []

while np.sqrt(energy.max()).item() > self.detect_thresh: # or use res directly
while energy.max().item() >= self.detect_thresh: # or use res directly
# Find and analyze neighborhood of maximum variance
slice_ = self._get_max_energy_slice(
arr=res, energy_landscape=energy, radius=detect_radius
Expand All @@ -55,25 +58,25 @@ def process(
)

l1_norm = slice_.sum().item()
# l0_norm = np.prod(slice_.shape) # this fails when the residuals are tiny
comp_recon = a_new @ c_new

energy.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0

if (self.error_ / l1_norm) <= self._model.tol:
if (self.error_ / l1_norm) <= self.reprod_tol:
fps.append(Footprint.from_array(a_new))
trs.append(Trace.from_array(c_new))
res = (res - comp_recon).clip(0)
else:
res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = 0
l0_norm = np.prod(slice_.shape)
res.loc[{ax: slice_.coords[ax] for ax in AXIS.spatial_dims}] = self.error_ / l0_norm

return fps, trs

def _get_energy(self, res: xr.DataArray) -> xr.DataArray:
pixels_median = res.median(dim=AXIS.frames_dim)
V = res - pixels_median

return (V**2).sum(dim=AXIS.frames_dim)
return np.sqrt((V**2).mean(dim=AXIS.frames_dim))

def _get_max_energy_slice(
self,
Expand Down
17 changes: 14 additions & 3 deletions src/cala/nodes/footprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from noob import Name, process_method

from cala.assets import CompStats, Footprints, PixStats
from cala.logging import init_logger
from cala.models import AXIS


class Footprinter:
logger = init_logger(__name__)

def __init__(self, tol: float, max_iter: int | None = None, bep: int | None = None):
self.bep = bep
Expand Down Expand Up @@ -55,7 +57,7 @@ def ingest_frame(
kernel = self._expansion_kernel()
mask = self._expand_boundary(kernel, mask)

# for _ in range(self.max_iter):
cnt = 0
while True:
AM = A.rename(AXIS.component_rename) @ M
numerator = W - AM
Expand All @@ -72,8 +74,15 @@ def ingest_frame(
update = numerator / (M_diag + np.finfo(float).tiny)
A_new = (mask * (A + update)).clip(min=0)

if (np.abs(A - A_new).sum() / np.prod(A.shape)).item() < self.tol:
step = (np.abs(A - A_new).sum() / np.prod(A.shape)).item()

cnt += 1
maxed = self.max_iter and (cnt == self.max_iter)

if step < self.tol or maxed:
footprints.array = A_new
if maxed:
self.logger.debug(msg="max_iter reached before converging.")
return footprints
else:
A = A_new
Expand All @@ -93,7 +102,9 @@ def _expand_boundary(self, kernel: np.ndarray, mask: xr.DataArray) -> xr.DataArr
)


def ingest_component(footprints: Footprints, new_footprints: Footprints) -> Footprints:
def ingest_component(
footprints: Footprints, new_footprints: Footprints
) -> A[Footprints, Name("footprints")]:
if new_footprints.array is None:
return footprints

Expand Down
Loading
Loading