Skip to content

Commit fe5edd8

Browse files
committed
fix: allow hyphens in identifiers, fix test fixtures, apply pyink formatting
- Allow hyphens in SQL identifier regex (GCP resource names use them) - Update PostgreSQL test fixtures to use proper Vertex AI endpoint URIs - Apply pyink formatting to pass pre-commit checks - Fix isort import ordering in validation tests
1 parent a0fa50f commit fe5edd8

3 files changed

Lines changed: 43 additions & 33 deletions

File tree

src/google/adk/tools/spanner/search_tool.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,25 @@
3232
from .settings import EXACT_NEAREST_NEIGHBORS
3333
from .settings import SpannerToolSettings
3434

35-
# Pattern for valid SQL identifiers: alphanumeric, underscores, dots (for
36-
# schema-qualified names), and backtick/double-quote quoting.
35+
# Pattern for valid SQL identifiers: alphanumeric, underscores, hyphens,
36+
# dots (for schema-qualified names), and backtick/double-quote quoting.
3737
_SAFE_IDENTIFIER_RE = re.compile(
38-
r'^(?:[A-Za-z_][A-Za-z0-9_]*' # unquoted identifier
39-
r'(?:\.[A-Za-z_][A-Za-z0-9_]*)*' # optional schema.table
40-
r'|`[^`]+`' # backtick-quoted
41-
r'|"[^"]+")$' # double-quote-quoted
38+
r"^(?:[A-Za-z_][A-Za-z0-9_-]*" # unquoted identifier
39+
r"(?:\.[A-Za-z_][A-Za-z0-9_-]*)*" # optional schema.table
40+
r"|`[^`]+`" # backtick-quoted
41+
r'|"[^"]+")$' # double-quote-quoted
4242
)
4343

4444
# Patterns that should never appear in an additional_filter value when
4545
# the filter is populated by the LLM at runtime.
4646
_FILTER_DENY_PATTERNS = re.compile(
47-
r';\s*' # statement separator
48-
r'|--' # single-line comment
49-
r'|/\*' # block comment start
50-
r'|\*/' # block comment end
51-
r'|\bUNION\b' # UNION-based injection
52-
r'|\bINTO\b\s+\bOUTFILE\b' # INTO OUTFILE
53-
, re.IGNORECASE
47+
r";\s*" # statement separator
48+
r"|--" # single-line comment
49+
r"|/\*" # block comment start
50+
r"|\*/" # block comment end
51+
r"|\bUNION\b" # UNION-based injection
52+
r"|\bINTO\b\s+\bOUTFILE\b", # INTO OUTFILE
53+
re.IGNORECASE,
5454
)
5555

5656

@@ -71,7 +71,7 @@ def _validate_identifier(value: str, param_name: str) -> str:
7171
raise ValueError(
7272
f"Invalid SQL identifier for {param_name}: {value!r}. "
7373
"Identifiers must contain only alphanumeric characters, underscores, "
74-
"and dots, or be quoted with backticks or double quotes."
74+
"hyphens, and dots, or be quoted with backticks or double quotes."
7575
)
7676
return value.strip()
7777

@@ -165,13 +165,13 @@ def _generate_postgresql_for_embedding_query(
165165
) -> str:
166166
# Validate endpoint format: projects/.../locations/.../publishers/.../models/...
167167
if not re.match(
168-
r'^projects/[\w-]+/locations/[\w-]+/publishers/[\w-]+/models/[\w.-]+$',
168+
r"^projects/[\w-]+/locations/[\w-]+/publishers/[\w-]+/models/[\w.-]+$",
169169
vertex_ai_embedding_model_endpoint,
170170
):
171171
raise ValueError(
172-
f"Invalid Vertex AI endpoint format: "
172+
"Invalid Vertex AI endpoint format: "
173173
f"{vertex_ai_embedding_model_endpoint!r}. Expected format: "
174-
f"projects/$project/locations/$location/publishers/google/models/$model"
174+
"projects/$project/locations/$location/publishers/google/models/$model"
175175
)
176176
instances_json = f"""
177177
'instances',

tests/unittests/tools/spanner/test_search_tool.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ async def test_similarity_search_postgresql_knn_success(
263263
embedding_column_to_search="embedding_col",
264264
columns=["col1"],
265265
embedding_options={
266-
"spanner_postgresql_vertex_ai_embedding_model_endpoint": (
267-
"test_endpoint"
268-
)
266+
"spanner_postgresql_vertex_ai_embedding_model_endpoint": "projects/test-project/locations/us-central1/publishers/google/models/text-embedding-005"
269267
},
270268
credentials=mock_credentials,
271269
)
@@ -301,9 +299,7 @@ async def test_similarity_search_postgresql_ann_unsupported(
301299
embedding_column_to_search="embedding_col",
302300
columns=["col1"],
303301
embedding_options={
304-
"spanner_postgresql_vertex_ai_embedding_model_endpoint": (
305-
"test_endpoint"
306-
)
302+
"spanner_postgresql_vertex_ai_embedding_model_endpoint": "projects/test-project/locations/us-central1/publishers/google/models/text-embedding-005"
307303
},
308304
credentials=mock_credentials,
309305
search_options={

tests/unittests/tools/spanner/test_spanner_sql_validation.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
via LLM-populated tool parameters).
2020
"""
2121

22-
import pytest
23-
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
24-
25-
from google.adk.tools.spanner.search_tool import _generate_sql_for_knn
2622
from google.adk.tools.spanner.search_tool import _generate_sql_for_ann
23+
from google.adk.tools.spanner.search_tool import _generate_sql_for_knn
2724
from google.adk.tools.spanner.search_tool import _validate_additional_filter
2825
from google.adk.tools.spanner.search_tool import _validate_column_list
2926
from google.adk.tools.spanner.search_tool import _validate_identifier
27+
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
28+
import pytest
3029

3130

3231
class TestValidateIdentifier:
@@ -36,7 +35,10 @@ def test_simple_identifier(self):
3635
assert _validate_identifier("documents", "test") == "documents"
3736

3837
def test_schema_qualified_identifier(self):
39-
assert _validate_identifier("my_schema.my_table", "test") == "my_schema.my_table"
38+
assert (
39+
_validate_identifier("my_schema.my_table", "test")
40+
== "my_schema.my_table"
41+
)
4042

4143
def test_identifier_with_underscores(self):
4244
assert _validate_identifier("embedding_col_1", "test") == "embedding_col_1"
@@ -56,7 +58,8 @@ def test_rejects_join_injection(self):
5658
def test_rejects_subquery_in_column(self):
5759
with pytest.raises(ValueError, match="Invalid SQL identifier"):
5860
_validate_identifier(
59-
"(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS schema_dump",
61+
"(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES)"
62+
" AS schema_dump",
6063
"columns",
6164
)
6265

@@ -84,7 +87,10 @@ def test_rejects_subquery_column(self):
8487
with pytest.raises(ValueError, match="Invalid SQL identifier"):
8588
_validate_column_list(
8689
[
87-
"(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS dump",
90+
(
91+
"(SELECT STRING_AGG(table_name, ',') FROM"
92+
" INFORMATION_SCHEMA.TABLES) AS dump"
93+
),
8894
"content",
8995
],
9096
"columns",
@@ -98,7 +104,10 @@ def test_none_filter(self):
98104
assert _validate_additional_filter(None) is None
99105

100106
def test_simple_filter(self):
101-
assert _validate_additional_filter("price_in_cents < 100000") == "price_in_cents < 100000"
107+
assert (
108+
_validate_additional_filter("price_in_cents < 100000")
109+
== "price_in_cents < 100000"
110+
)
102111

103112
def test_rejects_union(self):
104113
with pytest.raises(ValueError, match="UNION"):
@@ -142,7 +151,9 @@ def test_rejects_union_in_filter(self):
142151
table_name="documents",
143152
embedding_column_to_search="embedding",
144153
columns=["content"],
145-
additional_filter="1=1 UNION ALL SELECT password, 0.0 FROM admin_credentials",
154+
additional_filter=(
155+
"1=1 UNION ALL SELECT password, 0.0 FROM admin_credentials"
156+
),
146157
distance_type="COSINE",
147158
top_k=10,
148159
)
@@ -166,7 +177,10 @@ def test_rejects_subquery_in_columns(self):
166177
table_name="documents",
167178
embedding_column_to_search="embedding",
168179
columns=[
169-
"(SELECT STRING_AGG(table_name, ',') FROM INFORMATION_SCHEMA.TABLES) AS schema_dump",
180+
(
181+
"(SELECT STRING_AGG(table_name, ',') FROM"
182+
" INFORMATION_SCHEMA.TABLES) AS schema_dump"
183+
),
170184
],
171185
additional_filter=None,
172186
distance_type="COSINE",

0 commit comments

Comments
 (0)