Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

## New features

* Allow passing a `neo4j.Driver` instance as input to `from_neo4j`, in which case the driver will be used internally to fetch the graph data using a simple query


## Bug fixes

Expand Down
9 changes: 6 additions & 3 deletions docs/source/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,16 @@ Once you have installed the additional dependency, you can use the :doc:`from_ne
to import query results from Neo4j.

The ``from_neo4j`` method takes one mandatory positional parameter:

* A ``result`` representing the query result either in form of `neo4j.graph.Graph` or `neo4j.Result`.
A ``data`` argument representing either a query result in the shape of a ``neo4j.graph.Graph`` or ``neo4j.Result``, or a
``neo4j.Driver`` in which case a simple default query will be executed internally to retrieve the graph data.

We can also provide an optional ``size_property`` parameter, which should refer to a node property,
and will be used to determine the sizes of the nodes in the visualization.

The ``node_caption`` and ``relationship_caption`` parameters are also optional, and indicate the node and relationship properties to use for the captions of each element in the visualization.
The ``node_caption`` and ``relationship_caption`` parameters are also optional, and indicate the node and relationship
properties to use for the captions of each element in the visualization.
By default, the captions will be set to the node labels relationship types, but you can specify any property that
exists on these entities.

The last optional property, ``node_radius_min_max``, can be used (and is used by default) to scale the node sizes for
the visualization.
Expand Down
4 changes: 4 additions & 0 deletions python-wrapper/src/neo4j_viz/gds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from itertools import chain
from typing import Optional
from uuid import uuid4
Expand Down Expand Up @@ -99,6 +100,9 @@ def from_gds(

node_count = G.node_count()
if node_count > max_node_count:
warnings.warn(
f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
)
sampling_ratio = float(max_node_count) / node_count
sample_name = f"neo4j-viz_sample_{uuid4()}"
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
Expand Down
41 changes: 31 additions & 10 deletions python-wrapper/src/neo4j_viz/neo4j.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import warnings
from typing import Optional, Union

import neo4j.graph
from neo4j import Result
from neo4j import Driver, Result, RoutingControl
from pydantic import BaseModel, ValidationError

from neo4j_viz.node import Node
Expand All @@ -20,14 +21,15 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->


def from_neo4j(
result: Union[neo4j.graph.Graph, Result],
data: Union[neo4j.graph.Graph, Result, Driver],
size_property: Optional[str] = None,
node_caption: Optional[str] = "labels",
relationship_caption: Optional[str] = "type",
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
row_limit: int = 10_000,
) -> VisualizationGraph:
"""
Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object.
Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`.

All node and relationship properties will be included in the visualization graph.
If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
Expand All @@ -36,8 +38,9 @@ def from_neo4j(

Parameters
----------
result : Union[neo4j.graph.Graph, Result]
Query result either in shape of a Graph or result.
data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver]
Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in
which case a simple default query will be executed internally to retrieve the graph data.
size_property : str, optional
Property to use for node size, by default None.
node_caption : str, optional
Expand All @@ -47,14 +50,32 @@ def from_neo4j(
node_radius_min_max : tuple[float, float], optional
Minimum and maximum node radius, by default (3, 60).
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
row_limit : int, optional
Maximum number of rows to return from the query, by default 10_000.
This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored.
"""

if isinstance(result, Result):
graph = result.graph()
elif isinstance(result, neo4j.graph.Graph):
graph = result
if isinstance(data, Result):
graph = data.graph()
elif isinstance(data, neo4j.graph.Graph):
graph = data
elif isinstance(data, Driver):
rel_count = data.execute_query(
"MATCH ()-[r]->() RETURN count(r) as count",
routing_=RoutingControl.READ,
result_transformer_=Result.single,
).get("count") # type: ignore[union-attr]
if rel_count > row_limit:
warnings.warn(
f"Database relationship count ({rel_count}) exceeds `row_limit` ({row_limit}), so limiting will be applied. Increase the `row_limit` if needed"
)
graph = data.execute_query(
f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}",
routing_=RoutingControl.READ,
result_transformer_=Result.graph,
)
else:
raise ValueError(f"Invalid input type `{type(result)}`. Expected `neo4j.Graph` or `neo4j.Result`")
raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`")

all_node_field_aliases = Node.all_validation_aliases()
all_rel_field_aliases = Relationship.all_validation_aliases()
Expand Down
9 changes: 8 additions & 1 deletion python-wrapper/tests/test_gds.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -267,7 +268,13 @@ def test_from_gds_sample(gds: Any) -> None:
from neo4j_viz.gds import from_gds

with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
VG = from_gds(gds, G)
with pytest.warns(
UserWarning,
match=re.escape(
"The 'hello' projection's node count (11000) exceeds `max_node_count` (10000), so subsampling will be applied. Increase `max_node_count` if needed"
),
):
VG = from_gds(gds, G)

assert len(VG.nodes) >= 9_500
assert len(VG.nodes) <= 10_500
Expand Down
66 changes: 65 additions & 1 deletion python-wrapper/tests/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
from typing import Generator

import neo4j
import pytest
from neo4j import Session
from neo4j import Driver, Session

from neo4j_viz.neo4j import from_neo4j
from neo4j_viz.node import Node
Expand Down Expand Up @@ -201,3 +202,66 @@ def test_from_neo4j_rel_error(neo4j_session: Session) -> None:
match="Error for relationship property 'caption_align' with provided input 'banana'. Reason: Input should be 'top', 'center' or 'bottom'",
):
from_neo4j(graph)


@pytest.mark.requires_neo4j_and_gds
def test_from_neo4j_graph_driver(neo4j_session: Session, neo4j_driver: Driver) -> None:
graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()

# Note that this tests requires an empty Neo4j database, as it just fetches everything
VG = from_neo4j(neo4j_driver)

sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"])
node_ids: list[str] = [node.element_id for node in sorted_nodes]

expected_nodes = [
Node(
id=node_ids[0],
caption="_CI_A",
properties=dict(
labels=["_CI_A"],
name="Alice",
height=20,
id=42,
_id=1337,
caption="hello",
),
),
Node(
id=node_ids[1],
caption="_CI_A:_CI_B",
size=11,
properties=dict(
labels=["_CI_A", "_CI_B"],
name="Bob",
height=10,
id=84,
__labels=[1, 2],
),
),
]

assert len(VG.nodes) == 2
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes

assert len(VG.relationships) == 2
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")
assert vg_rels == [
(node_ids[0], node_ids[1], "KNOWS"),
(node_ids[1], node_ids[0], "RELATED"),
]


@pytest.mark.requires_neo4j_and_gds
def test_from_neo4j_graph_row_limit_warning(neo4j_session: Session, neo4j_driver: Driver) -> None:
neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()

with pytest.warns(
UserWarning,
match=re.escape(
"Database relationship count (2) exceeds `row_limit` (1), so limiting will be applied. Increase the `row_limit` if needed"
),
):
VG = from_neo4j(neo4j_driver, row_limit=1)

assert len(VG.relationships) == 1