-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathgql_create.py
More file actions
384 lines (333 loc) · 14.3 KB
/
gql_create.py
File metadata and controls
384 lines (333 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import re
import uuid
from typing import Any, Optional
from pydantic import BaseModel, ValidationError
from neo4j_viz import Node, Relationship, VisualizationGraph
def _parse_value(value_str: str) -> Any:
value_str = value_str.strip()
if not value_str:
return None
# Parse map
if value_str.startswith("{") and value_str.endswith("}"):
inner = value_str[1:-1].strip()
result = {}
depth = 0
in_string = None
start_idx = 0
for i, ch in enumerate(inner):
if in_string is None:
if ch in ["'", '"']:
in_string = ch
elif ch in ["{", "["]:
depth += 1
elif ch in ["}", "]"]:
depth -= 1
elif ch == "," and depth == 0:
segment = inner[start_idx:i].strip()
if ":" not in segment:
return None
k, v = segment.split(":", 1)
k = k.strip().strip("'\"")
result[k] = _parse_value(v)
start_idx = i + 1
else:
if ch == in_string:
in_string = None
if inner[start_idx:]:
segment = inner[start_idx:].strip()
if ":" not in segment:
return None
k, v = segment.split(":", 1)
k = k.strip().strip("'\"")
result[k] = _parse_value(v)
return result
# Parse list
if value_str.startswith("[") and value_str.endswith("]"):
inner = value_str[1:-1].strip()
items = []
depth = 0
in_string = None
start_idx = 0
for i, ch in enumerate(inner):
if in_string is None:
if ch in ["'", '"']:
in_string = ch
elif ch in ["{", "["]:
depth += 1
elif ch in ["}", "]"]:
depth -= 1
elif ch == "," and depth == 0:
items.append(_parse_value(inner[start_idx:i]))
start_idx = i + 1
else:
if ch == in_string:
in_string = None
if inner[start_idx:]:
items.append(_parse_value(inner[start_idx:]))
return items
# Parse boolean, float, int, or string
if re.match(r"^-?\d+$", value_str):
return int(value_str)
if re.match(r"^-?\d+\.\d+$", value_str):
return float(value_str)
if value_str.lower() == "true":
return True
if value_str.lower() == "false":
return False
if value_str.lower() == "null":
return None
return value_str.strip("'\"")
def _parse_prop_str(
query: str, prop_str: str, prop_start: int, top_level_keys: set[str]
) -> tuple[dict[str, Any], dict[str, Any]]:
top_level: dict[str, Any] = {}
props: dict[str, Any] = {}
depth = 0
in_string = None
start_idx = 0
for i, ch in enumerate(prop_str):
if in_string is None:
if ch in ["'", '"']:
in_string = ch
elif ch in ["{", "["]:
depth += 1
elif ch in ["}", "]"]:
depth -= 1
elif ch == "," and depth == 0:
pair = prop_str[start_idx:i].strip()
if ":" not in pair:
snippet = _get_snippet(query, prop_start + start_idx)
raise ValueError(f"Property syntax error near: `{snippet}`.")
k, v = pair.split(":", 1)
k = k.strip().strip("'\"")
if k in top_level_keys:
top_level[k] = _parse_value(v)
else:
props[k] = _parse_value(v)
start_idx = i + 1
else:
if ch == in_string:
in_string = None
if prop_str[start_idx:]:
pair = prop_str[start_idx:].strip()
if ":" not in pair:
snippet = _get_snippet(query, prop_start + start_idx)
raise ValueError(f"Property syntax error near: `{snippet}`.")
k, v = pair.split(":", 1)
k = k.strip().strip("'\"")
if k in top_level_keys:
top_level[k] = _parse_value(v)
else:
props[k] = _parse_value(v)
return top_level, props
def _parse_labels_and_props(
query: str, s: str, top_level_keys: set[str]
) -> tuple[Optional[str], dict[str, Any], dict[str, Any]]:
prop_match = re.search(r"\{(.*)\}", s)
prop_str = ""
if prop_match:
prop_str = prop_match.group(1)
prop_start = query.index(prop_str, query.index(s))
s = s[: prop_match.start()].strip()
alias_labels = re.split(r"[:&]", s)
raw_alias = alias_labels[0].strip()
final_alias = raw_alias if raw_alias else None
if prop_str:
top_level, props = _parse_prop_str(query, prop_str, prop_start, top_level_keys)
else:
top_level = {}
props = {}
label_list = [lbl.strip() for lbl in alias_labels[1:]]
if "labels" in props:
props["__labels"] = props["labels"]
props["labels"] = sorted(label_list)
return final_alias, top_level, props
def _get_snippet(q: str, idx: int, context: int = 15) -> str:
start = max(0, idx - context)
end = min(len(q), idx + context)
return q[start:end].replace("\n", " ")
def from_gql_create(
query: str,
size_property: Optional[str] = None,
node_caption: Optional[str] = "labels",
relationship_caption: Optional[str] = "type",
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
) -> VisualizationGraph:
"""
Parse a GQL CREATE query and return a VisualizationGraph object representing the graph it creates.
All node and relationship properties will be included in the visualization graph.
If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
top level fields of the respective objects. Otherwise, they will be included in the `properties` dictionary.
Additionally, a "labels" property will be added for nodes and a "type" property for relationships.
Please note that this function is not a full GQL parser, it only handles CREATE queries that do not contain
other clauses like MATCH, WHERE, RETURN, etc, or any Cypher function calls.
It also does not handle all possible GQL syntax, but it should work for most common cases.
For more complex cases, we recommend using a Neo4j database and the `from_neo4j` method.
Parameters
----------
query : str
The GQL CREATE query to parse
size_property : str, optional
Property to use for node size, by default None.
node_caption : str, optional
Property to use as the node caption, by default the node labels will be used.
relationship_caption : str, optional
Property to use as the relationship caption, by default the relationship type will be used.
node_radius_min_max : tuple[float, float], optional
Minimum and maximum node radius, by default (3, 60).
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
"""
query = query.strip()
if not re.match(r"(?i)^create\b", query):
raise ValueError("Query must begin with 'CREATE' (case insensitive).")
query = re.sub(r"(?i)^create\s*", "", query, count=1).rstrip(";").strip()
parts = []
paren_level = 0
bracket_level = 0
current: list[str] = []
for i, char in enumerate(query):
if char == "(":
paren_level += 1
elif char == ")":
paren_level -= 1
if paren_level < 0:
snippet = _get_snippet(query, i)
raise ValueError(f"Unbalanced parentheses near: `{snippet}`.")
if char == "[":
bracket_level += 1
elif char == "]":
bracket_level -= 1
if bracket_level < 0:
snippet = _get_snippet(query, i)
raise ValueError(f"Unbalanced square brackets near: `{snippet}`.")
if char == "," and paren_level == 0 and bracket_level == 0:
parts.append("".join(current).strip())
current = []
else:
current.append(char)
parts.append("".join(current).strip())
if paren_level != 0:
snippet = _get_snippet(query, len(query) - 1)
raise ValueError(f"Unbalanced parentheses near: `{snippet}`.")
if bracket_level != 0:
snippet = _get_snippet(query, len(query) - 1)
raise ValueError(f"Unbalanced square brackets near: `{snippet}`.")
node_pattern = re.compile(r"^\(([^)]*)\)$")
rel_pattern = re.compile(r"^\(([^)]*)\)-\s*\[\s*:(\w+)\s*(\{[^}]*\})?\s*\]->\(([^)]*)\)$")
node_top_level_keys = Node.all_validation_aliases(exempted_fields=["id", "size", "caption"])
rel_top_level_keys = Relationship.all_validation_aliases(exempted_fields=["id", "source", "target", "caption"])
def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) -> None:
for err in e.errors():
loc = err["loc"][0]
if (loc == "size") and size_property is not None:
loc = size_property
if loc == "caption":
if (entity_type == Node) and (node_caption is not None):
loc = node_caption
elif (entity_type == Relationship) and (relationship_caption is not None):
loc = relationship_caption
raise ValueError(
f"Error for {entity_type.__name__.lower()} property '{loc}' with provided input '{err['input']}'. Reason: {err['msg']}"
)
nodes = []
relationships = []
alias_to_id = {}
anonymous_count = 0
for part in parts:
node_m = node_pattern.match(part)
if node_m:
alias_labels_props = node_m.group(1).strip()
alias, top_level, props = _parse_labels_and_props(query, alias_labels_props, node_top_level_keys)
if not alias:
alias = f"_anon_{anonymous_count}"
anonymous_count += 1
if alias not in alias_to_id:
alias_to_id[alias] = str(uuid.uuid4())
try:
nodes.append(Node(id=alias_to_id[alias], **top_level, properties=props))
except ValidationError as e:
_parse_validation_error(e, Node)
continue
rel_m = rel_pattern.match(part)
if rel_m:
left_node = rel_m.group(1).strip()
right_node = rel_m.group(4).strip()
# Parse left node pattern
left_alias, left_top_level, left_props = _parse_labels_and_props(query, left_node, node_top_level_keys)
if not left_alias:
left_alias = f"_anon_{anonymous_count}"
anonymous_count += 1
if left_alias not in alias_to_id:
alias_to_id[left_alias] = str(uuid.uuid4())
try:
nodes.append(Node(id=alias_to_id[left_alias], **left_top_level, properties=left_props))
except ValidationError as e:
_parse_validation_error(e, Node)
elif left_alias not in alias_to_id:
snippet = _get_snippet(query, query.index(left_node))
raise ValueError(f"Relationship references unknown node alias: '{left_alias}' near: `{snippet}`.")
# Parse right node pattern
right_alias, right_top_level, right_props = _parse_labels_and_props(query, right_node, node_top_level_keys)
if not right_alias:
right_alias = f"_anon_{anonymous_count}"
anonymous_count += 1
if right_alias not in alias_to_id:
alias_to_id[right_alias] = str(uuid.uuid4())
try:
nodes.append(Node(id=alias_to_id[right_alias], **right_top_level, properties=right_props))
except ValidationError as e:
_parse_validation_error(e, Node)
elif right_alias not in alias_to_id:
snippet = _get_snippet(query, query.index(right_node))
raise ValueError(f"Relationship references unknown node alias: '{right_alias}' near: `{snippet}`.")
rel_id = str(uuid.uuid4())
rel_type = rel_m.group(2).replace(":", "").strip()
rel_props_str = rel_m.group(3) or ""
if rel_props_str:
inner_str = rel_props_str.strip("{}").strip()
prop_start = query.index(inner_str, query.index(inner_str))
top_level, props = _parse_prop_str(query, inner_str, prop_start, rel_top_level_keys)
else:
top_level = {}
props = {}
if "type" in props:
props["__type"] = props["type"]
props["type"] = rel_type
try:
relationships.append(
Relationship(
id=rel_id,
source=alias_to_id[left_alias],
target=alias_to_id[right_alias],
**top_level,
properties=props,
)
)
except ValidationError as e:
_parse_validation_error(e, Relationship)
continue
snippet = part[:30]
raise ValueError(f"Invalid element in CREATE near: `{snippet}`.")
if size_property is not None:
try:
for node in nodes:
node.size = node.properties.get(size_property)
except ValidationError as e:
_parse_validation_error(e, Node)
if node_caption is not None:
for node in nodes:
if node_caption == "labels":
if len(node.properties["labels"]) > 0:
node.caption = ":".join([label for label in node.properties["labels"]])
else:
node.caption = str(node.properties.get(node_caption))
if relationship_caption is not None:
for rel in relationships:
if relationship_caption == "type":
rel.caption = rel.properties["type"]
else:
rel.caption = str(rel.properties.get(relationship_caption))
VG = VisualizationGraph(nodes=nodes, relationships=relationships)
if (node_radius_min_max is not None) and (size_property is not None):
VG.resize_nodes(node_radius_min_max=node_radius_min_max)
return VG