Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 63bac62

Browse files
committed
more fixes
1 parent 8976025 commit 63bac62

6 files changed

Lines changed: 362 additions & 232 deletions

File tree

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
SpannerServicer,
4242
start_mock_server,
4343
)
44+
from tests._helpers import is_multiplexed_enabled
4445

4546

4647
# Creates an aborted status with the smallest possible retry delay.
@@ -228,3 +229,109 @@ def database(self) -> Database:
228229
enable_interceptors_in_tests=True,
229230
)
230231
return self._database
232+
233+
def assert_requests_sequence(
234+
self,
235+
requests,
236+
expected_types,
237+
transaction_type,
238+
allow_multiple_batch_create=True,
239+
):
240+
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
241+
242+
Args:
243+
requests: List of requests from spanner_service.requests
244+
expected_types: List of expected request types (excluding session creation requests)
245+
transaction_type: TransactionType enum value to check multiplexed session status
246+
allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
247+
"""
248+
from google.cloud.spanner_v1 import (
249+
BatchCreateSessionsRequest,
250+
CreateSessionRequest,
251+
)
252+
253+
mux_enabled = is_multiplexed_enabled(transaction_type)
254+
idx = 0
255+
# Skip all leading BatchCreateSessionsRequest (for retries)
256+
if allow_multiple_batch_create:
257+
while idx < len(requests) and isinstance(
258+
requests[idx], BatchCreateSessionsRequest
259+
):
260+
idx += 1
261+
# For multiplexed, optionally skip a CreateSessionRequest
262+
if (
263+
mux_enabled
264+
and idx < len(requests)
265+
and isinstance(requests[idx], CreateSessionRequest)
266+
):
267+
idx += 1
268+
else:
269+
if mux_enabled:
270+
self.assertTrue(
271+
isinstance(requests[idx], BatchCreateSessionsRequest),
272+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
273+
)
274+
idx += 1
275+
self.assertTrue(
276+
isinstance(requests[idx], CreateSessionRequest),
277+
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}",
278+
)
279+
idx += 1
280+
else:
281+
self.assertTrue(
282+
isinstance(requests[idx], BatchCreateSessionsRequest),
283+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
284+
)
285+
idx += 1
286+
# Check the rest of the expected request types
287+
for expected_type in expected_types:
288+
self.assertTrue(
289+
isinstance(requests[idx], expected_type),
290+
f"Expected {expected_type} at index {idx}, got {type(requests[idx])}",
291+
)
292+
idx += 1
293+
self.assertEqual(
294+
idx, len(requests), f"Expected {idx} requests, got {len(requests)}"
295+
)
296+
297+
def adjust_request_id_sequence(self, expected_segments, requests, transaction_type):
298+
"""Adjust expected request ID sequence numbers based on actual session creation requests.
299+
300+
Args:
301+
expected_segments: List of expected (method, (sequence_numbers)) tuples
302+
requests: List of actual requests from spanner_service.requests
303+
transaction_type: TransactionType enum value to check multiplexed session status
304+
305+
Returns:
306+
List of adjusted expected segments with corrected sequence numbers
307+
"""
308+
from google.cloud.spanner_v1 import (
309+
BatchCreateSessionsRequest,
310+
CreateSessionRequest,
311+
ExecuteSqlRequest,
312+
BeginTransactionRequest,
313+
)
314+
315+
# Count session creation requests that come before the first non-session request
316+
session_requests_before = 0
317+
for req in requests:
318+
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
319+
session_requests_before += 1
320+
elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)):
321+
break
322+
323+
# For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
324+
# For non-multiplexed, we expect 1 session request (BatchCreateSessions)
325+
mux_enabled = is_multiplexed_enabled(transaction_type)
326+
expected_session_requests = 2 if mux_enabled else 1
327+
extra_session_requests = session_requests_before - expected_session_requests
328+
329+
# Adjust sequence numbers based on extra session requests
330+
adjusted_segments = []
331+
for method, seq_nums in expected_segments:
332+
# Adjust the sequence number (5th element in the tuple)
333+
adjusted_seq_nums = list(seq_nums)
334+
adjusted_seq_nums[4] += extra_session_requests
335+
adjusted_segments.append((method, tuple(adjusted_seq_nums)))
336+
337+
return adjusted_segments

tests/mockserver_tests/test_aborted_transaction.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import random
1515

1616
from google.cloud.spanner_v1 import (
17-
BatchCreateSessionsRequest,
1817
BeginTransactionRequest,
1918
CommitRequest,
2019
ExecuteSqlRequest,
@@ -32,6 +31,7 @@
3231
)
3332
from google.api_core import exceptions
3433
from test_utils import retry
34+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
3535

3636
retry_maybe_aborted_txn = retry.RetryErrors(
3737
exceptions.Aborted, max_tries=5, delay=0, backoff=1
@@ -46,29 +46,28 @@ def test_run_in_transaction_commit_aborted(self):
4646
# time that the transaction tries to commit. It will then be retried
4747
# and succeed.
4848
self.database.run_in_transaction(_insert_mutations)
49-
50-
# Verify that the transaction was retried.
5149
requests = self.spanner_service.requests
52-
self.assertEqual(5, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
55-
self.assertTrue(isinstance(requests[2], CommitRequest))
56-
# The transaction is aborted and retried.
57-
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
58-
self.assertTrue(isinstance(requests[4], CommitRequest))
50+
self.assert_requests_sequence(
51+
requests,
52+
[
53+
BeginTransactionRequest,
54+
CommitRequest,
55+
BeginTransactionRequest,
56+
CommitRequest,
57+
],
58+
TransactionType.READ_WRITE,
59+
)
5960

6061
def test_run_in_transaction_update_aborted(self):
6162
add_update_count("update my_table set my_col=1 where id=2", 1)
6263
add_error(SpannerServicer.ExecuteSql.__name__, aborted_status())
6364
self.database.run_in_transaction(_execute_update)
64-
65-
# Verify that the transaction was retried.
6665
requests = self.spanner_service.requests
67-
self.assertEqual(4, len(requests), msg=requests)
68-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
69-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
70-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
71-
self.assertTrue(isinstance(requests[3], CommitRequest))
66+
self.assert_requests_sequence(
67+
requests,
68+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
69+
TransactionType.READ_WRITE,
70+
)
7271

7372
def test_run_in_transaction_query_aborted(self):
7473
add_single_result(
@@ -79,28 +78,24 @@ def test_run_in_transaction_query_aborted(self):
7978
)
8079
add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status())
8180
self.database.run_in_transaction(_execute_query)
82-
83-
# Verify that the transaction was retried.
8481
requests = self.spanner_service.requests
85-
self.assertEqual(4, len(requests), msg=requests)
86-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
87-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
88-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
89-
self.assertTrue(isinstance(requests[3], CommitRequest))
82+
self.assert_requests_sequence(
83+
requests,
84+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
85+
TransactionType.READ_WRITE,
86+
)
9087

9188
def test_run_in_transaction_batch_dml_aborted(self):
9289
add_update_count("update my_table set my_col=1 where id=1", 1)
9390
add_update_count("update my_table set my_col=1 where id=2", 1)
9491
add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status())
9592
self.database.run_in_transaction(_execute_batch_dml)
96-
97-
# Verify that the transaction was retried.
9893
requests = self.spanner_service.requests
99-
self.assertEqual(4, len(requests), msg=requests)
100-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
101-
self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest))
102-
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
103-
self.assertTrue(isinstance(requests[3], CommitRequest))
94+
self.assert_requests_sequence(
95+
requests,
96+
[ExecuteBatchDmlRequest, ExecuteBatchDmlRequest, CommitRequest],
97+
TransactionType.READ_WRITE,
98+
)
10499

105100
def test_batch_commit_aborted(self):
106101
# Add an Aborted error for the Commit method on the mock server.
@@ -117,14 +112,12 @@ def test_batch_commit_aborted(self):
117112
(5, "David", "Lomond"),
118113
],
119114
)
120-
121-
# Verify that the transaction was retried.
122115
requests = self.spanner_service.requests
123-
self.assertEqual(3, len(requests), msg=requests)
124-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
125-
self.assertTrue(isinstance(requests[1], CommitRequest))
126-
# The transaction is aborted and retried.
127-
self.assertTrue(isinstance(requests[2], CommitRequest))
116+
self.assert_requests_sequence(
117+
requests,
118+
[CommitRequest, CommitRequest],
119+
TransactionType.READ_WRITE,
120+
)
128121

129122
@retry_maybe_aborted_txn
130123
def test_retry_helper(self):

tests/mockserver_tests/test_basics.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from google.cloud.spanner_dbapi import Connection
1717
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
1818
from google.cloud.spanner_v1 import (
19-
BatchCreateSessionsRequest,
2019
BeginTransactionRequest,
2120
ExecuteBatchDmlRequest,
2221
ExecuteSqlRequest,
@@ -25,6 +24,7 @@
2524
)
2625
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
2726
from google.cloud.spanner_v1.transaction import Transaction
27+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
2828

2929
from tests.mockserver_tests.mock_server_test_base import (
3030
MockServerTestBase,
@@ -36,6 +36,7 @@
3636
unavailable_status,
3737
add_execute_streaming_sql_results,
3838
)
39+
from tests._helpers import is_multiplexed_enabled
3940

4041

4142
class TestBasics(MockServerTestBase):
@@ -49,9 +50,11 @@ def test_select1(self):
4950
self.assertEqual(1, row[0])
5051
self.assertEqual(1, len(result_list))
5152
requests = self.spanner_service.requests
52-
self.assertEqual(2, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
53+
self.assert_requests_sequence(
54+
requests,
55+
[ExecuteSqlRequest],
56+
TransactionType.READ_ONLY,
57+
)
5558

5659
def test_create_table(self):
5760
database_admin_api = self.client.database_admin_api
@@ -84,13 +87,31 @@ def test_dbapi_partitioned_dml(self):
8487
# with no parameters.
8588
cursor.execute(sql, [])
8689
self.assertEqual(100, cursor.rowcount)
87-
8890
requests = self.spanner_service.requests
89-
self.assertEqual(3, len(requests), msg=requests)
90-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
91-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
92-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
93-
begin_request: BeginTransactionRequest = requests[1]
91+
self.assert_requests_sequence(
92+
requests,
93+
[BeginTransactionRequest, ExecuteSqlRequest],
94+
TransactionType.PARTITIONED,
95+
allow_multiple_batch_create=True,
96+
)
97+
# Find the first BeginTransactionRequest after session creation
98+
idx = 0
99+
from google.cloud.spanner_v1 import (
100+
BatchCreateSessionsRequest,
101+
CreateSessionRequest,
102+
)
103+
104+
while idx < len(requests) and isinstance(
105+
requests[idx], BatchCreateSessionsRequest
106+
):
107+
idx += 1
108+
if (
109+
is_multiplexed_enabled(TransactionType.PARTITIONED)
110+
and idx < len(requests)
111+
and isinstance(requests[idx], CreateSessionRequest)
112+
):
113+
idx += 1
114+
begin_request: BeginTransactionRequest = requests[idx]
94115
self.assertEqual(
95116
TransactionOptions(dict(partitioned_dml={})), begin_request.options
96117
)
@@ -106,11 +127,12 @@ def test_batch_create_sessions_unavailable(self):
106127
self.assertEqual(1, row[0])
107128
self.assertEqual(1, len(result_list))
108129
requests = self.spanner_service.requests
109-
self.assertEqual(3, len(requests), msg=requests)
110-
# The BatchCreateSessions call should be retried.
111-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
112-
self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest))
113-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
130+
self.assert_requests_sequence(
131+
requests,
132+
[ExecuteSqlRequest],
133+
TransactionType.READ_ONLY,
134+
allow_multiple_batch_create=True,
135+
)
114136

115137
def test_execute_streaming_sql_unavailable(self):
116138
add_select1_result()
@@ -125,11 +147,11 @@ def test_execute_streaming_sql_unavailable(self):
125147
self.assertEqual(1, row[0])
126148
self.assertEqual(1, len(result_list))
127149
requests = self.spanner_service.requests
128-
self.assertEqual(3, len(requests), msg=requests)
129-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
130-
# The ExecuteStreamingSql call should be retried.
131-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
132-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
150+
self.assert_requests_sequence(
151+
requests,
152+
[ExecuteSqlRequest, ExecuteSqlRequest],
153+
TransactionType.READ_ONLY,
154+
)
133155

134156
def test_last_statement_update(self):
135157
sql = "update my_table set my_col=1 where id=2"
@@ -199,9 +221,11 @@ def test_execute_streaming_sql_last_field(self):
199221
count += 1
200222
self.assertEqual(3, len(result_list))
201223
requests = self.spanner_service.requests
202-
self.assertEqual(2, len(requests), msg=requests)
203-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
204-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
224+
self.assert_requests_sequence(
225+
requests,
226+
[ExecuteSqlRequest],
227+
TransactionType.READ_ONLY,
228+
)
205229

206230

207231
def _execute_query(transaction: Transaction, sql: str):

0 commit comments

Comments
 (0)