Skip to content
16 changes: 8 additions & 8 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
current_edge_attr_schemas = graph._edge_attr_schemas()
for k, v in other._edge_attr_schemas().items():
if k not in current_edge_attr_schemas:
print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}")
graph.add_edge_attr_key(k, v.dtype, v.default_value)

edge_attrs = edge_attrs.with_columns(
Expand Down Expand Up @@ -1927,13 +1926,6 @@ def _private_metadata(self) -> MetadataView:
is_public=False,
)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Return private metadata entries that should be propagated by `from_other` or `to_geff`.
Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

@classmethod
def _is_private_metadata_key(cls, key: str) -> bool:
return key.startswith(cls._PRIVATE_METADATA_PREFIX)
Expand Down Expand Up @@ -1962,6 +1954,14 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True)
self._validate_metadata_key(key, is_public=is_public)
self._remove_metadata(key)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Return private metadata entries that should be propagated by `from_other`.

Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

@abc.abstractmethod
def _metadata(self) -> dict[str, Any]:
"""
Expand Down
17 changes: 6 additions & 11 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:
self._time_to_nodes: dict[int, list[int]] = {}
self.__node_attr_schemas: dict[str, AttrSchema] = {}
self.__edge_attr_schemas: dict[str, AttrSchema] = {}
self._overlaps: list[list[int, 2]] = []
self._overlaps: list[list[int]] = []

# Add default node attributes with inferred schemas
self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema(
Expand Down Expand Up @@ -1159,16 +1159,11 @@ def edge_attrs(

edge_map = rx_graph.edge_index_map()
if len(edge_map) == 0:
return pl.DataFrame(
{
key: []
for key in [
*attr_keys,
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
]
}
)
empty_columns = {}
for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]:
schema = self._edge_attr_schemas()[key]
empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype)
return pl.DataFrame(empty_columns)

source, target, data = zip(*edge_map.values(), strict=False)

Expand Down
Loading
Loading