diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py new file mode 100644 index 0000000..630ed1c --- /dev/null +++ b/spanner_graphs/cloud_database.py @@ -0,0 +1,147 @@ +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains the cloud-specific implementation for talking to a Spanner database. +""" + +from __future__ import annotations +import json +from typing import Any, Dict, List, Tuple + +from google.cloud import spanner +from google.cloud.spanner_v1 import JsonObject +from google.api_core.client_options import ClientOptions +from google.cloud.spanner_v1.types import StructType, Type, TypeCode +import pydata_google_auth + +from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo + +def _get_default_credentials_with_project(): + return pydata_google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) + +def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]: + """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" + return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields] + +class CloudSpannerDatabase(SpannerDatabase): + """Concrete implementation for Spanner database on the cloud.""" + def __init__(self, project_id: str, instance_id: str, + database_id: str) -> None: + credentials, _ = _get_default_credentials_with_project() + self.client = spanner.Client( + project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) + self.instance = self.client.instance(instance_id) + self.database = self.instance.database(database_id) + self.schema_json: Any | None = None + + def __repr__(self) -> str: + return (f"") + + def _extract_graph_name(self, query: str) -> str: + words = query.strip().split() + if len(words) < 3: + raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)") + + if words[0].upper() != "GRAPH": + raise ValueError("invalid query: GRAPH must be the first word") + + return words[1] + + def _get_schema_for_graph(self, graph_query: str) -> Any | None: + try: + graph_name = self._extract_graph_name(graph_query) + except ValueError: + return None + + with self.database.snapshot() as snapshot: + schema_query = """ + SELECT property_graph_name, property_graph_metadata_json + FROM information_schema.property_graphs + WHERE property_graph_name = @graph_name + """ + params = {"graph_name": graph_name} + param_type = {"graph_name": spanner.param_types.STRING} + + result = snapshot.execute_sql(schema_query, params=params, param_types=param_type) + schema_rows = list(result) + + if schema_rows: + return schema_rows[0][1] + else: + return None + + def execute_query( + self, + query: str, + limit: int = None, + is_test_query: bool = False, + ) -> SpannerQueryResult: + """ + This method executes the provided `query` + + Args: + query: The SQL query to execute against the database + limit: An optional limit for the number of rows to return + is_test_query: If true, skips schema fetching for graph queries. + + Returns: + A `SpannerQueryResult` + """ + self.schema_json = None + if not is_test_query: + self.schema_json = self._get_schema_for_graph(query) + + with self.database.snapshot() as snapshot: + params = None + param_types = None + if limit and limit > 0: + params = dict(limit=limit) + + try: + results = snapshot.execute_sql(query, params=params, param_types=param_types) + rows = list(results) + except Exception as e: + return {}, [], [], self.schema_json, e + + fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields) + data = {field.name: [] for field in fields} + + if len(fields) == 0: + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) + + for row_data in rows: + for field, value in zip(fields, row_data): + if isinstance(value, JsonObject): + data[field.name].append(json.loads(value.serialize())) + else: + data[field.name].append(value) + + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) diff --git a/spanner_graphs/conversion.py b/spanner_graphs/conversion.py index 21eef39..52d385d 100644 --- a/spanner_graphs/conversion.py +++ b/spanner_graphs/conversion.py @@ -23,10 +23,11 @@ from google.cloud.spanner_v1.types import TypeCode, StructType +from spanner_graphs.database import SpannerFieldInfo from spanner_graphs.graph_entities import Node, Edge from spanner_graphs.schema_manager import SchemaManager -def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]: +def get_nodes_edges(data: Dict[str, List[Any]], fields: List[SpannerFieldInfo], schema_json: dict = None) -> Tuple[List[Node], List[Edge]]: schema_manager = SchemaManager(schema_json) nodes: List[Node] = [] edges: List[Edge] = [] @@ -37,15 +38,15 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], for field in fields: column_name = field.name column_data = data[column_name] - + # Only process JSON and Array of JSON types - if field.type_.code not in [TypeCode.JSON, TypeCode.ARRAY]: + if field.typename not in ["JSON", "ARRAY"]: continue # Process each value in the column for value in column_data: items_to_process = [] - + # Handle both single JSON and arrays of JSON if isinstance(value, list): items_to_process.extend(value) @@ -92,4 +93,4 @@ def get_nodes_edges(data: Dict[str, List[Any]], fields: List[StructType.Field], nodes.append(Node.make_intermediate(identifier)) node_identifiers.add(identifier) - return nodes, edges \ No newline at end of file + return nodes, edges diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index b879d73..ace5584 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -18,127 +18,59 @@ """ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, NamedTuple import json import os import csv -from google.cloud import spanner -from google.cloud.spanner_v1 import JsonObject -from google.api_core.client_options import ClientOptions -from google.cloud.spanner_v1.types import StructType, TypeCode, Type -import pydata_google_auth - -def _get_default_credentials_with_project(): - return pydata_google_auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) - -class SpannerDatabase: +from dataclasses import dataclass + +class SpannerQueryResult(NamedTuple): + # A dict where each key is a field name returned in the query and the list + # contains all items of the same type found for the given field + data: Dict[str, List[Any]] + # A list representing the fields in the result set. + fields: List[SpannerFieldInfo] + # A list of rows as returned by the query execution. + rows: List[Any] + # An optional field to return the schema as JSON + schema_json: Any | None + # The error message if any + error: Exception | None + +class SpannerDatabase(ABC): """The spanner class holding the database connection""" - def __init__(self, project_id: str, instance_id: str, - database_id: str) -> None: - credentials, _ = _get_default_credentials_with_project() - self.client = spanner.Client( - project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) - self.instance = self.client.instance(instance_id) - self.database = self.instance.database(database_id) - - def __repr__(self) -> str: - return (f" str: - words = query.strip().split() - if len(words) < 3: - raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)") - - if words[0].upper() != "GRAPH": - raise ValueError("invalid query: GRAPH must be the first word") - - return words[1] + pass + @abstractmethod def _get_schema_for_graph(self, graph_query: str): - try: - graph_name = self._extract_graph_name(graph_query) - except ValueError as e: - return None - - with self.database.snapshot() as snapshot: - schema_query = """ - SELECT property_graph_name, property_graph_metadata_json - FROM information_schema.property_graphs - WHERE property_graph_name = @graph_name - """ - params = {"graph_name": graph_name} - param_type = {"graph_name": spanner.param_types.STRING} - - result = snapshot.execute_sql(schema_query, params=params, param_types=param_type) - schema_rows = list(result) - - if schema_rows: - return schema_rows[0][1] - else: - return None + pass + @abstractmethod def execute_query( self, query: str, limit: int = None, is_test_query: bool = False, - ): - """ - This method executes the provided `query` - - Args: - query: The SQL query to execute against the database - limit: An optional limit for the number of rows to return - - Returns: - A tuple containing: - - Dict[str, List[Any]]: A dict where each key is a field name - returned in the query and the list contains all items of the same - type found for the given field. - - A list of StructType.Fields representing the fields in the result set. - - A list of rows as returned by the query execution. - - The error message if any. - """ - self.schema_json = None - if not is_test_query: - self.schema_json = self._get_schema_for_graph(query) - - with self.database.snapshot() as snapshot: - params = None - if limit and limit > 0: - params = dict(limit=limit) - try : - results = snapshot.execute_sql(query, params=params) - rows = list(results) - except Exception as e: - return {},[],[], self.schema_json, e - fields: List[StructType.Field] = results.fields - - data = {field.name: [] for field in fields} - - if len(fields) == 0: - return data, fields, rows - - for row in rows: - for field, value in zip(fields, row): - if isinstance(value, JsonObject): - # Handle JSON objects by properly deserializing them back into Python objects - data[field.name].append(json.loads(value.serialize())) - else: - data[field.name].append(value) - - return data, fields, rows, self.schema_json, None - + ) -> SpannerQueryResult: + pass + +# Represents the name and type of a field in a Spanner query result. (Implementation-agnostic) +@dataclass +class SpannerFieldInfo: + name: str + typename: str + + class MockSpannerResult: def __init__(self, file_path: str): self.file_path = file_path - self.fields: List[StructType] = [] + self.fields: List[SpannerFieldInfo] = [] self._rows: List[List[Any]] = [] self._load_data() @@ -147,7 +79,7 @@ def _load_data(self): csv_reader = csv.reader(csvfile) headers = next(csv_reader) self.fields = [ - StructType.Field(name=header, type_=Type(code=TypeCode.JSON)) + SpannerFieldInfo(name=header, typename="JSON") for header in headers ] @@ -164,8 +96,7 @@ def _load_data(self): def __iter__(self): return iter(self._rows) - -class MockSpannerDatabase: +class MockSpannerDatabase(): """Mock database class""" def __init__(self): @@ -180,20 +111,26 @@ def execute_query( self, _: str, limit: int = 5 - ) -> Tuple[Dict[str, List[Any]], List[StructType.Field], List, str]: + ) -> SpannerQueryResult: """Mock execution of query""" - # Before the actual query we fetch the schema as well + # Fetch the schema with open(self.schema_json_path, "r", encoding="utf-8") as js: self.schema_json = json.load(js) results = MockSpannerResult(self.graph_csv_path) - fields: List[StructType.Field] = results.fields + fields: List[SpannerFieldInfo] = results.fields rows = list(results) data = {field.name: [] for field in fields} if len(fields) == 0: - return data, fields, rows + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) for i, row in enumerate(results): if limit is not None and i >= limit: @@ -201,24 +138,10 @@ def execute_query( for field, value in zip(fields, row): data[field.name].append(value) - return data, fields, rows, self.schema_json, None - - -database_instances: dict[str, SpannerDatabase | MockSpannerDatabase] = { - # "project_instance_database": SpannerDatabase -} - - -def get_database_instance(project: str, instance: str, database: str, mock = False): - if mock: - return MockSpannerDatabase() - - key = f"{project}_{instance}_{database}" - - db = database_instances.get(key, None) - if not db: - # Now create and insert it. - db = SpannerDatabase(project, instance, database) - database_instances[key] = db - - return db \ No newline at end of file + return SpannerQueryResult( + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + error=None + ) diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py new file mode 100644 index 0000000..4a60efe --- /dev/null +++ b/spanner_graphs/exec_env.py @@ -0,0 +1,41 @@ + +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module maintains state for the execution environment of a session +""" +from typing import Dict, Union + +from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase +from spanner_graphs.cloud_database import CloudSpannerDatabase + +# Global dict of database instances created in a single session +database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {} + +def get_database_instance(project: str, instance: str, database: str, mock = False): + if mock: + return MockSpannerDatabase() + + key = f"{project}_{instance}_{database}" + db = database_instances.get(key) + + # Currently, we only create and return CloudSpannerDatabase instances. In the future, different + # implementations could be introduced. + if not db: + db = CloudSpannerDatabase(project, instance, database) + database_instances[key] = db + + return db + diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 7860c49..1ff7062 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -17,7 +17,7 @@ import json import threading from enum import Enum -from typing import Union +from typing import Union, Dict, Any import requests import portpicker @@ -25,7 +25,7 @@ from google.cloud.spanner_v1 import TypeCode from spanner_graphs.conversion import get_nodes_edges -from spanner_graphs.database import get_database_instance +from spanner_graphs.exec_env import get_database_instance # Mapping of string types from frontend to Spanner TypeCode enum values @@ -56,27 +56,27 @@ class EdgeDirection(Enum): def validate_property_type(property_type: str) -> TypeCode: """ Validates and converts a property type string to a Spanner TypeCode. - + Args: property_type: The property type string from the request - + Returns: The corresponding TypeCode enum value - + Raises: ValueError: If the property type is invalid """ if not property_type: raise ValueError("Property type must be provided") - + # Convert to uppercase for case-insensitive comparison property_type = property_type.upper() - + # Check if the type is valid if property_type not in PROPERTY_TYPE_MAP: valid_types = ', '.join(sorted(PROPERTY_TYPE_MAP.keys())) raise ValueError(f"Invalid property type: {property_type}. Allowed types are: {valid_types}") - + return PROPERTY_TYPE_MAP[property_type] def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection): @@ -149,11 +149,11 @@ def execute_node_expansion( params_str: str, request: dict) -> dict: """Execute a node expansion query to find connected nodes and edges. - + Args: params_str: A JSON string containing connection parameters (project, instance, database, graph, mock). request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label). - + Returns: dict: A dictionary containing the query response with nodes and edges. """ @@ -206,49 +206,70 @@ def execute_node_expansion( return execute_query(project, instance, database, query, mock=False) -def execute_query(project: str, instance: str, database: str, query: str, mock = False): - database = get_database_instance(project, instance, database, mock) +def execute_query( + project: str, + instance: str, + database: str, + query: str, + mock: bool = False, +) -> Dict[str, Any]: + """Executes a query against a database and formats the result. + Connects to a database, runs the query, and processes the resulting object. + On success, it formats the data into nodes and edges for graph visualization. + If the query fails, it returns a detailed error message, optionally + including the database schema to aid in debugging. + + Args: + project: The cloud project ID. + instance: The database instance name. + database: The database name. + query: The query string to execute. + mock: If True, use a mock database instance for testing. Defaults to False. + + Returns: + A dictionary containing either the structured 'response' with nodes, + edges, and other data, or an 'error' key with a descriptive message. + """ try: - query_result, fields, rows, schema_json, err = database.execute_query(query) - if len(rows) == 0 and err : # if query returned an error - if schema_json: # if the schema exists - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"We've detected an error in your query. To help you troubleshoot, the graph schema's layout is shown above." + "\n\n" + f"Query error: \n{getattr(err, 'message', str(err))}" - } - if not schema_json: # if the schema does not exist - return { - "response": { - "schema": schema_json, - "query_result": query_result, - "nodes": [], - "edges": [], - "rows": [] - }, - "error": f"Query error: \n{getattr(err, 'message', str(err))}" - } - nodes, edges = get_nodes_edges(query_result, fields, schema_json) - + db_instance = get_database_instance(project, instance, database, mock) + result: SpannerQueryResult = db_instance.execute_query(query) + + if len(result.rows) == 0 and result.err: + error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" + if result.schema_json: + # Prepend a helpful message if the schema is available + error_message = ( + "We've detected an error in your query. To help you troubleshoot, " + "the graph schema's layout is shown above.\n\n" + error_message + ) + + # Consolidate the repetitive error response into a single return + return { + "response": { + "schema": result.schema_json, + "query_result": result.data, + "nodes": [], + "edges": [], + "rows": [], + }, + "error": error_message, + } + + # Process a successful query result + nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json) + return { "response": { "nodes": [node.to_json() for node in nodes], "edges": [edge.to_json() for edge in edges], - "schema": schema_json, - "rows": rows, - "query_result": query_result + "schema": result.schema_json, + "rows": result.rows, + "query_result": result.data, } } except Exception as e: - return { - "error": getattr(e, "message", str(e)) - } + return {"error": getattr(e, "message", str(e))} class GraphServer: @@ -360,7 +381,7 @@ def handle_post_ping(self): def handle_post_query(self): data = self.parse_post_data() params = json.loads(data["params"]) - response = execute_query( + response = execute_query( project=params["project"], instance=params["instance"], database=params["database"], @@ -371,7 +392,7 @@ def handle_post_query(self): def handle_post_node_expansion(self): """Handle POST requests for node expansion. - + Expects a JSON payload with: - params: A JSON string containing connection parameters (project, instance, database, graph) - request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label) diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index 93c3d50..d3f6760 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -33,7 +33,7 @@ from ipywidgets import interact from jinja2 import Template -from spanner_graphs.database import get_database_instance +from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( GraphServer, execute_query, execute_node_expansion, validate_node_expansion_request @@ -94,7 +94,7 @@ def receive_query_request(query: str, params: str): def receive_node_expansion_request(request: dict, params_str: str): """Handle node expansion requests in Google Colab environment - + Args: request: A dictionary containing node expansion details including: - uid: str - Unique identifier of the node to expand @@ -108,7 +108,7 @@ def receive_node_expansion_request(request: dict, params_str: str): - database: str - Spanner database ID - graph: str - Graph name - mock: bool - Whether to use mock data - + Returns: JSON: A JSON-serialized response containing either: - The query results with nodes and edges diff --git a/tests/database_test.py b/tests/cloud_database_test.py similarity index 81% rename from tests/database_test.py rename to tests/cloud_database_test.py index ac39a76..dcb6c92 100644 --- a/tests/database_test.py +++ b/tests/cloud_database_test.py @@ -22,13 +22,13 @@ from google.cloud.spanner_v1.types import Type, TypeCode, StructType -from spanner_graphs.database import SpannerDatabase - +from spanner_graphs.cloud_database import CloudSpannerDatabase +from spanner_graphs.database import SpannerFieldInfo class TestDatabase(unittest.TestCase): - """Test cases for the SpannerDatabase class""" + """Test cases for the CloudSpannerDatabase class""" - @patch("spanner_graphs.database.spanner.Client") + @patch("spanner_graphs.cloud_database.spanner.Client") def test_execute_query(self, mock_client: MagicMock) -> None: """Test that a query is executed correctly""" mock_instance = MagicMock() @@ -51,11 +51,11 @@ def test_execute_query(self, mock_client: MagicMock) -> None: StructType.Field(name="field3", type_=Type(code=TypeCode.JSON)), ] - db = SpannerDatabase("test_project", "test_instance", "test_database") + db = CloudSpannerDatabase("test_project", "test_instance", "test_database") result = db.execute_query("SELECT * FROM test", is_test_query=True) - self.assertEqual(result[0]["field1"], ['{"key": "value1"}']) - self.assertEqual(result[1][0].name, "field1") + self.assertEqual(result.data["field1"], ['{"key": "value1"}']) + self.assertEqual(result.fields[0].name, "field1") if __name__ == "__main__": diff --git a/tests/conversion_test.py b/tests/conversion_test.py index e53c56c..8ec9e7f 100644 --- a/tests/conversion_test.py +++ b/tests/conversion_test.py @@ -23,7 +23,7 @@ from google.cloud.spanner_v1.types import StructType, Type, TypeCode from spanner_graphs.conversion import get_nodes_edges -from spanner_graphs.database import MockSpannerDatabase +from spanner_graphs.database import SpannerFieldInfo, MockSpannerDatabase class TestConversion(unittest.TestCase): @@ -38,10 +38,10 @@ def test_get_nodes_edges(self) -> None: """ # Get data from mock database mock_db = MockSpannerDatabase() - data, fields, _, schema_json = mock_db.execute_query("") + query_result = mock_db.execute_query("") # Convert data to nodes and edges - nodes, edges = get_nodes_edges(data, fields) + nodes, edges = get_nodes_edges(query_result.data, query_result.fields) # Verify we got some nodes and edges self.assertTrue(len(nodes) > 0, "Should have at least one node") @@ -72,7 +72,7 @@ def test_get_nodes_edges(self) -> None: self.assertTrue(hasattr(edge, 'destination'), "Edge should have a destination") self.assertIsInstance(edge.labels, list, "Edge labels should be a list") self.assertIsInstance(edge.properties, dict, "Edge properties should be a dict") - + # Verify edge endpoints exist in nodes source_exists = any(node.identifier == edge.source for node in nodes) dest_exists = any(node.identifier == edge.destination for node in nodes) @@ -94,32 +94,32 @@ def test_get_nodes_edges_with_missing_nodes(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }) # Note: node2 is intentionally missing ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 1, "Should have one edge") self.assertEqual(len(nodes), 2, "Should have two nodes (one real, one intermediate)") - + # Verify node identifiers node_ids = {node.identifier for node in nodes} self.assertIn("node1", node_ids, "Original node should exist") self.assertIn("node2", node_ids, "Missing node should be created as intermediate") - + # Find the intermediate node intermediate_node = next((node for node in nodes if node.identifier == "node2"), None) self.assertIsNotNone(intermediate_node, "Intermediate node should exist") @@ -150,33 +150,33 @@ def test_get_nodes_edges_with_multiple_references(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }), json.dumps({ "kind": "node", - "identifier": "node2", + "identifier": "node2", "labels": ["Device"], "properties": {"name": "Switch"} }) # Note: missing_node is intentionally missing ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 2, "Should have two edges") self.assertEqual(len(nodes), 3, "Should have three nodes (two real, one intermediate)") - + # Count intermediate nodes intermediate_nodes = [node for node in nodes if node.intermediate] self.assertEqual(len(intermediate_nodes), 1, "Should create only one intermediate node") @@ -197,32 +197,32 @@ def test_get_nodes_edges_with_complete_data(self) -> None: }), json.dumps({ "kind": "node", - "identifier": "node1", + "identifier": "node1", "labels": ["Device"], "properties": {"name": "Router"} }), json.dumps({ "kind": "node", - "identifier": "node2", + "identifier": "node2", "labels": ["Device"], "properties": {"name": "Switch"} }) ] } - + # Create a mock field for the column - field = StructType.Field( + field = SpannerFieldInfo( name="column1", - type_=Type(code=TypeCode.JSON) + typename="JSON" ) - + # Convert data to nodes and edges nodes, edges = get_nodes_edges(data, [field]) - + # Verify we got the expected number of nodes and edges self.assertEqual(len(edges), 1, "Should have one edge") self.assertEqual(len(nodes), 2, "Should have exactly two nodes (no intermediates)") - + # Verify no intermediate nodes exist intermediate_nodes = [node for node in nodes if node.intermediate] self.assertEqual(len(intermediate_nodes), 0, "Should not create any intermediate nodes")