diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 0c575a0..be57a11 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -91,6 +91,8 @@ def execute_query( query: str, limit: int = None, is_test_query: bool = False, + params: dict = None, + param_types: dict = None, ) -> SpannerQueryResult: """ This method executes the provided `query` @@ -99,6 +101,8 @@ def execute_query( query: The SQL query to execute against the database limit: An optional limit for the number of rows to return is_test_query: If true, skips schema fetching for graph queries. + params: A dictionary of query parameters. + param_types: A dictionary of query parameter types. Returns: A `SpannerQueryResult` @@ -108,10 +112,10 @@ def execute_query( self.schema_json = self._get_schema_for_graph(query) with self.database.snapshot() as snapshot: - params = None - param_types = None if limit and limit > 0: - params = dict(limit=limit) + if params is None: + params = {} + params["limit"] = limit try: results = snapshot.execute_sql(query, params=params, param_types=param_types) diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index cf318c3..662c337 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -23,6 +23,7 @@ import portpicker import atexit +from google.cloud import spanner from spanner_graphs.conversion import get_nodes_edges from spanner_graphs.exec_env import get_database_instance from spanner_graphs.database import SpannerQueryResult @@ -32,7 +33,6 @@ 'BOOL', 'BYTES', 'DATE', - 'ENUM', 'INT64', 'NUMERIC', 'FLOAT32', @@ -171,7 +171,6 @@ def execute_node_expansion( edge = "e" if not edge_label else f"e:{edge_label}" - # Build the path pattern based on direction path_pattern = ( f"(n)-[{edge}]->(d)" if direction == EdgeDirection.OUTGOING @@ -182,29 +181,35 @@ def execute_node_expansion( if node_labels and len(node_labels) > 0: node_label_str = f": {' & '.join(node_labels)}" + query_params = {"uid": uid} + param_types = {"uid": spanner.param_types.STRING} node_property_strings: list[str] = [] - for node_property in node_properties: - value_str: str - if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'): - value_str = node_property.value - else: - value_str = f"\'''{node_property.value}\'''" - node_property_strings.append(f"n.{node_property.key}={value_str}") + + for i, node_property in enumerate(node_properties): + param_name = f"p{i}" + node_property_strings.append(f"n.{node_property.key} = @{param_name}") + query_params[param_name] = node_property.value + param_types[param_name] = getattr(spanner.param_types, node_property.type_str) + + where_clause = " and ".join(node_property_strings) + if where_clause: + where_clause += " and " + where_clause += "STRING(TO_JSON(n).identifier) = @uid" query = f""" GRAPH {graph} - LET uid = "{uid}" MATCH (n{node_label_str}) - WHERE {' and '.join(node_property_strings)} {'and' if node_property_strings else ''} STRING(TO_JSON(n).identifier) = uid + WHERE {where_clause} RETURN n NEXT MATCH {path_pattern} + WHERE STRING(TO_JSON(n).identifier) = @uid RETURN TO_JSON(e) as e, TO_JSON(d) as d """ - return execute_query(project, instance, database, query, mock=False) + return execute_query(project, instance, database, query, mock=False, params=query_params, param_types=param_types) def execute_query( project: str, diff --git a/tests/cloud_database_test.py b/tests/cloud_database_test.py index dcb6c92..ebaff5d 100644 --- a/tests/cloud_database_test.py +++ b/tests/cloud_database_test.py @@ -28,9 +28,11 @@ class TestDatabase(unittest.TestCase): """Test cases for the CloudSpannerDatabase class""" + @patch("spanner_graphs.cloud_database._get_default_credentials_with_project") @patch("spanner_graphs.cloud_database.spanner.Client") - def test_execute_query(self, mock_client: MagicMock) -> None: + def test_execute_query(self, mock_client: MagicMock, mock_creds: MagicMock) -> None: """Test that a query is executed correctly""" + mock_creds.return_value = (MagicMock(), "test-project") mock_instance = MagicMock() mock_database = MagicMock() mock_snapshot = MagicMock() diff --git a/tests/graph_server_test.py b/tests/graph_server_test.py index 7b405e2..c6cd4aa 100644 --- a/tests/graph_server_test.py +++ b/tests/graph_server_test.py @@ -20,7 +20,6 @@ def test_validate_property_type_valid_types(self): 'DATE', 'TIMESTAMP', 'BYTES', - 'ENUM', # Test case insensitivity 'int64', 'string', @@ -54,25 +53,22 @@ def test_validate_property_type_invalid_types(self): self.assertIn("Invalid property type", str(cm.exception)) self.assertIn("Allowed types are:", str(cm.exception)) + @patch('spanner_graphs.graph_server.spanner') @patch('spanner_graphs.graph_server.execute_query') - def test_property_value_formatting(self, mock_execute_query): + def test_property_value_formatting(self, mock_execute_query, mock_spanner): """Test that property values are correctly formatted based on their type.""" mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}} test_cases = [ - # Numeric types (unquoted) - ("INT64", "123", "123"), - ("NUMERIC", "123", "123"), - ("FLOAT32", "123.45", "123.45"), - ("FLOAT64", "123.45", "123.45"), - # Boolean (unquoted) - ("BOOL", "true", "true"), - # String types (quoted) - ("STRING", "hello", "'''hello'''"), - ("DATE", "2024-03-14", "'''2024-03-14'''"), - ("TIMESTAMP", "2024-03-14T12:00:00Z", "'''2024-03-14T12:00:00Z'''"), - ("BYTES", "base64data", "'''base64data'''"), - ("ENUM", "ENUM_VALUE", "'''ENUM_VALUE'''"), + ("INT64", "123"), + ("NUMERIC", "123"), + ("FLOAT32", "123.45"), + ("FLOAT64", "123.45"), + ("BOOL", "true"), + ("STRING", "hello"), + ("DATE", "2024-03-14"), + ("TIMESTAMP", "2024-03-14T12:00:00Z"), + ("BYTES", "base64data"), ] params = json.dumps({ @@ -82,9 +78,8 @@ def test_property_value_formatting(self, mock_execute_query): "graph": "test-graph", }) - for type_str, value, expected_format in test_cases: + for type_str, value in test_cases: with self.subTest(type=type_str, value=value): - # Create a property dictionary prop_dict = {"key": "test_property", "value": value, "type": type_str} request = { @@ -99,22 +94,69 @@ def test_property_value_formatting(self, mock_execute_query): request=request ) - # Extract the actual formatted value from the query - last_call = mock_execute_query.call_args[0] # Get the positional args - query = last_call[3] # The query is the 4th positional arg + last_call = mock_execute_query.call_args + query_params = last_call[1]['params'] + param_types = last_call[1]['param_types'] - # Find the WHERE clause in the query and extract the value - where_line = [line for line in query.split('\n') if 'WHERE' in line][0] - expected_pattern = f"n.test_property={expected_format}" - self.assertIn(expected_pattern, where_line, - f"Expected property value pattern {expected_pattern} not found in WHERE clause for type {type_str}") + query = last_call[0][3] + where_clause = [line.strip() for line in query.split('\n') if 'WHERE' in line][0] + expected_where = "WHERE n.test_property = @p0 and STRING(TO_JSON(n).identifier) = @uid" + self.assertEqual(where_clause, expected_where) + + self.assertEqual(query_params['p0'], value) + self.assertEqual(param_types['p0'], getattr(mock_spanner.param_types, type_str)) + + + @patch('spanner_graphs.graph_server.spanner') + @patch('spanner_graphs.graph_server.execute_query') + def test_property_value_formatting_multiple_properties(self, mock_execute_query, mock_spanner): + """Test that multiple property values are correctly formatted.""" + mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}} + + props = [ + {"key": "age", "value": 30, "type": "INT64"}, + {"key": "name", "value": "John", "type": "STRING"}, + ] + + params = json.dumps({ + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "graph": "test-graph", + }) + + request = { + "uid": "test-uid", + "node_labels": ["Person"], + "node_properties": props, + "direction": "OUTGOING" + } + + execute_node_expansion( + params_str=params, + request=request + ) + + last_call = mock_execute_query.call_args + query_params = last_call[1]['params'] + param_types = last_call[1]['param_types'] + + query = last_call[0][3] + where_clause = [line.strip() for line in query.split('\n') if 'WHERE' in line][0] + expected_where = "WHERE n.age = @p0 and n.name = @p1 and STRING(TO_JSON(n).identifier) = @uid" + self.assertEqual(where_clause, expected_where) + + self.assertEqual(query_params['p0'], 30) + self.assertEqual(param_types['p0'], mock_spanner.param_types.INT64) + + self.assertEqual(query_params['p1'], "John") + self.assertEqual(param_types['p1'], mock_spanner.param_types.STRING) @patch('spanner_graphs.graph_server.execute_query') def test_property_value_formatting_no_type(self, mock_execute_query): """Test that property values are quoted when no type is provided.""" mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}} - # Create a property dictionary with string type (since null type is not allowed) prop_dict = {"key": "test_property", "value": "test_value", "type": "STRING"} params = json.dumps({ @@ -136,13 +178,13 @@ def test_property_value_formatting_no_type(self, mock_execute_query): request=request ) - # Extract the actual formatted value from the query - last_call = mock_execute_query.call_args[0] - query = last_call[3] - where_line = [line for line in query.split('\n') if 'WHERE' in line][0] - expected_pattern = "n.test_property='''test_value'''" - self.assertIn(expected_pattern, where_line, - "Property value should be quoted when string type is provided") + last_call = mock_execute_query.call_args + query_params = last_call[1]['params'] + param_types = last_call[1]['param_types'] + + self.assertIn('@p0', last_call[0][3]) + self.assertEqual(query_params['p0'], "test_value") + self.assertIsNotNone(param_types['p0']) if __name__ == '__main__': unittest.main()