diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 58371763..52c87120 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -60,6 +60,7 @@ jobs: --machine github-actions \ --python 3.12 \ --factor 1.5 \ + --attribute timeout=180 \ --show-stderr || status=$? if [ "$status" -eq 2 ]; then echo "asv: benchmark run failed (exit 2). Failing CI." >&2 diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5b3708ad..edc9ceba 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,8 +47,9 @@ class BaseGraph(abc.ABC): Base class for a graph backend. """ - node_added = Signal(int) - node_removed = Signal(int) + node_added = Signal(int, object) + node_removed = Signal(int, object) + node_updated = Signal(int, object, object) def __init__(self) -> None: self._cache = {} diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index b9f82ead..022429e9 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any, Literal, overload +from typing import Any, Literal, cast, overload import bidict import polars as pl @@ -400,8 +400,10 @@ def add_node( else: self._out_of_sync = True - self._root.node_added.emit_fast(parent_node_id) - self.node_added.emit_fast(parent_node_id) + if is_signal_on(self._root.node_added): + self._root.node_added.emit(parent_node_id, attrs) + if is_signal_on(self.node_added): + self.node_added.emit(parent_node_id, attrs) return parent_node_id @@ -417,12 +419,12 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None self._out_of_sync = True if is_signal_on(self._root.node_added): - for node_id in parent_node_ids: - self._root.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): + self._root.node_added.emit(node_id, node_attrs) if is_signal_on(self.node_added): - for node_id in parent_node_ids: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): + self.node_added.emit(node_id, node_attrs) return parent_node_ids @@ -446,9 +448,11 @@ def remove_node(self, node_id: int) -> None: if node_id not in self._external_to_local: raise ValueError(f"Node {node_id} does not exist in the graph.") + if is_signal_on(self.node_removed): + old_attrs = self.nodes[node_id].to_dict() + # Remove from root graph first, because removing bounding box requires node attrs self._root.remove_node(node_id) - self.node_removed.emit_fast(node_id) if self.sync: # Get the local node ID and remove from local graph @@ -474,6 +478,9 @@ def remove_node(self, node_id: int) -> None: else: self._out_of_sync = True + if is_signal_on(self.node_removed): + self.node_removed.emit(node_id, old_attrs) + def add_edge( self, source_id: int, @@ -652,6 +659,12 @@ def update_node_attrs( ) -> None: if node_ids is None: node_ids = self.node_ids() + else: + node_ids = list(node_ids) + + if is_signal_on(self.node_updated): + old_attrs_by_id = self._root.filter(node_ids=node_ids).node_attrs() + old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_attrs_by_id.to_dicts()} self._root.update_node_attrs( node_ids=node_ids, @@ -660,13 +673,23 @@ def update_node_attrs( # because attributes are passed by reference, we need don't need if both are rustworkx graphs if not self._is_root_rx_graph: if self.sync: - super().update_node_attrs( - node_ids=self._map_to_local(node_ids), - attrs=attrs, - ) + with self.node_updated.blocked(): + super().update_node_attrs( + node_ids=self._map_to_local(node_ids), + attrs=attrs, + ) else: self._out_of_sync = True + if is_signal_on(self.node_updated): + for node_id in node_ids: + old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy + self.node_updated.emit( + node_id, + old_attrs_by_id[node_id], + self._root.nodes[node_id].to_dict(), + ) + def update_edge_attrs( self, *, diff --git a/src/tracksdata/graph/_mapped_graph_mixin.py b/src/tracksdata/graph/_mapped_graph_mixin.py index 220aab3c..2889f3a0 100644 --- a/src/tracksdata/graph/_mapped_graph_mixin.py +++ b/src/tracksdata/graph/_mapped_graph_mixin.py @@ -6,6 +6,7 @@ """ from collections.abc import Sequence +from numbers import Integral from typing import Any, overload import bidict @@ -84,7 +85,7 @@ def _map_to_external(self, local_ids: int | Sequence[int] | None) -> int | list[ """ if local_ids is None: return None - if isinstance(local_ids, int): + if isinstance(local_ids, Integral): return self._local_to_external[local_ids] return [self._local_to_external[lid] for lid in local_ids] @@ -113,7 +114,7 @@ def _map_to_local(self, external_ids: int | Sequence[int] | None) -> int | list[ """ if external_ids is None: return None - if isinstance(external_ids, int): + if isinstance(external_ids, Integral): return self._external_to_local[external_ids] return [self._external_to_local[eid] for eid in external_ids] diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ef4a3f4f..20a145b3 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -492,7 +492,8 @@ def add_node( node_id = self.rx_graph.add_node(attrs) self._time_to_nodes.setdefault(attrs["t"], []).append(node_id) - self.node_added.emit_fast(node_id) + if is_signal_on(self.node_added): + self.node_added.emit(node_id, attrs) return node_id def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None = None) -> list[int]: @@ -523,8 +524,8 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None # checking if it has connections to reduce overhead if is_signal_on(self.node_added): - for node_id in node_indices: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(node_indices, nodes, strict=True): + self.node_added.emit(node_id, node_attrs) return node_indices @@ -548,7 +549,9 @@ def remove_node(self, node_id: int) -> None: if node_id not in self.rx_graph.node_indices(): raise ValueError(f"Node {node_id} does not exist in the graph.") - self.node_removed.emit_fast(node_id) + old_attrs = None + if is_signal_on(self.node_removed): + old_attrs = dict(self.rx_graph[node_id]) # Get the time value before removing the node t = self.rx_graph[node_id]["t"] @@ -566,6 +569,9 @@ def remove_node(self, node_id: int) -> None: if self._overlaps is not None: self._overlaps = [overlap for overlap in self._overlaps if node_id != overlap[0] and node_id != overlap[1]] + if is_signal_on(self.node_removed): + self.node_removed.emit(node_id, old_attrs) + def add_edge( self, source_id: int, @@ -1217,6 +1223,9 @@ def update_node_attrs( if node_ids is None: node_ids = self.node_ids() + if is_signal_on(self.node_updated): + old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids} + for key, value in attrs.items(): if key not in self.node_attr_keys(): raise ValueError(f"Node attribute key '{key}' not found in graph. Expected '{self.node_attr_keys()}'") @@ -1231,6 +1240,10 @@ def update_node_attrs( for node_id, v in zip(node_ids, value, strict=False): self._graph[node_id][key] = v + if is_signal_on(self.node_updated): + for node_id in node_ids: + self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) + def update_edge_attrs( self, *, @@ -1612,7 +1625,8 @@ def add_node( self._next_external_id = max(self._next_external_id, index + 1) # Add mapping using mixin self._add_id_mapping(node_id, index) - self.node_added.emit_fast(index) + if is_signal_on(self.node_added): + self.node_added.emit(index, attrs) return index def bulk_add_nodes( @@ -1658,8 +1672,8 @@ def bulk_add_nodes( self._add_id_mappings(list(zip(graph_ids, indices, strict=True))) if is_signal_on(self.node_added): - for index in indices: - self.node_added.emit_fast(index) + for index, node_attrs in zip(indices, nodes, strict=True): + self.node_added.emit(index, node_attrs) return indices @@ -1937,8 +1951,25 @@ def update_node_attrs( node_ids : Sequence[int] | None The node ids to update. """ - node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids) - super().update_node_attrs(attrs=attrs, node_ids=node_ids) + external_node_ids = self.node_ids() if node_ids is None else node_ids + local_node_ids = self._map_to_local(external_node_ids) + + if is_signal_on(self.node_updated): + old_attrs_by_id = { + external_node_id: dict(self._graph[local_node_id]) + for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True) + } + + with self.node_updated.blocked(): + super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) + + if is_signal_on(self.node_updated) and old_attrs_by_id is not None: + for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True): + self.node_updated.emit( + external_node_id, + old_attrs_by_id[external_node_id], + dict(self._graph[local_node_id]), + ) def remove_node(self, node_id: int) -> None: """ @@ -1959,11 +1990,15 @@ def remove_node(self, node_id: int) -> None: local_node_id = self._map_to_local(node_id) - self.node_removed.emit_fast(node_id) + if is_signal_on(self.node_removed): + old_attrs = dict(self._graph[local_node_id]) + with self.node_removed.blocked(): super().remove_node(local_node_id) self._remove_id_mapping(external_id=node_id) + if is_signal_on(self.node_removed): + self.node_removed.emit(node_id, old_attrs) def filter( self, diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 985cbdc9..d07ce386 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -714,7 +714,8 @@ def add_node( if index is None: self._max_id_per_time[time] = node_id - self.node_added.emit_fast(node_id) + if is_signal_on(self.node_added): + self.node_added.emit(node_id, attrs) return node_id @@ -781,8 +782,9 @@ def bulk_add_nodes( self._chunked_sa_operation(Session.bulk_insert_mappings, self.Node, nodes) if is_signal_on(self.node_added): - for node_id in node_ids: - self.node_added.emit_fast(node_id) + for node_id, node_attrs in zip(node_ids, nodes, strict=True): + new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID} + self.node_added.emit(node_id, new_attrs) return node_ids @@ -804,14 +806,15 @@ def remove_node(self, node_id: int) -> None: ValueError If the node_id does not exist in the graph. """ - self.node_removed.emit_fast(node_id) - with Session(self._engine) as session: # Check if the node exists node = session.query(self.Node).filter(self.Node.node_id == node_id).first() if node is None: raise ValueError(f"Node {node_id} does not exist in the graph.") + if is_signal_on(self.node_removed): + old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} + # Remove all edges where this node is source or target session.query(self.Edge).filter( sa.or_(self.Edge.source_id == node_id, self.Edge.target_id == node_id) @@ -825,6 +828,8 @@ def remove_node(self, node_id: int) -> None: # Remove the node itself session.delete(node) session.commit() + if is_signal_on(self.node_removed): + self.node_removed.emit(node_id, old_attrs) def add_edge( self, @@ -1755,8 +1760,27 @@ def update_node_attrs( if "t" in attrs: raise ValueError("Node attribute 't' cannot be updated.") + updated_node_ids = self.node_ids() if node_ids is None else list(node_ids) + if len(updated_node_ids) == 0: + return + + attr_keys = self.node_attr_keys() + if is_signal_on(self.node_updated): + old_df = self.filter(node_ids=updated_node_ids).node_attrs( + attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys] + ) + old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_df.rows(named=True)} + self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) + if is_signal_on(self.node_updated): + new_df = self.filter(node_ids=updated_node_ids).node_attrs( + attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys] + ) + new_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in new_df.rows(named=True)} + for node_id in updated_node_ids: + self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) + def update_edge_attrs( self, *, diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index 52ce080e..e13e1711 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -38,6 +38,7 @@ def __init__( start_time = time.time() self._attr_keys = df.columns + self._ndims = len(self._attr_keys) if df.is_empty(): self._node_rtree = None @@ -45,7 +46,6 @@ def __init__( indices = np.ascontiguousarray(indices.to_numpy(), dtype=np.int64).copy() node_pos = np.ascontiguousarray(df.to_numpy(), dtype=np.float32) - self._ndims = node_pos.shape[1] self._node_rtree = PointRTree( item_dtype="int64", coord_dtype="float32", @@ -151,12 +151,17 @@ def __init__( attr_keys = list(filter(lambda x: x in valid_keys, attr_keys)) self._graph = graph + self._attr_keys = attr_keys nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) node_ids = nodes_df[DEFAULT_ATTR_KEYS.NODE_ID] self._df_filter = DataFrameSpatialFilter(indices=node_ids, df=nodes_df.select(attr_keys)) + self._graph.node_added.connect(self._add_node) + self._graph.node_removed.connect(self._remove_node) + self._graph.node_updated.connect(self._update_node) + def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ Query nodes within a spatial region of interest. @@ -195,6 +200,54 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": node_ids = self._df_filter[keys] return self._graph.filter(node_ids=node_ids) + def _attrs_to_point(self, attrs: dict[str, Any]) -> np.ndarray: + return np.ascontiguousarray([[attrs[key] for key in self._attr_keys]], dtype=np.float32) + + def _add_node( + self, + node_id: int, + new_attrs: dict[str, Any], + ) -> None: + from spatial_graph import PointRTree + + if self._df_filter._node_rtree is None: + self._df_filter._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=len(self._attr_keys), + ) + self._df_filter._ndims = len(self._attr_keys) + + positions = self._attrs_to_point(new_attrs) + self._df_filter._node_rtree.insert_point_items( + np.atleast_1d(node_id).astype(np.int64), + positions, + ) + + def _remove_node( + self, + node_id: int, + old_attrs: dict[str, Any], + ) -> None: + # required by static type checking + if self._df_filter._node_rtree is None: + return + + positions = self._attrs_to_point(old_attrs) + self._df_filter._node_rtree.delete_items( + np.atleast_1d(node_id).astype(np.int64), + positions, + ) + + def _update_node( + self, + node_id: int, + old_attrs: dict[str, Any], + new_attrs: dict[str, Any], + ) -> None: + self._remove_node(node_id, old_attrs=old_attrs) + self._add_node(node_id, new_attrs=new_attrs) + class BBoxSpatialFilter: """ @@ -269,6 +322,7 @@ def __init__( # setup signal connections self._graph.node_added.connect(self._add_node) self._graph.node_removed.connect(self._remove_node) + self._graph.node_updated.connect(self._update_node) def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ @@ -358,7 +412,11 @@ def _attrs_to_bb_window(self, attrs: dict[str, Any]) -> tuple[np.ndarray, np.nda return positions_min, positions_max - def _add_node(self, node_id: int) -> None: + def _add_node( + self, + node_id: int, + new_attrs: dict[str, Any], + ) -> None: """ Add a node to the spatial filter. @@ -366,30 +424,28 @@ def _add_node(self, node_id: int) -> None: ---------- node_id : int The ID of the node to add. + new_attrs : dict[str, Any] + Current node attributes to insert into the spatial index. """ from spatial_graph import PointRTree if self._node_rtree is None: - if self._graph.num_nodes() > 0: - nodes_df = self._graph.node_attrs() - bboxes = self._bboxes_to_array(nodes_df[self._bbox_attr_key]) - num_dims = bboxes.shape[1] // 2 - - if self._frame_attr_key is None: - self._ndims = num_dims - else: - self._ndims = num_dims + 1 # +1 for the frame dimension - - self._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=self._ndims, - ) + bbox = new_attrs[self._bbox_attr_key] + if len(bbox) % 2 != 0: + raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {len(bbox)}") + num_dims = len(bbox) // 2 + if self._frame_attr_key is None: + self._ndims = num_dims else: - raise ValueError("Spatial filter is not initialized") + self._ndims = num_dims + 1 # +1 for the frame dimension - attrs = self._graph.nodes[node_id].to_dict() - positions_min, positions_max = self._attrs_to_bb_window(attrs) + self._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=self._ndims, + ) + + positions_min, positions_max = self._attrs_to_bb_window(new_attrs) self._node_rtree.insert_bb_items( np.atleast_1d(node_id).astype(np.int64), @@ -397,7 +453,11 @@ def _add_node(self, node_id: int) -> None: positions_max, ) - def _remove_node(self, node_id: int) -> None: + def _remove_node( + self, + node_id: int, + old_attrs: dict[str, Any], + ) -> None: """ Remove a node from the spatial filter. @@ -405,12 +465,13 @@ def _remove_node(self, node_id: int) -> None: ---------- node_id : int The ID of the node to remove. + old_attrs : dict[str, Any] + Previous node attributes used to remove the exact indexed bbox. """ if self._node_rtree is None: - raise ValueError("Spatial filter is not initialized") + return - attrs = self._graph.nodes[node_id].to_dict() - positions_min, positions_max = self._attrs_to_bb_window(attrs) + positions_min, positions_max = self._attrs_to_bb_window(old_attrs) self._node_rtree.delete_items( np.atleast_1d(node_id).astype(np.int64), @@ -418,6 +479,15 @@ def _remove_node(self, node_id: int) -> None: positions_max, ) + def _update_node( + self, + node_id: int, + old_attrs: dict[str, Any], + new_attrs: dict[str, Any], + ) -> None: + self._remove_node(node_id, old_attrs=old_attrs) + self._add_node(node_id, new_attrs=new_attrs) + @staticmethod def _bboxes_to_array(bbox_series: pl.Series) -> np.ndarray: """ diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index d17ba94b..4ad3a87d 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -342,6 +342,55 @@ def test_add_and_remove_node(graph_backend: BaseGraph) -> None: assert graph.num_nodes() == 2 +def test_spatial_filter_add_update_and_remove_node(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("y", pl.Int64) + graph_backend.add_node_attr_key("x", pl.Int64) + + graph_backend.add_node({"t": 0, "y": 1, "x": 1}) + graph_backend.add_node({"t": 1, "y": 10, "x": 10}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + spatial_filter = SpatialFilter(graph, attr_keys=[DEFAULT_ATTR_KEYS.T, "y", "x"]) + + assert spatial_filter[2:3, 6:9, 6:9].node_attrs().is_empty() + + new_node_id = graph.add_node({"t": 2, "y": 7, "x": 7}) + result_ids = spatial_filter[2:3, 6:9, 6:9].node_ids() + assert new_node_id in result_ids + + graph.update_node_attrs(attrs={"y": 20, "x": 20}, node_ids=[new_node_id]) + + assert spatial_filter[2:3, 6:9, 6:9].node_attrs().is_empty() + moved_ids = spatial_filter[2:3, 19:22, 19:22].node_ids() + assert new_node_id in moved_ids + + graph.remove_node(new_node_id) + assert spatial_filter[2:3, 19:22, 19:22].node_attrs().is_empty() + + +def test_bbox_spatial_filter_updates_node_position(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4)) + moved_node_id = graph_backend.add_node({"t": 0, "bbox": np.asarray([0, 0, 2, 2])}) + graph_backend.add_node({"t": 1, "bbox": np.asarray([10, 10, 12, 12])}) + + for graph in [graph_backend, graph_backend.filter().subgraph()]: + graph.update_node_attrs( + attrs={"bbox": [np.asarray([0, 0, 2, 2])]}, + node_ids=[moved_node_id], + ) + + spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox") + assert moved_node_id in spatial_filter[0:0.5, 0:3, 0:3].node_ids() + + graph.update_node_attrs( + attrs={"bbox": [np.asarray([20, 20, 22, 22])]}, + node_ids=[moved_node_id], + ) + + assert moved_node_id not in spatial_filter[0:0.5, 0:3, 0:3].node_ids() + assert moved_node_id in spatial_filter[0:0.5, 19:23, 19:23].node_ids() + + def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None: """Ensure bounding boxes stored as list dtype still work with the spatial filter.""" graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) diff --git a/src/tracksdata/utils/_signal.py b/src/tracksdata/utils/_signal.py index 61ba7f60..ee4c431e 100644 --- a/src/tracksdata/utils/_signal.py +++ b/src/tracksdata/utils/_signal.py @@ -1,6 +1,6 @@ -from psygnal import Signal +from psygnal import Signal, SignalInstance -def is_signal_on(sig: Signal) -> bool: +def is_signal_on(sig: Signal | SignalInstance) -> bool: """Check if a signal is connected and not blocked.""" return len(sig._slots) > 0 and not sig._is_blocked