From 39635ea11a32529f9b3e7140b95e0e01d78c5d13 Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Date: Fri, 6 Jun 2025 16:42:16 +0000 Subject: [PATCH 1/3] Standardize return type for execute_query() functions - Also removes cloud-spanner specific fields from the return type. Specifically `StructType.Field` is removed from the return type. Removing this tightly coupled logic is required to allow new DB implementations. --- spanner_graphs/database.py | 46 +++++++++++------ spanner_graphs/graph_server.py | 91 +++++++++++++++++++++------------- tests/conversion_test.py | 4 +- 3 files changed, 90 insertions(+), 51 deletions(-) diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index b879d73..d3e06fc 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -18,7 +18,7 @@ """ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, NamedTuple import json import os import csv @@ -29,6 +29,19 @@ from google.cloud.spanner_v1.types import StructType, TypeCode, Type import pydata_google_auth +class SpannerQueryResult(NamedTuple): + # A dict where each key is a field name returned in the query and the list + # contains all items of the same type found for the given field + data: Dict[str, List[Any]] + # A list representing the fields in the result set. + fields: List[Any] + # A list of rows as returned by the query execution. + rows: List[Any] + # An optional field to return the schema as JSON + schema_json: Any | None + # The error message if any + error: Exception | None + def _get_default_credentials_with_project(): return pydata_google_auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) @@ -87,7 +100,7 @@ def execute_query( query: str, limit: int = None, is_test_query: bool = False, - ): + ) -> SpannerQueryResult: """ This method executes the provided `query` @@ -96,13 +109,7 @@ def execute_query( limit: An optional limit for the number of rows to return Returns: - A tuple containing: - - Dict[str, List[Any]]: A dict where each key is a field name - returned in the query and the list contains all items of the same - type found for the given field. - - A list of StructType.Fields representing the fields in the result set. - - A list of rows as returned by the query execution. - - The error message if any. + A SpannerQueryResult tuple """ self.schema_json = None if not is_test_query: @@ -131,9 +138,14 @@ def execute_query( data[field.name].append(json.loads(value.serialize())) else: data[field.name].append(value) + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) - return data, fields, rows, self.schema_json, None - class MockSpannerResult: def __init__(self, file_path: str): @@ -180,7 +192,7 @@ def execute_query( self, _: str, limit: int = 5 - ) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]: + ) -> SpannerQueryResult: """Mock execution of query""" # Before the actual query we fetch the schema as well @@ -201,7 +213,13 @@ def execute_query( for field, value in zip(fields, row): data[field.name].append(value) - return data, fields, rows, self.schema_json, None + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = { @@ -221,4 +239,4 @@ def get_database_instance(project: str, instance: str, database: str, mock = Fal db = SpannerDatabase(project, instance, database) database_instances[key] = db - return db \ No newline at end of file + return db diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 7860c49..fe34d36 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -17,7 +17,7 @@ import json import threading from enum import Enum -from typing import Union +from typing import Union, Dict, Any import requests import portpicker @@ -206,49 +206,70 @@ def execute_node_expansion( return execute_query(project, instance, database, query, mock=False) -def execute_query(project: str, instance: str, database: str, query: str, mock = False): - database = get_database_instance(project, instance, database, mock) +def execute_query( + project: str, + instance: str, + database: str, + query: str, + mock: bool = False, +) -> Dict[str, Any]: + """Executes a query against a database and formats the result. + Connects to a database, runs the query, and processes the resulting object. + On success, it formats the data into nodes and edges for graph visualization. + If the query fails, it returns a detailed error message, optionally + including the database schema to aid in debugging. + + Args: + project: The cloud project ID. + instance: The database instance name. + database: The database name. + query: The query string to execute. + mock: If True, use a mock database instance for testing. Defaults to False. + + Returns: + A dictionary containing either the structured 'response' with nodes, + edges, and other data, or an 'error' key with a descriptive message. + """ try: - query_result, fields, rows, schema_json, err = database.execute_query(query) - if len(rows) == 0 and err : # if query returned an error - if schema_json: # if the schema exists - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"We've detected an error in your query. To help you troubleshoot, the graph schema's layout is shown above." + "\n\n" + f"Query error: \n{getattr(err, 'message', str(err))}" - } - if not schema_json: # if the schema does not exist - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"Query error: \n{getattr(err, 'message', str(err))}" - } - nodes, edges = get_nodes_edges(query_result, fields, schema_json) - + db_instance = get_database_instance(project, instance, database, mock) + result: SpannerQueryResult = db_instance.execute_query(query) + + if len(result.rows) == 0 and result.err: + error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" + if result.schema_json: + # Prepend a helpful message if the schema is available + error_message = ( + "We've detected an error in your query. To help you troubleshoot, " + "the graph schema's layout is shown above.\n\n" + error_message + ) + + # Consolidate the repetitive error response into a single return + return { + "response": { + "schema": result.schema_json, + "query_result": result.data, + "nodes": [], + "edges": [], + "rows": [], + }, + "error": error_message, + } + + # Process a successful query result + nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json) + return { "response": { "nodes": [node.to_json() for node in nodes], "edges": [edge.to_json() for edge in edges], - "schema": schema_json, - "rows": rows, - "query_result": query_result + "schema": result.schema_json, + "rows": result.rows, + "query_result": result.data, } } except Exception as e: - return { - "error": getattr(e, "message", str(e)) - } + return {"error": getattr(e, "message", str(e))} class GraphServer: diff --git a/tests/conversion_test.py b/tests/conversion_test.py index e53c56c..b014d3e 100644 --- a/tests/conversion_test.py +++ b/tests/conversion_test.py @@ -38,10 +38,10 @@ def test_get_nodes_edges(self) -> None: """ # Get data from mock database mock_db = MockSpannerDatabase() - data, fields, _, schema_json = mock_db.execute_query("") + query_result = mock_db.execute_query("") # Convert data to nodes and edges - nodes, edges = get_nodes_edges(data, fields) + nodes, edges = get_nodes_edges(query_result.data, query_result.fields) # Verify we got some nodes and edges self.assertTrue(len(nodes) > 0, "Should have at least one node") From 402b9d3aefc66c61df07a39618ce3566592e55e1 Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Date: Mon, 9 Jun 2025 18:48:07 +0000 Subject: [PATCH 2/3] Abstract away "database" implementations and remove strong coupling of DB implementation with APIs 1. Abstracts SpannerDatabase with clear APIs 2. Introduces CloudSpannerDatabase as an implementation of SpannerDatabase 3. Removes further tight coupling with the cloud spanner client by adding a SpannerFieldInfo dataclass to replace usage of StructType.Field --- spanner_graphs/cloud_database.py | 143 ++++++++++++++ spanner_graphs/conversion.py | 11 +- spanner_graphs/database.py | 182 ++++++------------ spanner_graphs/graph_server.py | 20 +- spanner_graphs/magics.py | 4 +- ...atabase_test.py => cloud_database_test.py} | 14 +- tests/conversion_test.py | 52 ++--- 7 files changed, 251 insertions(+), 175 deletions(-) create mode 100644 spanner_graphs/cloud_database.py rename tests/{database_test.py => cloud_database_test.py} (81%) diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py new file mode 100644 index 0000000..3ce3d26 --- /dev/null +++ b/spanner_graphs/cloud_database.py @@ -0,0 +1,143 @@ +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains the cloud-specific implementation for talking to a Spanner database. +""" + +from __future__ import annotations +import json +from typing import Any, Dict, List, Tuple + +from google.cloud import spanner +from google.cloud.spanner_v1 import JsonObject +from google.api_core.client_options import ClientOptions +from google.cloud.spanner_v1.types import StructType, Type +import pydata_google_auth + +from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo, get_as_field_info_list + +def _get_default_credentials_with_project(): + return pydata_google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) + +class CloudSpannerDatabase(SpannerDatabase): + """Concrete implementation for Spanner database on the cloud.""" + def __init__(self, project_id: str, instance_id: str, + database_id: str) -> None: + credentials, _ = _get_default_credentials_with_project() + self.client = spanner.Client( + project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) + self.instance = self.client.instance(instance_id) + self.database = self.instance.database(database_id) + self.schema_json: Any | None = None + + def __repr__(self) -> str: + return (f"") + + def _extract_graph_name(self, query: str) -> str: + words = query.strip().split() + if len(words) < 3: + raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)") + + if words[0].upper() != "GRAPH": + raise ValueError("invalid query: GRAPH must be the first word") + + return words[1] + + def _get_schema_for_graph(self, graph_query: str) -> Any | None: + try: + graph_name = self._extract_graph_name(graph_query) + except ValueError: + return None + + with self.database.snapshot() as snapshot: + schema_query = """ + SELECT property_graph_name, property_graph_metadata_json + FROM information_schema.property_graphs + WHERE property_graph_name = @graph_name + """ + params = {"graph_name": graph_name} + param_type = {"graph_name": spanner.param_types.STRING} + + result = snapshot.execute_sql(schema_query, params=params, param_types=param_type) + schema_rows = list(result) + + if schema_rows: + return schema_rows[0][1] + else: + return None + + def execute_query( + self, + query: str, + limit: int = None, + is_test_query: bool = False, + ) -> SpannerQueryResult: + """ + This method executes the provided `query` + + Args: + query: The SQL query to execute against the database + limit: An optional limit for the number of rows to return + is_test_query: If true, skips schema fetching for graph queries. + + Returns: + A `SpannerQueryResult` + """ + self.schema_json = None + if not is_test_query: + self.schema_json = self._get_schema_for_graph(query) + + with self.database.snapshot() as snapshot: + params = None + param_types = None + if limit and limit > 0: + params = dict(limit=limit) + + try: + results = snapshot.execute_sql(query, params=params, param_types=param_types) + rows = list(results) + except Exception as e: + return {}, [], [], self.schema_json, e + + fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields) + data = {field.name: [] for field in fields} + + if len(fields) == 0: + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) + + for row_data in rows: + for field, value in zip(fields, row_data): + if isinstance(value, JsonObject): + data[field.name].append(json.loads(value.serialize())) + else: + data[field.name].append(value) + + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) diff --git a/spanner_graphs/conversion.py b/spanner_graphs/conversion.py index 21eef39..52d385d 100644 --- a/spanner_graphs/conversion.py +++ b/spanner_graphs/conversion.py @@ -23,10 +23,11 @@ from google.cloud.spanner_v1.types import TypeCode, StructType +from spanner_graphs.database import SpannerFieldInfo from spanner_graphs.graph_entities import Node, Edge from spanner_graphs.schema_manager import SchemaManager -def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]: +def get_nodes_edges(data: Dict[str, List[Any]], fields: List[SpannerFieldInfo], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]: schema_manager = SchemaManager(schema_json) nodes: List[Node] = [] edges: List[Edge] = [] @@ -37,15 +38,15 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], for field in fields: column_name = field.name column_data = data[column_name] - + # Only process JSON and Array of JSON types - if field.type_.code not in [TypeCode.JSON, TypeCode.ARRAY]: + if field.typename not in ["JSON", "ARRAY"]: continue # Process each value in the column for value in column_data: items_to_process = [] - + # Handle both single JSON and arrays of JSON if isinstance(value, list): items_to_process.extend(value) @@ -92,4 +93,4 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], nodes.append(Node.make_intermediate(identifier)) node_identifiers.add(identifier) - return nodes, edges \ No newline at end of file + return nodes, edges diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index d3e06fc..2b2a285 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -18,11 +18,13 @@ """ from __future__ import annotations +from abc import ABC, abstractmethod from typing import Any, Dict, List, Tuple, NamedTuple import json import os import csv +from dataclasses import dataclass from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject from google.api_core.client_options import ClientOptions @@ -34,7 +36,7 @@ class SpannerQueryResult(NamedTuple): # contains all items of the same type found for the given field data: Dict[str, List[Any]] # A list representing the fields in the result set. - fields: List[Any] + fields: List[SpannerFieldInfo] # A list of rows as returned by the query execution. rows: List[Any] # An optional field to return the schema as JSON @@ -42,115 +44,60 @@ class SpannerQueryResult(NamedTuple): # The error message if any error: Exception | None -def _get_default_credentials_with_project(): - return pydata_google_auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) - -class SpannerDatabase: +class SpannerDatabase(ABC): """The spanner class holding the database connection""" - def __init__(self, project_id: str, instance_id: str, - database_id: str) -> None: - credentials, _ = _get_default_credentials_with_project() - self.client = spanner.Client( - project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) - self.instance = self.client.instance(instance_id) - self.database = self.instance.database(database_id) - - def __repr__(self) -> str: - return (f" str: - words = query.strip().split() - if len(words) < 3: - raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)") - - if words[0].upper() != "GRAPH": - raise ValueError("invalid query: GRAPH must be the first word") - - return words[1] + pass + @abstractmethod def _get_schema_for_graph(self, graph_query: str): - try: - graph_name = self._extract_graph_name(graph_query) - except ValueError as e: - return None - - with self.database.snapshot() as snapshot: - schema_query = """ - SELECT property_graph_name, property_graph_metadata_json - FROM information_schema.property_graphs - WHERE property_graph_name = @graph_name - """ - params = {"graph_name": graph_name} - param_type = {"graph_name": spanner.param_types.STRING} - - result = snapshot.execute_sql(schema_query, params=params, param_types=param_type) - schema_rows = list(result) - - if schema_rows: - return schema_rows[0][1] - else: - return None + pass + @abstractmethod def execute_query( self, query: str, limit: int = None, is_test_query: bool = False, ) -> SpannerQueryResult: - """ - This method executes the provided `query` - - Args: - query: The SQL query to execute against the database - limit: An optional limit for the number of rows to return - - Returns: - A SpannerQueryResult tuple - """ - self.schema_json = None - if not is_test_query: - self.schema_json = self._get_schema_for_graph(query) - - with self.database.snapshot() as snapshot: - params = None - if limit and limit > 0: - params = dict(limit=limit) - try : - results = snapshot.execute_sql(query, params=params) - rows = list(results) - except Exception as e: - return {},[],[], self.schema_json, e - fields: List[StructType.Field] = results.fields - - data = {field.name: [] for field in fields} - - if len(fields) == 0: - return data, fields, rows - - for row in rows: - for field, value in zip(fields, row): - if isinstance(value, JsonObject): - # Handle JSON objects by properly deserializing them back into Python objects - data[field.name].append(json.loads(value.serialize())) - else: - data[field.name].append(value) - return SpannerQueryResult( - data=data, - fields=fields, - rows=rows, - schema_json=self.schema_json, - error=None - ) + pass + +# Represents the name and type of a field in a Spanner query result. (Implementation-agnostic) +@dataclass +class SpannerFieldInfo: + name: str + typename: str + +def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]: + """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" + return [SpannerFieldInfo(name=field.name, typename=field.type_.code.name) for field in fields] + + +# Global dict of database instances created in a single session +database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = {} + +def get_database_instance(project: str, instance: str, database: str, mock = False) -> SpannerDatabase: + if mock: + return MockSpannerDatabase() + + key = f"{project}_{instance}_{database}" + db = database_instances.get(key) + + # Currently, we only create and return CloudSpannerDatabase instances. In the future, different + # implementations could be introduced. + if not db: + db = CloudSpannerDatabase(project, instance, database) + database_instances[key] = db + + return db class MockSpannerResult: def __init__(self, file_path: str): self.file_path = file_path - self.fields: List[StructType] = [] + self.fields: List[SpannerFieldInfo] = [] self._rows: List[List[Any]] = [] self._load_data() @@ -159,7 +106,7 @@ def _load_data(self): csv_reader = csv.reader(csvfile) headers = next(csv_reader) self.fields = [ - StructType.Field(name=header, type_=Type(code=TypeCode.JSON)) + SpannerFieldInfo(name=header, typename="JSON") for header in headers ] @@ -176,8 +123,7 @@ def _load_data(self): def __iter__(self): return iter(self._rows) - -class MockSpannerDatabase: +class MockSpannerDatabase(): """Mock database class""" def __init__(self): @@ -195,17 +141,23 @@ def execute_query( ) -> SpannerQueryResult: """Mock execution of query""" - # Before the actual query we fetch the schema as well + # Fetch the schema with open(self.schema_json_path, "r", encoding="utf-8") as js: self.schema_json = json.load(js) results = MockSpannerResult(self.graph_csv_path) - fields: List[StructType.Field] = results.fields + fields: List[SpannerFieldInfo] = results.fields rows = list(results) data = {field.name: [] for field in fields} if len(fields) == 0: - return data, fields, rows + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) for i, row in enumerate(results): if limit is not None and i >= limit: @@ -214,29 +166,9 @@ def execute_query( data[field.name].append(value) return SpannerQueryResult( - data=data, - fields=fields, - rows=rows, - schema_json=self.schema_json, - error=None - ) - - -database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = { - # "project_instance_database": SpannerDatabase -} - - -def get_database_instance(project: str, instance: str, database: str, mock = False): - if mock: - return MockSpannerDatabase() - - key = f"{project}_{instance}_{database}" - - db = database_instances.get(key, None) - if not db: - # Now create and insert it. - db = SpannerDatabase(project, instance, database) - database_instances[key] = db - - return db + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index fe34d36..6c7e347 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -56,27 +56,27 @@ class EdgeDirection(Enum): def validate_property_type(property_type: str) -> TypeCode: """ Validates and converts a property type string to a Spanner TypeCode. - + Args: property_type: The property type string from the request - + Returns: The corresponding TypeCode enum value - + Raises: ValueError: If the property type is invalid """ if not property_type: raise ValueError("Property type must be provided") - + # Convert to uppercase for case-insensitive comparison property_type = property_type.upper() - + # Check if the type is valid if property_type not in PROPERTY_TYPE_MAP: valid_types = ', '.join(sorted(PROPERTY_TYPE_MAP.keys())) raise ValueError(f"Invalid property type: {property_type}. Allowed types are: {valid_types}") - + return PROPERTY_TYPE_MAP[property_type] def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection): @@ -149,11 +149,11 @@ def execute_node_expansion( params_str: str, request: dict) -> dict: """Execute a node expansion query to find connected nodes and edges. - + Args: params_str: A JSON string containing connection parameters (project, instance, database, graph, mock). request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label). - + Returns: dict: A dictionary containing the query response with nodes and edges. """ @@ -381,7 +381,7 @@ def handle_post_ping(self): def handle_post_query(self): data = self.parse_post_data() params = json.loads(data["params"]) - response = execute_query( + response = execute_query( project=params["project"], instance=params["instance"], database=params["database"], @@ -392,7 +392,7 @@ def handle_post_query(self): def handle_post_node_expansion(self): """Handle POST requests for node expansion. - + Expects a JSON payload with: - params: A JSON string containing connection parameters (project, instance, database, graph) - request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label) diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index 93c3d50..b21eced 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -94,7 +94,7 @@ def receive_query_request(query: str, params: str): def receive_node_expansion_request(request: dict, params_str: str): """Handle node expansion requests in Google Colab environment - + Args: request: A dictionary containing node expansion details including: - uid: str - Unique identifier of the node to expand @@ -108,7 +108,7 @@ def receive_node_expansion_request(request: dict, params_str: str): - database: str - Spanner database ID - graph: str - Graph name - mock: bool - Whether to use mock data - + Returns: JSON: A JSON-serialized response containing either: - The query results with nodes and edges diff --git a/tests/database_test.py b/tests/cloud_database_test.py similarity index 81% rename from tests/database_test.py rename to tests/cloud_database_test.py index ac39a76..dcb6c92 100644 --- a/tests/database_test.py +++ b/tests/cloud_database_test.py @@ -22,13 +22,13 @@ from google.cloud.spanner_v1.types import Type, TypeCode, StructType -from spanner_graphs.database import SpannerDatabase - +from spanner_graphs.cloud_database import CloudSpannerDatabase +from spanner_graphs.database import SpannerFieldInfo class TestDatabase(unittest.TestCase): - """Test cases for the SpannerDatabase class""" + """Test cases for the CloudSpannerDatabase class""" - @patch("spanner_graphs.database.spanner.Client") + @patch("spanner_graphs.cloud_database.spanner.Client") def test_execute_query(self, mock_client: MagicMock) -> None: """Test that a query is executed correctly""" mock_instance = MagicMock() @@ -51,11 +51,11 @@ def test_execute_query(self, mock_client: MagicMock) -> None: StructType.Field(name="field3", type_=Type(code=TypeCode.JSON)), ] - db = SpannerDatabase("test_project", "test_instance", "test_database") + db = CloudSpannerDatabase("test_project", "test_instance", "test_database") result = db.execute_query("SELECT * FROM test", is_test_query=True) - self.assertEqual(result[0]["field1"], ['{"key": "value1"}']) - self.assertEqual(result[1][0].name, "field1") + self.assertEqual(result.data["field1"], ['{"key": "value1"}']) + self.assertEqual(result.fields[0].name, "field1") if __name__ == "__main__": diff --git a/tests/conversion_test.py b/tests/conversion_test.py index b014d3e..8ec9e7f 100644 --- a/tests/conversion_test.py +++ b/tests/conversion_test.py @@ -23,7 +23,7 @@ from google.cloud.spanner_v1.types import StructType, Type, TypeCode from spanner_graphs.conversion import get_nodes_edges -from spanner_graphs.database import MockSpannerDatabase +from spanner_graphs.database import SpannerFieldInfo, MockSpannerDatabase class TestConversion(unittest.TestCase): @@ -72,7 +72,7 @@ def test_get_nodes_edges(self) -> None: self.assertTrue(hasattr(edge, 'destination'), "Edge should have a destination") self.assertIsInstance(edge.labels, list, "Edge labels should be a list") self.assertIsInstance(edge.properties, dict, "Edge properties should be a dict") - + # Verify edge endpoints exist in nodes source_exists = any(node.identifier == edge.source for node in nodes) dest_exists = any(node.identifier == edge.destination for node in nodes) @@ -94,32 +94,32 @@ def test_get_nodes_edges_with_missing_nodes(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }) # Note: node2 is intentionally missing ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 1, "Should have one edge") self.assertEqual(len(nodes), 2, "Should have two nodes (one real, one intermediate)") - + # Verify node identifiers node_ids = {node.identifier for node in nodes} self.assertIn("node1", node_ids, "Original node should exist") self.assertIn("node2", node_ids, "Missing node should be created as intermediate") - + # Find the intermediate node intermediate_node = next((node for node in nodes if node.identifier == "node2"), None) self.assertIsNotNone(intermediate_node, "Intermediate node should exist") @@ -150,33 +150,33 @@ def test_get_nodes_edges_with_multiple_references(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }), json.dumps({ "kind": "node", - "identifier": "node2", + "identifier": "node2", "labels": ["Device"], "properties": {"name": "Switch"} }) # Note: missing_node is intentionally missing ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 2, "Should have two edges") self.assertEqual(len(nodes), 3, "Should have three nodes (two real, one intermediate)") - + # Count intermediate nodes intermediate_nodes = [node for node in nodes if node.intermediate] self.assertEqual(len(intermediate_nodes), 1, "Should create only one intermediate node") @@ -197,32 +197,32 @@ def test_get_nodes_edges_with_complete_data(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }), json.dumps({ "kind": "node", - "identifier": "node2", + "identifier": "node2", "labels": ["Device"], "properties": {"name": "Switch"} }) ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 1, "Should have one edge") self.assertEqual(len(nodes), 2, "Should have exactly two nodes (no intermediates)") - + # Verify no intermediate nodes exist intermediate_nodes = [node for node in nodes if node.intermediate] self.assertEqual(len(intermediate_nodes), 0, "Should not create any intermediate nodes") From 652a0733043ef145c39307681735fab0722df4a8 Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Gangatharan Date: Wed, 11 Jun 2025 11:26:44 -0700 Subject: [PATCH 3/3] Introduce exec_env.py to maintain global state + minor bug fix 1. The global database_instances is moved to exec_env.py to avoid circular imports. 2. SpannerFiledInfo.typename populated with the correct name now 3. Remove all cloud spanner refs from database.py --- spanner_graphs/cloud_database.py | 8 +++++-- spanner_graphs/database.py | 27 --------------------- spanner_graphs/exec_env.py | 41 ++++++++++++++++++++++++++++++++ spanner_graphs/graph_server.py | 2 +- spanner_graphs/magics.py | 2 +- 5 files changed, 49 insertions(+), 31 deletions(-) create mode 100644 spanner_graphs/exec_env.py diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 3ce3d26..630ed1c 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -23,15 +23,19 @@ from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject from google.api_core.client_options import ClientOptions -from google.cloud.spanner_v1.types import StructType, Type +from google.cloud.spanner_v1.types import StructType, Type, TypeCode import pydata_google_auth -from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo, get_as_field_info_list +from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo def _get_default_credentials_with_project(): return pydata_google_auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) +def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]: + """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" + return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields] + class CloudSpannerDatabase(SpannerDatabase): """Concrete implementation for Spanner database on the cloud.""" def __init__(self, project_id: str, instance_id: str, diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 2b2a285..ace5584 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -25,11 +25,6 @@ import csv from dataclasses import dataclass -from google.cloud import spanner -from google.cloud.spanner_v1 import JsonObject -from google.api_core.client_options import ClientOptions -from google.cloud.spanner_v1.types import StructType, TypeCode, Type -import pydata_google_auth class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list @@ -70,28 +65,6 @@ class SpannerFieldInfo: name: str typename: str -def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]: - """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" - return [SpannerFieldInfo(name=field.name, typename=field.type_.code.name) for field in fields] - - -# Global dict of database instances created in a single session -database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = {} - -def get_database_instance(project: str, instance: str, database: str, mock = False) -> SpannerDatabase: - if mock: - return MockSpannerDatabase() - - key = f"{project}_{instance}_{database}" - db = database_instances.get(key) - - # Currently, we only create and return CloudSpannerDatabase instances. In the future, different - # implementations could be introduced. - if not db: - db = CloudSpannerDatabase(project, instance, database) - database_instances[key] = db - - return db class MockSpannerResult: diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py new file mode 100644 index 0000000..4a60efe --- /dev/null +++ b/spanner_graphs/exec_env.py @@ -0,0 +1,41 @@ + +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module maintains state for the execution environment of a session +""" +from typing import Dict, Union + +from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase +from spanner_graphs.cloud_database import CloudSpannerDatabase + +# Global dict of database instances created in a single session +database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {} + +def get_database_instance(project: str, instance: str, database: str, mock = False): + if mock: + return MockSpannerDatabase() + + key = f"{project}_{instance}_{database}" + db = database_instances.get(key) + + # Currently, we only create and return CloudSpannerDatabase instances. In the future, different + # implementations could be introduced. + if not db: + db = CloudSpannerDatabase(project, instance, database) + database_instances[key] = db + + return db + diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 6c7e347..1ff7062 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -25,7 +25,7 @@ from google.cloud.spanner_v1 import TypeCode from spanner_graphs.conversion import get_nodes_edges -from spanner_graphs.database import get_database_instance +from spanner_graphs.exec_env import get_database_instance # Mapping of string types from frontend to Spanner TypeCode enum values diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index b21eced..d3f6760 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -33,7 +33,7 @@ from ipywidgets import interact from jinja2 import Template -from spanner_graphs.database import get_database_instance +from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( GraphServer, execute_query, execute_node_expansion, validate_node_expansion_request