Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
## Breaking changes

* The `from_gds` method now fetches all node properties of a given GDS projection by default, instead of none.
* The `from_gds` now adds node labels as captions for nodes.
* The `from_gds` method now adds node labels as captions for nodes.
* The `from_gds` method now samples large graphs before fetching them by default, but this can be overridden.


## New features
Expand Down
8 changes: 8 additions & 0 deletions docs/source/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ The ``from_gds`` method takes two mandatory positional parameters:
* An initialized ``GraphDataScience`` object for the connection to the GDS instance, and
* A ``Graph`` representing the projection that one wants to import.

The optional ``max_node_count`` parameter can be used to limit the number of nodes that are imported from the
projection.
By default, it is set to 10.000, meaning that if the projection has more than 10.000 nodes, ``from_gds`` will sample
from it using random walk with restarts, to get a smaller graph that can be visualized.
If you want to have more control of the sampling, such as choosing a specific start node for the sample, you can call
a `sampling <https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/sampling/>`_
method yourself and passing the resulting projection to ``from_gds``.

We can also provide an optional ``size_property`` parameter, which should refer to a node property of the projection,
and will be used to determine the sizes of the nodes in the visualization.

Expand Down
59 changes: 44 additions & 15 deletions python-wrapper/src/neo4j_viz/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

from itertools import chain
from typing import Optional
from uuid import uuid4

import pandas as pd
from graphdatascience import Graph, GraphDataScience
from pandas import Series

from .pandas import _from_dfs
from .visualization_graph import VisualizationGraph


def _node_dfs(
def _fetch_node_dfs(
gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
) -> dict[str, pd.DataFrame]:
return {
Expand All @@ -21,17 +23,17 @@ def _node_dfs(
}


def _rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
relationship_properties = G.relationship_properties()
assert isinstance(relationship_properties, Series)

if len(relationship_properties) > 0:
if isinstance(relationship_properties, pd.Series):
relationship_properties_per_type = relationship_properties.tolist()
property_set: set[str] = set()
for props in relationship_properties_per_type:
if props:
property_set.update(props)
relationship_properties_per_type = relationship_properties.tolist()
property_set: set[str] = set()
for props in relationship_properties_per_type:
if props:
property_set.update(props)

if len(property_set) > 0:
return gds.graph.relationshipProperties.stream(
G, relationship_properties=list(property_set), separate_property_columns=True
)
Expand All @@ -45,6 +47,7 @@ def from_gds(
size_property: Optional[str] = None,
additional_node_properties: Optional[list[str]] = None,
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
max_node_count: int = 10_000,
) -> VisualizationGraph:
"""
Create a VisualizationGraph from a GraphDataScience object and a Graph object.
Expand All @@ -68,6 +71,9 @@ def from_gds(
node_radius_min_max : tuple[float, float], optional
Minimum and maximum node radius, by default (3, 60).
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
max_node_count : int, optional
The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts
if its node count exceeds this number.
"""
node_properties_from_gds = G.node_properties()
assert isinstance(node_properties_from_gds, pd.Series)
Expand All @@ -86,14 +92,40 @@ def from_gds(
node_properties = set()
if additional_node_properties is not None:
node_properties.update(additional_node_properties)

if size_property is not None:
node_properties.add(size_property)

node_properties = list(node_properties)
node_dfs = _node_dfs(gds, G, node_properties, G.node_labels())

node_count = G.node_count()
if node_count > max_node_count:
sampling_ratio = float(max_node_count) / node_count
sample_name = f"neo4j-viz_sample_{uuid4()}"
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
else:
G_fetched = G

property_name = None
try:
# Since GDS does not allow us to only fetch node IDs, we add the degree property
# as a temporary property to ensure that we have at least one property to fetch
if len(actual_node_properties) == 0:
property_name = f"neo4j-viz_property_{uuid4()}"
gds.degree.mutate(G_fetched, mutateProperty=property_name)
node_properties = [property_name]

node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
rel_df = _fetch_rel_df(gds, G_fetched)
finally:
if G_fetched.name() != G.name():
G_fetched.drop()
elif property_name is not None:
gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])

for df in node_dfs.values():
df.rename(columns={"nodeId": "id"}, inplace=True)
if property_name is not None and property_name in df.columns:
df.drop(columns=[property_name], inplace=True)
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)

node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
if size_property is not None:
Expand All @@ -114,9 +146,6 @@ def from_gds(
if "caption" not in actual_node_properties:
node_df["caption"] = node_df["labels"].astype(str)

rel_df = _rel_df(gds, G)
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)

try:
return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
except ValueError as e:
Expand Down
18 changes: 16 additions & 2 deletions python-wrapper/tests/test_gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None:
lambda x: pd.Series({lbl: node_properties for lbl in nodes.keys()}),
)
mocker.patch("graphdatascience.Graph.node_labels", lambda x: list(nodes.keys()))
mocker.patch("graphdatascience.Graph.node_count", lambda x: sum(len(df) for df in nodes.values()))
mocker.patch("graphdatascience.GraphDataScience.__init__", lambda x: None)
mocker.patch("neo4j_viz.gds._node_dfs", return_value=nodes)
mocker.patch("neo4j_viz.gds._rel_df", return_value=rels)
mocker.patch("neo4j_viz.gds._fetch_node_dfs", return_value=nodes)
mocker.patch("neo4j_viz.gds._fetch_rel_df", return_value=rels)

gds = GraphDataScience() # type: ignore[call-arg]
G = Graph() # type: ignore[call-arg]
Expand Down Expand Up @@ -244,3 +245,16 @@ def test_from_gds_node_errors(gds: Any) -> None:
additional_node_properties=["component", "size"],
node_radius_min_max=None,
)


@pytest.mark.requires_neo4j_and_gds
def test_from_gds_sample(gds: Any) -> None:
from neo4j_viz.gds import from_gds

with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
VG = from_gds(gds, G)

assert len(VG.nodes) >= 9_500
assert len(VG.nodes) <= 10_500
assert len(VG.relationships) >= 9_500
assert len(VG.relationships) <= 10_500