Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from copy import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np

Expand Down Expand Up @@ -200,6 +200,9 @@ def __init__(
frame_attr_key=DEFAULT_ATTR_KEYS.T,
bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX,
)
self.graph.node_added.connect(self._on_node_added)
self.graph.node_removed.connect(self._on_node_removed)
self.graph.node_updated.connect(self._on_node_updated)

@property
def shape(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -351,3 +354,76 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda
for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)

def _offset_as_array(self, ndim: int) -> np.ndarray:
"""Normalize `offset` to a vector for each spatial axis."""
if np.isscalar(self._offset):
return np.full(ndim, int(self._offset), dtype=np.int64)

offset = np.asarray(self._offset, dtype=np.int64).reshape(-1)
if len(offset) != ndim:
raise ValueError(f"`offset` must have length {ndim}, got {len(offset)}")
return offset

def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None:
"""
Convert a bbox to clipped spatial slices in array coordinates.

Returns `None` when the bbox does not overlap the current array volume.
"""
bbox = np.asarray(bbox, dtype=np.int64).reshape(-1)
ndim = len(self.original_shape) - 1
if len(bbox) != 2 * ndim:
raise ValueError(f"`bbox` must have length {2 * ndim}, got {len(bbox)}")

offset = self._offset_as_array(ndim)
start = bbox[:ndim] + offset
stop = bbox[ndim:] + offset

shape = np.asarray(self.original_shape[1:], dtype=np.int64)
start = np.clip(start, 0, shape)
stop = np.clip(stop, 0, shape)

if np.any(stop <= start):
return None

return tuple(slice(int(s), int(e)) for s, e in zip(start, stop, strict=True))

def _invalidate_from_attrs(self, attrs: dict) -> None:
"""
Invalidate cache region touched by node attributes.

Falls back to larger invalidation windows when metadata is incomplete.
"""

time_value = attrs.get(DEFAULT_ATTR_KEYS.T)
if time_value is None:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.")
if DEFAULT_ATTR_KEYS.BBOX not in attrs:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.")

try:
time = int(np.asarray(time_value).item())
except (TypeError, ValueError) as e:
raise ValueError(
f"Time attribute value must be a scalar integer, got {time_value} of type {type(time_value)}"
) from e
if not (0 <= time < self.original_shape[0]):
return

slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX])
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)

def _on_node_added(self, node_id: int, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(new_attrs)

def _on_node_removed(self, node_id: int, old_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)

def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)
self._invalidate_from_attrs(new_attrs)
65 changes: 61 additions & 4 deletions src/tracksdata/array/_nd_chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def _chunk_bounds(self, slices: tuple[slice, ...]) -> tuple[tuple[int, int], ...
"""Return inclusive chunk-index bounds for every axis."""
return tuple((s.start // cs, (s.stop - 1) // cs) for s, cs in zip(slices, self.chunk_shape, strict=True))

def _chunk_slice(self, chunk_idx: tuple[int, ...]) -> tuple[slice, ...]:
"""Return absolute volume slices for a chunk index."""
return tuple(
slice(ci * cs, min((ci + 1) * cs, fs))
for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True)
)

def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]) -> np.ndarray:
"""
Retrieve data for `time` and arbitrary dimensional slices.
Expand Down Expand Up @@ -146,13 +153,63 @@ def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]
continue # already filled

# Absolute slice covering this chunk
chunk_slc = tuple(
slice(ci * cs, min((ci + 1) * cs, fs))
for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True)
)
chunk_slc = self._chunk_slice(chunk_idx)
# Handle the case where chunk_slc exceeds volume_slices
self.compute_func(time, chunk_slc, store_entry.buffer)
store_entry.ready[chunk_idx] = True

# Return view on the big buffer
return store_entry.buffer[volume_slicing]

def invalidate(
self,
*,
time: int | None = None,
volume_slicing: tuple[slice | int | Sequence[int], ...] | None = None,
) -> None:
"""
Invalidate a cached region.

Parameters
----------
time : int | None, optional
Time point to invalidate. If None, applies to all currently cached times.
volume_slicing : tuple[slice | int | Sequence[int], ...] | None, optional
Volume region to invalidate. If None, invalidates the full volume for the selected times.
"""
if time is None:
times = list(self._store.keys())
elif time in self._store:
times = [time]
else:
return

if volume_slicing is not None and len(volume_slicing) != self.ndim:
raise ValueError("Number of slices must equal dimensionality")

region_slices = None
if volume_slicing is not None:
region_slices = tuple(_to_slice(slc) for slc in volume_slicing)

for t in times:
store_entry = self._store[t]

if region_slices is None:
store_entry.ready.fill(False)
store_entry.buffer.fill(0)
continue

clipped_slices = []
for slc, size in zip(region_slices, self.shape, strict=True):
start = max(0, slc.start)
stop = min(size, slc.stop)
if stop <= start:
return
clipped_slices.append(slice(start, stop))

bounds = self._chunk_bounds(tuple(clipped_slices))
chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds]
for chunk_idx in itertools.product(*chunk_ranges):
chunk_slc = self._chunk_slice(chunk_idx)
store_entry.ready[chunk_idx] = False
store_entry.buffer[chunk_slc] = 0
119 changes: 119 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,122 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph

with pytest.raises(ValueError, match="Attribute values for key 'label' must be scalar"):
GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")


def _add_graph_array_node_attrs(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))


def _make_square_mask(y: int, x: int, size: int = 2) -> Mask:
return Mask(np.ones((size, size), dtype=bool), bbox=np.array([y, x, y + size, x + size]))


def test_graph_array_view_invalidates_only_affected_chunk_on_add(graph_backend: BaseGraph) -> None:
_add_graph_array_node_attrs(graph_backend)

first_mask = _make_square_mask(1, 1)
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: first_mask,
DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))

_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

second_mask = _make_square_mask(5, 5)
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 2,
DEFAULT_ATTR_KEYS.MASK: second_mask,
DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox,
}
)

expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[1, 1] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

output = np.asarray(array_view[0])
assert output[1, 1] == 1
assert output[5, 5] == 2


def test_graph_array_view_invalidates_old_and_new_chunks_on_update(graph_backend: BaseGraph) -> None:
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1)
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])

moved_mask = _make_square_mask(5, 5)
graph_backend.update_node_attrs(
attrs={
"label": [7],
DEFAULT_ATTR_KEYS.MASK: [moved_mask],
DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox],
},
node_ids=[node_id],
)

expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[0, 0] = False
expected_ready[1, 1] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

output = np.asarray(array_view[0])
assert output[1, 1] == 0
assert output[5, 5] == 7


def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph) -> None:
_add_graph_array_node_attrs(graph_backend)

first_mask = _make_square_mask(1, 1)
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: first_mask,
DEFAULT_ATTR_KEYS.BBOX: first_mask.bbox,
}
)
second_mask = _make_square_mask(5, 5)
second_node = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 2,
DEFAULT_ATTR_KEYS.MASK: second_mask,
DEFAULT_ATTR_KEYS.BBOX: second_mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])

graph_backend.remove_node(second_node)

expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[1, 1] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

output = np.asarray(array_view[0])
assert output[1, 1] == 1
assert output[5, 5] == 0
7 changes: 3 additions & 4 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ class BaseGraph(abc.ABC):
"""

_PRIVATE_METADATA_PREFIX = "__private_"

node_added = Signal(int, object)
node_removed = Signal(int, object)
node_updated = Signal(int, object, object)
node_added = Signal(int, dict)
node_removed = Signal(int, dict)
node_updated = Signal(int, dict, dict)

def __init__(self) -> None:
self._cache = {}
Expand Down
1 change: 1 addition & 0 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,7 @@ def remove_node(self, node_id: int) -> None:
raise ValueError(f"Node {node_id} does not exist in the graph.")

local_node_id = self._map_to_local(node_id)
old_attrs = dict(self._graph[local_node_id])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work?

Suggested change
old_attrs = dict(self._graph[local_node_id])
old_attrs = self._graph[local_node_id]


if is_signal_on(self.node_removed):
old_attrs = dict(self._graph[local_node_id])
Expand Down
10 changes: 10 additions & 0 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,8 @@ def remove_node(self, node_id: int) -> None:
node = session.query(self.Node).filter(self.Node.node_id == node_id).first()
if node is None:
raise ValueError(f"Node {node_id} does not exist in the graph.")
old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()}
self.node_removed.emit(node_id, old_attrs)

if is_signal_on(self.node_removed):
old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()}
Expand Down Expand Up @@ -1830,6 +1832,14 @@ def update_node_attrs(
old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_df.rows(named=True)}

self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs)
new_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys])
new_attrs_by_id = {
row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in new_df.rows(named=True)
}
Comment on lines +1836 to +1838
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work? row already contains only nod_id and *attr_keys from the line above.

Suggested change
new_attrs_by_id = {
row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in new_df.rows(named=True)
}
new_attrs_by_id = {
row.pop(DEFAULT_ATTR_KEYS.NODE_ID): row for row in new_df.rows(named=True)
}

Comment on lines +1836 to +1838
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_attrs_by_id = {
row[DEFAULT_ATTR_KEYS.NODE_ID]: {key: row[key] for key in attr_keys} for row in new_df.rows(named=True)
}
new_attrs_by_id = new_df.rows_by_key(key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True)

It seems polars has a function for this, rows_by_key
This would be the preferred usage, in my opinion.

I think I have this pattern in our codebase. I'll check later if they can all be refactored to use this.


if is_signal_on(self.node_updated):
for node_id in updated_node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id])

if is_signal_on(self.node_updated):
new_df = self.filter(node_ids=updated_node_ids).node_attrs(
Expand Down
Loading