diff --git a/docs/source/api-reference/from_gql_create.rst b/docs/source/api-reference/from_gql_create.rst new file mode 100644 index 00000000..06292910 --- /dev/null +++ b/docs/source/api-reference/from_gql_create.rst @@ -0,0 +1,5 @@ +Import from the Neo4j Graph Data Science Library +------------------------------------------------ + +.. automodule:: neo4j_viz.gql_create + :members: diff --git a/python-wrapper/src/neo4j_viz/gql_create.py b/python-wrapper/src/neo4j_viz/gql_create.py new file mode 100644 index 00000000..b6411148 --- /dev/null +++ b/python-wrapper/src/neo4j_viz/gql_create.py @@ -0,0 +1,242 @@ +import re +import uuid +from typing import Any, Optional + +from neo4j_viz import Node, Relationship, VisualizationGraph + + +def _parse_value(value_str: str) -> Any: + value_str = value_str.strip() + if not value_str: + return None + + # Parse object + if value_str.startswith("{") and value_str.endswith("}"): + inner = value_str[1:-1].strip() + result = {} + depth = 0 + in_string = None + start_idx = 0 + for i, ch in enumerate(inner): + if in_string is None: + if ch in ["'", '"']: + in_string = ch + elif ch in ["{", "["]: + depth += 1 + elif ch in ["}", "]"]: + depth -= 1 + elif ch == "," and depth == 0: + segment = inner[start_idx:i].strip() + if ":" not in segment: + return None + k, v = segment.split(":", 1) + k = k.strip().strip("'\"") + result[k] = _parse_value(v) + start_idx = i + 1 + else: + if ch == in_string: + in_string = None + if inner[start_idx:]: + segment = inner[start_idx:].strip() + if ":" not in segment: + return None + k, v = segment.split(":", 1) + k = k.strip().strip("'\"") + result[k] = _parse_value(v) + return result + + # Parse list + if value_str.startswith("[") and value_str.endswith("]"): + inner = value_str[1:-1].strip() + items = [] + depth = 0 + in_string = None + start_idx = 0 + for i, ch in enumerate(inner): + if in_string is None: + if ch in ["'", '"']: + in_string = ch + elif ch in ["{", "["]: + depth += 1 + elif ch in ["}", "]"]: + depth -= 1 + elif ch == "," and depth == 0: + items.append(_parse_value(inner[start_idx:i])) + start_idx = i + 1 + else: + if ch == in_string: + in_string = None + if inner[start_idx:]: + items.append(_parse_value(inner[start_idx:])) + return items + + # Parse boolean, float, int, or string + if re.match(r"^-?\d+$", value_str): + return int(value_str) + if re.match(r"^-?\d+\.\d+$", value_str): + return float(value_str) + if value_str.lower() == "true": + return True + if value_str.lower() == "false": + return False + if value_str.lower() == "null": + return None + return value_str.strip("'\"") + + +def _get_snippet(q: str, idx: int, context: int = 15) -> str: + start = max(0, idx - context) + end = min(len(q), idx + context) + return q[start:end].replace("\n", " ") + + +def from_gql_create(query: str) -> VisualizationGraph: + """ + Parse a GQL CREATE query and return a VisualizationGraph object representing the graph it creates. + + Please note that this function is not a full GQL parser, it only handles CREATE queries that do not contain + other clauses like MATCH, WHERE, RETURN, etc, or any Cypher function calls. + It also does not handle all possible GQL syntax, but it should work for most common cases. + + Parameters + ---------- + query : str + The GQL CREATE query to parse + """ + + query = query.strip() + # Case-insensitive check that 'CREATE' is the first non-whitespace token + if not re.match(r"(?i)^create\b", query): + raise ValueError("Query must begin with 'CREATE' (case insensitive).") + + def parse_prop_str(prop_str: str, prop_start: int, props: dict[str, Any]) -> None: + depth = 0 + in_string = None + start_idx = 0 + for i, ch in enumerate(prop_str): + if in_string is None: + if ch in ["'", '"']: + in_string = ch + elif ch in ["{", "["]: + depth += 1 + elif ch in ["}", "]"]: + depth -= 1 + elif ch == "," and depth == 0: + pair = prop_str[start_idx:i].strip() + if ":" not in pair: + snippet = _get_snippet(query, prop_start + start_idx) + raise ValueError(f"Property syntax error near: `{snippet}`.") + k, v = pair.split(":", 1) + k = k.strip().strip("'\"") + props[k] = _parse_value(v) + start_idx = i + 1 + else: + if ch == in_string: + in_string = None + if prop_str[start_idx:]: + pair = prop_str[start_idx:].strip() + if ":" not in pair: + snippet = _get_snippet(query, prop_start + start_idx) + raise ValueError(f"Property syntax error near: `{snippet}`.") + k, v = pair.split(":", 1) + k = k.strip().strip("'\"") + props[k] = _parse_value(v) + + def parse_labels_and_props(s: str) -> tuple[Optional[str], dict[str, Any]]: + props = {} + prop_match = re.search(r"\{(.*)\}", s) + prop_str = "" + if prop_match: + prop_str = prop_match.group(1) + prop_start = query.index(prop_str, query.index(s)) + s = s[: prop_match.start()].strip() + alias_labels = re.split(r"[:&]", s) + raw_alias = alias_labels[0].strip() + final_alias = raw_alias if raw_alias else None + + label_list = [lbl.strip() for lbl in alias_labels[1:]] + props["__labels"] = sorted(label_list) + + if prop_str: + parse_prop_str(prop_str, prop_start, props) + return final_alias, props + + nodes = [] + relationships = [] + alias_to_id = {} + anonymous_count = 0 + + query = re.sub(r"(?i)^create\s*", "", query, count=1).rstrip(";").strip() + parts = [] + bracket_level = 0 + current: list[str] = [] + for i, char in enumerate(query): + if char == "(": + bracket_level += 1 + elif char == ")": + bracket_level -= 1 + if bracket_level < 0: + snippet = _get_snippet(query, i) + raise ValueError(f"Unbalanced parentheses near: `{snippet}`.") + if char == "," and bracket_level == 0: + parts.append("".join(current).strip()) + current = [] + else: + current.append(char) + parts.append("".join(current).strip()) + if bracket_level != 0: + snippet = _get_snippet(query, len(query) - 1) + raise ValueError(f"Unbalanced parentheses near: `{snippet}`.") + + node_pattern = re.compile(r"^\(([^)]+)\)$") + rel_pattern = re.compile(r"^\(([^)]+)\)-\s*\[\s*:(\w+)\s*(\{[^}]*\})?\s*\]->\(([^)]+)\)$") + + for part in parts: + node_m = node_pattern.match(part) + if node_m: + alias_labels_props = node_m.group(1).strip() + alias, props = parse_labels_and_props(alias_labels_props) + if not alias: + alias = f"_anon_{anonymous_count}" + anonymous_count += 1 + if alias not in alias_to_id: + alias_to_id[alias] = str(uuid.uuid4()) + nodes.append(Node(id=alias_to_id[alias], properties=props)) + else: + rel_m = rel_pattern.match(part) + if rel_m: + left_node = rel_m.group(1).strip() + rel_type = rel_m.group(2).replace(":", "").strip() + right_node = rel_m.group(4).strip() + + left_alias, left_props = parse_labels_and_props(left_node) + if not left_alias or left_alias not in alias_to_id: + snippet = _get_snippet(query, query.index(left_node)) + raise ValueError(f"Relationship references unknown node alias: '{left_alias}' near: `{snippet}`.") + + right_alias, right_props = parse_labels_and_props(right_node) + if not right_alias or right_alias not in alias_to_id: + snippet = _get_snippet(query, query.index(right_node)) + raise ValueError(f"Relationship references unknown node alias: '{right_alias}' near: `{snippet}`.") + + rel_id = str(uuid.uuid4()) + rel_props = {"__type": rel_type} + rel_props_str = rel_m.group(3) or "" + if rel_props_str: + inner_str = rel_props_str.strip("{}").strip() + prop_start = query.index(inner_str, query.index(inner_str)) + parse_prop_str(inner_str, prop_start, rel_props) + + relationships.append( + Relationship( + id=rel_id, + source=alias_to_id[left_alias], + target=alias_to_id[right_alias], + properties=rel_props, + ) + ) + else: + snippet = part[:30] + raise ValueError(f"Invalid element in CREATE near: `{snippet}`.") + + return VisualizationGraph(nodes=nodes, relationships=relationships) diff --git a/python-wrapper/src/neo4j_viz/neo4j.py b/python-wrapper/src/neo4j_viz/neo4j.py index 72c350ea..60498e2e 100644 --- a/python-wrapper/src/neo4j_viz/neo4j.py +++ b/python-wrapper/src/neo4j_viz/neo4j.py @@ -79,13 +79,10 @@ def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_prop else: caption = str(node.get(caption_property)) - base_node_props = dict(id=node.element_id, caption=caption, labels=labels, size=size) + properties = {k: v for k, v in node.items()} + properties["__labels"] = labels - protected_props = base_node_props.keys() - additional_node_props = {k: v for k, v in node.items()} - additional_node_props = _rename_protected_props(additional_node_props, protected_props) - - return Node(**base_node_props, **additional_node_props) + return Node(id=node.element_id, caption=caption, size=size, properties=properties) def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[str]) -> Optional[Relationship]: @@ -100,32 +97,13 @@ def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[ else: caption = None - base_rel_props = dict( + properties = {k: v for k, v in rel.items()} + properties["__type"] = rel.type + + return Relationship( id=rel.element_id, source=rel.start_node.element_id, target=rel.end_node.element_id, - _type=rel.type, caption=caption, + properties=properties, ) - - protected_props = base_rel_props.keys() - additional_rel_props = {k: v for k, v in rel.items()} - additional_rel_props = _rename_protected_props(additional_rel_props, protected_props) - - return Relationship( - **base_rel_props, - **additional_rel_props, - ) - - -def _rename_protected_props( - additional_props: dict[str, Any], - protected_props: Iterable[str], -) -> dict[str, Union[str, int, float]]: - for prop in protected_props: - if prop not in additional_props: - continue - - additional_props[f"__{prop}"] = additional_props.pop(prop) - - return additional_props diff --git a/python-wrapper/src/neo4j_viz/node.py b/python-wrapper/src/neo4j_viz/node.py index 8718a497..52eced11 100644 --- a/python-wrapper/src/neo4j_viz/node.py +++ b/python-wrapper/src/neo4j_viz/node.py @@ -45,6 +45,8 @@ class Node(BaseModel, extra="allow"): x: Optional[RealNumber] = Field(None, description="The x-coordinate of the node") #: The y-coordinate of the node y: Optional[RealNumber] = Field(None, description="The y-coordinate of the node") + #: The properties of the node + properties: dict[str, Any] = Field(default_factory=dict, description="The properties of the node") @field_serializer("color") def serialize_color(self, color: Color) -> str: diff --git a/python-wrapper/src/neo4j_viz/relationship.py b/python-wrapper/src/neo4j_viz/relationship.py index b5f4c640..10e938ec 100644 --- a/python-wrapper/src/neo4j_viz/relationship.py +++ b/python-wrapper/src/neo4j_viz/relationship.py @@ -43,6 +43,8 @@ class Relationship(BaseModel, extra="allow"): ) #: The color of the relationship. Allowed input is for example "#FF0000", "red" or (255, 0, 0) color: Optional[ColorType] = Field(None, description="The color of the relationship") + #: The properties of the relationship + properties: dict[str, Any] = Field(default_factory=dict, description="The properties of the relationship") @field_serializer("color") def serialize_color(self, color: Color) -> str: diff --git a/python-wrapper/tests/test_gql_create.py b/python-wrapper/tests/test_gql_create.py new file mode 100644 index 00000000..00a8cf18 --- /dev/null +++ b/python-wrapper/tests/test_gql_create.py @@ -0,0 +1,108 @@ +import pytest + +from neo4j_viz.gql_create import from_gql_create + + +def test_from_gql_create() -> None: + query = """ + CREATE + (a:User {name: 'Alice', age: 23}), + (b:User:person {name: "Bridget", age: 34}), + (wizardMan:User {name: 'Charles: The wizard, man', hello: true, height: NULL}), + (d:User), + + (a)-[:LINK {weight: 0.5}]->(b), + + (e:User {age: 67, my_map: {key: 'value', key2: 3.14, key3: [1, 2, 3], key4: {a: 1, b: null}}}), + (:User {age: 42, pets: ['cat', false, 'dog']}), + + (f:User&Person + + + {name: 'Fawad', age: 78}), + + (a)-[:LINK {weight: 4}]->(wizardMan), + (e)-[:LINK]->(d), + (e)-[:OTHER_LINK {weight: -2}]->(f); + """ + expected_node_dicts = [ + {"properties": {"name": "Alice", "age": 23, "__labels": ["User"]}}, + {"properties": {"name": "Bridget", "age": 34, "__labels": ["User", "person"]}}, + {"properties": {"name": "Charles: The wizard, man", "hello": True, "height": None, "__labels": ["User"]}}, + {"properties": {"__labels": ["User"]}}, + { + "properties": { + "age": 67, + "my_map": {"key": "value", "key2": 3.14, "key3": [1, 2, 3], "key4": {"a": 1, "b": None}}, + "__labels": ["User"], + } + }, + {"properties": {"age": 42, "pets": ["cat", False, "dog"], "__labels": ["User"]}}, + {"properties": {"name": "Fawad", "age": 78, "__labels": ["Person", "User"]}}, + ] + + VG = from_gql_create(query) + + assert len(VG.nodes) == len(expected_node_dicts) + for i, exp_node in enumerate(expected_node_dicts): + created_node = VG.nodes[i] + + assert created_node.properties == exp_node["properties"] + + expected_relationships_dicts = [ + {"source_idx": 0, "target_idx": 1, "properties": {"weight": 0.5, "__type": "LINK"}}, + {"source_idx": 0, "target_idx": 2, "properties": {"weight": 4, "__type": "LINK"}}, + {"source_idx": 4, "target_idx": 3, "properties": {"__type": "LINK"}}, + {"source_idx": 4, "target_idx": 6, "properties": {"weight": -2, "__type": "OTHER_LINK"}}, + ] + + assert len(VG.relationships) == len(expected_relationships_dicts) + for i, exp_rel in enumerate(expected_relationships_dicts): + created_rel = VG.relationships[i] + assert created_rel.source == VG.nodes[exp_rel["source_idx"]].id + assert created_rel.target == VG.nodes[exp_rel["target_idx"]].id + assert created_rel.properties == exp_rel["properties"] + + +def test_unbalanced_parentheses_snippet() -> None: + query = "CREATE (a:User, (b:User })" + with pytest.raises(ValueError, match=r"Unbalanced parentheses near: `.*\(b:User.*"): + from_gql_create(query) + + +def test_node_property_syntax_error_snippet1() -> None: + query = "CREATE (a:User {x, y:4})" + with pytest.raises(ValueError, match=r"Property syntax error near: `.*x, y.*"): + from_gql_create(query) + + +def test_node_property_syntax_error_snippet2() -> None: + query = "CREATE (a:User {x:5,, y:4})" + with pytest.raises(ValueError, match=r"Property syntax error near: `.*x:5,, y.*"): + from_gql_create(query) + + +def test_invalid_element_in_create_snippet() -> None: + query = "CREATE [not_a_node]" + with pytest.raises(ValueError, match=r"Invalid element in CREATE near: `\[not_a_node.*"): + from_gql_create(query) + + +def test_rel_property_syntax_error_snippet() -> None: + query = "CREATE (a:User), (b:User), (a)-[:LINK {weight0.5}]->(b)" + with pytest.raises(ValueError, match=r"Property syntax error near: `\), \(a\)-\[:LINK {weight0.5}\]->\(b`."): + from_gql_create(query) + + +def test_unknown_node_alias() -> None: + query = "CREATE (a)-[:LINK {weight0.5}]->(b)" + with pytest.raises( + ValueError, match=r"Relationship references unknown node alias: 'a' near: `\(a\)-\[:LINK {weig`" + ): + from_gql_create(query) + + +def test_no_create_keyword() -> None: + query = "(a:User {y:4})" + with pytest.raises(ValueError, match=r"Query must begin with 'CREATE' \(case insensitive\)."): + from_gql_create(query) diff --git a/python-wrapper/tests/test_neo4j.py b/python-wrapper/tests/test_neo4j.py index 9c80766d..32675f5b 100644 --- a/python-wrapper/tests/test_neo4j.py +++ b/python-wrapper/tests/test_neo4j.py @@ -31,27 +31,29 @@ def test_from_neo4j_graph(neo4j_session: Session) -> None: Node( id=node_ids[0], caption="_CI_A", - labels=["_CI_A"], - name="Alice", - height=20, - __id=42, - _id=1337, - __caption="hello", + properties=dict( + __labels=["_CI_A"], + name="Alice", + height=20, + _id=1337, + caption="hello", + ) ), Node( id=node_ids[1], caption="_CI_A:_CI_B", - labels=["_CI_A", "_CI_B"], - name="Bob", - height=10, - __id=84, - __size=11, - __labels=[1, 2], + properties=dict( + __labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + size=11, + labels=[1, 2], + ), ), ] assert len(VG.nodes) == 2 - assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined] + assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes # type: ignore[attr-defined] assert len(VG.relationships) == 2 vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo") @@ -76,27 +78,29 @@ def test_from_neo4j_result(neo4j_session: Session) -> None: Node( id=node_ids[0], caption="_CI_A", - labels=["_CI_A"], - name="Alice", - height=20, - __id=42, - _id=1337, - __caption="hello", + properties=dict( + __labels=["_CI_A"], + name="Alice", + height=20, + _id=1337, + caption="hello", + ) ), Node( id=node_ids[1], caption="_CI_A:_CI_B", - labels=["_CI_A", "_CI_B"], - name="Bob", - height=10, - __id=84, - __size=11, - __labels=[1, 2], + properties=dict( + __labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + size=11, + labels=[1, 2], + ) ), ] assert len(VG.nodes) == 2 - assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined] + assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes # type: ignore[attr-defined] assert len(VG.relationships) == 2 vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo") @@ -119,29 +123,31 @@ def test_from_neo4j_graph_full(neo4j_session: Session) -> None: Node( id=node_ids[0], caption="Alice", - labels=["_CI_A"], - name="Alice", - height=20, size=60.0, - __id=42, - _id=1337, - __caption="hello", + properties=dict( + __labels=["_CI_A"], + name="Alice", + height=20, + _id=1337, + caption="hello", + ) ), Node( id=node_ids[1], caption="Bob", - labels=["_CI_A", "_CI_B"], - name="Bob", - height=10, size=3.0, - __id=84, - __size=11, - __labels=[1, 2], + properties=dict( + __labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + size=11, + labels=[1, 2], + ), ), ] assert len(VG.nodes) == 2 - assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined] + assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes # type: ignore[attr-defined] assert len(VG.relationships) == 2 vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo") diff --git a/python-wrapper/tests/test_node.py b/python-wrapper/tests/test_node.py index faa32e5b..a15031f6 100644 --- a/python-wrapper/tests/test_node.py +++ b/python-wrapper/tests/test_node.py @@ -26,6 +26,7 @@ def test_nodes_with_all_options() -> None: "pinned": True, "x": 1, "y": 10, + "properties": {}, } @@ -36,6 +37,7 @@ def test_nodes_minimal_node() -> None: assert node.to_dict() == { "id": "1", + "properties": {}, } @@ -48,6 +50,7 @@ def test_node_with_float_size() -> None: assert node.to_dict() == { "id": "1", "size": 10.2, + "properties": {}, } @@ -60,6 +63,7 @@ def test_node_with_additional_fields() -> None: assert node.to_dict() == { "id": "1", "componentId": 2, + "properties": {}, } @@ -69,6 +73,7 @@ def test_id_aliases(alias: str) -> None: assert node.to_dict() == { "id": "1", + "properties": {}, } diff --git a/python-wrapper/tests/test_relationship.py b/python-wrapper/tests/test_relationship.py index 171c7af8..a284109f 100644 --- a/python-wrapper/tests/test_relationship.py +++ b/python-wrapper/tests/test_relationship.py @@ -23,6 +23,7 @@ def test_rels_with_all_options() -> None: "captionAlign": "top", "captionSize": 12, "color": "#ff0000", + "properties": {}, } @@ -34,7 +35,7 @@ def test_rels_minimal_rel() -> None: rel_dict = rel.to_dict() - assert {"id", "from", "to"} == set(rel_dict.keys()) + assert {"id", "from", "to", "properties"} == set(rel_dict.keys()) assert rel_dict["from"] == "1" assert rel_dict["to"] == "2" @@ -43,12 +44,12 @@ def test_rels_additional_fields() -> None: rel = Relationship( source="1", target="2", - componentId=2, + properties=dict(componentId=2), ) rel_dict = rel.to_dict() - assert {"id", "from", "to", "componentId"} == set(rel_dict.keys()) - assert rel.componentId == 2 # type: ignore[attr-defined] + assert {"id", "from", "to", "properties"} == set(rel_dict.keys()) + assert rel.properties["componentId"] == 2 # type: ignore[attr-defined] @pytest.mark.parametrize("src_alias", ["source", "sourceNodeId", "source_node_id", "from"]) @@ -63,6 +64,6 @@ def test_aliases(src_alias: str, trg_alias: str) -> None: rel_dict = rel.to_dict() - assert {"id", "from", "to"} == set(rel_dict.keys()) + assert {"id", "from", "to", "properties"} == set(rel_dict.keys()) assert rel_dict["from"] == "1" assert rel_dict["to"] == "2"