Skip to content
1 change: 1 addition & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
--machine github-actions \
--python 3.12 \
--factor 1.5 \
--attribute timeout=180 \
--show-stderr || status=$?
if [ "$status" -eq 2 ]; then
echo "asv: benchmark run failed (exit 2). Failing CI." >&2
Expand Down
5 changes: 3 additions & 2 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class BaseGraph(abc.ABC):
Base class for a graph backend.
"""

node_added = Signal(int)
node_removed = Signal(int)
node_added = Signal(int, object)
node_removed = Signal(int, object)
node_updated = Signal(int, object, object)

def __init__(self) -> None:
self._cache = {}
Expand Down
47 changes: 35 additions & 12 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Any, Literal, overload
from typing import Any, Literal, cast, overload

import bidict
import polars as pl
Expand Down Expand Up @@ -400,8 +400,10 @@ def add_node(
else:
self._out_of_sync = True

self._root.node_added.emit_fast(parent_node_id)
self.node_added.emit_fast(parent_node_id)
if is_signal_on(self._root.node_added):
self._root.node_added.emit(parent_node_id, attrs)
if is_signal_on(self.node_added):
self.node_added.emit(parent_node_id, attrs)

return parent_node_id

Expand All @@ -417,12 +419,12 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
self._out_of_sync = True

if is_signal_on(self._root.node_added):
for node_id in parent_node_ids:
self._root.node_added.emit_fast(node_id)
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self._root.node_added.emit(node_id, node_attrs)

if is_signal_on(self.node_added):
for node_id in parent_node_ids:
self.node_added.emit_fast(node_id)
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self.node_added.emit(node_id, node_attrs)

return parent_node_ids

Expand All @@ -446,9 +448,11 @@ def remove_node(self, node_id: int) -> None:
if node_id not in self._external_to_local:
raise ValueError(f"Node {node_id} does not exist in the graph.")

if is_signal_on(self.node_removed):
old_attrs = self.nodes[node_id].to_dict()

# Remove from root graph first, because removing bounding box requires node attrs
self._root.remove_node(node_id)
self.node_removed.emit_fast(node_id)

if self.sync:
# Get the local node ID and remove from local graph
Expand All @@ -474,6 +478,9 @@ def remove_node(self, node_id: int) -> None:
else:
self._out_of_sync = True

if is_signal_on(self.node_removed):
self.node_removed.emit(node_id, old_attrs)

def add_edge(
self,
source_id: int,
Expand Down Expand Up @@ -652,6 +659,12 @@ def update_node_attrs(
) -> None:
if node_ids is None:
node_ids = self.node_ids()
else:
node_ids = list(node_ids)

if is_signal_on(self.node_updated):
old_attrs_by_id = self._root.filter(node_ids=node_ids).node_attrs()
old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_attrs_by_id.to_dicts()}

self._root.update_node_attrs(
node_ids=node_ids,
Expand All @@ -660,13 +673,23 @@ def update_node_attrs(
# because attributes are passed by reference, we need don't need if both are rustworkx graphs
if not self._is_root_rx_graph:
if self.sync:
super().update_node_attrs(
node_ids=self._map_to_local(node_ids),
attrs=attrs,
)
with self.node_updated.blocked():
super().update_node_attrs(
node_ids=self._map_to_local(node_ids),
attrs=attrs,
)
else:
self._out_of_sync = True

if is_signal_on(self.node_updated):
for node_id in node_ids:
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
self.node_updated.emit(
node_id,
old_attrs_by_id[node_id],
self._root.nodes[node_id].to_dict(),
)

def update_edge_attrs(
self,
*,
Expand Down
5 changes: 3 additions & 2 deletions src/tracksdata/graph/_mapped_graph_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from collections.abc import Sequence
from numbers import Integral
from typing import Any, overload

import bidict
Expand Down Expand Up @@ -84,7 +85,7 @@ def _map_to_external(self, local_ids: int | Sequence[int] | None) -> int | list[
"""
if local_ids is None:
return None
if isinstance(local_ids, int):
if isinstance(local_ids, Integral):
return self._local_to_external[local_ids]
return [self._local_to_external[lid] for lid in local_ids]

Expand Down Expand Up @@ -113,7 +114,7 @@ def _map_to_local(self, external_ids: int | Sequence[int] | None) -> int | list[
"""
if external_ids is None:
return None
if isinstance(external_ids, int):
if isinstance(external_ids, Integral):
return self._external_to_local[external_ids]
return [self._external_to_local[eid] for eid in external_ids]

Expand Down
55 changes: 45 additions & 10 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def add_node(

node_id = self.rx_graph.add_node(attrs)
self._time_to_nodes.setdefault(attrs["t"], []).append(node_id)
self.node_added.emit_fast(node_id)
if is_signal_on(self.node_added):
self.node_added.emit(node_id, attrs)
return node_id

def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None = None) -> list[int]:
Expand Down Expand Up @@ -523,8 +524,8 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None

# checking if it has connections to reduce overhead
if is_signal_on(self.node_added):
for node_id in node_indices:
self.node_added.emit_fast(node_id)
for node_id, node_attrs in zip(node_indices, nodes, strict=True):
self.node_added.emit(node_id, node_attrs)

return node_indices

Expand All @@ -548,7 +549,9 @@ def remove_node(self, node_id: int) -> None:
if node_id not in self.rx_graph.node_indices():
raise ValueError(f"Node {node_id} does not exist in the graph.")

self.node_removed.emit_fast(node_id)
old_attrs = None
if is_signal_on(self.node_removed):
old_attrs = dict(self.rx_graph[node_id])

# Get the time value before removing the node
t = self.rx_graph[node_id]["t"]
Expand All @@ -566,6 +569,9 @@ def remove_node(self, node_id: int) -> None:
if self._overlaps is not None:
self._overlaps = [overlap for overlap in self._overlaps if node_id != overlap[0] and node_id != overlap[1]]

if is_signal_on(self.node_removed):
self.node_removed.emit(node_id, old_attrs)

def add_edge(
self,
source_id: int,
Expand Down Expand Up @@ -1217,6 +1223,9 @@ def update_node_attrs(
if node_ids is None:
node_ids = self.node_ids()

if is_signal_on(self.node_updated):
old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids}

for key, value in attrs.items():
if key not in self.node_attr_keys():
raise ValueError(f"Node attribute key '{key}' not found in graph. Expected '{self.node_attr_keys()}'")
Expand All @@ -1231,6 +1240,10 @@ def update_node_attrs(
for node_id, v in zip(node_ids, value, strict=False):
self._graph[node_id][key] = v

if is_signal_on(self.node_updated):
for node_id in node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id]))

def update_edge_attrs(
self,
*,
Expand Down Expand Up @@ -1612,7 +1625,8 @@ def add_node(
self._next_external_id = max(self._next_external_id, index + 1)
# Add mapping using mixin
self._add_id_mapping(node_id, index)
self.node_added.emit_fast(index)
if is_signal_on(self.node_added):
self.node_added.emit(index, attrs)
return index

def bulk_add_nodes(
Expand Down Expand Up @@ -1658,8 +1672,8 @@ def bulk_add_nodes(
self._add_id_mappings(list(zip(graph_ids, indices, strict=True)))

if is_signal_on(self.node_added):
for index in indices:
self.node_added.emit_fast(index)
for index, node_attrs in zip(indices, nodes, strict=True):
self.node_added.emit(index, node_attrs)

return indices

Expand Down Expand Up @@ -1937,8 +1951,25 @@ def update_node_attrs(
node_ids : Sequence[int] | None
The node ids to update.
"""
node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids)
super().update_node_attrs(attrs=attrs, node_ids=node_ids)
external_node_ids = self.node_ids() if node_ids is None else node_ids
local_node_ids = self._map_to_local(external_node_ids)

if is_signal_on(self.node_updated):
old_attrs_by_id = {
external_node_id: dict(self._graph[local_node_id])
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True)
}

with self.node_updated.blocked():
super().update_node_attrs(attrs=attrs, node_ids=local_node_ids)

if is_signal_on(self.node_updated) and old_attrs_by_id is not None:
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True):
self.node_updated.emit(
external_node_id,
old_attrs_by_id[external_node_id],
dict(self._graph[local_node_id]),
)

def remove_node(self, node_id: int) -> None:
"""
Expand All @@ -1959,11 +1990,15 @@ def remove_node(self, node_id: int) -> None:

local_node_id = self._map_to_local(node_id)

self.node_removed.emit_fast(node_id)
if is_signal_on(self.node_removed):
old_attrs = dict(self._graph[local_node_id])

with self.node_removed.blocked():
super().remove_node(local_node_id)

self._remove_id_mapping(external_id=node_id)
if is_signal_on(self.node_removed):
self.node_removed.emit(node_id, old_attrs)

def filter(
self,
Expand Down
34 changes: 29 additions & 5 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,8 @@ def add_node(
if index is None:
self._max_id_per_time[time] = node_id

self.node_added.emit_fast(node_id)
if is_signal_on(self.node_added):
self.node_added.emit(node_id, attrs)

return node_id

Expand Down Expand Up @@ -781,8 +782,9 @@ def bulk_add_nodes(
self._chunked_sa_operation(Session.bulk_insert_mappings, self.Node, nodes)

if is_signal_on(self.node_added):
for node_id in node_ids:
self.node_added.emit_fast(node_id)
for node_id, node_attrs in zip(node_ids, nodes, strict=True):
new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
self.node_added.emit(node_id, new_attrs)

return node_ids

Expand All @@ -804,14 +806,15 @@ def remove_node(self, node_id: int) -> None:
ValueError
If the node_id does not exist in the graph.
"""
self.node_removed.emit_fast(node_id)

with Session(self._engine) as session:
# Check if the node exists
node = session.query(self.Node).filter(self.Node.node_id == node_id).first()
if node is None:
raise ValueError(f"Node {node_id} does not exist in the graph.")

if is_signal_on(self.node_removed):
old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()}

# Remove all edges where this node is source or target
session.query(self.Edge).filter(
sa.or_(self.Edge.source_id == node_id, self.Edge.target_id == node_id)
Expand All @@ -825,6 +828,8 @@ def remove_node(self, node_id: int) -> None:
# Remove the node itself
session.delete(node)
session.commit()
if is_signal_on(self.node_removed):
self.node_removed.emit(node_id, old_attrs)

def add_edge(
self,
Expand Down Expand Up @@ -1755,8 +1760,27 @@ def update_node_attrs(
if "t" in attrs:
raise ValueError("Node attribute 't' cannot be updated.")

updated_node_ids = self.node_ids() if node_ids is None else list(node_ids)
if len(updated_node_ids) == 0:
return

attr_keys = self.node_attr_keys()
if is_signal_on(self.node_updated):
old_df = self.filter(node_ids=updated_node_ids).node_attrs(
attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]
)
old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_df.rows(named=True)}

self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs)

if is_signal_on(self.node_updated):
new_df = self.filter(node_ids=updated_node_ids).node_attrs(
attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]
)
new_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in new_df.rows(named=True)}
for node_id in updated_node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id])

def update_edge_attrs(
self,
*,
Expand Down
Loading
Loading