Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions spanner_graphs/cloud_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,76 @@
import logging
import pydata_google_auth

from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo
from spanner_graphs.database import (
SpannerDatabase,
MockSpannerDatabase,
SpannerQueryResult,
SpannerFieldInfo,
)


def _get_default_credentials_with_project():
return pydata_google_auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"], use_local_webserver=False)
scopes=["https://www.googleapis.com/auth/cloud-platform"],
use_local_webserver=False,
)


def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldInfo]:
"""Converts a list of StructType.Field to a list of SpannerFieldInfo."""
return [SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name) for field in fields]
"""Converts a list of StructType.Field to a list of SpannerFieldInfo."""
return [
SpannerFieldInfo(name=field.name, typename=TypeCode(field.type_.code).name)
for field in fields
]


class CloudSpannerDatabase(SpannerDatabase):
"""Concrete implementation for Spanner database on the cloud."""
def __init__(self, project_id: str, instance_id: str,
database_id: str) -> None:
credentials, _ = _get_default_credentials_with_project()
self.client = spanner.Client(
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))

def __init__(
self,
project_id: str,
instance_id: str,
database_id: str,
experimental_host: str | None = None,
ca_certificate: str | None = None,
) -> None:
from google.auth.credentials import AnonymousCredentials

if experimental_host:
self.client = spanner.Client(
project=project_id,
credentials=AnonymousCredentials(),
experimental_host=experimental_host,
ca_certificate=ca_certificate,
)
else:
credentials, _ = _get_default_credentials_with_project()
self.client = spanner.Client(
project=project_id,
credentials=credentials,
client_options=ClientOptions(quota_project_id=project_id),
)
self.instance = self.client.instance(instance_id)
logger = logging.getLogger("spanner_graphs")
logger.setLevel(logging.CRITICAL)
self.database = self.instance.database(database_id, logger=logger)
self.schema_json: Any | None = None

def __repr__(self) -> str:
return (f"<CloudSpannerDatabase["
f"project:{self.client.project_name},"
f"instance:{self.instance.name},"
f"db:{self.database.name}]>")
return (
f"<CloudSpannerDatabase["
f"project:{self.client.project_name},"
f"instance:{self.instance.name},"
f"db:{self.database.name}]>"
)

def _extract_graph_name(self, query: str) -> str:
words = query.strip().split()
if len(words) < 3:
raise ValueError("invalid query: must contain at least (GRAPH, graph_name and query)")
raise ValueError(
"invalid query: must contain at least (GRAPH, graph_name and query)"
)

if words[0].upper() != "GRAPH":
raise ValueError("invalid query: GRAPH must be the first word")
Expand All @@ -81,7 +118,9 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None:
params = {"graph_name": graph_name}
param_type = {"graph_name": spanner.param_types.STRING}

result = snapshot.execute_sql(schema_query, params=params, param_types=param_type)
result = snapshot.execute_sql(
schema_query, params=params, param_types=param_type
)
schema_rows = list(result)

if schema_rows:
Expand Down Expand Up @@ -117,15 +156,13 @@ def execute_query(
params = dict(limit=limit)

try:
results = snapshot.execute_sql(query, params=params, param_types=param_types)
results = snapshot.execute_sql(
query, params=params, param_types=param_types
)
rows = list(results)
except Exception as e:
return SpannerQueryResult(
data={},
fields=[],
rows=[],
schema_json=self.schema_json,
err=e
data={}, fields=[], rows=[], schema_json=self.schema_json, err=e
)

fields: List[SpannerFieldInfo] = get_as_field_info_list(results.fields)
Expand All @@ -137,7 +174,7 @@ def execute_query(
fields=fields,
rows=rows,
schema_json=self.schema_json,
err=None
err=None,
)

for row_data in rows:
Expand All @@ -152,5 +189,5 @@ def execute_query(
fields=fields,
rows=rows,
schema_json=self.schema_json,
err=None
err=None,
)
86 changes: 57 additions & 29 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from dataclasses import dataclass
from enum import Enum, auto


class SpannerEnv(Enum):
"""Defines the types of Spanner environments the application can connect to."""

CLOUD = auto()
INFRA = auto()
MOCK = auto()
EXPERIMENTAL_HOST = auto()


@dataclass
class DatabaseSelector:
Expand All @@ -47,42 +51,74 @@ class DatabaseSelector:
instance: The Spanner instance.
database: The Spanner database.
infra_db_path: The path for an internal infrastructure database.
experimental_host: The Spanner experimental host endpoint.
ca_certificate: CA certificate path for the experimental host endpoint.

"""

env: SpannerEnv
project: str | None = None
instance: str | None = None
database: str | None = None
infra_db_path: str | None = None
experimental_host: str | None = None
ca_certificate: str | None = None


@classmethod
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
def cloud(cls, project: str, instance: str, database: str) -> "DatabaseSelector":
"""Creates a selector for a Google Cloud Spanner database."""
if not project or not instance or not database:
raise ValueError("project, instance, and database are required for Cloud Spanner")
return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database)
raise ValueError(
"project, instance, and database are required for Cloud Spanner"
)
return cls(
env=SpannerEnv.CLOUD, project=project, instance=instance, database=database
)

@classmethod
def infra(cls, infra_db_path: str) -> 'DatabaseSelector':
def infra(cls, infra_db_path: str) -> "DatabaseSelector":
"""Creates a selector for an internal infrastructure Spanner database."""
if not infra_db_path:
raise ValueError("infra_db_path is required for Infra Spanner")
return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path)

@classmethod
def mock(cls) -> 'DatabaseSelector':
def mock(cls) -> "DatabaseSelector":
"""Creates a selector for a mock Spanner database."""
return cls(env=SpannerEnv.MOCK)

@classmethod
def experimental_host(
cls, experimental_host: str, database: str, ca_certificate: str | None = None,
) -> "DatabaseSelector":
"""Creates a selector for a Google Experimental Host Spanner database."""
if not database:
raise ValueError(
"database is required for Experimental Host Spanner Endpoint"
)
return cls(
env=SpannerEnv.EXPERIMENTAL_HOST,
project="default",
instance="default",
database=database,
experimental_host=experimental_host,
ca_certificate=ca_certificate,
)

def get_key(self) -> str:
if self.env == SpannerEnv.CLOUD:
return f"cloud_{self.project}_{self.instance}_{self.database}"
elif self.env == SpannerEnv.INFRA:
return f"infra_{self.infra_db_path}"
elif self.env == SpannerEnv.MOCK:
return "mock"
elif self.env == SpannerEnv.EXPERIMENTAL_HOST:
return f"experimental_host_{self.database}"
else:
raise ValueError("Unknown Spanner environment")


class SpannerQueryResult(NamedTuple):
# A dict where each key is a field name returned in the query and the list
# contains all items of the same type found for the given field
Expand All @@ -96,6 +132,7 @@ class SpannerQueryResult(NamedTuple):
# The error message if any
err: Exception | None


class SpannerDatabase(ABC):
"""The spanner class holding the database connection"""

Expand All @@ -116,6 +153,7 @@ def execute_query(
) -> SpannerQueryResult:
pass


# Represents the name and type of a field in a Spanner query result. (Implementation-agnostic)
@dataclass
class SpannerFieldInfo:
Expand All @@ -136,8 +174,7 @@ def _load_data(self):
csv_reader = csv.reader(csvfile)
headers = next(csv_reader)
self.fields = [
SpannerFieldInfo(name=header, typename="JSON")
for header in headers
SpannerFieldInfo(name=header, typename="JSON") for header in headers
]

for row in csv_reader:
Expand All @@ -153,22 +190,17 @@ def _load_data(self):
def __iter__(self):
return iter(self._rows)

class MockSpannerDatabase():

class MockSpannerDatabase:
"""Mock database class"""

def __init__(self):
dirname = os.path.dirname(__file__)
self.graph_csv_path = os.path.join(
dirname, "graph_mock_data.csv")
self.schema_json_path = os.path.join(
dirname, "graph_mock_schema.json")
self.graph_csv_path = os.path.join(dirname, "graph_mock_data.csv")
self.schema_json_path = os.path.join(dirname, "graph_mock_schema.json")
self.schema_json: dict = {}

def execute_query(
self,
_: str,
limit: int = 5
) -> SpannerQueryResult:
def execute_query(self, _: str, limit: int = 5) -> SpannerQueryResult:
"""Mock execution of query"""

# Fetch the schema
Expand All @@ -182,12 +214,12 @@ def execute_query(

if len(fields) == 0:
return SpannerQueryResult(
data=data,
fields=fields,
rows=rows,
schema_json=self.schema_json,
err=None
)
data=data,
fields=fields,
rows=rows,
schema_json=self.schema_json,
err=None,
)

for i, row in enumerate(results):
if limit is not None and i >= limit:
Expand All @@ -196,9 +228,5 @@ def execute_query(
data[field.name].append(value)

return SpannerQueryResult(
data=data,
fields=fields,
rows=rows,
schema_json=self.schema_json,
err=None
)
data=data, fields=fields, rows=rows, schema_json=self.schema_json, err=None
)
25 changes: 18 additions & 7 deletions spanner_graphs/exec_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -29,6 +28,7 @@
# Global dict of database instances created in a single session
database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {}


def get_database_instance(
selector: DatabaseSelector,
) -> Union[SpannerDatabase, MockSpannerDatabase]:
Expand Down Expand Up @@ -59,9 +59,7 @@ def get_database_instance(

elif selector.env == SpannerEnv.CLOUD:
try:
cloud_db_module = importlib.import_module(
"spanner_graphs.cloud_database"
)
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
db = CloudSpannerDatabase(
selector.project, selector.instance, selector.database
Expand All @@ -72,15 +70,28 @@ def get_database_instance(
)
elif selector.env == SpannerEnv.INFRA:
try:
infra_db_module = importlib.import_module(
"spanner_graphs.infra_database"
)
infra_db_module = importlib.import_module("spanner_graphs.infra_database")
InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase")
db = InfraSpannerDatabase(selector.infra_db_path)
except ImportError:
raise RuntimeError(
"Infra Spanner support is not available in this environment."
)
elif selector.env == SpannerEnv.EXPERIMENTAL_HOST:
try:
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
db = CloudSpannerDatabase(
selector.project,
selector.instance,
selector.database,
selector.experimental_host,
selector.ca_certificate,
)
except ImportError:
raise RuntimeError(
"Spanner experimental host support is not available in this environment."
)
else:
raise ValueError(f"Unsupported Spanner environment: {selector.env}")

Expand Down
Loading