Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions arcos4py/tools/_detect_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,41 @@
self.coordinates = []
self.prev_cluster_ids = []

def overwrite_cluster_ids(self, seed_coordinates: np.ndarray, seed_cluster_ids: np.ndarray):
"""Adds seed clusters to the memory.

Arguments:
seed_coordinates (np.ndarray): Array of shape (N, D) of seed coordinates.
seed_cluster_ids (np.ndarray): Array of shape (N,) of seed cluster IDs.
"""
# convert to int
print("Overwriting cluster IDs in memory")
seed_cluster_ids = seed_cluster_ids.astype(int)
seed_coordinates = seed_coordinates.astype(int)

Check warning on line 367 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L365-L367

Added lines #L365 - L367 were not covered by tests
if not self.coordinates:
valid = seed_cluster_ids > 0
self.coordinates.append(seed_coordinates[valid])
self.prev_cluster_ids.append(seed_cluster_ids[valid])
self.max_prev_cluster_id = np.max(seed_cluster_ids[valid])

Check warning on line 372 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L369-L372

Added lines #L369 - L372 were not covered by tests
else:
coords = self.coordinates[-1]
ids = self.prev_cluster_ids[-1]

Check warning on line 375 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L374-L375

Added lines #L374 - L375 were not covered by tests
for sc, sci in zip(seed_coordinates, seed_cluster_ids):
if sci <= 0:
continue

Check warning on line 378 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L378

Added line #L378 was not covered by tests
if coords.size > 0:
matches = np.all(coords == sc, axis=1)
idx = np.nonzero(matches)[0]

Check warning on line 381 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L380-L381

Added lines #L380 - L381 were not covered by tests
if idx.size > 0:
ids[idx[0]] = sci
self.max_prev_cluster_id = max(self.max_prev_cluster_id, sci)
continue
coords = np.vstack([coords, sc[np.newaxis, :]]) if coords.size > 0 else sc[np.newaxis, :]
ids = np.concatenate([ids, [sci]]) if ids.size > 0 else np.array([sci], dtype=int)
self.max_prev_cluster_id = max(self.max_prev_cluster_id, sci)
self.coordinates[-1] = coords
self.prev_cluster_ids[-1] = ids

Check warning on line 390 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L383-L390

Added lines #L383 - L390 were not covered by tests

@property
def all_coordinates(self):
"""Returns all coordinates in memory as one array."""
Expand Down Expand Up @@ -1183,8 +1218,6 @@
def _get_next_id(self) -> int:
"""Generate a new unique ID."""
self._memory.max_prev_cluster_id += 1
if self._memory.max_prev_cluster_id == 92:
pass
return self._memory.max_prev_cluster_id

def _apply_remove_small_clusters(self, linked_cluster_ids, original_cluster_ids):
Expand Down Expand Up @@ -1397,6 +1430,23 @@

return final_cluster_ids

def add_seed_cluster_to_memory(self, cluster_id: np.ndarray, coordinates: np.ndarray):
"""Adds a seed cluster to memory or overwrites an existing cluster.

Arguments:
cluster_id (np.ndarray): The ID of the cluster.
coordinates (np.ndarray): The coordinates of the cluster.
"""
self._memory.overwrite_cluster_ids(cluster_id, coordinates)
memory_coordinates = self._memory.coordinates

Check warning on line 1441 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1440-L1441

Added lines #L1440 - L1441 were not covered by tests
if len(memory_coordinates) > 1:
_memory_coordinates = np.concatenate(memory_coordinates)

Check warning on line 1443 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1443

Added line #L1443 was not covered by tests
elif len(memory_coordinates) == 1:
_memory_coordinates = memory_coordinates[0]

Check warning on line 1445 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1445

Added line #L1445 was not covered by tests
else:
raise ValueError("Memory coordinates are empty")
self._update_tree(_memory_coordinates)

Check warning on line 1448 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1447-L1448

Added lines #L1447 - L1448 were not covered by tests


class BaseTracker(ABC):
"""Abstract base class for tracker classes."""
Expand Down Expand Up @@ -1678,6 +1728,29 @@

return out_img

def add_seed_image(self, seed_image: np.ndarray) -> None:
"""Adds a seed image to the tracker. This is used to guide the tracking process\
with a specific starting label image.

Arguments:
seed_image (np.ndarray): The seed image to add. Should be a 2D array of integers.
"""
if seed_image.ndim != 2:
raise ValueError("Seed image must be 2D.")

Check warning on line 1739 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1739

Added line #L1739 was not covered by tests
if not np.issubdtype(seed_image.dtype, np.integer):
raise ValueError("Seed image must be a label image (integer type).")

Check warning on line 1741 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1741

Added line #L1741 was not covered by tests
if self.linker.frame_counter != -1:
warnings.warn(

Check warning on line 1743 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1743

Added line #L1743 was not covered by tests
"Seed image is intended to be used before any tracking is done. If\
used after first frame coordinates in memory will be overwriten if present."
)

seed_image = downscale_image(seed_image, self._downsample)

Check warning on line 1748 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1748

Added line #L1748 was not covered by tests

coords = self._image_to_coordinates(seed_image)

Check warning on line 1750 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1750

Added line #L1750 was not covered by tests

self.linker.add_seed_cluster_to_memory(coords[0], coords[1])

Check warning on line 1752 in arcos4py/tools/_detect_events.py

View check run for this annotation

Codecov / codecov/patch

arcos4py/tools/_detect_events.py#L1752

Added line #L1752 was not covered by tests

def track_iteration(self, x: np.ndarray) -> np.ndarray:
"""Tracks events in a single frame. Returns the tracked labels.

Expand Down