Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions spanner_graphs/cloud_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None:
def execute_query(
self,
query: str,
params: Dict[str, Any] = None,
param_types: Dict[str, Any] = None,
limit: int = None,
is_test_query: bool = False,
) -> SpannerQueryResult:
Expand All @@ -97,6 +99,8 @@ def execute_query(

Args:
query: The SQL query to execute against the database
params: A dictionary of query parameters
param_types: A dictionary of parameter types
limit: An optional limit for the number of rows to return
is_test_query: If true, skips schema fetching for graph queries.

Expand All @@ -108,13 +112,16 @@ def execute_query(
self.schema_json = self._get_schema_for_graph(query)

with self.database.snapshot() as snapshot:
params = None
param_types = None
if limit and limit > 0:
params = dict(limit=limit)

try:
results = snapshot.execute_sql(query, params=params, param_types=param_types)
results = snapshot.execute_sql(
query,
params=params,
param_types=param_types
)
rows = list(results)
except Exception as e:
return SpannerQueryResult(
Expand Down
6 changes: 6 additions & 0 deletions spanner_graphs/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from dataclasses import dataclass


class SpannerQueryResult(NamedTuple):
# A dict where each key is a field name returned in the query and the list
# contains all items of the same type found for the given field
Expand All @@ -39,6 +40,7 @@ class SpannerQueryResult(NamedTuple):
# The error message if any
err: Exception | None


class SpannerDatabase(ABC):
"""The spanner class holding the database connection"""

Expand All @@ -54,6 +56,7 @@ def _get_schema_for_graph(self, graph_query: str):
def execute_query(
self,
query: str,
params: Dict[str, Any] = None,
limit: int = None,
is_test_query: bool = False,
) -> SpannerQueryResult:
Expand Down Expand Up @@ -96,6 +99,7 @@ def _load_data(self):
def __iter__(self):
return iter(self._rows)


class MockSpannerDatabase():
"""Mock database class"""

Expand All @@ -110,6 +114,8 @@ def __init__(self):
def execute_query(
self,
_: str,
params: Dict[str, Any] = None,
param_types: Dict[str, Any] = None,
limit: int = 5
) -> SpannerQueryResult:
"""Mock execution of query"""
Expand Down
80 changes: 64 additions & 16 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import requests
import portpicker
import atexit
from datetime import datetime

from spanner_graphs.conversion import get_nodes_edges
from spanner_graphs.exec_env import get_database_instance
from spanner_graphs.database import SpannerQueryResult
from google.cloud import spanner

# Supported types for a property
PROPERTY_TYPE_SET = {
Expand Down Expand Up @@ -145,14 +147,18 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio

return validated_properties, direction


def execute_node_expansion(
params_str: str,
request: dict) -> dict:
request: dict
) -> dict:
"""Execute a node expansion query to find connected nodes and edges.

Args:
params_str: A JSON string containing connection parameters (project, instance, database, graph, mock).
request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label).
params_str: A JSON string containing connection parameters (project,
instance, database, graph, mock).
request: A dictionary containing node expansion request details (uid,
node_labels, node_properties, direction, edge_label).

Returns:
dict: A dictionary containing the query response with nodes and edges.
Expand Down Expand Up @@ -182,20 +188,51 @@ def execute_node_expansion(
if node_labels and len(node_labels) > 0:
node_label_str = f": {' & '.join(node_labels)}"

node_property_strings: list[str] = []
for node_property in node_properties:
value_str: str
if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'):
value_str = node_property.value
else:
value_str = f"\'''{node_property.value}\'''"
node_property_strings.append(f"n.{node_property.key}={value_str}")
node_property_clauses: list[str] = []
params_dict: dict = {}
param_types_dict: dict = {}

for i, node_property in enumerate(node_properties):
param_name = f"param_{i}"
node_property_clauses.append(f"n.{node_property.key} = @{param_name}")

# Convert value to native Python type
type_str = node_property.type_str
value = node_property.value

if type_str in ("INT64", "NUMERIC"):
value_casting = int(value)
param_type = spanner.param_types.INT64
elif type_str in ("FLOAT32", "FLOAT64"):
value_casting = float(value)
param_type = spanner.param_types.FLOAT64
elif type_str == "BOOL":
value_casting = value.lower() == "true"
param_type = spanner.param_types.BOOL
elif type_str == "STRING":
value_casting = str(value)
param_type = spanner.param_types.STRING
elif type_str == "DATE":
value_casting = datetime.strptime(value, "%Y-%m-%d").date()
param_type = spanner.param_types.DATE
elif type_str == "TIMESTAMP":
value_casting = datetime.fromisoformat(value.replace("Z", "+00:00"))
param_type = spanner.param_types.TIMESTAMP

params_dict[param_name] = value_casting
param_types_dict[param_name] = param_type

filtered_uid = "STRING(TO_JSON(n).identifier) = @uid"
params_dict["uid"] = str(uid)
param_types_dict["uid"] = spanner.param_types.STRING

where_clauses = node_property_clauses + [filtered_uid]
where_clause_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

query = f"""
GRAPH {graph}
LET uid = "{uid}"
MATCH (n{node_label_str})
WHERE {' and '.join(node_property_strings)} {'and' if node_property_strings else ''} STRING(TO_JSON(n).identifier) = uid
{where_clause_str}
RETURN n

NEXT
Expand All @@ -204,14 +241,20 @@ def execute_node_expansion(
RETURN TO_JSON(e) as e, TO_JSON(d) as d
"""

return execute_query(project, instance, database, query, mock=False)
return execute_query(
project, instance, database, query, mock=False,
params=params_dict, param_types=param_types_dict
)


def execute_query(
project: str,
instance: str,
database: str,
query: str,
mock: bool = False,
params: Dict[str, Any] = None,
param_types: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Executes a query against a database and formats the result.

Expand All @@ -233,7 +276,11 @@ def execute_query(
"""
try:
db_instance = get_database_instance(project, instance, database, mock)
result: SpannerQueryResult = db_instance.execute_query(query)
result: SpannerQueryResult = db_instance.execute_query(
query,
params=params,
param_types=param_types
)

if len(result.rows) == 0 and result.err:
error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}"
Expand All @@ -257,7 +304,8 @@ def execute_query(
}

# Process a successful query result
nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json)
nodes, edges = get_nodes_edges(result.data, result.fields,
result.schema_json)

return {
"response": {
Expand Down
Loading