diff --git a/.codex/environments/environment.toml b/.codex/environments/environment.toml new file mode 100644 index 00000000..1324ca94 --- /dev/null +++ b/.codex/environments/environment.toml @@ -0,0 +1,10 @@ +# THIS IS AUTOGENERATED. DO NOT EDIT MANUALLY +version = 1 +name = "tracksdata" + +[setup] +script = ''' +uv venv +uv pip install -e .[spatial,test,docs] +source .venv/bin/activate +''' diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 8b29bf7d..cc403723 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -115,6 +115,23 @@ def test_attr_expr_method_delegation() -> None: assert result.to_list() == expected.to_list() +def test_attr_expr_struct_field_method_delegation() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + expr = NodeAttr("s").struct.field("x") + result = expr.evaluate(df) + assert isinstance(expr, NodeAttr) + assert result.to_list() == [1, 2, 3] + + +def test_attr_comparison_struct_field() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 1}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + comp = NodeAttr("s").struct.field("x") == 1 + result = comp.to_attr().evaluate(df) + assert comp.column == "s" + assert comp.attr.field_path == ("x",) + assert result.to_list() == [True, False, True] + + def test_attr_expr_complex_expression() -> None: df = pl.DataFrame({"iou": [0.5, 0.7, 0.9], "distance": [10, 20, 30]}) expr = (1 - Attr("iou")) * Attr("distance") diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 60f82db8..be223211 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -129,7 +129,7 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") self.attr = attr - self.column = columns[0] + self.column = attr.root_column if attr.root_column is not None else columns[0] self.op = op # casting numpy scalars to python scalars @@ -144,14 +144,18 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.other = other def __repr__(self) -> str: - return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" + if self.attr.field_path: + column = ".".join([str(self.column), *self.attr.field_path]) + else: + column = str(self.column) + return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" def to_attr(self) -> "Attr": """ Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object. This is useful for evaluating the expression on a DataFrame. """ - return Attr(self.op(pl.col(self.column), self.other)) + return Attr(self.op(self.attr.expr, self.other)) def __getattr__(self, attr: str) -> Any: return getattr(self.to_attr(), attr) @@ -198,6 +202,31 @@ def __ge__(self, other: ExprInput) -> "Attr": ... def __rge__(self, other: ExprInput) -> "Attr": ... +class _StructNamespace: + """Wrapper around polars struct namespace that preserves Attr semantics.""" + + def __init__(self, attr: "Attr") -> None: + self._attr = attr + self._namespace = attr.expr.struct + + def field(self, name: str) -> "Attr": + out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) + if isinstance(out, Attr): + out._append_field_path(name) + return out + + def __getattr__(self, name: str) -> Any: + namespace_attr = getattr(self._namespace, name) + if callable(namespace_attr): + + @functools.wraps(namespace_attr) + def _wrapped(*args, **kwargs): + return self._attr._wrap(namespace_attr(*args, **kwargs)) + + return _wrapped + return namespace_attr + + class Attr: """ A class to compose an attribute expression for attribute filtering or value evaluation. @@ -222,30 +251,43 @@ class Attr: def __init__(self, value: ExprInput) -> None: self._inf_exprs = [] # expressions multiplied by +inf self._neg_inf_exprs = [] # expressions multiplied by -inf + # Path-tracking for backend filters: + # - root_column: top-level column used to store the value. + # - field_path: nested struct path from that root column. + self._root_column: str | None = None + self._field_path: tuple[str, ...] = () if isinstance(value, str): self.expr = pl.col(value) + self._root_column = value elif isinstance(value, Attr): self.expr = value.expr # Copy infinity tracking from the other AttrExpr self._inf_exprs = value.inf_exprs self._neg_inf_exprs = value.neg_inf_exprs + self._root_column = value.root_column + self._field_path = value.field_path elif isinstance(value, AttrComparison): attr = value.to_attr() self.expr = attr.expr self._inf_exprs = attr.inf_exprs self._neg_inf_exprs = attr.neg_inf_exprs + self._root_column = attr.root_column + self._field_path = attr.field_path elif isinstance(value, Expr): self.expr = value else: self.expr = pl.lit(value) - def _wrap(self, expr: ExprInput) -> Union["Attr", Any]: + def _wrap(self, expr: ExprInput, *, preserve_field_path: bool = False) -> Union["Attr", Any]: if isinstance(expr, Expr): - result = Attr(expr) + result = type(self)(expr) # Propagate infinity tracking result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() + if preserve_field_path: + result._root_column = self._root_column + result._field_path = self._field_path return result return expr @@ -377,6 +419,33 @@ def evaluate(self, df: DataFrame) -> Series: def columns(self) -> list[str]: return list(dict.fromkeys(self.expr_columns + self.inf_columns + self.neg_inf_columns)) + @property + def root_column(self) -> str | None: + """ + Top-level column name from which this expression originates. + + Examples + -------- + `Attr("t").root_column == "t"` + `NodeAttr("measurements").struct.field("score").root_column == "measurements"` + """ + return self._root_column + + @property + def field_path(self) -> tuple[str, ...]: + """ + Nested struct-field path relative to [root_column][tracksdata.attrs.Attr.root_column]. + + Empty tuple means no nested access. + + Examples + -------- + `Attr("t").field_path == ()` + `NodeAttr("measurements").struct.field("score").field_path == ("score",)` + `NodeAttr("meta").struct.field("det").struct.field("conf").field_path == ("det", "conf")` + """ + return self._field_path + @property def inf_exprs(self) -> list["Attr"]: """Get the expressions multiplied by positive infinity.""" @@ -464,6 +533,9 @@ def __getattr__(self, attr: str) -> Any: if attr.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + if attr == "struct": + return _StructNamespace(self) + # To auto generate operator methods such as `.log()`` expr_attr = getattr(self.expr, attr) if callable(expr_attr): @@ -475,6 +547,12 @@ def _wrapped(*args, **kwargs): return _wrapped return expr_attr + def _append_field_path(self, field_name: str) -> None: + if self._root_column is None: + self._field_path = () + else: + self._field_path = (*self._field_path, field_name) + def __repr__(self) -> str: return f"Attr({self.expr})" @@ -733,4 +811,4 @@ def polars_reduce_attr_comps( # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - return pl.reduce(reduce_op, [attr_comp.op(df[str(attr_comp.column)], attr_comp.other) for attr_comp in attr_comps]) + return pl.reduce(reduce_op, [attr_comp.op(attr_comp.attr.expr, attr_comp.other) for attr_comp in attr_comps]) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 634aacb1..1e4d4364 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1282,7 +1282,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: current_edge_attr_schemas = graph._edge_attr_schemas() for k, v in other._edge_attr_schemas().items(): if k not in current_edge_attr_schemas: - print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}") graph.add_edge_attr_key(k, v.dtype, v.default_value) edge_attrs = edge_attrs.with_columns( @@ -1927,13 +1926,6 @@ def _private_metadata(self) -> MetadataView: is_public=False, ) - def _private_metadata_for_copy(self) -> dict[str, Any]: - """ - Return private metadata entries that should be propagated by `from_other` or `to_geff`. - Backends can override this to exclude backend-specific private metadata. - """ - return dict(self._private_metadata) - @classmethod def _is_private_metadata_key(cls, key: str) -> bool: return key.startswith(cls._PRIVATE_METADATA_PREFIX) @@ -1962,6 +1954,14 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) + def _private_metadata_for_copy(self) -> dict[str, Any]: + """ + Return private metadata entries that should be propagated by `from_other`. + + Backends can override this to exclude backend-specific private metadata. + """ + return dict(self._private_metadata) + @abc.abstractmethod def _metadata(self) -> dict[str, Any]: """ diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index e8c1cac1..91e76f2a 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -74,9 +74,30 @@ def _create_filter_func( ) -> Callable[[dict[str, Any]], bool]: LOG.info(f"Creating filter function for {attr_comps}") + def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: + for field in field_path: + if value is None: + return None + + if isinstance(value, dict): + value = value.get(field, None) + continue + + try: + value = value[field] + except (KeyError, IndexError, TypeError): + try: + value = getattr(value, field) + except AttributeError: + return None + + return value + def _filter(attrs: dict[str, Any]) -> bool: for attr_op in attr_comps: value = attrs.get(attr_op.column, schema[attr_op.column].default_value) + if attr_op.attr.field_path: + value = _extract_field_path(value, attr_op.attr.field_path) if not attr_op.op(value, attr_op.other): return False return True @@ -343,7 +364,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: self._time_to_nodes: dict[int, list[int]] = {} self.__node_attr_schemas: dict[str, AttrSchema] = {} self.__edge_attr_schemas: dict[str, AttrSchema] = {} - self._overlaps: list[list[int, 2]] = [] + self._overlaps: list[list[int]] = [] # Add default node attributes with inferred schemas self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( @@ -1159,16 +1180,11 @@ def edge_attrs( edge_map = rx_graph.edge_index_map() if len(edge_map) == 0: - return pl.DataFrame( - { - key: [] - for key in [ - *attr_keys, - DEFAULT_ATTR_KEYS.EDGE_SOURCE, - DEFAULT_ATTR_KEYS.EDGE_TARGET, - ] - } - ) + empty_columns = {} + for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + schema = self._edge_attr_schemas()[key] + empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype) + return pl.DataFrame(empty_columns) source, target, data = zip(*edge_map.values(), strict=False) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index e3eb9a40..6a287edd 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -21,8 +21,10 @@ from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( AttrSchema, + deserialize_attr_schema, polars_dtype_to_sqlalchemy_type, process_attr_key_args, + serialize_attr_schema, sqlalchemy_type_to_polars_dtype, ) from tracksdata.utils._logging import LOG @@ -54,10 +56,100 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _coerce_json_field_expr(lhs: Any, dtype: pl.DataType | None) -> Any: + if dtype is None: + return lhs + + dtype_base = dtype.base_type() + + if dtype_base == pl.Boolean: + if hasattr(lhs, "as_boolean"): + return lhs.as_boolean() + return sa.cast(lhs, sa.Boolean) + if dtype_base in {pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64}: + if hasattr(lhs, "as_integer"): + return lhs.as_integer() + return sa.cast(lhs, sa.BigInteger) + if dtype_base in {pl.Float16, pl.Float32, pl.Float64}: + if hasattr(lhs, "as_float"): + return lhs.as_float() + return sa.cast(lhs, sa.Float) + if dtype_base in {pl.String, pl.Utf8}: + if hasattr(lhs, "as_string"): + return lhs.as_string() + return sa.cast(lhs, sa.String) + return lhs + + +def _field_dtype_from_schema( + attr_filter: AttrComparison, + attr_schemas: dict[str, AttrSchema] | None, +) -> pl.DataType | None: + if attr_schemas is None: + return None + + schema = attr_schemas.get(str(attr_filter.column)) + if schema is None: + return None + + dtype = schema.dtype + for field in attr_filter.attr.field_path: + if not isinstance(dtype, pl.Struct): + return None + + dtype = dtype.to_schema().get(field) + if dtype is None: + return None + + return dtype + + +def _resolve_attr_filter_column( + table: type[DeclarativeBase], + attr_filter: AttrComparison, + attr_schemas: dict[str, AttrSchema] | None = None, +) -> Any: + lhs = getattr(table, str(attr_filter.column)) + + if not attr_filter.attr.field_path: + return lhs + + for field in attr_filter.attr.field_path: + lhs = lhs[field] + + field_dtype = _field_dtype_from_schema(attr_filter, attr_schemas) + return _coerce_json_field_expr(lhs, field_dtype) + + +def _json_decode_safe_dtype(dtype: pl.DataType) -> pl.DataType: + """ + Return a JSON-decodable dtype by replacing fixed-size arrays with lists recursively. + """ + if isinstance(dtype, pl.Array): + return pl.List(_json_decode_safe_dtype(dtype.inner)) + + if isinstance(dtype, pl.List): + return pl.List(_json_decode_safe_dtype(dtype.inner)) + + if isinstance(dtype, pl.Struct): + return pl.Struct({key: _json_decode_safe_dtype(inner) for key, inner in dtype.to_schema().items()}) + + return dtype + + +def _struct_json_decode_expr(column: str, target_dtype: pl.Struct) -> pl.Expr: + decode_dtype = _json_decode_safe_dtype(target_dtype) + decoded_expr = pl.when(pl.col(column).is_null()).then(None).otherwise(pl.col(column).str.json_decode(decode_dtype)) + if decode_dtype != target_dtype: + decoded_expr = decoded_expr.cast(target_dtype) + return decoded_expr.alias(column) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], attr_filters: list[AttrComparison], + attr_schemas: dict[str, AttrSchema] | None = None, ) -> sa.Select: """ Filter a query by a list of attribute filters. @@ -70,6 +162,8 @@ def _filter_query( The table to filter. attr_filters : list[AttrComparison] The attribute filters to apply. + attr_schemas : dict[str, AttrSchema] | None, optional + Attribute schema map used to resolve nested struct field dtypes. Returns ------- @@ -78,7 +172,13 @@ def _filter_query( """ LOG.info("Filter query:\n%s", attr_filters) query = query.filter( - *[attr_filter.op(getattr(table, str(attr_filter.column)), attr_filter.other) for attr_filter in attr_filters] + *[ + attr_filter.op( + _resolve_attr_filter_column(table, attr_filter, attr_schemas=attr_schemas), + attr_filter.other, + ) + for attr_filter in attr_filters + ] ) return query @@ -102,6 +202,8 @@ def __init__( self._node_query: sa.Select = sa.select(self._graph.Node) self._edge_query: sa.Select = sa.select(self._graph.Edge) node_filtered = False + node_attr_schemas = self._graph._node_attr_schemas() + edge_attr_schemas = self._graph._edge_attr_schemas() if node_ids is not None: if hasattr(node_ids, "tolist"): @@ -121,7 +223,12 @@ def __init__( if self._node_attr_comps: node_filtered = True # filtering nodes by attributes - self._node_query = _filter_query(self._node_query, self._graph.Node, self._node_attr_comps) + self._node_query = _filter_query( + self._node_query, + self._graph.Node, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) # if both node and edge attributes are filtered # we need to select subset of edges that belong to the filtered nodes @@ -137,17 +244,32 @@ def __init__( SourceNode, self._graph.Edge.source_id == SourceNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, SourceNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + SourceNode, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) if self._include_sources or include_none: self._edge_query = self._edge_query.join( TargetNode, self._graph.Edge.target_id == TargetNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, TargetNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + TargetNode, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) if self._edge_attr_comps: - self._edge_query = _filter_query(self._edge_query, self._graph.Edge, self._edge_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + self._graph.Edge, + self._edge_attr_comps, + attr_schemas=edge_attr_schemas, + ) # we haven't filtered the nodes by attributes # so we only return the nodes that are in the edges @@ -445,6 +567,8 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 + _PRIVATE_SQL_NODE_SCHEMA_STORE_KEY = "__private_sql_node_attr_schema_store" + _PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY = "__private_sql_edge_attr_schema_store" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -473,17 +597,12 @@ def __init__( # Create unique classes for this instance self._define_schema(overwrite=overwrite) - self.__node_attr_schemas: dict[str, AttrSchema] = {} - self.__edge_attr_schemas: dict[str, AttrSchema] = {} if overwrite: self.Base.metadata.drop_all(self._engine) self.Base.metadata.create_all(self._engine) - # Initialize schemas from existing table columns - self._init_schemas_from_tables() - self._max_id_per_time = {} self._update_max_id_per_time() @@ -555,68 +674,151 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata - def _init_schemas_from_tables(self) -> None: - """ - Initialize AttrSchema objects from existing database table columns. - This is used when loading an existing graph from the database. - """ - # Initialize node schemas from Node table columns - for column_name in self.Node.__table__.columns.keys(): - if column_name not in self.__node_attr_schemas: - column = self.Node.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__node_attr_schemas[column_name] = AttrSchema( - key=column_name, - dtype=pl_dtype, - ) + @staticmethod + def _default_node_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.T: AttrSchema(key=DEFAULT_ATTR_KEYS.T, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.NODE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.NODE_ID, dtype=pl.Int64), + } - # Initialize edge schemas from Edge table columns - for column_name in self.Edge.__table__.columns.keys(): - # Skip internal edge columns - if column_name not in self.__edge_attr_schemas: - column = self.Edge.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__edge_attr_schemas[column_name] = AttrSchema( + @staticmethod + def _default_edge_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.EDGE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_ID, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.EDGE_SOURCE: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_SOURCE, dtype=pl.Int64), + DEFAULT_ATTR_KEYS.EDGE_TARGET: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_TARGET, dtype=pl.Int64), + } + + def _attr_schemas_from_metadata( + self, + *, + table_class: type[DeclarativeBase], + metadata_key: str, + default_schemas: dict[str, AttrSchema], + preferred_order: Sequence[str], + ) -> dict[str, AttrSchema]: + encoded_schemas = self._private_metadata.get(metadata_key, {}) + schemas = default_schemas.copy() + schemas.update( + {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + ) + + # Legacy databases may not have schema metadata for all columns. + for column_name, column in table_class.__table__.columns.items(): + if column_name not in schemas: + schemas[column_name] = AttrSchema( key=column_name, - dtype=pl_dtype, + dtype=sqlalchemy_type_to_polars_dtype(column.type), ) + ordered_keys = [key for key in preferred_order if key in schemas] + ordered_keys.extend(key for key in table_class.__table__.columns.keys() if key not in ordered_keys) + ordered_keys.extend(key for key in schemas if key not in ordered_keys) + return {key: schemas[key] for key in ordered_keys} + + def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[str, AttrSchema]: + if table_class.__tablename__ == self.Node.__tablename__: + return self._node_attr_schemas() + return self._edge_attr_schemas() + + @staticmethod + def _is_pickled_sql_type(column_type: TypeEngine) -> bool: + return isinstance(column_type, sa.PickleType | sa.LargeBinary) + + @staticmethod + def _is_json_sql_type(column_type: TypeEngine) -> bool: + return isinstance(column_type, sa.JSON) + + @property + def __node_attr_schemas(self) -> dict[str, AttrSchema]: + return self._attr_schemas_from_metadata( + table_class=self.Node, + metadata_key=self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, + default_schemas=self._default_node_attr_schemas(), + preferred_order=[DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID], + ) + + @__node_attr_schemas.setter + def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_node_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY] = encoded_schemas + + @property + def __edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self._attr_schemas_from_metadata( + table_class=self.Edge, + metadata_key=self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, + default_schemas=self._default_edge_attr_schemas(), + preferred_order=[ + DEFAULT_ATTR_KEYS.EDGE_ID, + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + ], + ) + + @__edge_attr_schemas.setter + def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_edge_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas + def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() + schemas = self._attr_schemas_for_table(table_class) - # Return schema overrides for special types that need explicit casting + # Return schema overrides for columns safely represented in SQL. + # Pickled columns are unpickled and casted in a second pass. return { key: schema.dtype for key, schema in schemas.items() - if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + if ( + key in table_class.__table__.columns + and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) + and not isinstance(schema.dtype, pl.Struct) + ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() - - # Cast array columns (stored as blobs in database) - df = df.with_columns( - pl.Series(key, df[key].to_list(), dtype=schema.dtype) - for key, schema in schemas.items() - if isinstance(schema.dtype, pl.Array) and key in df.columns - ) + schemas = self._attr_schemas_for_table(table_class) + + decode_exprs: list[pl.Expr] = [] + casts: list[pl.Series] = [] + for key, schema in schemas.items(): + if key not in df.columns or key not in table_class.__table__.columns: + continue + + column_type = table_class.__table__.columns[key].type + source_dtype = df.schema[key] + + if isinstance(schema.dtype, pl.Struct) and self._is_json_sql_type(column_type): + if source_dtype == pl.String: + decode_exprs.append(_struct_json_decode_expr(key, schema.dtype)) + elif source_dtype != schema.dtype: + casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + continue + + if not self._is_pickled_sql_type(column_type): + continue + + try: + casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + except Exception: + # Keep original dtype when values cannot be casted to the target schema. + continue + + if decode_exprs: + df = df.with_columns(decode_exprs) + if casts: + df = df.with_columns(casts) return df def _update_max_id_per_time(self) -> None: @@ -1303,6 +1505,8 @@ def node_attrs( # indices are included by default and must be removed if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) + else: + nodes_df = nodes_df.select([pl.col(c) for c in self._node_attr_schemas() if c in nodes_df.columns]) if unpack: nodes_df = unpack_array_attrs(nodes_df) @@ -1345,6 +1549,8 @@ def edge_attrs( if unpack: edges_df = unpack_array_attrs(edges_df) + elif attr_keys is None: + edges_df = edges_df.select([pl.col(c) for c in self._edge_attr_schemas() if c in edges_df.columns]) return edges_df @@ -1589,6 +1795,9 @@ def _add_new_column( sa_column = sa.Column(schema.key, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(sa_column.name) # Properly quote default values based on type if isinstance(default_value, str): @@ -1599,8 +1808,8 @@ def _add_new_column( quoted_default = str(default_value) add_column_stmt = sa.DDL( - f"ALTER TABLE {table_class.__table__} ADD " - f"COLUMN {sa_column.name} {str_dialect_type} " + f"ALTER TABLE {quoted_table_name} ADD " + f"COLUMN {quoted_column_name} {str_dialect_type} " f"DEFAULT {quoted_default}", ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) @@ -1615,7 +1824,10 @@ def _add_new_column( table_class.__table__.append_column(sa_column) def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: - drop_column_stmt = sa.DDL(f"ALTER TABLE {table_class.__table__} DROP COLUMN {key}") + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(key) + drop_column_stmt = sa.DDL(f"ALTER TABLE {quoted_table_name} DROP COLUMN {quoted_column_name}") LOG.info("drop %s column statement:\n'%s'", table_class.__table__, drop_column_stmt) with Session(self._engine) as session: @@ -1631,14 +1843,14 @@ def add_node_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + node_schemas = self.__node_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__node_attr_schemas) - - # Store schema - self.__node_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, node_schemas) # Add column to database self._add_new_column(self.Node, schema) + node_schemas[schema.key] = schema + self.__node_attr_schemas = node_schemas def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1647,8 +1859,10 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") + node_schemas = self.__node_attr_schemas self._drop_column(self.Node, key) - self.__node_attr_schemas.pop(key, None) + node_schemas.pop(key, None) + self.__node_attr_schemas = node_schemas def add_edge_attr_key( self, @@ -1656,21 +1870,23 @@ def add_edge_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + edge_schemas = self.__edge_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__edge_attr_schemas) - - # Store schema - self.__edge_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, edge_schemas) # Add column to database self._add_new_column(self.Edge, schema) + edge_schemas[schema.key] = schema + self.__edge_attr_schemas = edge_schemas def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") + edge_schemas = self.__edge_attr_schemas self._drop_column(self.Edge, key) - self.__edge_attr_schemas.pop(key, None) + edge_schemas.pop(key, None) + self.__edge_attr_schemas = edge_schemas def num_edges(self) -> int: with Session(self._engine) as session: @@ -2081,6 +2297,12 @@ def _metadata(self) -> dict[str, Any]: result = session.query(self.Metadata).all() return {row.key: row.value for row in result} + def _private_metadata_for_copy(self) -> dict[str, Any]: + private_metadata = super()._private_metadata_for_copy() + private_metadata.pop(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, None) + private_metadata.pop(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, None) + return private_metadata + def _update_metadata(self, **kwargs) -> None: with Session(self._engine) as session: for key, value in kwargs.items(): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 1f943ed7..9d477821 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1,3 +1,4 @@ +import datetime as dt from pathlib import Path from typing import Any @@ -224,6 +225,20 @@ def test_filter_nodes_by_membership(graph_backend: BaseGraph) -> None: assert set(np_members) == {node_b} +def test_filter_nodes_by_struct_field(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "name": pl.String})) + + node_a = graph_backend.add_node({"t": 0, "measurements": {"score": 1, "name": "A"}}) + node_b = graph_backend.add_node({"t": 1, "measurements": {"score": 2, "name": "B"}}) + node_c = graph_backend.add_node({"t": 2, "measurements": {"score": 1, "name": "C"}}) + + score_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).node_ids() + assert set(score_nodes) == {node_a, node_c} + + name_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("name") == "B").node_ids() + assert set(name_nodes) == {node_b} + + def test_time_points(graph_backend: BaseGraph) -> None: """Test retrieving time points.""" graph_backend.add_node({"t": 0}) @@ -1437,6 +1452,108 @@ def test_from_other_with_edges( assert new_overlaps == source_overlaps +@pytest.mark.parametrize( + ("target_cls", "target_kwargs"), + [ + pytest.param(RustWorkXGraph, {}, id="rustworkx"), + pytest.param( + SQLGraph, + { + "drivername": "sqlite", + "database": ":memory:", + "engine_kwargs": {"connect_args": {"check_same_thread": False}}, + }, + id="sql", + ), + pytest.param(IndexedRXGraph, {}, id="indexed"), + ], +) +def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], target_kwargs: dict[str, Any]) -> None: + """Test that from_other preserves node and edge attribute schemas across backends.""" + graph = RustWorkXGraph() + for dtype in [ + pl.Float16, + pl.Float32, + pl.Float64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Date, + pl.Datetime, + pl.Boolean, + pl.Array(pl.Float32, 3), + pl.List(pl.Int32), + pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), + pl.String, + pl.Object, + ]: + graph.add_node_attr_key(f"attr_{dtype}", dtype=dtype) + graph.add_node( + { + "t": 0, + "attr_Float16": np.float16(1.5), + "attr_Float32": np.float32(2.5), + "attr_Float64": np.float64(3.5), + "attr_Int8": np.int8(4), + "attr_Int16": np.int16(5), + "attr_Int32": np.int32(6), + "attr_Int64": np.int64(7), + "attr_UInt8": np.uint8(8), + "attr_UInt16": np.uint16(9), + "attr_UInt32": np.uint32(10), + "attr_UInt64": np.uint64(11), + "attr_Date": pl.date(2024, 1, 1), + "attr_Datetime": dt.datetime(2024, 1, 1, 12, 0, 0), + "attr_Boolean": True, + "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "attr_List(Int32)": [1, 2, 3], + "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": { + "a": 1, + "b": np.array(["x", "y"], dtype=object), + }, + "attr_String": "test", + "attr_Object": {"key": "value"}, + } + ) + graph2 = target_cls.from_other(graph, **target_kwargs) + + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + assert graph2.edge_attrs().schema == graph.edge_attrs().schema + + graph3 = RustWorkXGraph.from_other(graph2) + assert graph3._node_attr_schemas() == graph._node_attr_schemas() + assert graph3._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph3.node_attrs().schema == graph.node_attrs().schema + assert graph3.edge_attrs().schema == graph.edge_attrs().schema + + +@pytest.mark.xfail(reason="This is because of the lack of support of shape-less pl.Array in write_ipc of polars.") +def test_from_other_with_array_no_shape(): + """Test that from_other raises an error when trying to copy array attributes without shape information.""" + graph = RustWorkXGraph() + graph.add_node_attr_key("array_attr", pl.Array) + graph.add_node({"t": 0, "array_attr": np.array([1.0, 2.0, 3.0], dtype=np.float32)}) + + # This should raise an error because the schema does not include shape information + graph2 = SQLGraph.from_other( + graph, drivername="sqlite", database=":memory:", engine_kwargs={"connect_args": {"check_same_thread": False}} + ) + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + + @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ @@ -1603,6 +1720,24 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: + db_path = tmp_path / "struct_graph.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "label": pl.String})) + + node_id = graph.add_node({"t": 0, "measurements": {"score": 7, "label": "A"}}) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + df = reloaded.node_attrs(attr_keys=["measurements"]) + assert df.schema["measurements"] == pl.Struct({"score": pl.Int64, "label": pl.String}) + assert df["measurements"].to_list() == [{"score": 7, "label": "A"}] + + ids = reloaded.filter(NodeAttr("measurements").struct.field("score") == 7).node_ids() + assert ids == [node_id] + + def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: """Reloading a SQLGraph should respect existing max IDs per time point.""" db_path = tmp_path / "id_restore.db" @@ -1619,6 +1754,72 @@ def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: assert next_id == first_id + 1 +def test_sql_graph_schema_defaults_survive_reload(tmp_path: Path) -> None: + """Reloading a SQLGraph should preserve dtype and default schema metadata.""" + db_path = tmp_path / "schema_defaults.db" + graph = SQLGraph("sqlite", str(db_path)) + + node_array_default = np.array([1.0, 2.0, 3.0], dtype=np.float32) + node_object_default = {"nested": [1, 2, 3]} + edge_score_default = 0.25 + + graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), node_array_default) + graph.add_node_attr_key("node_object_default", pl.Object, node_object_default) + graph.add_edge_attr_key("edge_score_default", pl.Float32, edge_score_default) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + node_schemas = reloaded._node_attr_schemas() + edge_schemas = reloaded._edge_attr_schemas() + np.testing.assert_array_equal(node_schemas["node_array_default"].default_value, node_array_default) + assert node_schemas["node_array_default"].dtype == pl.Array(pl.Float32, 3) + assert node_schemas["node_object_default"].default_value == node_object_default + assert node_schemas["node_object_default"].dtype == pl.Object + assert edge_schemas["edge_score_default"].default_value == edge_score_default + assert edge_schemas["edge_score_default"].dtype == pl.Float32 + + +def test_sql_schema_metadata_not_copied_to_in_memory_graphs() -> None: + """SQL-private schema metadata should not leak into in-memory backends via from_other.""" + sql_graph = SQLGraph("sqlite", ":memory:") + sql_graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), np.array([1.0, 2.0, 3.0], np.float32)) + sql_graph.add_node_attr_key("node_object_default", pl.Object, {"payload": [1, 2, 3]}) + sql_graph.add_edge_attr_key("edge_score_default", pl.Float32, 0.25) + + n1 = sql_graph.add_node( + { + "t": 0, + "node_array_default": np.array([1.0, 1.0, 1.0], dtype=np.float32), + "node_object_default": {"payload": [10]}, + } + ) + n2 = sql_graph.add_node( + { + "t": 1, + "node_array_default": np.array([2.0, 2.0, 2.0], dtype=np.float32), + "node_object_default": {"payload": [20]}, + } + ) + sql_graph.add_edge(n1, n2, {"edge_score_default": 0.75}) + + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY in sql_graph._private_metadata + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY in sql_graph._private_metadata + + rx_graph = RustWorkXGraph.from_other(sql_graph) + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY not in rx_graph._metadata() + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY not in rx_graph._metadata() + + sql_graph_roundtrip = SQLGraph.from_other( + rx_graph, + drivername="sqlite", + database=":memory:", + engine_kwargs={"connect_args": {"check_same_thread": False}}, + ) + assert sql_graph_roundtrip._node_attr_schemas() == sql_graph._node_attr_schemas() + assert sql_graph_roundtrip._edge_attr_schemas() == sql_graph._edge_attr_schemas() + + def test_compute_overlaps_invalid_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with invalid threshold values.""" with pytest.raises(ValueError, match=r"iou_threshold must be between 0.0 and 1\.0"): diff --git a/src/tracksdata/solvers/_ilp_solver.py b/src/tracksdata/solvers/_ilp_solver.py index 6485eaf7..3f6676d5 100644 --- a/src/tracksdata/solvers/_ilp_solver.py +++ b/src/tracksdata/solvers/_ilp_solver.py @@ -175,6 +175,9 @@ def _evaluate_expr( expr: Attr, df: pl.DataFrame, ) -> list[float]: + if df.is_empty(): + return [] + if len(expr.expr_columns) == 0: return [expr.evaluate(df).item()] * len(df) else: @@ -388,7 +391,11 @@ def solve( node_attr_keys.extend(self.merge_weight_expr.columns) nodes_df = graph.node_attrs(attr_keys=node_attr_keys) - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + # When no edges exist, avoid requesting edge weight columns that may not + # be registered in the backend schema yet. _solve() handles this as a + # regular "no edges" ValueError. + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) self._add_objective_and_variables(nodes_df, edges_df) self._add_continuous_flow_constraints(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_list(), edges_df) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 34011dee..21915290 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -235,7 +235,8 @@ def solve( The graph view of the solution if `return_solution` is True, otherwise None. """ # get edges and sort them by weight - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) if len(edges_df) == 0: raise ValueError("No edges found in the graph, there is nothing to solve.") diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 8e671487..dc3d6dd0 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import io from dataclasses import dataclass from typing import Any @@ -202,6 +204,37 @@ def copy(self) -> AttrSchema: """ return AttrSchema(key=self.key, dtype=self.dtype, default_value=self.default_value) + def __eq__(self, other: object) -> bool: + if not isinstance(other, AttrSchema): + return NotImplemented + return ( + self.key == other.key + and self.dtype == other.dtype + and _values_equal(self.default_value, other.default_value) + ) + + +def _values_equal(left: Any, right: Any) -> bool: + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return bool(np.array_equal(left, right)) + if isinstance(left, dict) and isinstance(right, dict): + if left.keys() != right.keys(): + return False + return all(_values_equal(left[k], right[k]) for k in left) + if isinstance(left, list | tuple) and isinstance(right, list | tuple): + if len(left) != len(right): + return False + return all(_values_equal(lv, rv) for lv, rv in zip(left, right, strict=True)) + + try: + value = left == right + except Exception: + return False + + if isinstance(value, np.ndarray): + return bool(np.all(value)) + return bool(value) + def process_attr_key_args( key_or_schema: str | AttrSchema, @@ -383,6 +416,10 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: >>> polars_dtype_to_sqlalchemy_type(pl.Boolean) """ + # Handle struct types as JSON for backend-level field filtering. + if isinstance(dtype, pl.Struct): + return sa.JSON() + # Handle sequence types - use PickleType for storage if isinstance(dtype, pl.Array | pl.List): return sa.PickleType() @@ -407,6 +444,7 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: (sa.Float, pl.Float64), (sa.Text, pl.String), # Must come before String (sa.String, pl.String), + (sa.JSON, pl.Object), (sa.PickleType, pl.Object), # Must come before LargeBinary (sa.LargeBinary, pl.Object), ] @@ -445,6 +483,99 @@ def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: return pl.Object +def _normalize_default_for_dtype(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array | pl.List) and isinstance(default_value, np.ndarray): + return default_value.tolist() + return default_value + + +def _normalize_deserialized_default(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array): + if isinstance(default_value, pl.Series): + default_value = default_value.to_list() + numpy_dtype = polars_dtype_to_numpy_dtype(dtype.inner, allow_sequence=True) + return np.asarray(default_value, dtype=numpy_dtype).reshape(dtype.shape) + + if isinstance(dtype, pl.List): + if isinstance(default_value, pl.Series): + return default_value.to_list() + if isinstance(default_value, np.ndarray): + return default_value.tolist() + + return default_value + + +_ATTR_SCHEMA_VALUE_COL = "__attr_schema_value__" +_ATTR_SCHEMA_FALLBACK_COL = "__attr_schema_fallback__" + + +def serialize_attr_schema(schema: AttrSchema) -> str: + """ + Serialize an AttrSchema into a base64-encoded Arrow IPC payload. + + The primary format stores schema.default_value in the first row of a + single dummy column whose dtype is schema.dtype. This keeps dtype and + default value in one Arrow IPC payload. + """ + normalized_default = _normalize_default_for_dtype(schema.default_value, schema.dtype) + df = pl.DataFrame( + { + _ATTR_SCHEMA_VALUE_COL: pl.Series( + _ATTR_SCHEMA_VALUE_COL, + values=[normalized_default], + dtype=schema.dtype, + ), + } + ) + + buffer = io.BytesIO() + try: + df.write_ipc(buffer) + except Exception: + # Some dtypes (e.g. pl.Object) cannot roundtrip through Arrow IPC schema. + # Store pickled (dtype, default) in the first row of a binary dummy column. + fallback_payload = dumps((schema.dtype, schema.default_value)) + fallback_df = pl.DataFrame( + { + _ATTR_SCHEMA_FALLBACK_COL: pl.Series( + _ATTR_SCHEMA_FALLBACK_COL, + values=[fallback_payload], + dtype=pl.Binary, + ), + } + ) + buffer = io.BytesIO() + fallback_df.write_ipc(buffer) + + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def deserialize_attr_schema(encoded_schema: str, *, key: str) -> AttrSchema: + """ + Deserialize an AttrSchema previously encoded by `serialize_attr_schema`. + """ + data = base64.b64decode(encoded_schema) + buffer = io.BytesIO(data) + restored_df = pl.read_ipc(buffer) + + if _ATTR_SCHEMA_VALUE_COL in restored_df.columns: + dtype = restored_df.schema[_ATTR_SCHEMA_VALUE_COL] + default_value = restored_df[_ATTR_SCHEMA_VALUE_COL][0] + elif _ATTR_SCHEMA_FALLBACK_COL in restored_df.columns: + fallback_payload = restored_df[_ATTR_SCHEMA_FALLBACK_COL][0] + if fallback_payload is None: + raise ValueError("Fallback schema payload is missing.") + dtype, default_value = loads(fallback_payload) + else: + raise ValueError("Unrecognized attr schema payload format.") + + if not pl.datatypes.is_polars_dtype(dtype): + raise TypeError(f"Decoded value is not a polars dtype: {type(dtype)}") + + default_value = _normalize_deserialized_default(default_value, dtype) + return AttrSchema(key=key, dtype=dtype, default_value=default_value) + + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py new file mode 100644 index 00000000..1f406224 --- /dev/null +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -0,0 +1,83 @@ +import base64 +import binascii +import io + +import numpy as np +import polars as pl +import pytest + +from tracksdata.utils._dtypes import ( + AttrSchema, + deserialize_attr_schema, + serialize_attr_schema, +) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Int64, + pl.Float32, + pl.Boolean, + pl.String, + pl.List(pl.Int16), + pl.Array(pl.Float64, 4), + pl.Array(pl.Int32, (2, 3)), + pl.Struct({"x": pl.Int64, "y": pl.List(pl.String)}), + pl.Datetime("us", "UTC"), + ], +) +def test_serialize_deserialize_attr_schema_dtype_roundtrip(dtype: pl.DataType) -> None: + schema = AttrSchema(key="dummy", dtype=dtype) + encoded = serialize_attr_schema(schema) + + assert isinstance(encoded, str) + assert encoded + assert base64.b64decode(encoded) + + restored = deserialize_attr_schema(encoded, key=schema.key) + + assert restored == schema + + +def test_deserialize_attr_schema_invalid_base64_raises() -> None: + with pytest.raises(binascii.Error): + deserialize_attr_schema("not-base64", key="dummy") + + +def test_deserialize_attr_schema_non_ipc_payload_raises() -> None: + encoded = base64.b64encode(b"not-arrow-ipc").decode("utf-8") + + with pytest.raises((OSError, pl.exceptions.PolarsError)): + deserialize_attr_schema(encoded, key="dummy") + + +@pytest.mark.parametrize( + "schema", + [ + AttrSchema(key="score", dtype=pl.Float64, default_value=1.25), + AttrSchema( + key="vector", + dtype=pl.Array(pl.Float32, 3), + default_value=np.array([1.0, 2.0, 3.0], dtype=np.float32), + ), + AttrSchema(key="payload", dtype=pl.Object, default_value={"nested": [1, 2, 3]}), + ], +) +def test_serialize_deserialize_attr_schema_roundtrip(schema: AttrSchema) -> None: + encoded = serialize_attr_schema(schema) + restored = deserialize_attr_schema(encoded, key=schema.key) + assert restored == schema + + +def test_serialize_attr_schema_stores_default_in_dummy_row() -> None: + schema = AttrSchema(key="score", dtype=pl.Float64, default_value=1.25) + encoded = serialize_attr_schema(schema) + + payload = base64.b64decode(encoded) + df = pl.read_ipc(io.BytesIO(payload)) + + assert "__attr_schema_value__" in df.columns + assert df.schema["__attr_schema_value__"] == pl.Float64 + assert df["__attr_schema_value__"][0] == 1.25 + assert "__attr_schema_dtype_pickle__" not in df.columns