From 1988838bccad4c191efbdfa195c86add9ef4fa48 Mon Sep 17 00:00:00 2001 From: sagnghos Date: Tue, 3 Mar 2026 09:11:24 +0000 Subject: [PATCH] feat: add support for experimental host Spanner endpoints --- spanner_graphs/cloud_database.py | 83 +++++++++++++----- spanner_graphs/database.py | 86 ++++++++++++------- spanner_graphs/exec_env.py | 25 ++++-- spanner_graphs/graph_server.py | 140 +++++++++++++++++++------------ spanner_graphs/magics.py | 111 +++++++++++++++--------- 5 files changed, 294 insertions(+), 151 deletions(-) diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 00dd047..1bed575 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -27,23 +27,56 @@ import logging import pydata_google_auth -from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo +from spanner_graphs.database import ( + SpannerDatabase, + MockSpannerDatabase, + SpannerQueryResult, + SpannerFieldInfo, +) + def _get_default_credentials_with_project(): return pydata_google_auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False) + scopes=["https://www.googleapis.com/auth/cloud-platform"], + use_local_webserver=False, + ) + def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]: - """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" - return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields] + """Converts a list of StructType.Field to a list of SpannerFieldInfo.""" + return [ + SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) + for field in fields + ] + class CloudSpannerDatabase(SpannerDatabase): """Concrete implementation for Spanner database on the cloud.""" - def __init__(self, project_id: str, instance_id: str, - database_id: str) -> None: - credentials, _ = _get_default_credentials_with_project() - self.client = spanner.Client( - project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id)) + + def __init__( + self, + project_id: str, + instance_id: str, + database_id: str, + experimental_host: str | None = None, + ca_certificate: str | None = None, + ) -> None: + from google.auth.credentials import AnonymousCredentials + + if experimental_host: + self.client = spanner.Client( + project=project_id, + credentials=AnonymousCredentials(), + experimental_host=experimental_host, + ca_certificate=ca_certificate, + ) + else: + credentials, _ = _get_default_credentials_with_project() + self.client = spanner.Client( + project=project_id, + credentials=credentials, + client_options=ClientOptions(quota_project_id=project_id), + ) self.instance = self.client.instance(instance_id) logger = logging.getLogger("spanner_graphs") logger.setLevel(logging.CRITICAL) @@ -51,15 +84,19 @@ def __init__(self, project_id: str, instance_id: str, self.schema_json: Any | None = None def __repr__(self) -> str: - return (f"") + return ( + f"" + ) def _extract_graph_name(self, query: str) -> str: words = query.strip().split() if len(words) < 3: - raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)") + raise ValueError( + "invalid query: must contain at least (GRAPH, graph_name and query)" + ) if words[0].upper() != "GRAPH": raise ValueError("invalid query: GRAPH must be the first word") @@ -81,7 +118,9 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None: params = {"graph_name": graph_name} param_type = {"graph_name": spanner.param_types.STRING} - result = snapshot.execute_sql(schema_query, params=params, param_types=param_type) + result = snapshot.execute_sql( + schema_query, params=params, param_types=param_type + ) schema_rows = list(result) if schema_rows: @@ -117,15 +156,13 @@ def execute_query( params = dict(limit=limit) try: - results = snapshot.execute_sql(query, params=params, param_types=param_types) + results = snapshot.execute_sql( + query, params=params, param_types=param_types + ) rows = list(results) except Exception as e: return SpannerQueryResult( - data={}, - fields=[], - rows=[], - schema_json=self.schema_json, - err=e + data={}, fields=[], rows=[], schema_json=self.schema_json, err=e ) fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields) @@ -137,7 +174,7 @@ def execute_query( fields=fields, rows=rows, schema_json=self.schema_json, - err=None + err=None, ) for row_data in rows: @@ -152,5 +189,5 @@ def execute_query( fields=fields, rows=rows, schema_json=self.schema_json, - err=None + err=None, ) diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 63d94d4..1f643bf 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -27,11 +27,15 @@ from dataclasses import dataclass from enum import Enum, auto + class SpannerEnv(Enum): """Defines the types of Spanner environments the application can connect to.""" + CLOUD = auto() INFRA = auto() MOCK = auto() + EXPERIMENTAL_HOST = auto() + @dataclass class DatabaseSelector: @@ -47,32 +51,61 @@ class DatabaseSelector: instance: The Spanner instance. database: The Spanner database. infra_db_path: The path for an internal infrastructure database. + experimental_host: The Spanner experimental host endpoint. + ca_certificate: CA certificate path for the experimental host endpoint. + """ + env: SpannerEnv project: str | None = None instance: str | None = None database: str | None = None infra_db_path: str | None = None + experimental_host: str | None = None + ca_certificate: str | None = None + @classmethod - def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector': + def cloud(cls, project: str, instance: str, database: str) -> "DatabaseSelector": """Creates a selector for a Google Cloud Spanner database.""" if not project or not instance or not database: - raise ValueError("project, instance, and database are required for Cloud Spanner") - return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database) + raise ValueError( + "project, instance, and database are required for Cloud Spanner" + ) + return cls( + env=SpannerEnv.CLOUD, project=project, instance=instance, database=database + ) @classmethod - def infra(cls, infra_db_path: str) -> 'DatabaseSelector': + def infra(cls, infra_db_path: str) -> "DatabaseSelector": """Creates a selector for an internal infrastructure Spanner database.""" if not infra_db_path: raise ValueError("infra_db_path is required for Infra Spanner") return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path) @classmethod - def mock(cls) -> 'DatabaseSelector': + def mock(cls) -> "DatabaseSelector": """Creates a selector for a mock Spanner database.""" return cls(env=SpannerEnv.MOCK) + @classmethod + def experimental_host( + cls, experimental_host: str, database: str, ca_certificate: str | None = None, + ) -> "DatabaseSelector": + """Creates a selector for a Google Experimental Host Spanner database.""" + if not database: + raise ValueError( + "database is required for Experimental Host Spanner Endpoint" + ) + return cls( + env=SpannerEnv.EXPERIMENTAL_HOST, + project="default", + instance="default", + database=database, + experimental_host=experimental_host, + ca_certificate=ca_certificate, + ) + def get_key(self) -> str: if self.env == SpannerEnv.CLOUD: return f"cloud_{self.project}_{self.instance}_{self.database}" @@ -80,9 +113,12 @@ def get_key(self) -> str: return f"infra_{self.infra_db_path}" elif self.env == SpannerEnv.MOCK: return "mock" + elif self.env == SpannerEnv.EXPERIMENTAL_HOST: + return f"experimental_host_{self.database}" else: raise ValueError("Unknown Spanner environment") + class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list # contains all items of the same type found for the given field @@ -96,6 +132,7 @@ class SpannerQueryResult(NamedTuple): # The error message if any err: Exception | None + class SpannerDatabase(ABC): """The spanner class holding the database connection""" @@ -116,6 +153,7 @@ def execute_query( ) -> SpannerQueryResult: pass + # Represents the name and type of a field in a Spanner query result. (Implementation-agnostic) @dataclass class SpannerFieldInfo: @@ -136,8 +174,7 @@ def _load_data(self): csv_reader = csv.reader(csvfile) headers = next(csv_reader) self.fields = [ - SpannerFieldInfo(name=header, typename="JSON") - for header in headers + SpannerFieldInfo(name=header, typename="JSON") for header in headers ] for row in csv_reader: @@ -153,22 +190,17 @@ def _load_data(self): def __iter__(self): return iter(self._rows) -class MockSpannerDatabase(): + +class MockSpannerDatabase: """Mock database class""" def __init__(self): dirname = os.path.dirname(__file__) - self.graph_csv_path = os.path.join( - dirname, "graph_mock_data.csv") - self.schema_json_path = os.path.join( - dirname, "graph_mock_schema.json") + self.graph_csv_path = os.path.join(dirname, "graph_mock_data.csv") + self.schema_json_path = os.path.join(dirname, "graph_mock_schema.json") self.schema_json: dict = {} - def execute_query( - self, - _: str, - limit: int = 5 - ) -> SpannerQueryResult: + def execute_query(self, _: str, limit: int = 5) -> SpannerQueryResult: """Mock execution of query""" # Fetch the schema @@ -182,12 +214,12 @@ def execute_query( if len(fields) == 0: return SpannerQueryResult( - data=data, - fields=fields, - rows=rows, - schema_json=self.schema_json, - err=None - ) + data=data, + fields=fields, + rows=rows, + schema_json=self.schema_json, + err=None, + ) for i, row in enumerate(results): if limit is not None and i >= limit: @@ -196,9 +228,5 @@ def execute_query( data[field.name].append(value) return SpannerQueryResult( - data=data, - fields=fields, - rows=rows, - schema_json=self.schema_json, - err=None - ) + data=data, fields=fields, rows=rows, schema_json=self.schema_json, err=None + ) diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py index 93ed825..38108cd 100644 --- a/spanner_graphs/exec_env.py +++ b/spanner_graphs/exec_env.py @@ -1,4 +1,3 @@ - # Copyright 2024 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +28,7 @@ # Global dict of database instances created in a single session database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {} + def get_database_instance( selector: DatabaseSelector, ) -> Union[SpannerDatabase, MockSpannerDatabase]: @@ -59,9 +59,7 @@ def get_database_instance( elif selector.env == SpannerEnv.CLOUD: try: - cloud_db_module = importlib.import_module( - "spanner_graphs.cloud_database" - ) + cloud_db_module = importlib.import_module("spanner_graphs.cloud_database") CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase") db = CloudSpannerDatabase( selector.project, selector.instance, selector.database @@ -72,15 +70,28 @@ def get_database_instance( ) elif selector.env == SpannerEnv.INFRA: try: - infra_db_module = importlib.import_module( - "spanner_graphs.infra_database" - ) + infra_db_module = importlib.import_module("spanner_graphs.infra_database") InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase") db = InfraSpannerDatabase(selector.infra_db_path) except ImportError: raise RuntimeError( "Infra Spanner support is not available in this environment." ) + elif selector.env == SpannerEnv.EXPERIMENTAL_HOST: + try: + cloud_db_module = importlib.import_module("spanner_graphs.cloud_database") + CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase") + db = CloudSpannerDatabase( + selector.project, + selector.instance, + selector.database, + selector.experimental_host, + selector.ca_certificate, + ) + except ImportError: + raise RuntimeError( + "Spanner experimental host support is not available in this environment." + ) else: raise ValueError(f"Unsupported Spanner environment: {selector.env}") diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 6324207..9a8427f 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -29,18 +29,19 @@ # Supported types for a property PROPERTY_TYPE_SET = { - 'BOOL', - 'BYTES', - 'DATE', - 'ENUM', - 'INT64', - 'NUMERIC', - 'FLOAT32', - 'FLOAT64', - 'STRING', - 'TIMESTAMP' + "BOOL", + "BYTES", + "DATE", + "ENUM", + "INT64", + "NUMERIC", + "FLOAT32", + "FLOAT64", + "STRING", + "TIMESTAMP", } + class NodePropertyForDataExploration: def __init__(self, key: str, value: Union[str, int, float, bool], type_str: str): self.key = key @@ -58,16 +59,24 @@ def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector: Picks the correct DB selector based on the environment the server is running in. """ try: - env = SpannerEnv[selector_dict['env'].split('.')[-1]] + env = SpannerEnv[selector_dict["env"].split(".")[-1]] if env == SpannerEnv.CLOUD: - return DatabaseSelector.cloud(selector_dict['project'], selector_dict['instance'], selector_dict['database']) + return DatabaseSelector.cloud( + selector_dict["project"], + selector_dict["instance"], + selector_dict["database"], + ) elif env == SpannerEnv.INFRA: - return DatabaseSelector.infra(selector_dict['infra_db_path']) + return DatabaseSelector.infra(selector_dict["infra_db_path"]) elif env == SpannerEnv.MOCK: return DatabaseSelector.mock() + elif env == SpannerEnv.EXPERIMENTAL_HOST: + return DatabaseSelector.experimental_host( + selector_dict["experimental_host"], selector_dict["database"], selector_dict["ca_certificate"] + ) raise ValueError(f"Invalid env in selector dict: {selector_dict}") except Exception as e: - print (f"Unexpected error when fetching selector: {e}") + print(f"Unexpected error when fetching selector: {e}") def is_valid_property_type(property_type: str) -> bool: @@ -91,12 +100,17 @@ def is_valid_property_type(property_type: str) -> bool: # Check if the type is valid if property_type not in PROPERTY_TYPE_SET: - valid_types = ', '.join(sorted(PROPERTY_TYPE_SET)) - raise ValueError(f"Invalid property type: {property_type}. Allowed types are: {valid_types}") + valid_types = ", ".join(sorted(PROPERTY_TYPE_SET)) + raise ValueError( + f"Invalid property type: {property_type}. Allowed types are: {valid_types}" + ) return True -def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection): + +def validate_node_expansion_request( + data, +) -> (list[NodePropertyForDataExploration], EdgeDirection): required_fields = ["uid", "node_labels", "direction"] missing_fields = [field for field in required_fields if data.get(field) is None] @@ -124,49 +138,65 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio raise ValueError(f"Property at index {idx} must be an object") if not all(field in prop for field in ["key", "value", "type"]): - raise ValueError(f"Property at index {idx} is missing required fields (key, value, type)") + raise ValueError( + f"Property at index {idx} is missing required fields (key, value, type)" + ) try: prop_type_str = prop["type"] if isinstance(prop_type_str, str): # This must be True. If not, an execption would be thrown. - assert(is_valid_property_type(prop_type_str)) + assert is_valid_property_type(prop_type_str) value = prop["value"] - if prop_type_str in ('INT64', 'NUMERIC'): - if not (isinstance(value, int) or (isinstance(value, str) and value.isdigit())): - raise ValueError(f"Property '{prop['key']}' value must be a number for type {prop_type_str}") - elif prop_type_str in ('FLOAT32', 'FLOAT64'): + if prop_type_str in ("INT64", "NUMERIC"): + if not ( + isinstance(value, int) + or (isinstance(value, str) and value.isdigit()) + ): + raise ValueError( + f"Property '{prop['key']}' value must be a number for type {prop_type_str}" + ) + elif prop_type_str in ("FLOAT32", "FLOAT64"): try: float(value) except (ValueError, TypeError): raise ValueError( - f"Property '{prop['key']}' value must be a valid float for type {prop_type_str}") - elif prop_type_str == 'BOOL': - if not isinstance(value, bool) and not (isinstance(value, str) and value.lower() in ["true", "false"]): - raise ValueError(f"Property '{prop['key']}' value must be a boolean for type {prop_type_str}") - - validated_properties.append(NodePropertyForDataExploration( - key=prop["key"], - value=prop["value"], - type_str=prop_type_str - )) + f"Property '{prop['key']}' value must be a valid float for type {prop_type_str}" + ) + elif prop_type_str == "BOOL": + if not isinstance(value, bool) and not ( + isinstance(value, str) and value.lower() in ["true", "false"] + ): + raise ValueError( + f"Property '{prop['key']}' value must be a boolean for type {prop_type_str}" + ) + + validated_properties.append( + NodePropertyForDataExploration( + key=prop["key"], value=prop["value"], type_str=prop_type_str + ) + ) else: raise ValueError(f"Property type at index {idx} must be a string") except ValueError as e: - raise ValueError(f"Invalid property type in property at index {idx}: {str(e)}") + raise ValueError( + f"Invalid property type in property at index {idx}: {str(e)}" + ) try: direction = EdgeDirection(data.get("direction")) except ValueError: - raise ValueError(f"Invalid direction: must be INCOMING or OUTGOING, got \"{data.get('direction')}\"") + raise ValueError( + f"Invalid direction: must be INCOMING or OUTGOING, got \"{data.get('direction')}\"" + ) return validated_properties, direction + def execute_node_expansion( - selector_dict: Dict[str, Any], - graph: str, - request: dict) -> dict: + selector_dict: Dict[str, Any], graph: str, request: dict +) -> dict: """Execute a node expansion query to find connected nodes and edges. Args: @@ -199,10 +229,10 @@ def execute_node_expansion( node_property_strings: list[str] = [] for node_property in node_properties: value_str: str - if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'): + if node_property.type_str in ("INT64", "NUMERIC", "FLOAT32", "FLOAT64", "BOOL"): value_str = node_property.value else: - value_str = f"\'''{node_property.value}\'''" + value_str = f"'''{node_property.value}'''" node_property_strings.append(f"n.{node_property.key}={value_str}") query = f""" @@ -220,6 +250,7 @@ def execute_node_expansion( return execute_query(selector_dict, query) + def execute_query( selector_dict: Dict[str, Any], query: str, @@ -242,7 +273,9 @@ def execute_query( result: SpannerQueryResult = db_instance.execute_query(query) if len(result.rows) == 0 and result.err: - error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" + error_message = ( + f"Query error: \n{getattr(result.err, 'message', str(result.err))}" + ) if result.schema_json: # Prepend a helpful message if the schema is available error_message = ( @@ -280,14 +313,14 @@ def execute_query( class GraphServer: port = portpicker.pick_unused_port() - host = 'http://localhost' + host = "http://localhost" url = f"{host}:{port}" endpoints = { "get_ping": "/get_ping", "post_ping": "/post_ping", "post_query": "/post_query", - "post_node_expansion": '/post_node_expansion', + "post_node_expansion": "/post_node_expansion", } _server = None @@ -349,6 +382,7 @@ def post_ping(data): print(f"Request failed with status code {response.status_code}") return False + class GraphServerHandler(http.server.SimpleHTTPRequestHandler): def log_message(self, format, *args): pass @@ -362,7 +396,7 @@ def do_json_response(self, data): self.wfile.write(json.dumps(data).encode()) def do_message_response(self, message): - self.do_json_response({'message': message}) + self.do_json_response({"message": message}) def do_data_response(self, data): self.do_json_response(data) @@ -370,7 +404,7 @@ def do_data_response(self, data): def do_error_response(self, message): if isinstance(message, Exception): message = str(message) - self.do_json_response({'error': message}) + self.do_json_response({"error": message}) def parse_post_data(self): content_length = int(self.headers["Content-Length"]) @@ -387,10 +421,7 @@ def handle_post_ping(self): def handle_post_query(self): data = self.parse_post_data() params = json.loads(data["params"]) - response = execute_query( - selector_dict=params["selector"], - query=data["query"] - ) + response = execute_query(selector_dict=params["selector"], query=data["query"]) self.do_data_response(response) def handle_post_node_expansion(self): @@ -401,11 +432,11 @@ def handle_post_node_expansion(self): graph = params.get("graph") request_data = data.get("request") - self.do_data_response(execute_node_expansion( - selector_dict=selector_dict, - graph=graph, - request=request_data - )) + self.do_data_response( + execute_node_expansion( + selector_dict=selector_dict, graph=graph, request=request_data + ) + ) except Exception as e: self.do_error_response(e) return @@ -424,4 +455,5 @@ def do_POST(self): elif self.path == GraphServer.endpoints["post_node_expansion"]: self.handle_post_node_expansion() + atexit.register(GraphServer.stop_server) diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index 1741e0d..5eb347d 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -37,36 +37,41 @@ from spanner_graphs.database import DatabaseSelector from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( - GraphServer, execute_query, execute_node_expansion, - validate_node_expansion_request + GraphServer, + execute_query, + execute_node_expansion, + validate_node_expansion_request, ) from spanner_graphs.graph_visualization import generate_visualization_html singleton_server_thread: Thread = None + def _load_file(path: list[str]) -> str: - file_path = os.path.sep.join(path) - if not os.path.exists(file_path): - raise FileNotFoundError(f"Template file not found: {file_path}") + file_path = os.path.sep.join(path) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Template file not found: {file_path}") + + with open(file_path, "r") as file: + content = file.read() - with open(file_path, 'r') as file: - content = file.read() + return content - return content def _load_image(path: list[str]) -> str: file_path = os.path.sep.join(path) if not os.path.exists(file_path): print("image does not exist") - return '' + return "" - if file_path.lower().endswith('.svg'): - with open(file_path, 'r') as file: + if file_path.lower().endswith(".svg"): + with open(file_path, "r") as file: svg = file.read() - return base64.b64encode(svg.encode('utf-8')).decode('utf-8') + return base64.b64encode(svg.encode("utf-8")).decode("utf-8") else: - with open(file_path, 'rb') as file: - return base64.b64decode(file.read()).decode('utf-8') + with open(file_path, "rb") as file: + return base64.b64decode(file.read()).decode("utf-8") + def _parse_element_display(element_rep: str) -> dict[str, str]: """Helper function to parse element display fields into a dict.""" @@ -78,14 +83,17 @@ def _parse_element_display(element_rep: str) -> dict[str, str]: } return res + def is_colab() -> bool: """Check if code is running in Google Colab""" try: import google.colab + return True except ImportError: return False + def receive_query_request(query: str, params: str): params_dict = json.loads(params) selector_dict = params_dict.get("selector") @@ -96,6 +104,7 @@ def receive_query_request(query: str, params: str): except Exception as e: return JSON({"error": str(e)}) + def receive_node_expansion_request(request: dict, params_str: str): """Handle node expansion requests in Google Colab environment @@ -122,10 +131,15 @@ def receive_node_expansion_request(request: dict, params_str: str): if not selector_dict: return JSON({"error": "Missing selector in params"}) - return JSON(execute_node_expansion(selector_dict=selector_dict, graph=graph, request=request)) + return JSON( + execute_node_expansion( + selector_dict=selector_dict, graph=graph, request=request + ) + ) except BaseException as e: return JSON({"error": str(e)}) + def custom_json_serializer(o): """A JSON serializer that handles dataclasses and enums.""" if is_dataclass(o): @@ -134,6 +148,7 @@ def custom_json_serializer(o): return f"{o.__class__.__name__}.{o.name}" raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + @magics_class class NetworkVisualizationMagics(Magics): """Network visualizer with Networkx""" @@ -148,8 +163,11 @@ def __init__(self, shell): if is_colab(): from google.colab import output - output.register_callback('graph_visualization.Query', receive_query_request) - output.register_callback('graph_visualization.NodeExpansion', receive_node_expansion_request) + + output.register_callback("graph_visualization.Query", receive_query_request) + output.register_callback( + "graph_visualization.NodeExpansion", receive_node_expansion_request + ) else: global singleton_server_thread alive = singleton_server_thread and singleton_server_thread.is_alive() @@ -160,22 +178,20 @@ def visualize(self): """Helper function to create and display the visualization""" # Extract the graph name from the query (if present) graph = "" - if 'GRAPH ' in self.cell.upper(): - match = re.search(r'GRAPH\s+(\w+)', self.cell, re.IGNORECASE) + if "GRAPH " in self.cell.upper(): + match = re.search(r"GRAPH\s+(\w+)", self.cell, re.IGNORECASE) if match: graph = match.group(1) # Pack the selector and graph into the params to be sent to the GraphServer - params = { - "selector": self.selector, - "graph": graph - } + params = {"selector": self.selector, "graph": graph} # Generate the HTML content html_content = generate_visualization_html( query=self.cell, port=GraphServer.port, - params=json.dumps(params, default=custom_json_serializer)) + params=json.dumps(params, default=custom_json_serializer), + ) display(HTML(html_content)) @@ -184,19 +200,29 @@ def spanner_graph(self, line: str, cell: str): """spanner_graph function""" parser = argparse.ArgumentParser( - description="Visualize network from Spanner database", - exit_on_error=False) + description="Visualize network from Spanner database", exit_on_error=False + ) parser.add_argument("--project", help="GCP project ID") - parser.add_argument("--instance", - help="Spanner instance ID") - parser.add_argument("--database", - help="Spanner database ID") - parser.add_argument("--mock", - action="store_true", - help="Use mock database") - parser.add_argument("--infra_db_path", - action="store_true", - help="Connect to internal Infra Spanner") + parser.add_argument("--instance", help="Spanner instance ID") + parser.add_argument("--database", help="Spanner database ID") + parser.add_argument("--mock", action="store_true", help="Use mock database") + parser.add_argument( + "--infra_db_path", + action="store_true", + help="Connect to internal Infra Spanner", + ) + parser.add_argument( + "--experimental_host", + type=str, + required=False, + help="Spanner experimental host endpoint", + ) + parser.add_argument( + "--ca_certificate", + type=str, + required=False, + help="CA certificate path for the experimental host", + ) try: args = parser.parse_args(line.split()) @@ -205,12 +231,18 @@ def spanner_graph(self, line: str, cell: str): selector = DatabaseSelector.mock() elif args.infra_db_path: selector = DatabaseSelector.infra(infra_db_path=args.database) + elif args.experimental_host: + selector = DatabaseSelector.experimental_host( + experimental_host=args.experimental_host, database=args.database, ca_certificate=args.ca_certificate + ) else: if not (args.project and args.instance): raise ValueError( "Please provide `--project` and `--instance` for Cloud Spanner." ) - selector = DatabaseSelector.cloud(args.project, args.instance, args.database) + selector = DatabaseSelector.cloud( + args.project, args.instance, args.database + ) if not args.mock and (not cell or not cell.strip()): print("Error: Query is required.") @@ -224,10 +256,13 @@ def spanner_graph(self, line: str, cell: str): self.visualize() except BaseException as e: print(f"Error: {e}") - print(" %%spanner_graph --project --instance --database ") + print( + " %%spanner_graph --project --instance --database " + ) print(" %%spanner_graph --mock") print(" Graph query here...") + def load_ipython_extension(ipython): """Registration function""" ipython.register_magics(NetworkVisualizationMagics)