diff --git a/changelog.md b/changelog.md index 547eb770..eed2d720 100644 --- a/changelog.md +++ b/changelog.md @@ -4,7 +4,8 @@ ## Breaking changes * The `from_gds` method now fetches all node properties of a given GDS projection by default, instead of none. -* The `from_gds` now adds node labels as captions for nodes. +* The `from_gds` method now adds node labels as captions for nodes. +* The `from_gds` method now samples large graphs before fetching them by default, but this can be overridden. ## New features diff --git a/docs/source/integration.rst b/docs/source/integration.rst index 11f59013..d99091d9 100644 --- a/docs/source/integration.rst +++ b/docs/source/integration.rst @@ -105,6 +105,14 @@ The ``from_gds`` method takes two mandatory positional parameters: * An initialized ``GraphDataScience`` object for the connection to the GDS instance, and * A ``Graph`` representing the projection that one wants to import. +The optional ``max_node_count`` parameter can be used to limit the number of nodes that are imported from the +projection. +By default, it is set to 10.000, meaning that if the projection has more than 10.000 nodes, ``from_gds`` will sample +from it using random walk with restarts, to get a smaller graph that can be visualized. +If you want to have more control of the sampling, such as choosing a specific start node for the sample, you can call +a `sampling `_ +method yourself and passing the resulting projection to ``from_gds``. + We can also provide an optional ``size_property`` parameter, which should refer to a node property of the projection, and will be used to determine the sizes of the nodes in the visualization. diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index e1ec024b..e952e605 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -2,15 +2,17 @@ from itertools import chain from typing import Optional +from uuid import uuid4 import pandas as pd from graphdatascience import Graph, GraphDataScience +from pandas import Series from .pandas import _from_dfs from .visualization_graph import VisualizationGraph -def _node_dfs( +def _fetch_node_dfs( gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str] ) -> dict[str, pd.DataFrame]: return { @@ -21,17 +23,17 @@ def _node_dfs( } -def _rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame: +def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame: relationship_properties = G.relationship_properties() + assert isinstance(relationship_properties, Series) - if len(relationship_properties) > 0: - if isinstance(relationship_properties, pd.Series): - relationship_properties_per_type = relationship_properties.tolist() - property_set: set[str] = set() - for props in relationship_properties_per_type: - if props: - property_set.update(props) + relationship_properties_per_type = relationship_properties.tolist() + property_set: set[str] = set() + for props in relationship_properties_per_type: + if props: + property_set.update(props) + if len(property_set) > 0: return gds.graph.relationshipProperties.stream( G, relationship_properties=list(property_set), separate_property_columns=True ) @@ -45,6 +47,7 @@ def from_gds( size_property: Optional[str] = None, additional_node_properties: Optional[list[str]] = None, node_radius_min_max: Optional[tuple[float, float]] = (3, 60), + max_node_count: int = 10_000, ) -> VisualizationGraph: """ Create a VisualizationGraph from a GraphDataScience object and a Graph object. @@ -68,6 +71,9 @@ def from_gds( node_radius_min_max : tuple[float, float], optional Minimum and maximum node radius, by default (3, 60). To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. + max_node_count : int, optional + The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts + if its node count exceeds this number. """ node_properties_from_gds = G.node_properties() assert isinstance(node_properties_from_gds, pd.Series) @@ -86,14 +92,40 @@ def from_gds( node_properties = set() if additional_node_properties is not None: node_properties.update(additional_node_properties) - if size_property is not None: node_properties.add(size_property) - node_properties = list(node_properties) - node_dfs = _node_dfs(gds, G, node_properties, G.node_labels()) + + node_count = G.node_count() + if node_count > max_node_count: + sampling_ratio = float(max_node_count) / node_count + sample_name = f"neo4j-viz_sample_{uuid4()}" + G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True) + else: + G_fetched = G + + property_name = None + try: + # Since GDS does not allow us to only fetch node IDs, we add the degree property + # as a temporary property to ensure that we have at least one property to fetch + if len(actual_node_properties) == 0: + property_name = f"neo4j-viz_property_{uuid4()}" + gds.degree.mutate(G_fetched, mutateProperty=property_name) + node_properties = [property_name] + + node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels()) + rel_df = _fetch_rel_df(gds, G_fetched) + finally: + if G_fetched.name() != G.name(): + G_fetched.drop() + elif property_name is not None: + gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name]) + for df in node_dfs.values(): df.rename(columns={"nodeId": "id"}, inplace=True) + if property_name is not None and property_name in df.columns: + df.drop(columns=[property_name], inplace=True) + rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True) node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates() if size_property is not None: @@ -114,9 +146,6 @@ def from_gds( if "caption" not in actual_node_properties: node_df["caption"] = node_df["labels"].astype(str) - rel_df = _rel_df(gds, G) - rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True) - try: return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}) except ValueError as e: diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index ff5da5e3..abec75b9 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -170,9 +170,10 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None: lambda x: pd.Series({lbl: node_properties for lbl in nodes.keys()}), ) mocker.patch("graphdatascience.Graph.node_labels", lambda x: list(nodes.keys())) + mocker.patch("graphdatascience.Graph.node_count", lambda x: sum(len(df) for df in nodes.values())) mocker.patch("graphdatascience.GraphDataScience.__init__", lambda x: None) - mocker.patch("neo4j_viz.gds._node_dfs", return_value=nodes) - mocker.patch("neo4j_viz.gds._rel_df", return_value=rels) + mocker.patch("neo4j_viz.gds._fetch_node_dfs", return_value=nodes) + mocker.patch("neo4j_viz.gds._fetch_rel_df", return_value=rels) gds = GraphDataScience() # type: ignore[call-arg] G = Graph() # type: ignore[call-arg] @@ -244,3 +245,16 @@ def test_from_gds_node_errors(gds: Any) -> None: additional_node_properties=["component", "size"], node_radius_min_max=None, ) + + +@pytest.mark.requires_neo4j_and_gds +def test_from_gds_sample(gds: Any) -> None: + from neo4j_viz.gds import from_gds + + with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G: + VG = from_gds(gds, G) + + assert len(VG.nodes) >= 9_500 + assert len(VG.nodes) <= 10_500 + assert len(VG.relationships) >= 9_500 + assert len(VG.relationships) <= 10_500