From c871ecc8c83190c1254390ca14b5f45c1e5d21af Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 15:50:10 +0900 Subject: [PATCH 1/5] added update_node signal and connected them to spatial filters --- src/tracksdata/graph/_base_graph.py | 5 +- src/tracksdata/graph/_graph_view.py | 38 ++++-- src/tracksdata/graph/_rustworkx_graph.py | 40 ++++-- src/tracksdata/graph/_sql_graph.py | 28 +++- .../graph/filters/_spatial_filter.py | 121 ++++++++++++++---- .../filters/_test/test_spatial_filter.py | 49 +++++++ 6 files changed, 229 insertions(+), 52 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5b3708ad..edc9ceba 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,8 +47,9 @@ class BaseGraph(abc.ABC): Base class for a graph backend. """ - node_added = Signal(int) - node_removed = Signal(int) + node_added = Signal(int, object) + node_removed = Signal(int, object) + node_updated = Signal(int, object, object) def __init__(self) -> None: self._cache = {} diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index b9f82ead..431c3119 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -400,8 +400,9 @@ def add_node( else: self._out_of_sync = True - self._root.node_added.emit_fast(parent_node_id) - self.node_added.emit_fast(parent_node_id) + new_attrs = dict(attrs) + self._root.node_added.emit(parent_node_id, dict(new_attrs)) + self.node_added.emit(parent_node_id, dict(new_attrs)) return parent_node_id @@ -417,12 +418,12 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None self._out_of_sync = True if is_signal_on(self._root.node_added): - for node_id in parent_node_ids: - self._root.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): + self._root.node_added.emit(node_id, dict(node_attrs)) if is_signal_on(self.node_added): - for node_id in parent_node_ids: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): + self.node_added.emit(node_id, dict(node_attrs)) return parent_node_ids @@ -446,9 +447,11 @@ def remove_node(self, node_id: int) -> None: if node_id not in self._external_to_local: raise ValueError(f"Node {node_id} does not exist in the graph.") + old_attrs = self.nodes[node_id].to_dict() + # Remove from root graph first, because removing bounding box requires node attrs self._root.remove_node(node_id) - self.node_removed.emit_fast(node_id) + self.node_removed.emit(node_id, old_attrs) if self.sync: # Get the local node ID and remove from local graph @@ -652,6 +655,10 @@ def update_node_attrs( ) -> None: if node_ids is None: node_ids = self.node_ids() + else: + node_ids = list(node_ids) + + old_attrs_by_id = {node_id: self._root.nodes[node_id].to_dict() for node_id in node_ids} self._root.update_node_attrs( node_ids=node_ids, @@ -660,13 +667,22 @@ def update_node_attrs( # because attributes are passed by reference, we need don't need if both are rustworkx graphs if not self._is_root_rx_graph: if self.sync: - super().update_node_attrs( - node_ids=self._map_to_local(node_ids), - attrs=attrs, - ) + with self.node_updated.blocked(): + super().update_node_attrs( + node_ids=self._map_to_local(node_ids), + attrs=attrs, + ) else: self._out_of_sync = True + if is_signal_on(self.node_updated): + for node_id in node_ids: + self.node_updated.emit( + node_id, + old_attrs_by_id[node_id], + self._root.nodes[node_id].to_dict(), + ) + def update_edge_attrs( self, *, diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ef4a3f4f..438f46d3 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -492,7 +492,7 @@ def add_node( node_id = self.rx_graph.add_node(attrs) self._time_to_nodes.setdefault(attrs["t"], []).append(node_id) - self.node_added.emit_fast(node_id) + self.node_added.emit(node_id, dict(attrs)) return node_id def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None = None) -> list[int]: @@ -523,8 +523,8 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None # checking if it has connections to reduce overhead if is_signal_on(self.node_added): - for node_id in node_indices: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(node_indices, nodes, strict=True): + self.node_added.emit(node_id, dict(node_attrs)) return node_indices @@ -548,7 +548,8 @@ def remove_node(self, node_id: int) -> None: if node_id not in self.rx_graph.node_indices(): raise ValueError(f"Node {node_id} does not exist in the graph.") - self.node_removed.emit_fast(node_id) + old_attrs = dict(self.rx_graph[node_id]) + self.node_removed.emit(node_id, old_attrs) # Get the time value before removing the node t = self.rx_graph[node_id]["t"] @@ -1217,6 +1218,8 @@ def update_node_attrs( if node_ids is None: node_ids = self.node_ids() + old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids} + for key, value in attrs.items(): if key not in self.node_attr_keys(): raise ValueError(f"Node attribute key '{key}' not found in graph. Expected '{self.node_attr_keys()}'") @@ -1231,6 +1234,10 @@ def update_node_attrs( for node_id, v in zip(node_ids, value, strict=False): self._graph[node_id][key] = v + if is_signal_on(self.node_updated): + for node_id in node_ids: + self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) + def update_edge_attrs( self, *, @@ -1612,7 +1619,7 @@ def add_node( self._next_external_id = max(self._next_external_id, index + 1) # Add mapping using mixin self._add_id_mapping(node_id, index) - self.node_added.emit_fast(index) + self.node_added.emit(index, dict(attrs)) return index def bulk_add_nodes( @@ -1658,8 +1665,8 @@ def bulk_add_nodes( self._add_id_mappings(list(zip(graph_ids, indices, strict=True))) if is_signal_on(self.node_added): - for index in indices: - self.node_added.emit_fast(index) + for index, node_attrs in zip(indices, nodes, strict=True): + self.node_added.emit(index, dict(node_attrs)) return indices @@ -1937,8 +1944,20 @@ def update_node_attrs( node_ids : Sequence[int] | None The node ids to update. """ - node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids) - super().update_node_attrs(attrs=attrs, node_ids=node_ids) + external_node_ids = self.node_ids() if node_ids is None else list(node_ids) + old_attrs_by_id = {node_id: dict(self._graph[self._map_to_local(node_id)]) for node_id in external_node_ids} + local_node_ids = self._map_to_local(external_node_ids) + + with self.node_updated.blocked(): + super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) + + if is_signal_on(self.node_updated): + for node_id in external_node_ids: + self.node_updated.emit( + node_id, + old_attrs_by_id[node_id], + dict(self._graph[self._map_to_local(node_id)]), + ) def remove_node(self, node_id: int) -> None: """ @@ -1958,8 +1977,9 @@ def remove_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} does not exist in the graph.") local_node_id = self._map_to_local(node_id) + old_attrs = dict(self._graph[local_node_id]) - self.node_removed.emit_fast(node_id) + self.node_removed.emit(node_id, old_attrs) with self.node_removed.blocked(): super().remove_node(local_node_id) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 985cbdc9..b05ae18c 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -714,7 +714,7 @@ def add_node( if index is None: self._max_id_per_time[time] = node_id - self.node_added.emit_fast(node_id) + self.node_added.emit(node_id, dict(attrs)) return node_id @@ -781,8 +781,9 @@ def bulk_add_nodes( self._chunked_sa_operation(Session.bulk_insert_mappings, self.Node, nodes) if is_signal_on(self.node_added): - for node_id in node_ids: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(node_ids, nodes, strict=True): + new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID} + self.node_added.emit(node_id, new_attrs) return node_ids @@ -804,13 +805,13 @@ def remove_node(self, node_id: int) -> None: ValueError If the node_id does not exist in the graph. """ - self.node_removed.emit_fast(node_id) - with Session(self._engine) as session: # Check if the node exists node = session.query(self.Node).filter(self.Node.node_id == node_id).first() if node is None: raise ValueError(f"Node {node_id} does not exist in the graph.") + old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} + self.node_removed.emit(node_id, old_attrs) # Remove all edges where this node is source or target session.query(self.Edge).filter( @@ -1755,7 +1756,24 @@ def update_node_attrs( if "t" in attrs: raise ValueError("Node attribute 't' cannot be updated.") + updated_node_ids = self.node_ids() if node_ids is None else list(node_ids) + if len(updated_node_ids) == 0: + return + attr_keys = self.node_attr_keys() + old_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) + old_attrs_by_id = { + row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in old_df.rows(named=True) + } + self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) + new_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) + new_attrs_by_id = { + row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in new_df.rows(named=True) + } + + if is_signal_on(self.node_updated): + for node_id in updated_node_ids: + self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) def update_edge_attrs( self, diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index 52ce080e..870c97b2 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -38,6 +38,7 @@ def __init__( start_time = time.time() self._attr_keys = df.columns + self._ndims = len(self._attr_keys) if df.is_empty(): self._node_rtree = None @@ -45,7 +46,6 @@ def __init__( indices = np.ascontiguousarray(indices.to_numpy(), dtype=np.int64).copy() node_pos = np.ascontiguousarray(df.to_numpy(), dtype=np.float32) - self._ndims = node_pos.shape[1] self._node_rtree = PointRTree( item_dtype="int64", coord_dtype="float32", @@ -151,12 +151,17 @@ def __init__( attr_keys = list(filter(lambda x: x in valid_keys, attr_keys)) self._graph = graph + self._attr_keys = attr_keys nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) node_ids = nodes_df[DEFAULT_ATTR_KEYS.NODE_ID] self._df_filter = DataFrameSpatialFilter(indices=node_ids, df=nodes_df.select(attr_keys)) + self._graph.node_added.connect(self._add_node) + self._graph.node_removed.connect(self._remove_node) + self._graph.node_updated.connect(self._update_node) + def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ Query nodes within a spatial region of interest. @@ -195,6 +200,55 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": node_ids = self._df_filter[keys] return self._graph.filter(node_ids=node_ids) + def _attrs_to_point(self, attrs: dict[str, Any]) -> np.ndarray: + return np.ascontiguousarray([[attrs[key] for key in self._attr_keys]], dtype=np.float32) + + def _add_node( + self, + node_id: int, + new_attrs: dict[str, Any], + ) -> None: + from spatial_graph import PointRTree + + if self._df_filter._node_rtree is None: + if self._graph.num_nodes() == 0: + raise ValueError("Spatial filter is not initialized") + self._df_filter._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=len(self._attr_keys), + ) + self._df_filter._ndims = len(self._attr_keys) + + positions = self._attrs_to_point(new_attrs) + self._df_filter._node_rtree.insert_point_items( + np.atleast_1d(node_id).astype(np.int64), + positions, + ) + + def _remove_node( + self, + node_id: int, + old_attrs: dict[str, Any], + ) -> None: + if self._df_filter._node_rtree is None: + return + + positions = self._attrs_to_point(old_attrs) + self._df_filter._node_rtree.delete_items( + np.atleast_1d(node_id).astype(np.int64), + positions, + ) + + def _update_node( + self, + node_id: int, + old_attrs: dict[str, Any], + new_attrs: dict[str, Any], + ) -> None: + self._remove_node(node_id, old_attrs=old_attrs) + self._add_node(node_id, new_attrs=new_attrs) + class BBoxSpatialFilter: """ @@ -269,6 +323,7 @@ def __init__( # setup signal connections self._graph.node_added.connect(self._add_node) self._graph.node_removed.connect(self._remove_node) + self._graph.node_updated.connect(self._update_node) def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ @@ -358,7 +413,11 @@ def _attrs_to_bb_window(self, attrs: dict[str, Any]) -> tuple[np.ndarray, np.nda return positions_min, positions_max - def _add_node(self, node_id: int) -> None: + def _add_node( + self, + node_id: int, + new_attrs: dict[str, Any], + ) -> None: """ Add a node to the spatial filter. @@ -366,30 +425,30 @@ def _add_node(self, node_id: int) -> None: ---------- node_id : int The ID of the node to add. + new_attrs : dict[str, Any] + Current node attributes to insert into the spatial index. """ from spatial_graph import PointRTree if self._node_rtree is None: - if self._graph.num_nodes() > 0: - nodes_df = self._graph.node_attrs() - bboxes = self._bboxes_to_array(nodes_df[self._bbox_attr_key]) - num_dims = bboxes.shape[1] // 2 - - if self._frame_attr_key is None: - self._ndims = num_dims - else: - self._ndims = num_dims + 1 # +1 for the frame dimension - - self._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=self._ndims, - ) - else: + if self._graph.num_nodes() == 0: raise ValueError("Spatial filter is not initialized") + bbox = new_attrs[self._bbox_attr_key] + if len(bbox) % 2 != 0: + raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {len(bbox)}") + num_dims = len(bbox) // 2 + if self._frame_attr_key is None: + self._ndims = num_dims + else: + self._ndims = num_dims + 1 # +1 for the frame dimension + + self._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=self._ndims, + ) - attrs = self._graph.nodes[node_id].to_dict() - positions_min, positions_max = self._attrs_to_bb_window(attrs) + positions_min, positions_max = self._attrs_to_bb_window(new_attrs) self._node_rtree.insert_bb_items( np.atleast_1d(node_id).astype(np.int64), @@ -397,7 +456,11 @@ def _add_node(self, node_id: int) -> None: positions_max, ) - def _remove_node(self, node_id: int) -> None: + def _remove_node( + self, + node_id: int, + old_attrs: dict[str, Any], + ) -> None: """ Remove a node from the spatial filter. @@ -405,12 +468,13 @@ def _remove_node(self, node_id: int) -> None: ---------- node_id : int The ID of the node to remove. + old_attrs : dict[str, Any] + Previous node attributes used to remove the exact indexed bbox. """ if self._node_rtree is None: - raise ValueError("Spatial filter is not initialized") + return - attrs = self._graph.nodes[node_id].to_dict() - positions_min, positions_max = self._attrs_to_bb_window(attrs) + positions_min, positions_max = self._attrs_to_bb_window(old_attrs) self._node_rtree.delete_items( np.atleast_1d(node_id).astype(np.int64), @@ -418,6 +482,15 @@ def _remove_node(self, node_id: int) -> None: positions_max, ) + def _update_node( + self, + node_id: int, + old_attrs: dict[str, Any], + new_attrs: dict[str, Any], + ) -> None: + self._remove_node(node_id, old_attrs=old_attrs) + self._add_node(node_id, new_attrs=new_attrs) + @staticmethod def _bboxes_to_array(bbox_series: pl.Series) -> np.ndarray: """ diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index d17ba94b..4ad3a87d 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -342,6 +342,55 @@ def test_add_and_remove_node(graph_backend: BaseGraph) -> None: assert graph.num_nodes() == 2 +def test_spatial_filter_add_update_and_remove_node(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("y", pl.Int64) + graph_backend.add_node_attr_key("x", pl.Int64) + + graph_backend.add_node({"t": 0, "y": 1, "x": 1}) + graph_backend.add_node({"t": 1, "y": 10, "x": 10}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + spatial_filter = SpatialFilter(graph, attr_keys=[DEFAULT_ATTR_KEYS.T, "y", "x"]) + + assert spatial_filter[2:3, 6:9, 6:9].node_attrs().is_empty() + + new_node_id = graph.add_node({"t": 2, "y": 7, "x": 7}) + result_ids = spatial_filter[2:3, 6:9, 6:9].node_ids() + assert new_node_id in result_ids + + graph.update_node_attrs(attrs={"y": 20, "x": 20}, node_ids=[new_node_id]) + + assert spatial_filter[2:3, 6:9, 6:9].node_attrs().is_empty() + moved_ids = spatial_filter[2:3, 19:22, 19:22].node_ids() + assert new_node_id in moved_ids + + graph.remove_node(new_node_id) + assert spatial_filter[2:3, 19:22, 19:22].node_attrs().is_empty() + + +def test_bbox_spatial_filter_updates_node_position(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4)) + moved_node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2])}) + graph_backend.add_node({"t": 1, "bbox": np.asarray([10, 10, 12, 12])}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + graph.update_node_attrs( + attrs={"bbox": [np.asarray([0, 0, 2, 2])]}, + node_ids=[moved_node_id], + ) + + spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox") + assert moved_node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids() + + graph.update_node_attrs( + attrs={"bbox": [np.asarray([20, 20, 22, 22])]}, + node_ids=[moved_node_id], + ) + + assert moved_node_id not in spatial_filter[0:0.5, 0:3, 0:3].node_ids() + assert moved_node_id in spatial_filter[0:0.5, 19:23, 19:23].node_ids() + + def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None: """Ensure bounding boxes stored as list dtype still work with the spatial filter.""" graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) From 153ca5a54c550f048217141944ce51d65f813790 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 17:09:14 +0900 Subject: [PATCH 2/5] initial impl --- src/tracksdata/array/_graph_array.py | 91 +++++++++++++- src/tracksdata/array/_nd_chunk_cache.py | 65 +++++++++- .../array/_test/test_graph_array.py | 119 ++++++++++++++++++ 3 files changed, 270 insertions(+), 5 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 018e7ae6..75d55da9 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from copy import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -200,6 +200,9 @@ def __init__( frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX, ) + self.graph.node_added.connect(self._on_node_added) + self.graph.node_removed.connect(self._on_node_removed) + self.graph.node_updated.connect(self._on_node_updated) @property def shape(self) -> tuple[int, ...]: @@ -351,3 +354,89 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): mask: Mask mask.paint_buffer(buffer, value, offset=self._offset) + + def _offset_to_array(self, ndim: int) -> np.ndarray: + """Normalize `offset` to a vector for each spatial axis.""" + if np.isscalar(self._offset): + return np.full(ndim, int(self._offset), dtype=np.int64) + + offset = np.asarray(self._offset, dtype=np.int64).reshape(-1) + if len(offset) != ndim: + raise ValueError(f"`offset` must have length {ndim}, got {len(offset)}") + return offset + + def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None: + """ + Convert a bbox to clipped spatial slices in array coordinates. + + Returns `None` when the bbox does not overlap the current array volume. + """ + bbox = np.asarray(bbox, dtype=np.int64).reshape(-1) + ndim = len(self.original_shape) - 1 + if len(bbox) != 2 * ndim: + raise ValueError(f"`bbox` must have length {2 * ndim}, got {len(bbox)}") + + offset = self._offset_to_array(ndim) + start = bbox[:ndim] + offset + stop = bbox[ndim:] + offset + + shape = np.asarray(self.original_shape[1:], dtype=np.int64) + start = np.clip(start, 0, shape) + stop = np.clip(stop, 0, shape) + + if np.any(stop <= start): + return None + + return tuple(slice(int(s), int(e)) for s, e in zip(start, stop, strict=True)) + + def _invalidate_from_attrs(self, attrs: object) -> None: + """ + Invalidate cache region touched by node attributes. + + Falls back to larger invalidation windows when metadata is incomplete. + """ + if not isinstance(attrs, dict): + self._cache.invalidate() + return + + time_value = attrs.get(DEFAULT_ATTR_KEYS.T) + if time_value is None: + self._cache.invalidate() + return + + try: + time = int(np.asarray(time_value).item()) + except (TypeError, ValueError): + self._cache.invalidate() + return + + if not (0 <= time < self.original_shape[0]): + return + + if DEFAULT_ATTR_KEYS.BBOX not in attrs: + self._cache.invalidate(time=time) + return + + try: + slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX]) + except (TypeError, ValueError): + self._cache.invalidate(time=time) + return + + if slices is None: + return + + self._cache.invalidate(time=time, volume_slicing=slices) + + def _on_node_added(self, node_id: int, new_attrs: object) -> None: + del node_id + self._invalidate_from_attrs(new_attrs) + + def _on_node_removed(self, node_id: int, old_attrs: object) -> None: + del node_id + self._invalidate_from_attrs(old_attrs) + + def _on_node_updated(self, node_id: int, old_attrs: object, new_attrs: object) -> None: + del node_id + self._invalidate_from_attrs(old_attrs) + self._invalidate_from_attrs(new_attrs) diff --git a/src/tracksdata/array/_nd_chunk_cache.py b/src/tracksdata/array/_nd_chunk_cache.py index 447dabe6..7a9a0715 100644 --- a/src/tracksdata/array/_nd_chunk_cache.py +++ b/src/tracksdata/array/_nd_chunk_cache.py @@ -114,6 +114,13 @@ def _chunk_bounds(self, slices: tuple[slice, ...]) -> tuple[tuple[int, int], ... """Return inclusive chunk-index bounds for every axis.""" return tuple((s.start // cs, (s.stop - 1) // cs) for s, cs in zip(slices, self.chunk_shape, strict=True)) + def _chunk_slice(self, chunk_idx: tuple[int, ...]) -> tuple[slice, ...]: + """Return absolute volume slices for a chunk index.""" + return tuple( + slice(ci * cs, min((ci + 1) * cs, fs)) + for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) + ) + def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]) -> np.ndarray: """ Retrieve data for `time` and arbitrary dimensional slices. @@ -146,13 +153,63 @@ def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...] continue # already filled # Absolute slice covering this chunk - chunk_slc = tuple( - slice(ci * cs, min((ci + 1) * cs, fs)) - for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) - ) + chunk_slc = self._chunk_slice(chunk_idx) # Handle the case where chunk_slc exceeds volume_slices self.compute_func(time, chunk_slc, store_entry.buffer) store_entry.ready[chunk_idx] = True # Return view on the big buffer return store_entry.buffer[volume_slicing] + + def invalidate( + self, + *, + time: int | None = None, + volume_slicing: tuple[slice | int | Sequence[int], ...] | None = None, + ) -> None: + """ + Invalidate a cached region. + + Parameters + ---------- + time : int | None, optional + Time point to invalidate. If None, applies to all currently cached times. + volume_slicing : tuple[slice | int | Sequence[int], ...] | None, optional + Volume region to invalidate. If None, invalidates the full volume for the selected times. + """ + if time is None: + times = list(self._store.keys()) + elif time in self._store: + times = [time] + else: + return + + if volume_slicing is not None and len(volume_slicing) != self.ndim: + raise ValueError("Number of slices must equal dimensionality") + + region_slices = None + if volume_slicing is not None: + region_slices = tuple(_to_slice(slc) for slc in volume_slicing) + + for t in times: + store_entry = self._store[t] + + if region_slices is None: + store_entry.ready.fill(False) + store_entry.buffer.fill(0) + continue + + clipped_slices = [] + for slc, size in zip(region_slices, self.shape, strict=True): + start = max(0, slc.start) + stop = min(size, slc.stop) + if stop <= start: + return + clipped_slices.append(slice(start, stop)) + + bounds = self._chunk_bounds(tuple(clipped_slices)) + chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds] + for chunk_idx in itertools.product(*chunk_ranges): + chunk_slc = self._chunk_slice(chunk_idx) + store_entry.ready[chunk_idx] = False + store_entry.buffer[chunk_slc] = 0 diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index f8dbc552..7362f89c 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -378,3 +378,122 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph with pytest.raises(ValueError, match="Attribute values for key 'label' must be scalar"): GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + + +def _add_graph_array_node_attrs(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + +def _make_square_mask(y: int, x: int, size: int = 2) -> Mask: + return Mask(np.ones((size, size), dtype=bool), bbox=np.array([y, x, y + size, x + size])) + + +def test_graph_array_view_invalidates_only_affected_chunk_on_add(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + first_mask = _make_square_mask(1, 1) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: first_mask, + DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + + _ = np.asarray(array_view[0]) + np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool)) + + second_mask = _make_square_mask(5, 5) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 2, + DEFAULT_ATTR_KEYS.MASK: second_mask, + DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox, + } + ) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 1 + assert output[5, 5] == 2 + + +def test_graph_array_view_invalidates_old_and_new_chunks_on_update(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + mask = _make_square_mask(1, 1) + node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + _ = np.asarray(array_view[0]) + + moved_mask = _make_square_mask(5, 5) + graph_backend.update_node_attrs( + attrs={ + "label": [7], + DEFAULT_ATTR_KEYS.MASK: [moved_mask], + DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox], + }, + node_ids=[node_id], + ) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[0, 0] = False + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 0 + assert output[5, 5] == 7 + + +def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph) -> None: + _add_graph_array_node_attrs(graph_backend) + + first_mask = _make_square_mask(1, 1) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: first_mask, + DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox, + } + ) + second_mask = _make_square_mask(5, 5) + second_node = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 2, + DEFAULT_ATTR_KEYS.MASK: second_mask, + DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4)) + _ = np.asarray(array_view[0]) + + graph_backend.remove_node(second_node) + + expected_ready = np.ones((2, 2), dtype=bool) + expected_ready[1, 1] = False + np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready) + + output = np.asarray(array_view[0]) + assert output[1, 1] == 1 + assert output[5, 5] == 0 From 44241660b6499a7f6adcde63d5fce64002fc85eb Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 17:21:37 +0900 Subject: [PATCH 3/5] fixed failing benchmarks --- src/tracksdata/graph/_mapped_graph_mixin.py | 5 ++-- src/tracksdata/graph/_rustworkx_graph.py | 28 ++++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/tracksdata/graph/_mapped_graph_mixin.py b/src/tracksdata/graph/_mapped_graph_mixin.py index 220aab3c..2889f3a0 100644 --- a/src/tracksdata/graph/_mapped_graph_mixin.py +++ b/src/tracksdata/graph/_mapped_graph_mixin.py @@ -6,6 +6,7 @@ """ from collections.abc import Sequence +from numbers import Integral from typing import Any, overload import bidict @@ -84,7 +85,7 @@ def _map_to_external(self, local_ids: int | Sequence[int] | None) -> int | list[ """ if local_ids is None: return None - if isinstance(local_ids, int): + if isinstance(local_ids, Integral): return self._local_to_external[local_ids] return [self._local_to_external[lid] for lid in local_ids] @@ -113,7 +114,7 @@ def _map_to_local(self, external_ids: int | Sequence[int] | None) -> int | list[ """ if external_ids is None: return None - if isinstance(external_ids, int): + if isinstance(external_ids, Integral): return self._external_to_local[external_ids] return [self._external_to_local[eid] for eid in external_ids] diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 438f46d3..0d8e768d 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1218,7 +1218,8 @@ def update_node_attrs( if node_ids is None: node_ids = self.node_ids() - old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids} + emit_node_updated = is_signal_on(self.node_updated) + old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids} if emit_node_updated else None for key, value in attrs.items(): if key not in self.node_attr_keys(): @@ -1234,7 +1235,7 @@ def update_node_attrs( for node_id, v in zip(node_ids, value, strict=False): self._graph[node_id][key] = v - if is_signal_on(self.node_updated): + if emit_node_updated and old_attrs_by_id is not None: for node_id in node_ids: self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) @@ -1944,19 +1945,28 @@ def update_node_attrs( node_ids : Sequence[int] | None The node ids to update. """ - external_node_ids = self.node_ids() if node_ids is None else list(node_ids) - old_attrs_by_id = {node_id: dict(self._graph[self._map_to_local(node_id)]) for node_id in external_node_ids} + external_node_ids = self.node_ids() if node_ids is None else [int(node_id) for node_id in node_ids] local_node_ids = self._map_to_local(external_node_ids) + emit_node_updated = is_signal_on(self.node_updated) + old_attrs_by_id = ( + { + external_node_id: dict(self._graph[local_node_id]) + for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True) + } + if emit_node_updated + else None + ) + with self.node_updated.blocked(): super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) - if is_signal_on(self.node_updated): - for node_id in external_node_ids: + if emit_node_updated and old_attrs_by_id is not None: + for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True): self.node_updated.emit( - node_id, - old_attrs_by_id[node_id], - dict(self._graph[self._map_to_local(node_id)]), + external_node_id, + old_attrs_by_id[external_node_id], + dict(self._graph[local_node_id]), ) def remove_node(self, node_id: int) -> None: From c09d9340b729aada5a5d790550050896b733ca01 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 18:43:25 +0900 Subject: [PATCH 4/5] bm timeout longer --- .github/workflows/benchmarks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 58371763..52c87120 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -60,6 +60,7 @@ jobs: --machine github-actions \ --python 3.12 \ --factor 1.5 \ + --attribute timeout=180 \ --show-stderr || status=$? if [ "$status" -eq 2 ]; then echo "asv: benchmark run failed (exit 2). Failing CI." >&2 From 38c075f811afd0abe5f93b7ad5559aa9340fc1f6 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 22:46:00 +0900 Subject: [PATCH 5/5] updated --- src/tracksdata/array/_graph_array.py | 45 ++++++++++------------------ src/tracksdata/graph/_base_graph.py | 6 ++-- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 75d55da9..6a2c9de5 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -355,7 +355,7 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda mask: Mask mask.paint_buffer(buffer, value, offset=self._offset) - def _offset_to_array(self, ndim: int) -> np.ndarray: + def _offset_as_array(self, ndim: int) -> np.ndarray: """Normalize `offset` to a vector for each spatial axis.""" if np.isscalar(self._offset): return np.full(ndim, int(self._offset), dtype=np.int64) @@ -376,7 +376,7 @@ def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None: if len(bbox) != 2 * ndim: raise ValueError(f"`bbox` must have length {2 * ndim}, got {len(bbox)}") - offset = self._offset_to_array(ndim) + offset = self._offset_as_array(ndim) start = bbox[:ndim] + offset stop = bbox[ndim:] + offset @@ -389,54 +389,41 @@ def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None: return tuple(slice(int(s), int(e)) for s, e in zip(start, stop, strict=True)) - def _invalidate_from_attrs(self, attrs: object) -> None: + def _invalidate_from_attrs(self, attrs: dict) -> None: """ Invalidate cache region touched by node attributes. Falls back to larger invalidation windows when metadata is incomplete. """ - if not isinstance(attrs, dict): - self._cache.invalidate() - return time_value = attrs.get(DEFAULT_ATTR_KEYS.T) if time_value is None: - self._cache.invalidate() - return + raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.") + if DEFAULT_ATTR_KEYS.BBOX not in attrs: + raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.") try: time = int(np.asarray(time_value).item()) - except (TypeError, ValueError): - self._cache.invalidate() - return - + except (TypeError, ValueError) as e: + raise ValueError( + f"Time attribute value must be a scalar integer, got {time_value} of type {type(time_value)}" + ) from e if not (0 <= time < self.original_shape[0]): return - if DEFAULT_ATTR_KEYS.BBOX not in attrs: - self._cache.invalidate(time=time) - return - - try: - slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX]) - except (TypeError, ValueError): - self._cache.invalidate(time=time) - return - - if slices is None: - return - - self._cache.invalidate(time=time, volume_slicing=slices) + slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX]) + if slices is not None: + self._cache.invalidate(time=time, volume_slicing=slices) - def _on_node_added(self, node_id: int, new_attrs: object) -> None: + def _on_node_added(self, node_id: int, new_attrs: dict) -> None: del node_id self._invalidate_from_attrs(new_attrs) - def _on_node_removed(self, node_id: int, old_attrs: object) -> None: + def _on_node_removed(self, node_id: int, old_attrs: dict) -> None: del node_id self._invalidate_from_attrs(old_attrs) - def _on_node_updated(self, node_id: int, old_attrs: object, new_attrs: object) -> None: + def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None: del node_id self._invalidate_from_attrs(old_attrs) self._invalidate_from_attrs(new_attrs) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index edc9ceba..c96768e3 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,9 +47,9 @@ class BaseGraph(abc.ABC): Base class for a graph backend. """ - node_added = Signal(int, object) - node_removed = Signal(int, object) - node_updated = Signal(int, object, object) + node_added = Signal(int, dict) + node_removed = Signal(int, dict) + node_updated = Signal(int, dict, dict) def __init__(self) -> None: self._cache = {}