diff --git a/examples/snowpark-example.ipynb b/examples/snowpark-example.ipynb index d87f51aa..45aebed6 100644 --- a/examples/snowpark-example.ipynb +++ b/examples/snowpark-example.ipynb @@ -171,7 +171,7 @@ "## Fetching the data\n", "\n", "Next we fetch our tables from Snowflake and convert them to pandas DataFrames.\n", - "Additionally, we rename the most of the table columns so that they are named according to the `neo4j-viz` API." + "Additionally, we rename some of the table columns so that they are named according to the `neo4j-viz` API." ] }, { @@ -181,16 +181,8 @@ "metadata": {}, "outputs": [], "source": [ - "products_df = (\n", - " session.table(\"products\")\n", - " .to_pandas()\n", - " .rename(columns={\"ID\": \"id\", \"NAME\": \"caption\"})\n", - ")\n", - "parents_df = (\n", - " session.table(\"parents\")\n", - " .to_pandas()\n", - " .rename(columns={\"SOURCE\": \"source\", \"TARGET\": \"target\", \"TYPE\": \"caption\"})\n", - ")" + "products_df = session.table(\"products\").to_pandas().rename(columns={\"NAME\": \"caption\"})\n", + "parents_df = session.table(\"parents\").to_pandas().rename(columns={\"TYPE\": \"caption\"})" ] }, { diff --git a/python-wrapper/src/neo4j_viz/case_insensitive_model.py b/python-wrapper/src/neo4j_viz/case_insensitive_model.py new file mode 100644 index 00000000..56ef122c --- /dev/null +++ b/python-wrapper/src/neo4j_viz/case_insensitive_model.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, model_validator +from pydantic.alias_generators import to_snake + + +class CaseInsensitiveModel(BaseModel): + @model_validator(mode="before") + def _snake_property_keys(cls, values: Any) -> Any: + def _snake(value: Any) -> Any: + if isinstance(value, dict): + return {to_snake(k): _snake(v) for k, v in value.items()} + return value + + return _snake(values) diff --git a/python-wrapper/src/neo4j_viz/node.py b/python-wrapper/src/neo4j_viz/node.py index 8718a497..629345ab 100644 --- a/python-wrapper/src/neo4j_viz/node.py +++ b/python-wrapper/src/neo4j_viz/node.py @@ -2,24 +2,26 @@ from typing import Any, Optional, Union -from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator +from pydantic import AliasChoices, Field, field_serializer, field_validator from pydantic_extra_types.color import Color, ColorType +from .case_insensitive_model import CaseInsensitiveModel from .node_size import RealNumber from .options import CaptionAlignment NodeIdType = Union[str, int] -class Node(BaseModel, extra="allow"): +class Node(CaseInsensitiveModel, extra="allow"): """ A node in a graph to visualize. All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_nodes) + All field names are case-insensitive. """ #: Unique identifier for the node id: NodeIdType = Field( - validation_alias=AliasChoices("id", "nodeId", "node_id"), description="Unique identifier for the node" + validation_alias=AliasChoices("id", "nodeid", "node_id"), description="Unique identifier for the node" ) #: The caption of the node caption: Optional[str] = Field(None, description="The caption of the node") diff --git a/python-wrapper/src/neo4j_viz/relationship.py b/python-wrapper/src/neo4j_viz/relationship.py index b5f4c640..faa4185b 100644 --- a/python-wrapper/src/neo4j_viz/relationship.py +++ b/python-wrapper/src/neo4j_viz/relationship.py @@ -3,16 +3,18 @@ from typing import Any, Optional, Union from uuid import uuid4 -from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator +from pydantic import AliasChoices, Field, field_serializer, field_validator from pydantic_extra_types.color import Color, ColorType +from .case_insensitive_model import CaseInsensitiveModel from .options import CaptionAlignment -class Relationship(BaseModel, extra="allow"): +class Relationship(CaseInsensitiveModel, extra="allow"): """ A relationship in a graph to visualize. All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_relationships) + All field names are case-insensitive. """ #: Unique identifier for the relationship @@ -22,13 +24,13 @@ class Relationship(BaseModel, extra="allow"): #: Node ID where the relationship points from source: Union[str, int] = Field( serialization_alias="from", - validation_alias=AliasChoices("source", "sourceNodeId", "source_node_id", "from"), + validation_alias=AliasChoices("source", "sourcenodeid", "source_node_id", "from"), description="Node ID where the relationship points from", ) #: Node ID where the relationship points to target: Union[str, int] = Field( serialization_alias="to", - validation_alias=AliasChoices("target", "targetNodeId", "target_node_id", "to"), + validation_alias=AliasChoices("target", "targetnodeid", "target_node_id", "to"), description="Node ID where the relationship points to", ) #: The caption of the relationship diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index c6610a4a..3aff5607 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -117,7 +117,7 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None: ] assert len(VG.relationships) == 3 - vg_rels = sorted([(e.source, e.target, e.relationshipType) for e in VG.relationships], key=lambda x: x[0]) # type: ignore[attr-defined] + vg_rels = sorted([(e.source, e.target, e.relationship_type) for e in VG.relationships], key=lambda x: x[0]) # type: ignore[attr-defined] assert vg_rels == [ (0, 1, "REL"), (1, 2, "REL2"), diff --git a/python-wrapper/tests/test_node.py b/python-wrapper/tests/test_node.py index faa32e5b..59b1eaad 100644 --- a/python-wrapper/tests/test_node.py +++ b/python-wrapper/tests/test_node.py @@ -59,11 +59,11 @@ def test_node_with_additional_fields() -> None: assert node.to_dict() == { "id": "1", - "componentId": 2, + "component_id": 2, } -@pytest.mark.parametrize("alias", ["id", "nodeId", "node_id"]) +@pytest.mark.parametrize("alias", ["id", "nodeId", "node_id", "ID"]) def test_id_aliases(alias: str) -> None: node = Node(**{alias: 1}) diff --git a/python-wrapper/tests/test_relationship.py b/python-wrapper/tests/test_relationship.py index 171c7af8..039c1b0c 100644 --- a/python-wrapper/tests/test_relationship.py +++ b/python-wrapper/tests/test_relationship.py @@ -47,12 +47,17 @@ def test_rels_additional_fields() -> None: ) rel_dict = rel.to_dict() - assert {"id", "from", "to", "componentId"} == set(rel_dict.keys()) - assert rel.componentId == 2 # type: ignore[attr-defined] + assert {"id", "from", "to", "component_id"} == set(rel_dict.keys()) + assert rel.component_id == 2 # type: ignore[attr-defined] -@pytest.mark.parametrize("src_alias", ["source", "sourceNodeId", "source_node_id", "from"]) -@pytest.mark.parametrize("trg_alias", ["target", "targetNodeId", "target_node_id", "to"]) +@pytest.mark.parametrize( + "src_alias", + ["source", "sourceNodeId", "source_node_id", "from", "SOURCE", "SOURCE_NODE_ID", "SOURCENODEID", "FROM"], +) +@pytest.mark.parametrize( + "trg_alias", ["target", "targetNodeId", "target_node_id", "to", "TARGET", "TARGET_NODE_ID", "TARGETNODEID", "TO"] +) def test_aliases(src_alias: str, trg_alias: str) -> None: rel = Relationship( **{