diff --git a/arcos4py/tools/_detect_events.py b/arcos4py/tools/_detect_events.py index 14f25de..7f03504 100644 --- a/arcos4py/tools/_detect_events.py +++ b/arcos4py/tools/_detect_events.py @@ -354,6 +354,41 @@ def reset(self): self.coordinates = [] self.prev_cluster_ids = [] + def overwrite_cluster_ids(self, seed_coordinates: np.ndarray, seed_cluster_ids: np.ndarray): + """Adds seed clusters to the memory. + + Arguments: + seed_coordinates (np.ndarray): Array of shape (N, D) of seed coordinates. + seed_cluster_ids (np.ndarray): Array of shape (N,) of seed cluster IDs. + """ + # convert to int + print("Overwriting cluster IDs in memory") + seed_cluster_ids = seed_cluster_ids.astype(int) + seed_coordinates = seed_coordinates.astype(int) + if not self.coordinates: + valid = seed_cluster_ids > 0 + self.coordinates.append(seed_coordinates[valid]) + self.prev_cluster_ids.append(seed_cluster_ids[valid]) + self.max_prev_cluster_id = np.max(seed_cluster_ids[valid]) + else: + coords = self.coordinates[-1] + ids = self.prev_cluster_ids[-1] + for sc, sci in zip(seed_coordinates, seed_cluster_ids): + if sci <= 0: + continue + if coords.size > 0: + matches = np.all(coords == sc, axis=1) + idx = np.nonzero(matches)[0] + if idx.size > 0: + ids[idx[0]] = sci + self.max_prev_cluster_id = max(self.max_prev_cluster_id, sci) + continue + coords = np.vstack([coords, sc[np.newaxis, :]]) if coords.size > 0 else sc[np.newaxis, :] + ids = np.concatenate([ids, [sci]]) if ids.size > 0 else np.array([sci], dtype=int) + self.max_prev_cluster_id = max(self.max_prev_cluster_id, sci) + self.coordinates[-1] = coords + self.prev_cluster_ids[-1] = ids + @property def all_coordinates(self): """Returns all coordinates in memory as one array.""" @@ -1183,8 +1218,6 @@ def _update_tree(self, coords): def _get_next_id(self) -> int: """Generate a new unique ID.""" self._memory.max_prev_cluster_id += 1 - if self._memory.max_prev_cluster_id == 92: - pass return self._memory.max_prev_cluster_id def _apply_remove_small_clusters(self, linked_cluster_ids, original_cluster_ids): @@ -1397,6 +1430,23 @@ def _apply_stable_merges_splits(self, linked_cluster_ids, original_cluster_ids): return final_cluster_ids + def add_seed_cluster_to_memory(self, cluster_id: np.ndarray, coordinates: np.ndarray): + """Adds a seed cluster to memory or overwrites an existing cluster. + + Arguments: + cluster_id (np.ndarray): The ID of the cluster. + coordinates (np.ndarray): The coordinates of the cluster. + """ + self._memory.overwrite_cluster_ids(cluster_id, coordinates) + memory_coordinates = self._memory.coordinates + if len(memory_coordinates) > 1: + _memory_coordinates = np.concatenate(memory_coordinates) + elif len(memory_coordinates) == 1: + _memory_coordinates = memory_coordinates[0] + else: + raise ValueError("Memory coordinates are empty") + self._update_tree(_memory_coordinates) + class BaseTracker(ABC): """Abstract base class for tracker classes.""" @@ -1678,6 +1728,29 @@ def _coordinates_to_image(self, x, position_data, tracked_events): return out_img + def add_seed_image(self, seed_image: np.ndarray) -> None: + """Adds a seed image to the tracker. This is used to guide the tracking process\ + with a specific starting label image. + + Arguments: + seed_image (np.ndarray): The seed image to add. Should be a 2D array of integers. + """ + if seed_image.ndim != 2: + raise ValueError("Seed image must be 2D.") + if not np.issubdtype(seed_image.dtype, np.integer): + raise ValueError("Seed image must be a label image (integer type).") + if self.linker.frame_counter != -1: + warnings.warn( + "Seed image is intended to be used before any tracking is done. If\ + used after first frame coordinates in memory will be overwriten if present." + ) + + seed_image = downscale_image(seed_image, self._downsample) + + coords = self._image_to_coordinates(seed_image) + + self.linker.add_seed_cluster_to_memory(coords[0], coords[1]) + def track_iteration(self, x: np.ndarray) -> np.ndarray: """Tracks events in a single frame. Returns the tracked labels.