diff --git a/python-wrapper/src/neo4j_viz/pandas.py b/python-wrapper/src/neo4j_viz/pandas.py index 5c9a0d28..9197b2c5 100644 --- a/python-wrapper/src/neo4j_viz/pandas.py +++ b/python-wrapper/src/neo4j_viz/pandas.py @@ -1,64 +1,60 @@ from __future__ import annotations -from typing import Any, Optional +from collections.abc import Iterable +from typing import Optional, Union -import pandas as pd from pandas import DataFrame from .node import Node -from .node_size import verify_radii from .relationship import Relationship from .visualization_graph import VisualizationGraph +DFS_TYPE = Union[DataFrame, Iterable[DataFrame]] + def from_dfs( - node_df: DataFrame, rel_df: DataFrame, node_radius_min_max: Optional[tuple[float, float]] = (3, 60) + node_dfs: DFS_TYPE, rel_dfs: DFS_TYPE, node_radius_min_max: Optional[tuple[float, float]] = (3, 60) ) -> VisualizationGraph: """ Create a VisualizationGraph from two pandas DataFrames. Parameters ---------- - node_df : DataFrame - DataFrame containing node data. - rel_df : DataFrame - DataFrame containing relationship data. + node_dfs: Union[DataFrame, Iterable[DataFrame]] + DataFrame or iterable of DataFrames containing node data. + rel_dfs: Union[DataFrame, Iterable[DataFrame]] + DataFrame or iterable of DataFrames containing relationship data. node_radius_min_max : tuple[float, float], optional Minimum and maximum node radius. To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. """ - if node_radius_min_max is not None and "size" in node_df.columns: - verify_radii(node_radius_min_max) - node_df["size"] = _scale_node_size( - node_df["size"], min_size=node_radius_min_max[0], max_size=node_radius_min_max[1] - ) + if isinstance(node_dfs, DataFrame): + node_dfs_iter: Iterable[DataFrame] = [node_dfs] + else: + node_dfs_iter = node_dfs + has_size = True nodes = [] - for _, row in node_df.iterrows(): - node = Node(**row.to_dict()) - nodes.append(node) - - relationships = [] - for _, row in rel_df.iterrows(): - rel = Relationship(**row.to_dict()) - relationships.append(rel) + for node_df in node_dfs_iter: + has_size &= "size" in node_df.columns + for _, row in node_df.iterrows(): + node = Node(**row.to_dict()) + nodes.append(node) - return VisualizationGraph(nodes=nodes, relationships=relationships) + if isinstance(rel_dfs, DataFrame): + rel_dfs_iter: Iterable[DataFrame] = [rel_dfs] + else: + rel_dfs_iter = rel_dfs + relationships = [] + for rel_df in rel_dfs_iter: + for _, row in rel_df.iterrows(): + rel = Relationship(**row.to_dict()) + relationships.append(rel) -def _scale_node_size(sizes: pd.Series[Any], min_size: float, max_size: float) -> pd.Series[Any]: - old_min_size = sizes.min() - old_max_size = sizes.max() - old_size_range = old_max_size - old_min_size - if abs(old_size_range) < 1e-6: - default_size = min_size + (max_size - min_size) / 2.0 - return pd.Series([default_size] * len(sizes)) - - normalized_sizes: pd.Series[Any] = (sizes - old_min_size) / old_size_range - - new_size_range = max_size - min_size + VG = VisualizationGraph(nodes=nodes, relationships=relationships) - range_scaled_sizes = normalized_sizes * new_size_range - scaled_sizes = range_scaled_sizes + min_size + if node_radius_min_max is not None and has_size: + VG.resize_nodes(node_radius_min_max=node_radius_min_max) - return scaled_sizes + return VG diff --git a/python-wrapper/tests/test_pandas.py b/python-wrapper/tests/test_pandas.py index db2f7bc8..d901b65e 100644 --- a/python-wrapper/tests/test_pandas.py +++ b/python-wrapper/tests/test_pandas.py @@ -1,4 +1,3 @@ -import pandas as pd from pandas import DataFrame from pydantic_extra_types.color import Color @@ -46,25 +45,62 @@ def test_from_df() -> None: assert VG.relationships[1].caption == "REL2" -def test_node_scaling() -> None: - from neo4j_viz.pandas import _scale_node_size - - sizes = pd.Series([0, 2, 3, 4, 10]) - min_size = 3 - max_size = 6 - - scaled_sizes = _scale_node_size(sizes, min_size, max_size) +def test_from_dfs() -> None: + nodes = [ + DataFrame( + { + "id": [0], + "caption": ["A"], + "size": [1337], + "color": "#FF0000", + } + ), + DataFrame( + { + "id": [1], + "caption": ["B"], + "size": [42], + "color": "#FF0000", + } + ), + ] + + relationships = [ + DataFrame( + { + "source": [0], + "target": [1], + "caption": ["REL"], + } + ), + DataFrame( + { + "source": [1], + "target": [0], + "caption": ["REL2"], + } + ), + ] + VG = from_dfs(nodes, relationships, node_radius_min_max=(42, 1337)) - assert scaled_sizes.equals(pd.Series([3.0, 3.6, 3.9, 4.2, 6.0])) + assert len(VG.nodes) == 2 + assert VG.nodes[0].id == 0 + assert VG.nodes[0].caption == "A" + assert VG.nodes[0].size == 1337 + assert VG.nodes[0].color == Color("#ff0000") -def test_node_scaling_constant() -> None: - from neo4j_viz.pandas import _scale_node_size + assert VG.nodes[1].id == 1 + assert VG.nodes[1].caption == "B" + assert VG.nodes[1].size == 42 + assert VG.nodes[0].color == Color("#ff0000") - sizes = pd.Series([2, 2, 2, 2, 2]) - min_size = 3 - max_size = 6 + assert len(VG.relationships) == 2 - scaled_sizes = _scale_node_size(sizes, min_size, max_size) + assert VG.relationships[0].source == 0 + assert VG.relationships[0].target == 1 + assert VG.relationships[0].caption == "REL" - assert scaled_sizes.equals(pd.Series([min_size + (max_size - min_size) / 2.0] * len(sizes))) + assert VG.relationships[1].source == 1 + assert VG.relationships[1].target == 0 + assert VG.relationships[1].caption == "REL2"