Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/topics/test_topic_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,49 @@ async def callee(tx: ydb.aio.QueryTxContext):
msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT)
assert msg.data.decode() == "123"

async def test_tx_commit_after_reconnect_does_not_commit_stale_offsets(
self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer
):
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
async with ydb.aio.QuerySessionPool(driver) as pool:
async with pool.checkout() as session:
tx = session.transaction()
await tx.begin()

batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT)
assert batch.messages[0].data.decode() == "123"

reconnector = reader._reconnector
old_stream = reconnector._stream_reader

with mock.patch.object(
reconnector,
"_do_commit_batches_with_tx_call",
wraps=reconnector._do_commit_batches_with_tx_call,
) as update_offsets_call:
# Force a reconnect between receive_batch_with_tx() and commit, so the
# batch belongs to a partition session that no longer exists.
old_stream._set_first_error(ydb.issues.ConnectionLost("forced reconnect"))
for _ in range(100):
await asyncio.sleep(0.05)
current = reconnector._stream_reader
if current is not None and current is not old_stream and current._started:
break
assert reconnector._stream_reader is not old_stream

# Committing the stale batch must fail loudly instead of silently
# sending a gapped UpdateOffsetsInTransaction for the dead session.
with pytest.raises(ydb.Error):
await tx.commit()

update_offsets_call.assert_not_called()

assert len(reader._reconnector._tx_to_batches_map) == 0

# The consumer offset must not have advanced: the message is read again.
msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT)
assert msg.data.decode() == "123"


class TestTopicTransactionalReaderSync:
def test_commit(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer):
Expand Down
32 changes: 31 additions & 1 deletion ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,44 @@ def _init_tx(self, tx: "BaseQueryTxContext"):
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop)
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop)

def _batch_partition_session_expired(self, batch: datatypes.PublicBatch) -> bool:
# A batch is expired if the reader reconnected after it was received: its partition
# session no longer belongs to the current stream. Mirrors the guard in
# ReaderStream.commit() for the non-transactional commit path.
stream = self._stream_reader
partition_session = batch._partition_session
return (
stream is None
or partition_session.reader_stream_id != stream._id
or partition_session.id not in stream._partition_sessions
)

async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"):
tx_id = tx.tx_id
if tx_id is None:
raise TopicReaderError("Transaction ID is None")

batches = self._tx_to_batches_map[tx_id]

if any(self._batch_partition_session_expired(batch) for batch in batches):
# The reader reconnected between receive_batch_with_tx() and tx.commit(), so
# these offsets belong to a partition session that no longer exists. Committing
# them would send a stale/gapped range (server "Gap", issue_code 2011) while the
# client believes the commit succeeded. Fail the tx instead (retriable) without
# sending the request; the AFTER_COMMIT handler then reconnects to reset the
# read-ahead state, and the pool re-reads from the committed offset.
err = issues.ClientInternalError(
"Topic reader partition session expired before tx commit; "
"offsets were not committed, the transaction will be retried"
)
tx._set_external_error(err)
del self._tx_to_batches_map[tx_id]
return

grouped_batches: Dict[str, Dict[int, typing.List[datatypes.PublicBatch]]] = defaultdict(
lambda: defaultdict(list)
)
for batch in self._tx_to_batches_map[tx_id]:
for batch in batches:
grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch)

consumer = self._settings.consumer
Expand Down
Loading