From 05defacea3fcd642d11e091e88614f463bdd25fa Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Sun, 17 May 2026 17:48:08 -0700 Subject: [PATCH] Handle unresolved JSON parameter types Serialize to_json parameters only when Python JSON values contain unresolved type shapes such as empty lists, empty dicts, or nulls. This avoids LIST failures while keeping homogeneous arrays on the normal typed binding path. --- src_py/connection.py | 79 ++++++++++++++++++++++++++++++++++++++++++-- test/test_json.py | 36 ++++++++++++++++++++ 2 files changed, 112 insertions(+), 3 deletions(-) diff --git a/src_py/connection.py b/src_py/connection.py index 4fe6665..6022829 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -155,12 +155,84 @@ def _normalize_parameters_for_capi( normalized_params[key] = "".join(f"\\x{byte:02x}" for byte in binary) pattern = rf"(?i)(? str: + return rf"(?i)\bto_json\(\s*\${re.escape(key)}\s*\)" + + @staticmethod + def _is_json_serializable_parameter(value: Any) -> bool: + return value is None or isinstance(value, (bool, int, float, list, tuple, dict)) + + @classmethod + def _contains_unresolved_json_type(cls, value: Any) -> bool: + if value is None: + return True + if isinstance(value, (list, tuple)): + return len(value) == 0 or any( + cls._contains_unresolved_json_type(item) for item in value + ) + if isinstance(value, dict): + return len(value) == 0 or any( + cls._contains_unresolved_json_type(item) for item in value.values() + ) + return False + + @staticmethod + def _json_string_literal(value: str) -> str: + return "'" + value.replace("\\", "\\\\").replace("'", "\\u0027") + "'" + + def _normalize_parameters_for_pybind( + self, + query: str, + parameters: dict[str, Any], + ) -> tuple[str, dict[str, Any]]: + normalized_query = query + normalized_params = dict(parameters) + + for key, value in list(normalized_params.items()): + if not isinstance(key, str): + msg = f"Parameter name must be of type string but got {type(key)}" + raise RuntimeError(msg) # noqa: TRY004 + + pattern = self._to_json_parameter_pattern(key) + if re.search(pattern, normalized_query) is None: + continue + if isinstance(value, str): + json.loads(value) + json_value = value + elif self._is_json_serializable_parameter( + value + ) and self._contains_unresolved_json_type(value): + json_value = json.dumps(value, allow_nan=False) + else: + continue + json_expr = f"CAST({self._json_string_literal(json_value)} AS JSON)" + normalized_query = re.sub( + pattern, + lambda _, json_expr=json_expr: json_expr, + normalized_query, + ) + if re.search(rf"\${re.escape(key)}\b", normalized_query) is None: + normalized_params.pop(key, None) return normalized_query, normalized_params @@ -378,6 +450,7 @@ def _execute_with_pybind( if len(parameters) == 0: return py_connection.query(query) + query, parameters = self._normalize_parameters_for_pybind(query, parameters) prepared = py_connection.prepare(query, parameters) return py_connection.execute(prepared, parameters) diff --git a/test/test_json.py b/test/test_json.py index 2e57175..376bba3 100644 --- a/test/test_json.py +++ b/test/test_json.py @@ -60,3 +60,39 @@ def test_to_json_string_param_roundtrip(conn_db_empty: ConnDB) -> None: response_data = json.loads(response.rows_as_dict().get_all()[0]["meta"]) assert response_data == data + + +def test_to_json_python_param_with_empty_nested_list(conn_db_empty: ConnDB) -> None: + conn, _ = conn_db_empty + conn.execute(""" + CREATE NODE TABLE User (id SERIAL PRIMARY KEY, meta JSON); + """) + + data = {"tags": []} + + response = conn.execute( + """ + CREATE (n:User {meta: to_json($meta)}) + RETURN n.id as id, cast(n.meta AS STRING) as meta; + """, + parameters={"meta": data}, + ) + + response_data = json.loads(response.rows_as_dict().get_all()[0]["meta"]) + assert response_data == data + + +def test_to_json_python_param_with_homogeneous_list_uses_typed_binding( + conn_db_empty: ConnDB, +) -> None: + conn, _ = conn_db_empty + query = "CREATE (n:User {meta: to_json($meta)})" + parameters = {"meta": {"tags": [1, 2, 3]}} + + normalized_query, normalized_parameters = conn._normalize_parameters_for_pybind( + query, + parameters, + ) + + assert normalized_query == query + assert normalized_parameters == parameters