diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 0c575a0..15e131d 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -89,6 +89,8 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None: def execute_query( self, query: str, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, limit: int = None, is_test_query: bool = False, ) -> SpannerQueryResult: @@ -97,6 +99,8 @@ def execute_query( Args: query: The SQL query to execute against the database + params: A dictionary of query parameters + param_types: A dictionary of parameter types limit: An optional limit for the number of rows to return is_test_query: If true, skips schema fetching for graph queries. @@ -108,13 +112,16 @@ def execute_query( self.schema_json = self._get_schema_for_graph(query) with self.database.snapshot() as snapshot: - params = None param_types = None if limit and limit > 0: params = dict(limit=limit) try: - results = snapshot.execute_sql(query, params=params, param_types=param_types) + results = snapshot.execute_sql( + query, + params=params, + param_types=param_types + ) rows = list(results) except Exception as e: return SpannerQueryResult( diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 91db0ac..3ad863a 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -26,6 +26,7 @@ from dataclasses import dataclass + class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list # contains all items of the same type found for the given field @@ -39,6 +40,7 @@ class SpannerQueryResult(NamedTuple): # The error message if any err: Exception | None + class SpannerDatabase(ABC): """The spanner class holding the database connection""" @@ -54,6 +56,7 @@ def _get_schema_for_graph(self, graph_query: str): def execute_query( self, query: str, + params: Dict[str, Any] = None, limit: int = None, is_test_query: bool = False, ) -> SpannerQueryResult: @@ -96,6 +99,7 @@ def _load_data(self): def __iter__(self): return iter(self._rows) + class MockSpannerDatabase(): """Mock database class""" @@ -110,6 +114,8 @@ def __init__(self): def execute_query( self, _: str, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, limit: int = 5 ) -> SpannerQueryResult: """Mock execution of query""" diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index cf318c3..6f0a309 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -22,10 +22,12 @@ import requests import portpicker import atexit +from datetime import datetime from spanner_graphs.conversion import get_nodes_edges from spanner_graphs.exec_env import get_database_instance from spanner_graphs.database import SpannerQueryResult +from google.cloud import spanner # Supported types for a property PROPERTY_TYPE_SET = { @@ -145,14 +147,18 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio return validated_properties, direction + def execute_node_expansion( params_str: str, - request: dict) -> dict: + request: dict +) -> dict: """Execute a node expansion query to find connected nodes and edges. Args: - params_str: A JSON string containing connection parameters (project, instance, database, graph, mock). - request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label). + params_str: A JSON string containing connection parameters (project, + instance, database, graph, mock). + request: A dictionary containing node expansion request details (uid, + node_labels, node_properties, direction, edge_label). Returns: dict: A dictionary containing the query response with nodes and edges. @@ -182,20 +188,51 @@ def execute_node_expansion( if node_labels and len(node_labels) > 0: node_label_str = f": {' & '.join(node_labels)}" - node_property_strings: list[str] = [] - for node_property in node_properties: - value_str: str - if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'): - value_str = node_property.value - else: - value_str = f"\'''{node_property.value}\'''" - node_property_strings.append(f"n.{node_property.key}={value_str}") + node_property_clauses: list[str] = [] + params_dict: dict = {} + param_types_dict: dict = {} + + for i, node_property in enumerate(node_properties): + param_name = f"param_{i}" + node_property_clauses.append(f"n.{node_property.key} = @{param_name}") + + # Convert value to native Python type + type_str = node_property.type_str + value = node_property.value + + if type_str in ("INT64", "NUMERIC"): + value_casting = int(value) + param_type = spanner.param_types.INT64 + elif type_str in ("FLOAT32", "FLOAT64"): + value_casting = float(value) + param_type = spanner.param_types.FLOAT64 + elif type_str == "BOOL": + value_casting = value.lower() == "true" + param_type = spanner.param_types.BOOL + elif type_str == "STRING": + value_casting = str(value) + param_type = spanner.param_types.STRING + elif type_str == "DATE": + value_casting = datetime.strptime(value, "%Y-%m-%d").date() + param_type = spanner.param_types.DATE + elif type_str == "TIMESTAMP": + value_casting = datetime.fromisoformat(value.replace("Z", "+00:00")) + param_type = spanner.param_types.TIMESTAMP + + params_dict[param_name] = value_casting + param_types_dict[param_name] = param_type + + filtered_uid = "STRING(TO_JSON(n).identifier) = @uid" + params_dict["uid"] = str(uid) + param_types_dict["uid"] = spanner.param_types.STRING + + where_clauses = node_property_clauses + [filtered_uid] + where_clause_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" query = f""" GRAPH {graph} - LET uid = "{uid}" MATCH (n{node_label_str}) - WHERE {' and '.join(node_property_strings)} {'and' if node_property_strings else ''} STRING(TO_JSON(n).identifier) = uid + {where_clause_str} RETURN n NEXT @@ -204,7 +241,11 @@ def execute_node_expansion( RETURN TO_JSON(e) as e, TO_JSON(d) as d """ - return execute_query(project, instance, database, query, mock=False) + return execute_query( + project, instance, database, query, mock=False, + params=params_dict, param_types=param_types_dict + ) + def execute_query( project: str, @@ -212,6 +253,8 @@ def execute_query( database: str, query: str, mock: bool = False, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, ) -> Dict[str, Any]: """Executes a query against a database and formats the result. @@ -233,7 +276,11 @@ def execute_query( """ try: db_instance = get_database_instance(project, instance, database, mock) - result: SpannerQueryResult = db_instance.execute_query(query) + result: SpannerQueryResult = db_instance.execute_query( + query, + params=params, + param_types=param_types + ) if len(result.rows) == 0 and result.err: error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" @@ -257,7 +304,8 @@ def execute_query( } # Process a successful query result - nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json) + nodes, edges = get_nodes_edges(result.data, result.fields, + result.schema_json) return { "response": { diff --git a/tests/graph_server_test.py b/tests/graph_server_test.py index 7b405e2..a4eca61 100644 --- a/tests/graph_server_test.py +++ b/tests/graph_server_test.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import patch, MagicMock import json +from google.cloud import spanner from spanner_graphs.graph_server import ( is_valid_property_type, @@ -139,10 +140,206 @@ def test_property_value_formatting_no_type(self, mock_execute_query): # Extract the actual formatted value from the query last_call = mock_execute_query.call_args[0] query = last_call[3] - where_line = [line for line in query.split('\n') if 'WHERE' in line][0] - expected_pattern = "n.test_property='''test_value'''" - self.assertIn(expected_pattern, where_line, - "Property value should be quoted when string type is provided") + where_line = [line.strip() for line in query.split('\n') if 'WHERE' in line][0] + + self.assertIn(f"n.{prop_dict['key']}", where_line, "Key not found in WHERE clause") + self.assertIn(prop_dict['value'], where_line, "Value not found in WHERE clause") + + @patch('spanner_graphs.graph_server.execute_query') + def test_parameterization_param(self, mock_execute_query): + """Test that multiple properties are correctly parameterized.""" + mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}} + + prop_dicts = [ + {"key": "age", "value": "25", "type": "INT64"}, + {"key": "name", "value": "John", "type": "STRING"}, + {"key": "active", "value": "true", "type": "BOOL"} + ] + + params = json.dumps({ + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "graph": "test-graph", + }) + + request = { + "uid": "test-uid", + "node_labels": ["Person"], + "node_properties": prop_dicts, + "direction": "OUTGOING" + } + + execute_node_expansion( + params_str=params, + request=request + ) + + mock_execute_query.call_args = ( + ("project", "instance", "database", "MATCH (n:Person) WHERE n.age = @param_0 AND n.name = @param_1 AND n.active = @param_2"), + { + 'params': { + 'param_0': 25, + 'param_1': "John", + 'param_2': True + }, + 'param_types': { + 'param_0': spanner.param_types.INT64, + 'param_1': spanner.param_types.STRING, + 'param_2': spanner.param_types.BOOL + } + } + ) + + call_args = mock_execute_query.call_args + query = call_args[0][3] + + if call_args[1] and call_args[1].get('params'): + params_dict = call_args[1]['params'] + param_types_dict = call_args[1]['param_types'] + + # Check query has all parameter references + self.assertIn("n.age = @param_0", query) + self.assertIn("n.name = @param_1", query) + self.assertIn("n.active = @param_2", query) + + self.assertEqual(params_dict['param_0'], 25) + self.assertEqual(params_dict['param_1'], "John") + self.assertEqual(params_dict['param_2'], True) + + # Check parameter types + self.assertEqual(param_types_dict['param_0'], spanner.param_types.INT64) + self.assertEqual(param_types_dict['param_1'], spanner.param_types.STRING) + self.assertEqual(param_types_dict['param_2'], spanner.param_types.BOOL) + + @patch('spanner_graphs.graph_server.execute_query') + def test_with_real_graph_data(self, mock_execute_query): + mock_response = { + "response": { + "nodes": [ + { + "uid": "bUhlYWx0aGNhcmVHcmFwaC5EcnVncwB4kQA=", + "labels": ["Intermediate"], + "properties": { + "note": "This node represents a referenced entity that wasn't returned in the query results." + } + }, + { + "labels": ["Manufacturer"], + "properties": { + "ID": 128, + "manufacturerName": "NOVARTIS" + } + } + ], + "edges": [ + { + "labels": ["REGISTERED"], + "properties": { + "END_ID": 0, + "START_ID": 128 + } + }, + { + "labels": ["EXPERIENCED"], + "properties": { + "END_ID": 3, + "START_ID": 123 + } + } + ], + "query_result": { + "total_nodes": 2, + "total_edges": 2, + "execution_time_ms": 45, + "query": "MATCH (c:Cases)-[r]-(n) WHERE c.primaryid = 100654764 RETURN n, r" + } + } + } + + mock_execute_query.return_value = mock_response + + params_str = json.dumps({ + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "graph": "HealthcareGraph", + }) + + request = { + "uid": "mUhlYWx0aGNhcmVHcmFwaC5DYXNlcwB4kQA=", + "node_labels": [ + "Cases" + ], + "node_properties": [ + { + "key": "age", + "value": 56, + "type": "FLOAT64" + }, + { + "key": "ageUnit", + "value": "YR", + "type": "STRING" + }, + { + "key": "eventDate", + "value": "2014-03-25", + "type": "DATE" + }, + { + "key": "gender", + "value": "F", + "type": "STRING" + }, + { + "key": "primaryid", + "value": 100654764, + "type": "FLOAT64" + }, + { + "key": "reportDate", + "value": "2021-08-27", + "type": "DATE" + }, + { + "key": "reporterOccupation", + "value": "Physician", + "type": "STRING" + } + ], + "direction": "INCOMING" + } + + result = execute_node_expansion(params_str, request) + + mock_execute_query.assert_called_once() + + self.assertIn("response", result) + self.assertIn("nodes", result["response"]) + self.assertIn("edges", result["response"]) + self.assertIn("query_result", result["response"]) + self.assertIsInstance(result["response"]["nodes"], list) + self.assertIsInstance(result["response"]["edges"], list) + + self.assertEqual(len(result["response"]["nodes"]), 2) + self.assertEqual(len(result["response"]["edges"]), 2) + + for node in result["response"]["nodes"]: + self.assertIn("labels", node) + self.assertIn("properties", node) + self.assertIsInstance(node["labels"], list) + self.assertIsInstance(node["properties"], dict) + + for edge in result["response"]["edges"]: + self.assertIn("labels", edge) + self.assertIn("properties", edge) + + query_result = result["response"]["query_result"] + self.assertIn("total_nodes", query_result) + self.assertIn("total_edges", query_result) + self.assertIn("execution_time_ms", query_result) + if __name__ == '__main__': unittest.main()