diff --git a/changelog.md b/changelog.md index 54e3a0b9..762168aa 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,9 @@ ## Bug fixes +* Make sure that temporary internal node properties are not included in the visualization output +* Fixed bug where loading a graph with `from_gds` where all node or relationship properties are not present on every entity would result in an error + ## Improvements diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index 338de75c..4aba8ec4 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -14,11 +14,11 @@ def _fetch_node_dfs( - gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str] + gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str] ) -> dict[str, pd.DataFrame]: return { lbl: gds.graph.nodeProperties.stream( - G, node_properties=node_properties, node_labels=[lbl], separate_property_columns=True + G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True ) for lbl in node_labels } @@ -79,24 +79,31 @@ def from_gds( """ node_properties_from_gds = G.node_properties() assert isinstance(node_properties_from_gds, pd.Series) - actual_node_properties = list(chain.from_iterable(node_properties_from_gds.to_dict().values())) + actual_node_properties = node_properties_from_gds.to_dict() + all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values())) - if size_property is not None and size_property not in actual_node_properties: - raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'") + if size_property is not None: + if size_property not in all_actual_node_properties: + raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'") if additional_node_properties is None: - additional_node_properties = actual_node_properties + node_properties_by_label = {k: set(v) for k, v in actual_node_properties.items()} else: for prop in additional_node_properties: - if prop not in actual_node_properties: + if prop not in all_actual_node_properties: raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'") - node_properties = set() - if additional_node_properties is not None: - node_properties.update(additional_node_properties) + node_properties_by_label = {} + for label, props in actual_node_properties.items(): + node_properties_by_label[label] = { + prop for prop in actual_node_properties[label] if prop in additional_node_properties + } + if size_property is not None: - node_properties.add(size_property) - node_properties = list(node_properties) + for label, props in node_properties_by_label.items(): + props.add(size_property) + + node_properties_by_label = {k: list(v) for k, v in node_properties_by_label.items()} node_count = G.node_count() if node_count > max_node_count: @@ -112,13 +119,18 @@ def from_gds( property_name = None try: # Since GDS does not allow us to only fetch node IDs, we add the degree property - # as a temporary property to ensure that we have at least one property to fetch - if len(actual_node_properties) == 0: + # as a temporary property to ensure that we have at least one property for each label to fetch + if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0: property_name = f"neo4j-viz_property_{uuid4()}" gds.degree.mutate(G_fetched, mutateProperty=property_name) - node_properties = [property_name] + for props in node_properties_by_label.values(): + props.append(property_name) + + node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels()) + if property_name is not None: + for df in node_dfs.values(): + df.drop(columns=[property_name], inplace=True) - node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels()) rel_df = _fetch_rel_df(gds, G_fetched) finally: if G_fetched.name() != G.name(): @@ -127,35 +139,35 @@ def from_gds( gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name]) for df in node_dfs.values(): - df.rename(columns={"nodeId": "id"}, inplace=True) if property_name is not None and property_name in df.columns: df.drop(columns=[property_name], inplace=True) - rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True) node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates() if size_property is not None: - if "size" in actual_node_properties and size_property != "size": + if "size" in all_actual_node_properties and size_property != "size": node_props_df.rename(columns={"size": "__size"}, inplace=True) node_props_df.rename(columns={size_property: "size"}, inplace=True) for lbl, df in node_dfs.items(): - if "labels" in actual_node_properties: + if "labels" in all_actual_node_properties: df.rename(columns={"labels": "__labels"}, inplace=True) df["labels"] = lbl - node_labels_df = pd.concat([df[["id", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0) - node_labels_df = node_labels_df.groupby("id").agg({"labels": list}) + node_labels_df = pd.concat([df[["nodeId", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0) + node_labels_df = node_labels_df.groupby("nodeId").agg({"labels": list}) - node_df = node_props_df.merge(node_labels_df, on="id") + node_df = node_props_df.merge(node_labels_df, on="nodeId") - if "caption" not in actual_node_properties: + if "caption" not in all_actual_node_properties: node_df["caption"] = node_df["labels"].astype(str) if "caption" not in rel_df.columns: rel_df["caption"] = rel_df["relationshipType"] try: - return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}) + return _from_dfs( + node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True + ) except ValueError as e: err_msg = str(e) if "column" in err_msg: diff --git a/python-wrapper/src/neo4j_viz/pandas.py b/python-wrapper/src/neo4j_viz/pandas.py index 15e29c0e..b07d9c39 100644 --- a/python-wrapper/src/neo4j_viz/pandas.py +++ b/python-wrapper/src/neo4j_viz/pandas.py @@ -31,8 +31,9 @@ def _from_dfs( rel_dfs: DFS_TYPE, node_radius_min_max: Optional[tuple[float, float]] = (3, 60), rename_properties: Optional[dict[str, str]] = None, + dropna: bool = False, ) -> VisualizationGraph: - relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties) + relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties, dropna=dropna) if node_dfs is None: has_size = False @@ -42,7 +43,7 @@ def _from_dfs( node_ids.add(rel.target) nodes = [Node(id=id) for id in node_ids] else: - nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties) + nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties, dropna=dropna) VG = VisualizationGraph(nodes=nodes, relationships=relationships) @@ -52,7 +53,9 @@ def _from_dfs( return VG -def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> tuple[list[Node], bool]: +def _parse_nodes( + node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False +) -> tuple[list[Node], bool]: if isinstance(node_dfs, DataFrame): node_dfs_iter: Iterable[DataFrame] = [node_dfs] elif node_dfs is None: @@ -67,6 +70,8 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]] for node_df in node_dfs_iter: has_size &= "size" in node_df.columns for _, row in node_df.iterrows(): + if dropna: + row = row.dropna(inplace=False) top_level = {} properties = {} for key, value in row.to_dict().items(): @@ -85,7 +90,9 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]] return nodes, has_size -def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> list[Relationship]: +def _parse_relationships( + rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False +) -> list[Relationship]: all_rel_field_aliases = Relationship.all_validation_aliases() if isinstance(rel_dfs, DataFrame): @@ -96,6 +103,8 @@ def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str for rel_df in rel_dfs_iter: for _, row in rel_df.iterrows(): + if dropna: + row = row.dropna(inplace=False) top_level = {} properties = {} for key, value in row.to_dict().items(): @@ -138,4 +147,4 @@ def from_dfs( To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. """ - return _from_dfs(node_dfs, rel_dfs, node_radius_min_max) + return _from_dfs(node_dfs, rel_dfs, node_radius_min_max, dropna=False) diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py index 081a7007..3f7aa6c7 100644 --- a/python-wrapper/tests/conftest.py +++ b/python-wrapper/tests/conftest.py @@ -43,7 +43,9 @@ def aura_ds_instance() -> Generator[Any, None, None]: # setting as environment variables to run notebooks with this connection os.environ["NEO4J_URI"] = dbms_connection_info.uri + assert isinstance(dbms_connection_info.username, str) os.environ["NEO4J_USER"] = dbms_connection_info.username + assert isinstance(dbms_connection_info.password, str) os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password yield dbms_connection_info diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py index e5fa270d..e5a0d3dc 100644 --- a/python-wrapper/tests/gds_helper.py +++ b/python-wrapper/tests/gds_helper.py @@ -62,8 +62,6 @@ def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]: if wait_result.error: raise Exception(f"Error while waiting for instance to be running: {wait_result.error}") - wait_result.connection_url - return instance_details.id, DbmsConnectionInfo( uri=wait_result.connection_url, username="neo4j", diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index fda1caf3..75f87471 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -276,7 +276,71 @@ def test_from_gds_sample(gds: Any) -> None: ): VG = from_gds(gds, G) + # Make sure internal temporary properties are not present + assert set(VG.nodes[0].properties.keys()) == {"labels"} + assert len(VG.nodes) >= 9_500 assert len(VG.nodes) <= 10_500 assert len(VG.relationships) >= 9_500 assert len(VG.relationships) <= 10_500 + + +@pytest.mark.requires_neo4j_and_gds +def test_from_gds_hetero(gds: Any) -> None: + from neo4j_viz.gds import from_gds + + A_nodes = pd.DataFrame( + { + "nodeId": [0, 1], + "labels": ["A", "A"], + "component": [1, 2], + } + ) + B_nodes = pd.DataFrame( + { + "nodeId": [2, 3], + "labels": ["B", "B"], + # No 'component' property + } + ) + rels = pd.DataFrame( + { + "sourceNodeId": [0, 1], + "targetNodeId": [2, 3], + "weight": [0.5, 1.5], + "relationshipType": ["REL", "REL2"], + } + ) + + with gds.graph.construct("flo", [A_nodes, B_nodes], rels) as G: + VG = from_gds( + gds, + G, + ) + + assert len(VG.nodes) == 4 + assert sorted(VG.nodes, key=lambda x: x.id) == [ + Node(id=0, caption="['A']", properties=dict(labels=["A"], component=float(1))), + Node(id=1, caption="['A']", properties=dict(labels=["A"], component=float(2))), + Node(id=2, caption="['B']", properties=dict(labels=["B"])), + Node(id=3, caption="['B']", properties=dict(labels=["B"])), + ] + + assert len(VG.relationships) == 2 + vg_rels = sorted( + [ + ( + e.source, + e.target, + e.caption, + e.properties["relationshipType"], + e.properties["weight"], + ) + for e in VG.relationships + ], + key=lambda x: x[0], + ) + assert vg_rels == [ + (0, 2, "REL", "REL", 0.5), + (1, 3, "REL2", "REL2", 1.5), + ]