diff --git a/python-wrapper/src/neo4j_viz/node.py b/python-wrapper/src/neo4j_viz/node.py index b322c262..8c7cad0f 100644 --- a/python-wrapper/src/neo4j_viz/node.py +++ b/python-wrapper/src/neo4j_viz/node.py @@ -5,8 +5,11 @@ from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator from pydantic_extra_types.color import Color, ColorType +from .node_size import RealNumber from .options import CaptionAlignment +NodeIdType = Union[str, int] + class Node(BaseModel, extra="allow"): """ @@ -14,7 +17,7 @@ class Node(BaseModel, extra="allow"): All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_nodes) """ - id: Union[str, int] = Field( + id: NodeIdType = Field( validation_alias=AliasChoices("id", "nodeId", "node_id"), description="Unique identifier for the node" ) caption: Optional[str] = Field(None, description="The caption of the node") @@ -28,7 +31,7 @@ class Node(BaseModel, extra="allow"): serialization_alias="captionSize", description="The size of the caption text. The font size to node radius ratio", ) - size: Optional[Union[int, float]] = Field(None, ge=0, description="The size of the node as radius in pixel") + size: Optional[RealNumber] = Field(None, ge=0, description="The size of the node as radius in pixel") color: Optional[ColorType] = Field(None, description="The color of the node") @field_serializer("color") diff --git a/python-wrapper/src/neo4j_viz/node_size.py b/python-wrapper/src/neo4j_viz/node_size.py new file mode 100644 index 00000000..a8b1a2b5 --- /dev/null +++ b/python-wrapper/src/neo4j_viz/node_size.py @@ -0,0 +1,29 @@ +from typing import Union + +RealNumber = Union[int, float] + + +def verify_radii(node_radius_min_max: tuple[RealNumber, RealNumber]) -> None: + if not isinstance(node_radius_min_max, tuple): + raise ValueError(f"`node_radius_min_max` must be a tuple of two values, but was {node_radius_min_max}") + + if len(node_radius_min_max) != 2: + raise ValueError(f"`node_radius_min_max` must be a tuple of two values, but was {node_radius_min_max}") + + min_size, max_size = node_radius_min_max + if not isinstance(min_size, (int, float)): + raise ValueError(f"Minimum node size must be a real number, but was of type {type(min_size)}") + + if not isinstance(max_size, (int, float)): + raise ValueError(f"Maximum node size must be a real number, but was of type {type(max_size)}") + + if min_size < 0: + raise ValueError(f"Minimum node size must be non-negative, but was {min_size}") + + if max_size < 0: + raise ValueError(f"Maximum node size must be non-negative, but was {max_size}") + + if min_size > max_size: + raise ValueError( + f"Minimum node size must be less than or equal to maximum node size, but was {min_size} > {max_size}" + ) diff --git a/python-wrapper/src/neo4j_viz/pandas.py b/python-wrapper/src/neo4j_viz/pandas.py index 0df069da..5c9a0d28 100644 --- a/python-wrapper/src/neo4j_viz/pandas.py +++ b/python-wrapper/src/neo4j_viz/pandas.py @@ -6,6 +6,7 @@ from pandas import DataFrame from .node import Node +from .node_size import verify_radii from .relationship import Relationship from .visualization_graph import VisualizationGraph @@ -26,7 +27,8 @@ def from_dfs( Minimum and maximum node radius. To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range. """ - if node_radius_min_max and "size" in node_df.columns: + if node_radius_min_max is not None and "size" in node_df.columns: + verify_radii(node_radius_min_max) node_df["size"] = _scale_node_size( node_df["size"], min_size=node_radius_min_max[0], max_size=node_radius_min_max[1] ) @@ -45,7 +47,14 @@ def from_dfs( def _scale_node_size(sizes: pd.Series[Any], min_size: float, max_size: float) -> pd.Series[Any]: - normalized_sizes: pd.Series[Any] = (sizes - sizes.min()) / (sizes.max() - sizes.min()) + old_min_size = sizes.min() + old_max_size = sizes.max() + old_size_range = old_max_size - old_min_size + if abs(old_size_range) < 1e-6: + default_size = min_size + (max_size - min_size) / 2.0 + return pd.Series([default_size] * len(sizes)) + + normalized_sizes: pd.Series[Any] = (sizes - old_min_size) / old_size_range new_size_range = max_size - min_size diff --git a/python-wrapper/src/neo4j_viz/visualization_graph.py b/python-wrapper/src/neo4j_viz/visualization_graph.py index 69c91d53..373f912f 100644 --- a/python-wrapper/src/neo4j_viz/visualization_graph.py +++ b/python-wrapper/src/neo4j_viz/visualization_graph.py @@ -9,7 +9,8 @@ from pydantic_extra_types.color import Color, ColorType from .colors import ColorsType, neo4j_colors -from .node import Node +from .node import Node, NodeIdType +from .node_size import RealNumber, verify_radii from .nvl import NVL from .options import Layout, Renderer, RenderOptions from .relationship import Relationship @@ -65,6 +66,75 @@ def render( height, ) + def resize_nodes( + self, + sizes: Optional[dict[NodeIdType, RealNumber]] = None, + node_radius_min_max: Optional[tuple[RealNumber, RealNumber]] = (3, 60), + ) -> None: + """ + Resize the nodes in the graph. + + Parameters + ---------- + sizes: + A dictionary mapping from node ID to the new size of the node. + If a node ID is not in the dictionary, the size of the node is not changed. + node_radius_min_max: + Minimum and maximum node size radius as a tuple. To avoid tiny or huge nodes in the visualization, the + node sizes are scaled to fit in the given range. If None, the sizes are used as is. + """ + if sizes is None and node_radius_min_max is None: + raise ValueError("At least one of `sizes` and `node_radius_min_max` must be given") + + # Gather and verify all node size values we have to work with + all_sizes = {} + for node in self.nodes: + size = None + if sizes is not None: + size = sizes.get(node.id) + + if size is not None: + if not isinstance(size, (int, float)): + raise ValueError(f"Size for node '{node.id}' must be a real number, but was {size}") + + if size < 0: + raise ValueError(f"Size for node '{node.id}' must be non-negative, but was {size}") + + all_sizes[node.id] = size + + if size is None: + if node.size is not None: + all_sizes[node.id] = node.size + + if node_radius_min_max is not None: + verify_radii(node_radius_min_max) + + unscaled_min_size = min(all_sizes.values()) + unscaled_max_size = max(all_sizes.values()) + unscaled_size_range = float(unscaled_max_size - unscaled_min_size) + + new_min_size, new_max_size = node_radius_min_max + new_size_range = new_max_size - new_min_size + + if abs(unscaled_size_range) < 1e-6: + default_node_size = new_min_size + new_size_range / 2.0 + final_sizes = {id: default_node_size for id in all_sizes} + else: + final_sizes = { + id: new_min_size + new_size_range * ((nz - unscaled_min_size) / unscaled_size_range) + for id, nz in all_sizes.items() + } + else: + final_sizes = all_sizes + + for node in self.nodes: + size = final_sizes.get(node.id) + + if size is None: + continue + + node.size = size + def color_nodes(self, property: str, colors: Optional[ColorsType] = None, override: bool = False) -> None: """ Color the nodes in the graph based on a property. diff --git a/python-wrapper/tests/test_pandas.py b/python-wrapper/tests/test_pandas.py index b5ea6f0c..db2f7bc8 100644 --- a/python-wrapper/tests/test_pandas.py +++ b/python-wrapper/tests/test_pandas.py @@ -56,3 +56,15 @@ def test_node_scaling() -> None: scaled_sizes = _scale_node_size(sizes, min_size, max_size) assert scaled_sizes.equals(pd.Series([3.0, 3.6, 3.9, 4.2, 6.0])) + + +def test_node_scaling_constant() -> None: + from neo4j_viz.pandas import _scale_node_size + + sizes = pd.Series([2, 2, 2, 2, 2]) + min_size = 3 + max_size = 6 + + scaled_sizes = _scale_node_size(sizes, min_size, max_size) + + assert scaled_sizes.equals(pd.Series([min_size + (max_size - min_size) / 2.0] * len(sizes))) diff --git a/python-wrapper/tests/test_sizes.py b/python-wrapper/tests/test_sizes.py new file mode 100644 index 00000000..37a353b3 --- /dev/null +++ b/python-wrapper/tests/test_sizes.py @@ -0,0 +1,128 @@ +import re + +import pytest + +from neo4j_viz import Node, VisualizationGraph +from neo4j_viz.node import NodeIdType +from neo4j_viz.node_size import RealNumber, verify_radii + + +def test_verify_radii() -> None: + with pytest.raises(ValueError, match="`node_radius_min_max` must be a tuple of two values, but was 3"): + verify_radii(3) # type: ignore + + with pytest.raises( + ValueError, match=re.escape("`node_radius_min_max` must be a tuple of two values, but was (1, 2, 3)") + ): + verify_radii((1, 2, 3)) # type: ignore + + with pytest.raises(ValueError, match="Minimum node size must be a real number, but was of type "): + verify_radii(("1", 2)) # type: ignore + + with pytest.raises(ValueError, match="Maximum node size must be a real number, but was of type "): + verify_radii((1, "2")) # type: ignore + + with pytest.raises(ValueError, match="Minimum node size must be non-negative, but was -1"): + verify_radii((-1, 2)) + + with pytest.raises(ValueError, match="Maximum node size must be non-negative, but was -2"): + verify_radii((1, -2)) + + with pytest.raises( + ValueError, match="Minimum node size must be less than or equal to maximum node size, but was 2 > 1" + ): + verify_radii((2, 1)) + + # This should not raise an exception + verify_radii((1, 2)) + + +def test_resize_nodes_no_scaling() -> None: + nodes = [ + Node(id=42), + Node(id="1337", size=10), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + new_sizes: dict[NodeIdType, RealNumber] = {"1337": 20} + VG.resize_nodes(new_sizes, None) + + assert VG.nodes[0].size is None + assert VG.nodes[1].size == 20 + + new_sizes = {42: 8.1, "1337": 3} + VG.resize_nodes(new_sizes, None) + + assert VG.nodes[0].size == 8.1 + assert VG.nodes[1].size == 3 + + new_sizes = {42: -4.2} + with pytest.raises(ValueError, match="Size for node '42' must be non-negative, but was -4.2"): + VG.resize_nodes(new_sizes, None) + + +def test_resize_nodes_with_scaling_constant() -> None: + nodes = [ + Node(id=42), + Node(id="1337", size=10), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + new_sizes: dict[NodeIdType, RealNumber] = {"1337": 20} + VG.resize_nodes(new_sizes, (3, 60)) + + assert VG.nodes[0].size is None + # Should just be the default since min == max in VG (only one node) + assert VG.nodes[1].size == 3 + (60 - 3) / 2.0 + + +def test_resize_nodes_with_scaling_all_sizes_provided() -> None: + nodes = [ + Node(id=42, size=10), + Node(id=43, size=10), + Node(id="1337", size=15), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + new_sizes: dict[NodeIdType, RealNumber] = {42: 18, 43: 19, "1337": 20} + VG.resize_nodes(new_sizes, (3, 60)) + + assert VG.nodes[0].size == 3 + assert VG.nodes[1].size == 3 + (60 - 3) / 2.0 + assert VG.nodes[2].size == 60 + + +def test_resize_nodes_with_scaling_some_sizes_provided() -> None: + nodes = [ + Node(id=42, size=10), + Node(id="1337", size=15), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + new_sizes: dict[NodeIdType, RealNumber] = {"1337": 1} + VG.resize_nodes(new_sizes, (3, 60)) + + assert VG.nodes[0].size == 60 + assert VG.nodes[1].size == 3 + + +def test_resize_nodes_with_scaling_only() -> None: + nodes = [ + Node(id=42, size=10), + Node(id=43, size=10), + Node(id="1337", size=15), + ] + VG = VisualizationGraph(nodes=nodes, relationships=[]) + + VG.resize_nodes(node_radius_min_max=(3, 60)) + + assert VG.nodes[0].size == 3 + assert VG.nodes[1].size == 3 + assert VG.nodes[2].size == 60 + + +def test_resize_nodes_no_args_failure() -> None: + VG = VisualizationGraph(nodes=[], relationships=[]) + + with pytest.raises(ValueError, match="At least one of `sizes` and `node_radius_min_max` must be given"): + VG.resize_nodes(node_radius_min_max=None)