Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

## Bug fixes

* Make sure that temporary internal node properties are not included in the visualization output
* Fixed bug where loading a graph with `from_gds` where all node or relationship properties are not present on every entity would result in an error


## Improvements

Expand Down
62 changes: 37 additions & 25 deletions python-wrapper/src/neo4j_viz/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@


def _fetch_node_dfs(
gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str]
) -> dict[str, pd.DataFrame]:
return {
lbl: gds.graph.nodeProperties.stream(
G, node_properties=node_properties, node_labels=[lbl], separate_property_columns=True
G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True
)
for lbl in node_labels
}
Expand Down Expand Up @@ -79,24 +79,31 @@ def from_gds(
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also allow additional_node_properties to be given as a dict: label -> properties, and similarly size_property as dict: label -> property?
It would make things even more heterogeneous-native I think. But something for another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there is a tradeoff in how difficult the api looks for first-users.
I think for such cases, you could always modify the visualization graph after initial import

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok 👍

node_properties_from_gds = G.node_properties()
assert isinstance(node_properties_from_gds, pd.Series)
actual_node_properties = list(chain.from_iterable(node_properties_from_gds.to_dict().values()))
actual_node_properties = node_properties_from_gds.to_dict()
all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values()))

if size_property is not None and size_property not in actual_node_properties:
raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
if size_property is not None:
if size_property not in all_actual_node_properties:
raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")

if additional_node_properties is None:
additional_node_properties = actual_node_properties
node_properties_by_label = {k: set(v) for k, v in actual_node_properties.items()}
else:
for prop in additional_node_properties:
if prop not in actual_node_properties:
if prop not in all_actual_node_properties:
raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'")

node_properties = set()
if additional_node_properties is not None:
node_properties.update(additional_node_properties)
node_properties_by_label = {}
for label, props in actual_node_properties.items():
node_properties_by_label[label] = {
prop for prop in actual_node_properties[label] if prop in additional_node_properties
}

if size_property is not None:
node_properties.add(size_property)
node_properties = list(node_properties)
for label, props in node_properties_by_label.items():
props.add(size_property)

node_properties_by_label = {k: list(v) for k, v in node_properties_by_label.items()}

node_count = G.node_count()
if node_count > max_node_count:
Expand All @@ -112,13 +119,18 @@ def from_gds(
property_name = None
try:
# Since GDS does not allow us to only fetch node IDs, we add the degree property
# as a temporary property to ensure that we have at least one property to fetch
if len(actual_node_properties) == 0:
# as a temporary property to ensure that we have at least one property for each label to fetch
if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0:
property_name = f"neo4j-viz_property_{uuid4()}"
gds.degree.mutate(G_fetched, mutateProperty=property_name)
node_properties = [property_name]
for props in node_properties_by_label.values():
props.append(property_name)

node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels())
if property_name is not None:
for df in node_dfs.values():
df.drop(columns=[property_name], inplace=True)

node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
rel_df = _fetch_rel_df(gds, G_fetched)
finally:
if G_fetched.name() != G.name():
Expand All @@ -127,35 +139,35 @@ def from_gds(
gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])

for df in node_dfs.values():
df.rename(columns={"nodeId": "id"}, inplace=True)
if property_name is not None and property_name in df.columns:
df.drop(columns=[property_name], inplace=True)
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)

node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
if size_property is not None:
if "size" in actual_node_properties and size_property != "size":
if "size" in all_actual_node_properties and size_property != "size":
node_props_df.rename(columns={"size": "__size"}, inplace=True)
node_props_df.rename(columns={size_property: "size"}, inplace=True)

for lbl, df in node_dfs.items():
if "labels" in actual_node_properties:
if "labels" in all_actual_node_properties:
df.rename(columns={"labels": "__labels"}, inplace=True)
df["labels"] = lbl

node_labels_df = pd.concat([df[["id", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
node_labels_df = node_labels_df.groupby("id").agg({"labels": list})
node_labels_df = pd.concat([df[["nodeId", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
node_labels_df = node_labels_df.groupby("nodeId").agg({"labels": list})

node_df = node_props_df.merge(node_labels_df, on="id")
node_df = node_props_df.merge(node_labels_df, on="nodeId")

if "caption" not in actual_node_properties:
if "caption" not in all_actual_node_properties:
node_df["caption"] = node_df["labels"].astype(str)

if "caption" not in rel_df.columns:
rel_df["caption"] = rel_df["relationshipType"]

try:
return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
return _from_dfs(
node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
)
except ValueError as e:
err_msg = str(e)
if "column" in err_msg:
Expand Down
19 changes: 14 additions & 5 deletions python-wrapper/src/neo4j_viz/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def _from_dfs(
rel_dfs: DFS_TYPE,
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
rename_properties: Optional[dict[str, str]] = None,
dropna: bool = False,
) -> VisualizationGraph:
relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties)
relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties, dropna=dropna)

if node_dfs is None:
has_size = False
Expand All @@ -42,7 +43,7 @@ def _from_dfs(
node_ids.add(rel.target)
nodes = [Node(id=id) for id in node_ids]
else:
nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties)
nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties, dropna=dropna)

VG = VisualizationGraph(nodes=nodes, relationships=relationships)

Expand All @@ -52,7 +53,9 @@ def _from_dfs(
return VG


def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> tuple[list[Node], bool]:
def _parse_nodes(
node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
) -> tuple[list[Node], bool]:
if isinstance(node_dfs, DataFrame):
node_dfs_iter: Iterable[DataFrame] = [node_dfs]
elif node_dfs is None:
Expand All @@ -67,6 +70,8 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
for node_df in node_dfs_iter:
has_size &= "size" in node_df.columns
for _, row in node_df.iterrows():
if dropna:
row = row.dropna(inplace=False)
top_level = {}
properties = {}
for key, value in row.to_dict().items():
Expand All @@ -85,7 +90,9 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
return nodes, has_size


def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> list[Relationship]:
def _parse_relationships(
rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
) -> list[Relationship]:
all_rel_field_aliases = Relationship.all_validation_aliases()

if isinstance(rel_dfs, DataFrame):
Expand All @@ -96,6 +103,8 @@ def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str

for rel_df in rel_dfs_iter:
for _, row in rel_df.iterrows():
if dropna:
row = row.dropna(inplace=False)
top_level = {}
properties = {}
for key, value in row.to_dict().items():
Expand Down Expand Up @@ -138,4 +147,4 @@ def from_dfs(
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
"""

return _from_dfs(node_dfs, rel_dfs, node_radius_min_max)
return _from_dfs(node_dfs, rel_dfs, node_radius_min_max, dropna=False)
2 changes: 2 additions & 0 deletions python-wrapper/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def aura_ds_instance() -> Generator[Any, None, None]:

# setting as environment variables to run notebooks with this connection
os.environ["NEO4J_URI"] = dbms_connection_info.uri
assert isinstance(dbms_connection_info.username, str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right latest gds release changed this.
There is dbms_connection_info.get_auth(), but thats not useful here to set an env variable

os.environ["NEO4J_USER"] = dbms_connection_info.username
assert isinstance(dbms_connection_info.password, str)
os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password
yield dbms_connection_info

Expand Down
2 changes: 0 additions & 2 deletions python-wrapper/tests/gds_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]:
if wait_result.error:
raise Exception(f"Error while waiting for instance to be running: {wait_result.error}")

wait_result.connection_url

return instance_details.id, DbmsConnectionInfo(
uri=wait_result.connection_url,
username="neo4j",
Expand Down
64 changes: 64 additions & 0 deletions python-wrapper/tests/test_gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,71 @@ def test_from_gds_sample(gds: Any) -> None:
):
VG = from_gds(gds, G)

# Make sure internal temporary properties are not present
assert set(VG.nodes[0].properties.keys()) == {"labels"}

assert len(VG.nodes) >= 9_500
assert len(VG.nodes) <= 10_500
assert len(VG.relationships) >= 9_500
assert len(VG.relationships) <= 10_500


@pytest.mark.requires_neo4j_and_gds
def test_from_gds_hetero(gds: Any) -> None:
from neo4j_viz.gds import from_gds

A_nodes = pd.DataFrame(
{
"nodeId": [0, 1],
"labels": ["A", "A"],
"component": [1, 2],
}
)
B_nodes = pd.DataFrame(
{
"nodeId": [2, 3],
"labels": ["B", "B"],
# No 'component' property
}
)
rels = pd.DataFrame(
{
"sourceNodeId": [0, 1],
"targetNodeId": [2, 3],
"weight": [0.5, 1.5],
"relationshipType": ["REL", "REL2"],
}
)

with gds.graph.construct("flo", [A_nodes, B_nodes], rels) as G:
VG = from_gds(
gds,
G,
)

assert len(VG.nodes) == 4
assert sorted(VG.nodes, key=lambda x: x.id) == [
Node(id=0, caption="['A']", properties=dict(labels=["A"], component=float(1))),
Node(id=1, caption="['A']", properties=dict(labels=["A"], component=float(2))),
Node(id=2, caption="['B']", properties=dict(labels=["B"])),
Node(id=3, caption="['B']", properties=dict(labels=["B"])),
]

assert len(VG.relationships) == 2
vg_rels = sorted(
[
(
e.source,
e.target,
e.caption,
e.properties["relationshipType"],
e.properties["weight"],
)
for e in VG.relationships
],
key=lambda x: x[0],
)
assert vg_rels == [
(0, 2, "REL", "REL", 0.5),
(1, 3, "REL2", "REL2", 1.5),
]