Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1b6aa87
feat: update test settings
raymondwjang Oct 7, 2025
7d2feaa
feat: asset validate / sparsify optional
raymondwjang Oct 7, 2025
9640834
feat: asset validate / sparsify optional
raymondwjang Oct 7, 2025
7a2d692
perf: optimize residual to only update the latest, or zero out if rec…
raymondwjang Oct 8, 2025
1ca5763
feat: assert_scalar_multiple_arrays error message
raymondwjang Oct 8, 2025
8dc78f7
test: rank1nmf test
raymondwjang Oct 8, 2025
72245e4
test: assert_scalar_multiple_arrays for numpy
raymondwjang Oct 8, 2025
c830f8c
feat: fix trace update
raymondwjang Oct 8, 2025
b0deebb
feat: using rank1nmf instead of sklearn nmf
raymondwjang Oct 9, 2025
9ba0838
feat: optimize footprint update
raymondwjang Oct 9, 2025
5a61f7e
feat: footprint save as sparse
raymondwjang Oct 9, 2025
55de160
feat: footprint save as sparse
raymondwjang Oct 9, 2025
0e7bbce
feat: allow making empty assets with from_array
raymondwjang Oct 9, 2025
8016051
chore: simplify indexing in residual
raymondwjang Oct 9, 2025
efffcf0
perf: no assets within nodes methods
raymondwjang Oct 9, 2025
6154c61
feat: register batch instead of loop
raymondwjang Oct 9, 2025
e939cc1
test: refit tests
raymondwjang Oct 9, 2025
dbbd349
test: refit test yamls
raymondwjang Oct 9, 2025
832c3c2
feat: more optimizations / simplifications
raymondwjang Oct 9, 2025
e7b12c3
format: ruff
raymondwjang Oct 9, 2025
961ef20
test: no more boundary expansion
raymondwjang Oct 9, 2025
339d79d
debug: residual size limit
raymondwjang Oct 9, 2025
0fb2841
debug: overlap update in case of multiple new comps
raymondwjang Oct 9, 2025
9e91d44
test: refit residual test
raymondwjang Oct 9, 2025
cbd1497
test: performance test setup
raymondwjang Oct 13, 2025
ebb433c
feat: rename residual asset to buffer
raymondwjang Oct 13, 2025
6f1665c
feat: build buffer class for concat / query speed improvement (x50 fa…
raymondwjang Oct 14, 2025
3f14987
feat: build buffer class for concat / query speed improvement (x50 fa…
raymondwjang Oct 14, 2025
05afc8f
feat: cover residual array none case
raymondwjang Oct 15, 2025
8345226
tests: refit tests
raymondwjang Oct 15, 2025
a24a78f
feat: residuals node uses buffer
raymondwjang Oct 15, 2025
ceee071
test: refit yaml
raymondwjang Oct 15, 2025
5e368bd
feat: frame indexing is easier with separate coord vs. dim names
raymondwjang Oct 15, 2025
293b1ca
perf: optimize residual and overlaps updates
raymondwjang Oct 15, 2025
389f63c
chore: rename to Tracer for easier perf check
raymondwjang Oct 15, 2025
4bd678b
perf: optimize residual - avoid numba bootup
raymondwjang Oct 15, 2025
0aec0b1
perf: optimize footprint - avoid coo indexing
raymondwjang Oct 16, 2025
14cc810
perf: optimize unlayered footprint search
raymondwjang Oct 16, 2025
13a07e4
feat: footprint remove tiny pixels
raymondwjang Oct 20, 2025
093a58a
feat: perf update on residual align_overestimate
raymondwjang Oct 20, 2025
05de590
feat: residual learns instead of one-shot flooring
raymondwjang Oct 20, 2025
f24f038
feat: remove symmetry from toy footprints now we're moving away from …
raymondwjang Oct 20, 2025
7815b4f
tests: remove hotpix in tests - its messing up the frame values with …
raymondwjang Oct 20, 2025
556195b
tests: refit tests to work with new toy
raymondwjang Oct 21, 2025
15c8c20
feat: relax nmf acceptance condition for both l0 and l1
raymondwjang Oct 21, 2025
82c4648
perf: optimize residual speed
raymondwjang Oct 21, 2025
1f668f3
format
raymondwjang Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ tests = [
"snakeviz>=2.2.2",
"pytest-profiling>=1.8.1",
"pytest-timeout>=2.3.1",
"yappi>=1.6.10",
"tuna>=0.5.11",
]
docs = [
"sphinx>=8.2.3",
Expand Down
111 changes: 91 additions & 20 deletions src/cala/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Self, TypeVar

import numpy as np
import xarray as xr
Expand All @@ -19,6 +19,7 @@


class Asset(BaseModel):
validate_schema: bool = False
array_: AssetType = None
_entity: ClassVar[Entity]

Expand Down Expand Up @@ -46,13 +47,12 @@ def __eq__(self, other: "Asset") -> bool:
def entity(cls) -> Entity:
return cls._entity

@field_validator("array_", mode="after")
@classmethod
def validate_array_schema(cls, value: xr.DataArray) -> AssetType:
if value is not None:
value.validate.against_schema(cls._entity.model)
@model_validator(mode="after")
def validate_array_schema(self) -> Self:
if self.validate_schema and self.array_ is not None:
self.array_.validate.against_schema(self._entity.model)

return value
return self


class Footprint(Asset):
Expand All @@ -66,8 +66,8 @@ class Footprint(Asset):
)

@classmethod
def from_array(cls, array: xr.DataArray) -> "Footprint":
if isinstance(array.data, np.ndarray):
def from_array(cls, array: xr.DataArray, sparsify: bool = True) -> "Footprint":
if sparsify and isinstance(array.data, np.ndarray):
array.data = COO.from_numpy(array.data)
return cls(array_=array)

Expand Down Expand Up @@ -105,9 +105,17 @@ class Footprints(Asset):
)
)

@Asset.array.setter
def array(self, array: xr.DataArray) -> None:
if self.validate_schema:
array.validate.against_schema(self._entity.model)
if array is not None and isinstance(array.data, np.ndarray):
array.data = COO.from_numpy(array.data)
self.array_ = array

@classmethod
def from_array(cls, array: xr.DataArray) -> "Footprints":
if isinstance(array.data, np.ndarray):
if array is not None and isinstance(array.data, np.ndarray):
array.data = COO.from_numpy(array.data)
return cls(array_=array)

Expand All @@ -130,7 +138,8 @@ def array(self) -> xr.DataArray:
@array.setter
def array(self, array: xr.DataArray) -> None:
if self.zarr_path:
self.validate_array_schema(array)
if self.validate_schema:
array.validate.against_schema(self._entity.model)
array.to_zarr(self.zarr_path, mode="w") # need to make sure it can overwrite
else:
self.array_ = array
Expand Down Expand Up @@ -173,7 +182,8 @@ def load_zarr(self, isel_filter: dict = None, sel_filter: dict = None) -> xr.Dat
)

def update(self, array: xr.DataArray, **kwargs: Any) -> None:
self.validate_array_schema(array)
if self.validate_schema:
array.validate.against_schema(self._entity.model)
array.to_zarr(self.zarr_path, **kwargs)

@classmethod
Expand Down Expand Up @@ -283,16 +293,12 @@ class Overlaps(Asset):
)


class Residual(Asset):
class Buffer(Asset):
"""
Computes and maintains a buffer of residual signals.
Implements a fake ring buffer to avoid expensive copying that occurs with
numpy concat, append, and stack.

This method implements the residual computation by subtracting the
reconstructed signal from the original data. It maintains only the
most recent frames as specified by the buffer length.

The residual buffer contains the recent history of unexplained variance
in the data after accounting for known components.
Works by preallocating a space twice the desired size.
"""

_entity: ClassVar[Entity] = PrivateAttr(
Expand All @@ -304,3 +310,68 @@ class Residual(Asset):
allow_extra_coords=False,
)
)

validate_schema: bool = False
"""Validation currently does not play nicely with this class."""

size: int
_full: bool = PrivateAttr(False)
_next: int = PrivateAttr(default=0)

def append(self, array: xr.DataArray) -> None:
self.array_.data[self._next] = array.data
self.array_.data[self._next + self.size] = array.data
for coord in [AXIS.frame_coord, AXIS.timestamp_coord]:
self.array_[coord].data[self._next] = array[coord].item()
self.array_[coord].data[self._next + self.size] = array[coord].item()

self._next = (self._next + 1) % self.size
if not self._full:
# check if this made the buffer full
self._full = self._next == 0

@property
def array(self) -> xr.DataArray | None:
if self.array_ is None:
return None
if self._full:
out = self.array_.isel({AXIS.frames_dim: slice(self._next, self._next + self.size)})
else:
out = self.array_.isel({AXIS.frames_dim: slice(None, self._next)})
# kinda expensive. maybe float is fine?
return out # .assign_coords({AXIS.frame_coord: out[AXIS.frame_coord].astype(int)})

@array.setter
def array(self, array: xr.DataArray) -> None:
"""
Build a new buffer array.
"""
array = (
array.volumize.dim_with_coords(
dim=AXIS.frames_dim, coords=[AXIS.frame_coord, AXIS.timestamp_coord]
)
if AXIS.frames_dim not in array.dims
else array.isel({AXIS.frames_dim: slice(-self.size, None)})
)
fill_sizes = dict(array.sizes)
fill_sizes[AXIS.frames_dim] = self.size - array.sizes[AXIS.frames_dim]
fill = np.zeros(list(fill_sizes.values()))
filler = xr.DataArray(
fill,
dims=array.dims,
coords={
AXIS.frame_coord: (AXIS.frames_dim, [np.nan] * (fill_sizes[AXIS.frames_dim])),
AXIS.timestamp_coord: (AXIS.frames_dim, [""] * (fill_sizes[AXIS.frames_dim])),
},
)
buffer = xr.concat([array, filler] * 2, dim=AXIS.frames_dim)

self._full = array.sizes[AXIS.frames_dim] >= self.size
self._next = np.min((array.sizes[AXIS.frames_dim], self.size)) % self.size
self.array_ = buffer

@classmethod
def from_array(cls, array: xr.DataArray, size: int) -> Self:
buffer = cls(size=size)
buffer.array = array
return buffer
2 changes: 1 addition & 1 deletion src/cala/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@cli.command()
def main(spec: str, gui: Annotated[bool, typer.Option()] = False) -> None:
if gui:
uvicorn.run("cala.main:app", reload=True, reload_dirs=[Path(__file__).parent])
uvicorn.run("cala.main:app", reload=False, reload_dirs=[Path(__file__).parent])
else:
tube = Tube.from_specification(spec)
runner = SynchronousRunner(tube=tube)
Expand Down
2 changes: 1 addition & 1 deletion src/cala/models/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Axis:
id_coord: str = "id_"
timestamp_coord: str = "timestamp"
detect_coord: str = "detected_on"
frame_coord: str = "frame"
frame_coord: str = "frame_idx"
width_coord: str = "width"
height_coord: str = "height"

Expand Down
17 changes: 5 additions & 12 deletions src/cala/nodes/buffer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
from typing import Annotated as A

import xarray as xr
from noob import Name

from cala.assets import Frame, Movie
from cala.assets import Buffer, Frame
from cala.models import AXIS


def fill_buffer(size: int, buffer: Movie, frame: Frame) -> A[Movie, Name("buffer")]:
def fill_buffer(buffer: Buffer, frame: Frame) -> A[Buffer, Name("buffer")]:
if buffer.array is None:
buffer.array = frame.array.expand_dims(AXIS.frames_dim).assign_coords(
{AXIS.timestamp_coord: (AXIS.frames_dim, [frame.array[AXIS.timestamp_coord].item()])}
buffer.array = frame.array.volumize.dim_with_coords(
dim=AXIS.frames_dim, coords=[AXIS.timestamp_coord]
)
return buffer

buffered = (
buffer.array.transpose(AXIS.frames_dim, ...)[-size + 1 :]
if buffer.array.sizes[AXIS.frames_dim] >= size
else buffer.array
)

buffer.array = xr.concat([buffered, frame.array], dim=AXIS.frames_dim)
buffer.append(frame.array)
return buffer
4 changes: 2 additions & 2 deletions src/cala/nodes/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import xarray as xr
from noob import Name

from cala.assets import CompStats, Footprints, Overlaps, PixStats, Residual, Traces
from cala.assets import Buffer, CompStats, Footprints, Overlaps, PixStats, Traces
from cala.models import AXIS


def clear_overestimates(
footprints: Footprints, residuals: Residual, nmf_error: float
footprints: Footprints, residuals: Buffer, nmf_error: float
) -> A[Footprints, Name("footprints")]:
"""
Remove all sections of the footprints that cause negative residuals.
Expand Down
2 changes: 1 addition & 1 deletion src/cala/nodes/component_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def ingest_frame(component_stats: CompStats, frame: Frame, new_traces: PopSnap)
return component_stats

# Compute scaling factors
frame_idx = frame.array.coords[AXIS.frame_coord].item()
frame_idx = frame.array[AXIS.frame_coord].item()
prev_scale = frame_idx / (frame_idx + 1)
new_scale = 1 / (frame_idx + 1)

Expand Down
Loading