Skip to content

Commit 344170d

Browse files
committed
Add more params to from_gql_create
1 parent 45f5fdb commit 344170d

3 files changed

Lines changed: 119 additions & 5 deletions

File tree

python-wrapper/src/neo4j_viz/gql_create.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ def _get_snippet(q: str, idx: int, context: int = 15) -> str:
9090
return q[start:end].replace("\n", " ")
9191

9292

93-
def from_gql_create(query: str) -> VisualizationGraph:
93+
def from_gql_create(
94+
query: str,
95+
size_property: Optional[str] = None,
96+
node_caption: Optional[str] = "labels",
97+
relationship_caption: Optional[str] = "type",
98+
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
99+
) -> VisualizationGraph:
94100
"""
95101
Parse a GQL CREATE query and return a VisualizationGraph object representing the graph it creates.
96102
@@ -107,6 +113,15 @@ def from_gql_create(query: str) -> VisualizationGraph:
107113
----------
108114
query : str
109115
The GQL CREATE query to parse
116+
size_property : str, optional
117+
Property to use for node size, by default None.
118+
node_caption : str, optional
119+
Property to use as the node caption, by default the node labels will be used.
120+
relationship_caption : str, optional
121+
Property to use as the relationship caption, by default the relationship type will be used.
122+
node_radius_min_max : tuple[float, float], optional
123+
Minimum and maximum node radius, by default (3, 60).
124+
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
110125
"""
111126

112127
query = query.strip()
@@ -295,4 +310,28 @@ def parse_labels_and_props(
295310
snippet = part[:30]
296311
raise ValueError(f"Invalid element in CREATE near: `{snippet}`.")
297312

298-
return VisualizationGraph(nodes=nodes, relationships=relationships)
313+
if size_property is not None:
314+
for node in nodes:
315+
node.size = node.properties.get(size_property)
316+
317+
if node_caption is not None:
318+
for node in nodes:
319+
if node_caption == "labels":
320+
if len(node.properties["labels"]) > 0:
321+
node.caption = ":".join([label for label in node.properties["labels"]])
322+
else:
323+
node.caption = str(node.properties.get(node_caption))
324+
325+
if relationship_caption is not None:
326+
for rel in relationships:
327+
if relationship_caption == "type":
328+
rel.caption = rel.properties["type"]
329+
else:
330+
rel.caption = str(rel.properties.get(relationship_caption))
331+
332+
VG = VisualizationGraph(nodes=nodes, relationships=relationships)
333+
334+
if (node_radius_min_max is not None) and (size_property is not None):
335+
VG.resize_nodes(node_radius_min_max=node_radius_min_max)
336+
337+
return VG

python-wrapper/src/neo4j_viz/neo4j.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def from_neo4j(
5656

5757
VG = VisualizationGraph(nodes, relationships)
5858

59-
if node_radius_min_max and size_property is not None:
59+
if (node_radius_min_max is not None) and (size_property is not None):
6060
VG.resize_nodes(node_radius_min_max=node_radius_min_max)
6161

6262
return VG

python-wrapper/tests/test_gql_create.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from neo4j_viz.gql_create import from_gql_create
66

77

8-
def test_from_gql_create() -> None:
8+
def test_from_gql_create_syntax() -> None:
99
query = """
1010
CREATE
1111
(a:User {name: 'Alice', age: 23, labels: ['Happy'], "id": 42}),
@@ -47,7 +47,7 @@ def test_from_gql_create() -> None:
4747
{"top_level": {}, "properties": {"name": "Fawad", "age": 78, "labels": ["Person", "User"]}},
4848
]
4949

50-
VG = from_gql_create(query)
50+
VG = from_gql_create(query, node_caption=None, relationship_caption=None)
5151

5252
assert len(VG.nodes) == len(expected_node_dicts)
5353
for i, exp_node in enumerate(expected_node_dicts):
@@ -80,6 +80,81 @@ def test_from_gql_create() -> None:
8080
assert created_rel.properties == exp_rel["properties"]
8181

8282

83+
def test_from_gql_create_captions() -> None:
84+
query = """
85+
CREATE
86+
(a:User {name: 'Alice', age: 23}),
87+
(b:User:person {name: "Bridget", age: 34, "caption": "Bridget"}),
88+
(a)-[:LINK {weight: 0.5}]->(b);
89+
"""
90+
expected_node_dicts: list[dict[str, dict[str, Any]]] = [
91+
{
92+
"top_level": {"caption": "User"},
93+
"properties": {"name": "Alice", "age": 23, "labels": ["User"]},
94+
},
95+
{
96+
"top_level": {"caption": "User:person"},
97+
"properties": {"name": "Bridget", "age": 34, "labels": ["User", "person"]},
98+
},
99+
]
100+
101+
VG = from_gql_create(query)
102+
103+
assert len(VG.nodes) == len(expected_node_dicts)
104+
for i, exp_node in enumerate(expected_node_dicts):
105+
created_node = VG.nodes[i]
106+
107+
assert created_node.model_dump(exclude_none=True, exclude={"properties", "id"}) == exp_node["top_level"]
108+
assert created_node.properties == exp_node["properties"]
109+
110+
expected_relationships_dicts: list[dict[str, Any]] = [
111+
{
112+
"source_idx": 0,
113+
"target_idx": 1,
114+
"top_level": {"caption": "LINK"},
115+
"properties": {"weight": 0.5, "type": "LINK"},
116+
},
117+
]
118+
119+
assert len(VG.relationships) == len(expected_relationships_dicts)
120+
for i, exp_rel in enumerate(expected_relationships_dicts):
121+
created_rel = VG.relationships[i]
122+
assert created_rel.source == VG.nodes[exp_rel["source_idx"]].id
123+
assert created_rel.target == VG.nodes[exp_rel["target_idx"]].id
124+
assert (
125+
created_rel.model_dump(exclude_none=True, exclude={"properties", "id", "source", "target"})
126+
== exp_rel["top_level"]
127+
)
128+
assert created_rel.properties == exp_rel["properties"]
129+
130+
131+
def test_from_gql_create_sizes() -> None:
132+
query = """
133+
CREATE
134+
(a:User {name: 'Alice', age: 23}),
135+
(b:User:person {name: "Bridget", age: 34, "caption": "Bridget"});
136+
"""
137+
expected_node_dicts: list[dict[str, dict[str, Any]]] = [
138+
{
139+
"top_level": {"size": 3.0},
140+
"properties": {"name": "Alice", "age": 23, "labels": ["User"]},
141+
},
142+
{
143+
"top_level": {"caption": "Bridget", "size": 60.0},
144+
"properties": {"name": "Bridget", "age": 34, "labels": ["User", "person"]},
145+
},
146+
]
147+
148+
VG = from_gql_create(query, size_property="age", node_caption=None, relationship_caption=None)
149+
150+
assert len(VG.nodes) == len(expected_node_dicts)
151+
for i, exp_node in enumerate(expected_node_dicts):
152+
created_node = VG.nodes[i]
153+
154+
assert created_node.model_dump(exclude_none=True, exclude={"properties", "id"}) == exp_node["top_level"]
155+
assert created_node.properties == exp_node["properties"]
156+
157+
83158
def test_unbalanced_parentheses_snippet() -> None:
84159
query = "CREATE (a:User, (b:User })"
85160
with pytest.raises(ValueError, match=r"Unbalanced parentheses near: `.*\(b:User.*"):

0 commit comments

Comments
 (0)