Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions spanner_graphs/cloud_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def execute_query(
query: str,
limit: int = None,
is_test_query: bool = False,
params: dict = None,
param_types: dict = None,
) -> SpannerQueryResult:
"""
This method executes the provided `query`
Expand All @@ -99,6 +101,8 @@ def execute_query(
query: The SQL query to execute against the database
limit: An optional limit for the number of rows to return
is_test_query: If true, skips schema fetching for graph queries.
params: A dictionary of query parameters.
param_types: A dictionary of query parameter types.

Returns:
A `SpannerQueryResult`
Expand All @@ -108,10 +112,10 @@ def execute_query(
self.schema_json = self._get_schema_for_graph(query)

with self.database.snapshot() as snapshot:
params = None
param_types = None
if limit and limit > 0:
params = dict(limit=limit)
if params is None:
params = {}
params["limit"] = limit

try:
results = snapshot.execute_sql(query, params=params, param_types=param_types)
Expand Down
29 changes: 17 additions & 12 deletions spanner_graphs/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import portpicker
import atexit

from google.cloud import spanner
from spanner_graphs.conversion import get_nodes_edges
from spanner_graphs.exec_env import get_database_instance
from spanner_graphs.database import SpannerQueryResult
Expand All @@ -32,7 +33,6 @@
'BOOL',
'BYTES',
'DATE',
'ENUM',
'INT64',
'NUMERIC',
'FLOAT32',
Expand Down Expand Up @@ -171,7 +171,6 @@ def execute_node_expansion(

edge = "e" if not edge_label else f"e:{edge_label}"

# Build the path pattern based on direction
path_pattern = (
f"(n)-[{edge}]->(d)"
if direction == EdgeDirection.OUTGOING
Expand All @@ -182,29 +181,35 @@ def execute_node_expansion(
if node_labels and len(node_labels) > 0:
node_label_str = f": {' & '.join(node_labels)}"

query_params = {"uid": uid}
param_types = {"uid": spanner.param_types.STRING}
node_property_strings: list[str] = []
for node_property in node_properties:
value_str: str
if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'):
value_str = node_property.value
else:
value_str = f"\'''{node_property.value}\'''"
node_property_strings.append(f"n.{node_property.key}={value_str}")

for i, node_property in enumerate(node_properties):
param_name = f"p{i}"
node_property_strings.append(f"n.{node_property.key} = @{param_name}")
query_params[param_name] = node_property.value
param_types[param_name] = getattr(spanner.param_types, node_property.type_str)

where_clause = " and ".join(node_property_strings)
if where_clause:
where_clause += " and "
where_clause += "STRING(TO_JSON(n).identifier) = @uid"

query = f"""
GRAPH {graph}
LET uid = "{uid}"
MATCH (n{node_label_str})
WHERE {' and '.join(node_property_strings)} {'and' if node_property_strings else ''} STRING(TO_JSON(n).identifier) = uid
WHERE {where_clause}
RETURN n

NEXT

MATCH {path_pattern}
WHERE STRING(TO_JSON(n).identifier) = @uid
RETURN TO_JSON(e) as e, TO_JSON(d) as d
"""

return execute_query(project, instance, database, query, mock=False)
return execute_query(project, instance, database, query, mock=False, params=query_params, param_types=param_types)

def execute_query(
project: str,
Expand Down
4 changes: 3 additions & 1 deletion tests/cloud_database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
class TestDatabase(unittest.TestCase):
"""Test cases for the CloudSpannerDatabase class"""

@patch("spanner_graphs.cloud_database._get_default_credentials_with_project")
@patch("spanner_graphs.cloud_database.spanner.Client")
def test_execute_query(self, mock_client: MagicMock) -> None:
def test_execute_query(self, mock_client: MagicMock, mock_creds: MagicMock) -> None:
"""Test that a query is executed correctly"""
mock_creds.return_value = (MagicMock(), "test-project")
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
Expand Down
108 changes: 75 additions & 33 deletions tests/graph_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_validate_property_type_valid_types(self):
'DATE',
'TIMESTAMP',
'BYTES',
'ENUM',
# Test case insensitivity
'int64',
'string',
Expand Down Expand Up @@ -54,25 +53,22 @@ def test_validate_property_type_invalid_types(self):
self.assertIn("Invalid property type", str(cm.exception))
self.assertIn("Allowed types are:", str(cm.exception))

@patch('spanner_graphs.graph_server.spanner')
@patch('spanner_graphs.graph_server.execute_query')
def test_property_value_formatting(self, mock_execute_query):
def test_property_value_formatting(self, mock_execute_query, mock_spanner):
"""Test that property values are correctly formatted based on their type."""
mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}}

test_cases = [
# Numeric types (unquoted)
("INT64", "123", "123"),
("NUMERIC", "123", "123"),
("FLOAT32", "123.45", "123.45"),
("FLOAT64", "123.45", "123.45"),
# Boolean (unquoted)
("BOOL", "true", "true"),
# String types (quoted)
("STRING", "hello", "'''hello'''"),
("DATE", "2024-03-14", "'''2024-03-14'''"),
("TIMESTAMP", "2024-03-14T12:00:00Z", "'''2024-03-14T12:00:00Z'''"),
("BYTES", "base64data", "'''base64data'''"),
("ENUM", "ENUM_VALUE", "'''ENUM_VALUE'''"),
("INT64", "123"),
("NUMERIC", "123"),
("FLOAT32", "123.45"),
("FLOAT64", "123.45"),
("BOOL", "true"),
("STRING", "hello"),
("DATE", "2024-03-14"),
("TIMESTAMP", "2024-03-14T12:00:00Z"),
("BYTES", "base64data"),
]

params = json.dumps({
Expand All @@ -82,9 +78,8 @@ def test_property_value_formatting(self, mock_execute_query):
"graph": "test-graph",
})

for type_str, value, expected_format in test_cases:
for type_str, value in test_cases:
with self.subTest(type=type_str, value=value):
# Create a property dictionary
prop_dict = {"key": "test_property", "value": value, "type": type_str}

request = {
Expand All @@ -99,22 +94,69 @@ def test_property_value_formatting(self, mock_execute_query):
request=request
)

# Extract the actual formatted value from the query
last_call = mock_execute_query.call_args[0] # Get the positional args
query = last_call[3] # The query is the 4th positional arg
last_call = mock_execute_query.call_args
query_params = last_call[1]['params']
param_types = last_call[1]['param_types']

# Find the WHERE clause in the query and extract the value
where_line = [line for line in query.split('\n') if 'WHERE' in line][0]
expected_pattern = f"n.test_property={expected_format}"
self.assertIn(expected_pattern, where_line,
f"Expected property value pattern {expected_pattern} not found in WHERE clause for type {type_str}")
query = last_call[0][3]
where_clause = [line.strip() for line in query.split('\n') if 'WHERE' in line][0]
expected_where = "WHERE n.test_property = @p0 and STRING(TO_JSON(n).identifier) = @uid"
self.assertEqual(where_clause, expected_where)

self.assertEqual(query_params['p0'], value)
self.assertEqual(param_types['p0'], getattr(mock_spanner.param_types, type_str))


@patch('spanner_graphs.graph_server.spanner')
@patch('spanner_graphs.graph_server.execute_query')
def test_property_value_formatting_multiple_properties(self, mock_execute_query, mock_spanner):
"""Test that multiple property values are correctly formatted."""
mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}}

props = [
{"key": "age", "value": 30, "type": "INT64"},
{"key": "name", "value": "John", "type": "STRING"},
]

params = json.dumps({
"project": "test-project",
"instance": "test-instance",
"database": "test-database",
"graph": "test-graph",
})

request = {
"uid": "test-uid",
"node_labels": ["Person"],
"node_properties": props,
"direction": "OUTGOING"
}

execute_node_expansion(
params_str=params,
request=request
)

last_call = mock_execute_query.call_args
query_params = last_call[1]['params']
param_types = last_call[1]['param_types']

query = last_call[0][3]
where_clause = [line.strip() for line in query.split('\n') if 'WHERE' in line][0]
expected_where = "WHERE n.age = @p0 and n.name = @p1 and STRING(TO_JSON(n).identifier) = @uid"
self.assertEqual(where_clause, expected_where)

self.assertEqual(query_params['p0'], 30)
self.assertEqual(param_types['p0'], mock_spanner.param_types.INT64)

self.assertEqual(query_params['p1'], "John")
self.assertEqual(param_types['p1'], mock_spanner.param_types.STRING)

@patch('spanner_graphs.graph_server.execute_query')
def test_property_value_formatting_no_type(self, mock_execute_query):
"""Test that property values are quoted when no type is provided."""
mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}}

# Create a property dictionary with string type (since null type is not allowed)
prop_dict = {"key": "test_property", "value": "test_value", "type": "STRING"}

params = json.dumps({
Expand All @@ -136,13 +178,13 @@ def test_property_value_formatting_no_type(self, mock_execute_query):
request=request
)

# Extract the actual formatted value from the query
last_call = mock_execute_query.call_args[0]
query = last_call[3]
where_line = [line for line in query.split('\n') if 'WHERE' in line][0]
expected_pattern = "n.test_property='''test_value'''"
self.assertIn(expected_pattern, where_line,
"Property value should be quoted when string type is provided")
last_call = mock_execute_query.call_args
query_params = last_call[1]['params']
param_types = last_call[1]['param_types']

self.assertIn('@p0', last_call[0][3])
self.assertEqual(query_params['p0'], "test_value")
self.assertIsNotNone(param_types['p0'])

if __name__ == '__main__':
unittest.main()