diff --git a/src_py/connection.py b/src_py/connection.py index 4fe6665..6022829 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -155,12 +155,84 @@ def _normalize_parameters_for_capi( normalized_params[key] = "".join(f"\\x{byte:02x}" for byte in binary) pattern = rf"(?i)(? str: + return rf"(?i)\bto_json\(\s*\${re.escape(key)}\s*\)" + + @staticmethod + def _is_json_serializable_parameter(value: Any) -> bool: + return value is None or isinstance(value, (bool, int, float, list, tuple, dict)) + + @classmethod + def _contains_unresolved_json_type(cls, value: Any) -> bool: + if value is None: + return True + if isinstance(value, (list, tuple)): + return len(value) == 0 or any( + cls._contains_unresolved_json_type(item) for item in value + ) + if isinstance(value, dict): + return len(value) == 0 or any( + cls._contains_unresolved_json_type(item) for item in value.values() + ) + return False + + @staticmethod + def _json_string_literal(value: str) -> str: + return "'" + value.replace("\\", "\\\\").replace("'", "\\u0027") + "'" + + def _normalize_parameters_for_pybind( + self, + query: str, + parameters: dict[str, Any], + ) -> tuple[str, dict[str, Any]]: + normalized_query = query + normalized_params = dict(parameters) + + for key, value in list(normalized_params.items()): + if not isinstance(key, str): + msg = f"Parameter name must be of type string but got {type(key)}" + raise RuntimeError(msg) # noqa: TRY004 + + pattern = self._to_json_parameter_pattern(key) + if re.search(pattern, normalized_query) is None: + continue + if isinstance(value, str): + json.loads(value) + json_value = value + elif self._is_json_serializable_parameter( + value + ) and self._contains_unresolved_json_type(value): + json_value = json.dumps(value, allow_nan=False) + else: + continue + json_expr = f"CAST({self._json_string_literal(json_value)} AS JSON)" + normalized_query = re.sub( + pattern, + lambda _, json_expr=json_expr: json_expr, + normalized_query, + ) + if re.search(rf"\${re.escape(key)}\b", normalized_query) is None: + normalized_params.pop(key, None) return normalized_query, normalized_params @@ -378,6 +450,7 @@ def _execute_with_pybind( if len(parameters) == 0: return py_connection.query(query) + query, parameters = self._normalize_parameters_for_pybind(query, parameters) prepared = py_connection.prepare(query, parameters) return py_connection.execute(prepared, parameters) diff --git a/test/test_json.py b/test/test_json.py index 2e57175..376bba3 100644 --- a/test/test_json.py +++ b/test/test_json.py @@ -60,3 +60,39 @@ def test_to_json_string_param_roundtrip(conn_db_empty: ConnDB) -> None: response_data = json.loads(response.rows_as_dict().get_all()[0]["meta"]) assert response_data == data + + +def test_to_json_python_param_with_empty_nested_list(conn_db_empty: ConnDB) -> None: + conn, _ = conn_db_empty + conn.execute(""" + CREATE NODE TABLE User (id SERIAL PRIMARY KEY, meta JSON); + """) + + data = {"tags": []} + + response = conn.execute( + """ + CREATE (n:User {meta: to_json($meta)}) + RETURN n.id as id, cast(n.meta AS STRING) as meta; + """, + parameters={"meta": data}, + ) + + response_data = json.loads(response.rows_as_dict().get_all()[0]["meta"]) + assert response_data == data + + +def test_to_json_python_param_with_homogeneous_list_uses_typed_binding( + conn_db_empty: ConnDB, +) -> None: + conn, _ = conn_db_empty + query = "CREATE (n:User {meta: to_json($meta)})" + parameters = {"meta": {"tags": [1, 2, 3]}} + + normalized_query, normalized_parameters = conn._normalize_parameters_for_pybind( + query, + parameters, + ) + + assert normalized_query == query + assert normalized_parameters == parameters