Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions examples/snowpark-example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
"## Fetching the data\n",
"\n",
"Next we fetch our tables from Snowflake and convert them to pandas DataFrames.\n",
"Additionally, we rename the most of the table columns so that they are named according to the `neo4j-viz` API."
"Additionally, we rename some of the table columns so that they are named according to the `neo4j-viz` API."
]
},
{
Expand All @@ -181,16 +181,8 @@
"metadata": {},
"outputs": [],
"source": [
"products_df = (\n",
" session.table(\"products\")\n",
" .to_pandas()\n",
" .rename(columns={\"ID\": \"id\", \"NAME\": \"caption\"})\n",
")\n",
"parents_df = (\n",
" session.table(\"parents\")\n",
" .to_pandas()\n",
" .rename(columns={\"SOURCE\": \"source\", \"TARGET\": \"target\", \"TYPE\": \"caption\"})\n",
")"
"products_df = session.table(\"products\").to_pandas().rename(columns={\"NAME\": \"caption\"})\n",
"parents_df = session.table(\"parents\").to_pandas().rename(columns={\"TYPE\": \"caption\"})"
]
},
{
Expand Down
17 changes: 17 additions & 0 deletions python-wrapper/src/neo4j_viz/case_insensitive_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from typing import Any

from pydantic import BaseModel, model_validator
from pydantic.alias_generators import to_snake


class CaseInsensitiveModel(BaseModel):
@model_validator(mode="before")
def _snake_property_keys(cls, values: Any) -> Any:
def _snake(value: Any) -> Any:
if isinstance(value, dict):
return {to_snake(k): _snake(v) for k, v in value.items()}
return value

return _snake(values)
8 changes: 5 additions & 3 deletions python-wrapper/src/neo4j_viz/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@

from typing import Any, Optional, Union

from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator
from pydantic import AliasChoices, Field, field_serializer, field_validator
from pydantic_extra_types.color import Color, ColorType

from .case_insensitive_model import CaseInsensitiveModel
from .node_size import RealNumber
from .options import CaptionAlignment

NodeIdType = Union[str, int]


class Node(BaseModel, extra="allow"):
class Node(CaseInsensitiveModel, extra="allow"):
"""
A node in a graph to visualize.
All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_nodes)
All field names are case-insensitive.
"""

#: Unique identifier for the node
id: NodeIdType = Field(
validation_alias=AliasChoices("id", "nodeId", "node_id"), description="Unique identifier for the node"
validation_alias=AliasChoices("id", "nodeid", "node_id"), description="Unique identifier for the node"
)
#: The caption of the node
caption: Optional[str] = Field(None, description="The caption of the node")
Expand Down
10 changes: 6 additions & 4 deletions python-wrapper/src/neo4j_viz/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
from typing import Any, Optional, Union
from uuid import uuid4

from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator
from pydantic import AliasChoices, Field, field_serializer, field_validator
from pydantic_extra_types.color import Color, ColorType

from .case_insensitive_model import CaseInsensitiveModel
from .options import CaptionAlignment


class Relationship(BaseModel, extra="allow"):
class Relationship(CaseInsensitiveModel, extra="allow"):
"""
A relationship in a graph to visualize.
All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_relationships)
All field names are case-insensitive.
"""

#: Unique identifier for the relationship
Expand All @@ -22,13 +24,13 @@ class Relationship(BaseModel, extra="allow"):
#: Node ID where the relationship points from
source: Union[str, int] = Field(
serialization_alias="from",
validation_alias=AliasChoices("source", "sourceNodeId", "source_node_id", "from"),
validation_alias=AliasChoices("source", "sourcenodeid", "source_node_id", "from"),
description="Node ID where the relationship points from",
)
#: Node ID where the relationship points to
target: Union[str, int] = Field(
serialization_alias="to",
validation_alias=AliasChoices("target", "targetNodeId", "target_node_id", "to"),
validation_alias=AliasChoices("target", "targetnodeid", "target_node_id", "to"),
description="Node ID where the relationship points to",
)
#: The caption of the relationship
Expand Down
2 changes: 1 addition & 1 deletion python-wrapper/tests/test_gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None:
]

assert len(VG.relationships) == 3
vg_rels = sorted([(e.source, e.target, e.relationshipType) for e in VG.relationships], key=lambda x: x[0]) # type: ignore[attr-defined]
vg_rels = sorted([(e.source, e.target, e.relationship_type) for e in VG.relationships], key=lambda x: x[0]) # type: ignore[attr-defined]
assert vg_rels == [
(0, 1, "REL"),
(1, 2, "REL2"),
Expand Down
4 changes: 2 additions & 2 deletions python-wrapper/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def test_node_with_additional_fields() -> None:

assert node.to_dict() == {
"id": "1",
"componentId": 2,
"component_id": 2,
}


@pytest.mark.parametrize("alias", ["id", "nodeId", "node_id"])
@pytest.mark.parametrize("alias", ["id", "nodeId", "node_id", "ID"])
def test_id_aliases(alias: str) -> None:
node = Node(**{alias: 1})

Expand Down
13 changes: 9 additions & 4 deletions python-wrapper/tests/test_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def test_rels_additional_fields() -> None:
)

rel_dict = rel.to_dict()
assert {"id", "from", "to", "componentId"} == set(rel_dict.keys())
assert rel.componentId == 2 # type: ignore[attr-defined]
assert {"id", "from", "to", "component_id"} == set(rel_dict.keys())
assert rel.component_id == 2 # type: ignore[attr-defined]


@pytest.mark.parametrize("src_alias", ["source", "sourceNodeId", "source_node_id", "from"])
@pytest.mark.parametrize("trg_alias", ["target", "targetNodeId", "target_node_id", "to"])
@pytest.mark.parametrize(
"src_alias",
["source", "sourceNodeId", "source_node_id", "from", "SOURCE", "SOURCE_NODE_ID", "SOURCENODEID", "FROM"],
)
@pytest.mark.parametrize(
"trg_alias", ["target", "targetNodeId", "target_node_id", "to", "TARGET", "TARGET_NODE_ID", "TARGETNODEID", "TO"]
)
def test_aliases(src_alias: str, trg_alias: str) -> None:
rel = Relationship(
**{
Expand Down