Skip to content
2 changes: 1 addition & 1 deletion src/cala/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Frame(Asset):
Entity(
name="frame",
dims=(Dims.width.value, Dims.height.value),
dtype=float,
dtype=None, # np.number, # gets converted to float64 in xarray-validate
checks=[is_non_negative, has_no_nan],
)
)
Expand Down
11 changes: 4 additions & 7 deletions src/cala/gui/components/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import av
import numpy as np
from av.video import VideoStream
from noob import process_method
from pydantic import BaseModel
from noob.node import Node

from cala.assets import Frame
from cala.config import config
Expand All @@ -26,14 +25,13 @@ def __str__(self) -> str:
return "Encoding failed."


class Encoder(BaseModel):
grid_id: str
class Encoder(Node):
frame_rate: int
_stream: VideoStream | None = None
_container: av.container.OutputContainer | None = None

def model_post_init(self, context: Any, /) -> None:
encode_dir = config.runtime_dir / self.grid_id
encode_dir = config.runtime_dir / self.id
encode_dir.mkdir(parents=True, exist_ok=True)
clear_dir(encode_dir)
hls_manifest = encode_dir / "stream.m3u8"
Expand All @@ -51,8 +49,7 @@ def model_post_init(self, context: Any, /) -> None:
self._stream = self._container.add_stream("h264", rate=self.frame_rate)
self._stream.pix_fmt = "yuv420p"

@process_method
def save(self, frame: Frame) -> None:
def process(self, frame: Frame) -> None:
frame = frame.array.astype(np.uint8)
self._stream.width = frame.sizes[AXIS.width_dim]
self._stream.height = frame.sizes[AXIS.height_dim]
Expand Down
78 changes: 0 additions & 78 deletions src/cala/gui/plots.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from pathlib import Path

import cv2
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

sns.set_theme(style="whitegrid", context="notebook", font_scale=1.2, palette="deep")


def write_movie(video: xr.DataArray, path: str | Path) -> None:
"""Test visualization of stabilized calcium video to verify motion stabilization."""
Expand All @@ -23,76 +18,3 @@ def write_movie(video: xr.DataArray, path: str | Path) -> None:
out.write(frame_bgr)

out.release()


def write_gif(
videos: xr.DataArray | list[xr.DataArray],
path: str | Path,
n_cols: int | None = None,
) -> None:
"""
Save video frames with optional processing function. Can handle single or multiple videos.

Parameters:
-----------
videos : Union[xr.DataArray, List[Tuple[xr.DataArray, str]]]
Either a single video DataArray or list of (video, title) tuples for comparison
n_cols : Optional[int]
Number of columns when displaying multiple videos. If None, tries to make square grid
"""
# Handle single video case
if isinstance(videos, xr.DataArray):
videos = [videos]

# Verify all videos have same number of frames
n_frames = len(videos[0][0])
if not all(len(video) == n_frames for video in videos):
raise ValueError("All videos must have the same number of frames")

n_videos = len(videos)
if n_cols is None:
n_cols = int(np.ceil(np.sqrt(n_videos))) if n_videos > 1 else 1
n_rows = int(np.ceil(n_videos / n_cols))

# Get global min/max for consistent scaling
vmin = np.min([np.min(video) for video in videos])
vmax = np.max([np.max(video) for video in videos])

for frame_idx in range(n_frames):
if n_videos == 1:
fig, ax = plt.subplots(figsize=(8, 8))
axes = [[ax]]
else:
fig, axes = plt.subplots(
n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False
)

for vid_idx, (video, title) in enumerate(videos):
last_row = vid_idx // n_cols
remn_col = vid_idx % n_cols
ax = axes[last_row][remn_col]

frame = video[frame_idx]

ax.imshow(frame, cmap="gray", vmin=vmin, vmax=vmax)
if title:
ax.set_title(f"{title}\nFrame {frame_idx}")
else:
ax.set_title(f"Frame {frame_idx}")
ax.axis("off")

# Hide empty subplots
if n_videos > 1:
for idx in range(n_videos, n_rows * n_cols):
last_row = idx // n_cols
remn_col = idx % n_cols
axes[last_row][remn_col].set_visible(False)

plt.tight_layout()
plt.savefig(path / f"{frame_idx:04d}.png", dpi=150, bbox_inches="tight")

# Create gif
frames = []
for i in range(n_frames):
frames.append(imageio.imread(path / f"{i:04d}.png"))
imageio.mimsave(path, frames, fps=30)
17 changes: 11 additions & 6 deletions src/cala/gui/spec.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
nodes:
prep_movie:
raw_movie:
type: cala.gui.components.Encoder
params:
grid_id: prep_movie
frame_rate: 30
depends:
- frame.value
component_count:
type: cala.gui.components.component_counter
prep_movie:
type: cala.gui.components.Encoder
params:
frame_rate: 30
depends:
- index: counter.idx
- traces: assets.traces
- flatten.frame
# component_count:
# type: cala.gui.components.component_counter
# depends:
# - index: counter.idx
# - traces: assets.traces
4 changes: 2 additions & 2 deletions src/cala/models/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Entity(BaseModel):
name: str
dims: tuple[Dim, ...]
coords: list[Coord] = Field(default_factory=list)
dtype: type
dtype: type | None
checks: list[Callable] = Field(default_factory=list)
allow_extra_coords: bool = True

Expand Down Expand Up @@ -44,7 +44,7 @@ def to_schema(self) -> DataArraySchema:
return DataArraySchema(
dims=DimsSchema(tuple(dim.name for dim in self.dims), ordered=False),
coords=coords_schema,
dtype=DTypeSchema(self.dtype),
dtype=DTypeSchema(self.dtype) if self.dtype else None,
checks=self.checks,
)

Expand Down
4 changes: 2 additions & 2 deletions src/cala/nodes/prep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from .flatten import butter
from .glow_removal import GlowRemover
from .lines import remove_freq, remove_mean
from .motion import Stabilizer
from .motion import Anchor
from .r_estimate import SizeEst

__all__ = [
"blur",
"GlowRemover",
"remove_background",
"Stabilizer",
"Anchor",
"SizeEst",
"butter",
"remove_mean",
Expand Down
18 changes: 10 additions & 8 deletions src/cala/nodes/prep/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@

def butter(frame: Frame, kwargs: dict[str, Any]) -> A[Frame, Name("frame")]:
"""
butterworth filter centers the image to zero. this causes two images with same intensity ratio
across pixels to be indistinguishable.
To recover the absolute brightness, we shift the filtered image by the
mean brightness of the original frame.
Butterworth filter centers the image to zero. This is due to the constant term (the mean)
being expressed as the 0th term in the fourier series.
Since the absolute background activity does not matter (all that is left is the high-frequency
signal), we simply add half of the 8-bit pixel max so that the total cannot exceed the
0-255 range.

The filter can also be used to reduce the scattering and the glow! (inspired by Marcel Brosche)
This helps remove overlap between cells (with higher cutoff_frequency_ratio)
"""
arr = butterworth(frame.array, **kwargs) + frame.array.mean().item()
arr = butterworth(frame.array, **kwargs) + 2**7

return Frame.from_array(
xr.DataArray(arr.clip(0), dims=frame.array.dims, coords=frame.array.coords)
)
return Frame.from_array(xr.DataArray(arr, dims=frame.array.dims, coords=frame.array.coords))


def ball(frame: Frame, kwargs: dict[str, Any]) -> Frame:
Expand Down
Loading
Loading