Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,16 @@ def _contains_unresolved_json_type(cls, value: Any) -> bool:
if value is None:
return True
if isinstance(value, (list, tuple)):
return len(value) == 0 or any(
cls._contains_unresolved_json_type(item) for item in value
if len(value) == 0:
return True
has_nested_child = any(
isinstance(item, (list, tuple, dict)) for item in value
)
has_scalar_child = any(
not isinstance(item, (list, tuple, dict)) for item in value
)
return any(cls._contains_unresolved_json_type(item) for item in value) or (
has_nested_child and has_scalar_child
)
if isinstance(value, dict):
return len(value) == 0 or any(
Expand Down
64 changes: 64 additions & 0 deletions test/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import builtins
import json
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -84,6 +85,69 @@ def test_to_json_python_param_with_empty_nested_list(conn_db_empty: ConnDB) -> N
assert response_data == data


def test_to_json_python_param_with_mixed_nested_list(conn_db_empty: ConnDB) -> None:
conn, _ = conn_db_empty
conn.execute("""
INSTALL json;
LOAD json;
CREATE NODE TABLE User (id SERIAL PRIMARY KEY, meta JSON);
""")

data = {
"@context": [
"entry1",
"entry2",
{"key": "value"},
],
}

response = conn.execute(
"""
CREATE (n:User {meta: to_json($meta)})
RETURN n.id as id, cast(n.meta AS STRING) as meta;
""",
parameters={"meta": data},
)

response_data = json.loads(response.rows_as_dict().get_all()[0]["meta"])
assert response_data == data


def test_to_json_mixed_nested_list_normalization_does_not_import_numpy(
conn_db_empty: ConnDB,
monkeypatch,
) -> None:
conn, _ = conn_db_empty
query = "CREATE (n:User {meta: to_json($meta)})"
data = {"@context": ["entry1", "entry2", {"key": "value"}]}
parameters = {"meta": data}

real_import = builtins.__import__

def guarded_import(name, *args, **kwargs):
if name == "numpy" or name.startswith("numpy."):
msg = "JSON parameter normalization should not import NumPy"
raise AssertionError(msg)
return real_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", guarded_import)

normalized_query, normalized_parameters = conn._normalize_parameters_for_capi(
query,
parameters,
)
assert normalized_query == "CREATE (n:User {meta: $meta})"
assert normalized_parameters["meta"].value == json.dumps(data, allow_nan=False)

normalized_query, normalized_parameters = conn._normalize_parameters_for_pybind(
query,
parameters,
)
assert normalized_query.startswith("CREATE (n:User {meta: CAST(")
assert normalized_query.endswith(" AS JSON)})")
assert normalized_parameters == {}


def test_to_json_python_param_with_homogeneous_list_uses_typed_binding(
conn_db_empty: ConnDB,
) -> None:
Expand Down
Loading