Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .codex/environments/environment.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# THIS IS AUTOGENERATED. DO NOT EDIT MANUALLY
version = 1
name = "tracksdata"

[setup]
script = '''
uv venv
uv pip install -e .[spatial,test,docs]
source .venv/bin/activate
'''
17 changes: 17 additions & 0 deletions src/tracksdata/_test/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ def test_attr_expr_method_delegation() -> None:
assert result.to_list() == expected.to_list()


def test_attr_expr_struct_field_method_delegation() -> None:
df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}, schema={"s": pl.Struct({"x": pl.Int64})})
expr = NodeAttr("s").struct.field("x")
result = expr.evaluate(df)
assert isinstance(expr, NodeAttr)
assert result.to_list() == [1, 2, 3]


def test_attr_comparison_struct_field() -> None:
df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 1}]}, schema={"s": pl.Struct({"x": pl.Int64})})
comp = NodeAttr("s").struct.field("x") == 1
result = comp.to_attr().evaluate(df)
assert comp.column == "s"
assert comp.attr.field_path == ("x",)
assert result.to_list() == [True, False, True]


def test_attr_expr_complex_expression() -> None:
df = pl.DataFrame({"iou": [0.5, 0.7, 0.9], "distance": [10, 20, 30]})
expr = (1 - Attr("iou")) * Attr("distance")
Expand Down
90 changes: 84 additions & 6 deletions src/tracksdata/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr
raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.")

self.attr = attr
self.column = columns[0]
self.column = attr.root_column if attr.root_column is not None else columns[0]
self.op = op

# casting numpy scalars to python scalars
Expand All @@ -144,14 +144,18 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr
self.other = other

def __repr__(self) -> str:
return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}"
if self.attr.field_path:
column = ".".join([str(self.column), *self.attr.field_path])
else:
column = str(self.column)
return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}"

def to_attr(self) -> "Attr":
"""
Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object.
This is useful for evaluating the expression on a DataFrame.
"""
return Attr(self.op(pl.col(self.column), self.other))
return Attr(self.op(self.attr.expr, self.other))

def __getattr__(self, attr: str) -> Any:
return getattr(self.to_attr(), attr)
Expand Down Expand Up @@ -198,6 +202,31 @@ def __ge__(self, other: ExprInput) -> "Attr": ...
def __rge__(self, other: ExprInput) -> "Attr": ...


class _StructNamespace:
"""Wrapper around polars struct namespace that preserves Attr semantics."""

def __init__(self, attr: "Attr") -> None:
self._attr = attr
self._namespace = attr.expr.struct

def field(self, name: str) -> "Attr":
out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True)
if isinstance(out, Attr):
out._append_field_path(name)
return out

def __getattr__(self, name: str) -> Any:
namespace_attr = getattr(self._namespace, name)
if callable(namespace_attr):

@functools.wraps(namespace_attr)
def _wrapped(*args, **kwargs):
return self._attr._wrap(namespace_attr(*args, **kwargs))

return _wrapped
return namespace_attr


class Attr:
"""
A class to compose an attribute expression for attribute filtering or value evaluation.
Expand All @@ -222,30 +251,43 @@ class Attr:
def __init__(self, value: ExprInput) -> None:
self._inf_exprs = [] # expressions multiplied by +inf
self._neg_inf_exprs = [] # expressions multiplied by -inf
# Path-tracking for backend filters:
# - root_column: top-level column used to store the value.
# - field_path: nested struct path from that root column.
self._root_column: str | None = None
self._field_path: tuple[str, ...] = ()

if isinstance(value, str):
self.expr = pl.col(value)
self._root_column = value
elif isinstance(value, Attr):
self.expr = value.expr
# Copy infinity tracking from the other AttrExpr
self._inf_exprs = value.inf_exprs
self._neg_inf_exprs = value.neg_inf_exprs
self._root_column = value.root_column
self._field_path = value.field_path
elif isinstance(value, AttrComparison):
attr = value.to_attr()
self.expr = attr.expr
self._inf_exprs = attr.inf_exprs
self._neg_inf_exprs = attr.neg_inf_exprs
self._root_column = attr.root_column
self._field_path = attr.field_path
elif isinstance(value, Expr):
self.expr = value
else:
self.expr = pl.lit(value)

def _wrap(self, expr: ExprInput) -> Union["Attr", Any]:
def _wrap(self, expr: ExprInput, *, preserve_field_path: bool = False) -> Union["Attr", Any]:
if isinstance(expr, Expr):
result = Attr(expr)
result = type(self)(expr)
# Propagate infinity tracking
result._inf_exprs = self._inf_exprs.copy()
result._neg_inf_exprs = self._neg_inf_exprs.copy()
if preserve_field_path:
result._root_column = self._root_column
result._field_path = self._field_path
return result
return expr

Expand Down Expand Up @@ -377,6 +419,33 @@ def evaluate(self, df: DataFrame) -> Series:
def columns(self) -> list[str]:
return list(dict.fromkeys(self.expr_columns + self.inf_columns + self.neg_inf_columns))

@property
def root_column(self) -> str | None:
"""
Top-level column name from which this expression originates.

Examples
--------
`Attr("t").root_column == "t"`
`NodeAttr("measurements").struct.field("score").root_column == "measurements"`
"""
return self._root_column

@property
def field_path(self) -> tuple[str, ...]:
"""
Nested struct-field path relative to [root_column][tracksdata.attrs.Attr.root_column].

Empty tuple means no nested access.

Examples
--------
`Attr("t").field_path == ()`
`NodeAttr("measurements").struct.field("score").field_path == ("score",)`
`NodeAttr("meta").struct.field("det").struct.field("conf").field_path == ("det", "conf")`
"""
return self._field_path

@property
def inf_exprs(self) -> list["Attr"]:
"""Get the expressions multiplied by positive infinity."""
Expand Down Expand Up @@ -464,6 +533,9 @@ def __getattr__(self, attr: str) -> Any:
if attr.startswith("_"):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

if attr == "struct":
return _StructNamespace(self)

# To auto generate operator methods such as `.log()``
expr_attr = getattr(self.expr, attr)
if callable(expr_attr):
Expand All @@ -475,6 +547,12 @@ def _wrapped(*args, **kwargs):
return _wrapped
return expr_attr

def _append_field_path(self, field_name: str) -> None:
if self._root_column is None:
self._field_path = ()
else:
self._field_path = (*self._field_path, field_name)

def __repr__(self) -> str:
return f"Attr({self.expr})"

Expand Down Expand Up @@ -733,4 +811,4 @@ def polars_reduce_attr_comps(
# Return True for all rows by using the first column as a reference
raise ValueError("No attribute comparisons provided.")

return pl.reduce(reduce_op, [attr_comp.op(df[str(attr_comp.column)], attr_comp.other) for attr_comp in attr_comps])
return pl.reduce(reduce_op, [attr_comp.op(attr_comp.attr.expr, attr_comp.other) for attr_comp in attr_comps])
16 changes: 8 additions & 8 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
current_edge_attr_schemas = graph._edge_attr_schemas()
for k, v in other._edge_attr_schemas().items():
if k not in current_edge_attr_schemas:
print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}")
graph.add_edge_attr_key(k, v.dtype, v.default_value)

edge_attrs = edge_attrs.with_columns(
Expand Down Expand Up @@ -1927,13 +1926,6 @@ def _private_metadata(self) -> MetadataView:
is_public=False,
)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Return private metadata entries that should be propagated by `from_other` or `to_geff`.
Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

@classmethod
def _is_private_metadata_key(cls, key: str) -> bool:
return key.startswith(cls._PRIVATE_METADATA_PREFIX)
Expand Down Expand Up @@ -1962,6 +1954,14 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True)
self._validate_metadata_key(key, is_public=is_public)
self._remove_metadata(key)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Return private metadata entries that should be propagated by `from_other`.

Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

@abc.abstractmethod
def _metadata(self) -> dict[str, Any]:
"""
Expand Down
38 changes: 27 additions & 11 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,30 @@ def _create_filter_func(
) -> Callable[[dict[str, Any]], bool]:
LOG.info(f"Creating filter function for {attr_comps}")

def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any:
for field in field_path:
if value is None:
return None

if isinstance(value, dict):
value = value.get(field, None)
continue

try:
value = value[field]
except (KeyError, IndexError, TypeError):
try:
value = getattr(value, field)
except AttributeError:
return None

return value

def _filter(attrs: dict[str, Any]) -> bool:
for attr_op in attr_comps:
value = attrs.get(attr_op.column, schema[attr_op.column].default_value)
if attr_op.attr.field_path:
value = _extract_field_path(value, attr_op.attr.field_path)
if not attr_op.op(value, attr_op.other):
return False
return True
Expand Down Expand Up @@ -343,7 +364,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:
self._time_to_nodes: dict[int, list[int]] = {}
self.__node_attr_schemas: dict[str, AttrSchema] = {}
self.__edge_attr_schemas: dict[str, AttrSchema] = {}
self._overlaps: list[list[int, 2]] = []
self._overlaps: list[list[int]] = []

# Add default node attributes with inferred schemas
self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema(
Expand Down Expand Up @@ -1159,16 +1180,11 @@ def edge_attrs(

edge_map = rx_graph.edge_index_map()
if len(edge_map) == 0:
return pl.DataFrame(
{
key: []
for key in [
*attr_keys,
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
]
}
)
empty_columns = {}
for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]:
schema = self._edge_attr_schemas()[key]
empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype)
return pl.DataFrame(empty_columns)

source, target, data = zip(*edge_map.values(), strict=False)

Expand Down
Loading
Loading