diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 80418986..6e88828d 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from copy import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -200,6 +200,9 @@ def __init__( frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX, ) + self.graph.node_added.connect(self._on_node_added) + self.graph.node_removed.connect(self._on_node_removed) + self.graph.node_updated.connect(self._on_node_updated) @property def shape(self) -> tuple[int, ...]: @@ -351,3 +354,76 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): mask: Mask mask.paint_buffer(buffer, value, offset=self._offset) + + def _offset_as_array(self, ndim: int) -> np.ndarray: + """Normalize `offset` to a vector for each spatial axis.""" + if np.isscalar(self._offset): + return np.full(ndim, int(self._offset), dtype=np.int64) + + offset = np.asarray(self._offset, dtype=np.int64).reshape(-1) + if len(offset) != ndim: + raise ValueError(f"`offset` must have length {ndim}, got {len(offset)}") + return offset + + def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None: + """ + Convert a bbox to clipped spatial slices in array coordinates. + + Returns `None` when the bbox does not overlap the current array volume. + """ + bbox = np.asarray(bbox, dtype=np.int64).reshape(-1) + ndim = len(self.original_shape) - 1 + if len(bbox) != 2 * ndim: + raise ValueError(f"`bbox` must have length {2 * ndim}, got {len(bbox)}") + + offset = self._offset_as_array(ndim) + start = bbox[:ndim] + offset + stop = bbox[ndim:] + offset + + shape = np.asarray(self.original_shape[1:], dtype=np.int64) + start = np.clip(start, 0, shape) + stop = np.clip(stop, 0, shape) + + if np.any(stop <= start): + return None + + return tuple(slice(int(s), int(e)) for s, e in zip(start, stop, strict=True)) + + def _invalidate_from_attrs(self, attrs: dict) -> None: + """ + Invalidate cache region touched by node attributes. + + Falls back to larger invalidation windows when metadata is incomplete. + """ + + time_value = attrs.get(DEFAULT_ATTR_KEYS.T) + if time_value is None: + raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.") + if DEFAULT_ATTR_KEYS.BBOX not in attrs: + raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.") + + try: + time = int(np.asarray(time_value).item()) + except (TypeError, ValueError) as e: + raise ValueError( + f"Time attribute value must be a scalar integer, got {time_value} of type {type(time_value)}" + ) from e + if not (0 <= time < self.original_shape[0]): + return + + slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX]) + if slices is not None: + self._cache.invalidate(time=time, volume_slicing=slices) + + def _on_node_added(self, node_id: int, new_attrs: dict) -> None: + del node_id + self._invalidate_from_attrs(new_attrs) + + def _on_node_removed(self, node_id: int, old_attrs: dict) -> None: + del node_id + self._invalidate_from_attrs(old_attrs) + + def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None: + del node_id + self._invalidate_from_attrs(old_attrs) + self._invalidate_from_attrs(new_attrs) diff --git a/src/tracksdata/array/_nd_chunk_cache.py b/src/tracksdata/array/_nd_chunk_cache.py index 447dabe6..7a9a0715 100644 --- a/src/tracksdata/array/_nd_chunk_cache.py +++ b/src/tracksdata/array/_nd_chunk_cache.py @@ -114,6 +114,13 @@ def _chunk_bounds(self, slices: tuple[slice, ...]) -> tuple[tuple[int, int], ... """Return inclusive chunk-index bounds for every axis.""" return tuple((s.start // cs, (s.stop - 1) // cs) for s, cs in zip(slices, self.chunk_shape, strict=True)) + def _chunk_slice(self, chunk_idx: tuple[int, ...]) -> tuple[slice, ...]: + """Return absolute volume slices for a chunk index.""" + return tuple( + slice(ci * cs, min((ci + 1) * cs, fs)) + for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) + ) + def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]) -> np.ndarray: """ Retrieve data for `time` and arbitrary dimensional slices. @@ -146,13 +153,63 @@ def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...] continue # already filled # Absolute slice covering this chunk - chunk_slc = tuple( - slice(ci * cs, min((ci + 1) * cs, fs)) - for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) - ) + chunk_slc = self._chunk_slice(chunk_idx) # Handle the case where chunk_slc exceeds volume_slices self.compute_func(time, chunk_slc, store_entry.buffer) store_entry.ready[chunk_idx] = True # Return view on the big buffer return store_entry.buffer[volume_slicing] + + def invalidate( + self, + *, + time: int | None = None, + volume_slicing: tuple[slice | int | Sequence[int], ...] | None = None, + ) -> None: + """ + Invalidate a cached region. + + Parameters + ---------- + time : int | None, optional + Time point to invalidate. If None, applies to all currently cached times. + volume_slicing : tuple[slice | int | Sequence[int], ...] | None, optional + Volume region to invalidate. If None, invalidates the full volume for the selected times. + """ + if time is None: + times = list(self._store.keys()) + elif time in self._store: + times = [time] + else: + return + + if volume_slicing is not None and len(volume_slicing) != self.ndim: + raise ValueError("Number of slices must equal dimensionality") + + region_slices = None + if volume_slicing is not None: + region_slices = tuple(_to_slice(slc) for slc in volume_slicing) + + for t in times: + store_entry = self._store[t] + + if region_slices is None: + store_entry.ready.fill(False) + store_entry.buffer.fill(0) + continue + + clipped_slices = [] + for slc, size in zip(region_slices, self.shape, strict=True): + start = max(0, slc.start) + stop = min(size, slc.stop) + if stop <= start: + return + clipped_slices.append(slice(start, stop)) + + bounds = self._chunk_bounds(tuple(clipped_slices)) + chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds] + for chunk_idx in itertools.product(*chunk_ranges): + chunk_slc = self._chunk_slice(chunk_idx) + store_entry.ready[chunk_idx] = False + store_entry.buffer[chunk_slc] = 0 diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index f8dbc552..7362f89c 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -378,3 +378,122 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph with pytest.raises(ValueError, match="Attribute values for key 'label' must be scalar"): GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + + +def _add_graph_array_node_attrs(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + +def _make_square_mask(y: int, x: int, size: int = 2) -> Mask: + return Mask(np.ones((size, size), dtype=bool), bbox=np.array([y, x, y + size, x + size])) + + +def test_graph_array_view_invalidates_only_affected_chunk_on_add(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + first_mask = _make_square_mask(1, 1) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: first_mask, + DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + + _ = np.asarray(array_view[0]) + np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool)) + + second_mask = _make_square_mask(5, 5) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 2, + DEFAULT_ATTR_KEYS.MASK: second_mask, + DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox, + } + ) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 1 + assert output[5, 5] == 2 + + +def test_graph_array_view_invalidates_old_and_new_chunks_on_update(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + mask = _make_square_mask(1, 1) + node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + _ = np.asarray(array_view[0]) + + moved_mask = _make_square_mask(5, 5) + graph_backend.update_node_attrs( + attrs={ + "label": [7], + DEFAULT_ATTR_KEYS.MASK: [moved_mask], + DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox], + }, + node_ids=[node_id], + ) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[0, 0] = False + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 0 + assert output[5, 5] == 7 + + +def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + first_mask = _make_square_mask(1, 1) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: first_mask, + DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox, + } + ) + second_mask = _make_square_mask(5, 5) + second_node = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 2, + DEFAULT_ATTR_KEYS.MASK: second_mask, + DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + _ = np.asarray(array_view[0]) + + graph_backend.remove_node(second_node) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 1 + assert output[5, 5] == 0 diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 634aacb1..581bc990 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -110,10 +110,9 @@ class BaseGraph(abc.ABC): """ _PRIVATE_METADATA_PREFIX = "__private_" - - node_added = Signal(int, object) - node_removed = Signal(int, object) - node_updated = Signal(int, object, object) + node_added = Signal(int, dict) + node_removed = Signal(int, dict) + node_updated = Signal(int, dict, dict) def __init__(self) -> None: self._cache = {} diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index e8c1cac1..adc2aa29 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1993,6 +1993,7 @@ def remove_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} does not exist in the graph.") local_node_id = self._map_to_local(node_id) + old_attrs = dict(self._graph[local_node_id]) if is_signal_on(self.node_removed): old_attrs = dict(self._graph[local_node_id]) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index e3eb9a40..5b9d871f 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -815,6 +815,8 @@ def remove_node(self, node_id: int) -> None: node = session.query(self.Node).filter(self.Node.node_id == node_id).first() if node is None: raise ValueError(f"Node {node_id} does not exist in the graph.") + old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} + self.node_removed.emit(node_id, old_attrs) if is_signal_on(self.node_removed): old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} @@ -1830,6 +1832,14 @@ def update_node_attrs( old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_df.rows(named=True)} self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) + new_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) + new_attrs_by_id = { + row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in new_df.rows(named=True) + } + + if is_signal_on(self.node_updated): + for node_id in updated_node_ids: + self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) if is_signal_on(self.node_updated): new_df = self.filter(node_ids=updated_node_ids).node_attrs(