diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 43034d4f..4c9b538e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -217,6 +217,7 @@ class ReaderReconnector: _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] + _closed: bool def __init__( self, @@ -233,6 +234,7 @@ def __init__( self._state_changed = asyncio.Event() self._stream_reader = None + self._closed = False self._background_tasks.add(asyncio.create_task(self._connection_loop())) self._first_error = asyncio.get_running_loop().create_future() @@ -241,6 +243,8 @@ def __init__( async def _connection_loop(self): attempt = 0 while True: + if self._closed: + return try: logger.debug("reader %s connect attempt %s", self._id, attempt) self._stream_reader = await ReaderStream.create(self._id, self._driver, self._settings) @@ -266,8 +270,12 @@ async def _connection_loop(self): # noinspection PyBroadException try: await self._stream_reader.close(flush=False) - except BaseException: - # supress any error on close stream reader + except asyncio.CancelledError: + # propagate cancellation (e.g. from reader.close()) so the loop stops + # instead of swallowing it and reconnecting into a zombie stream + raise + except Exception: + # suppress any error on close stream reader pass async def wait_message(self): @@ -431,8 +439,16 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co async def close(self, flush: bool): logger.debug("reader reconnector %s close", self._id) + # Mark closed so the connection loop won't start a new stream, then close the + # current stream with the requested flush before cancelling the loop. On a normal + # close this flushes pending commits; cancelling the loop first would let it close + # the stream with flush=False instead and skip the flush. + self._closed = True if self._stream_reader: await self._stream_reader.close(flush) + # Wake any pending wait_message() waiter (e.g. a concurrent receive) so it doesn't + # hang if the loop was reconnecting when close() cancelled it. + self._set_first_error(TopicReaderStreamClosedError()) for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index cb7ce408..388c50f8 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1563,6 +1563,50 @@ async def stream_create( reader_stream_mock_with_error.wait_error.assert_any_await() reader_stream_mock_with_error.wait_messages.assert_any_await() + async def test_close_during_reconnect_does_not_hang(self): + # The connection loop must stop on reader.close() even while it is closing the old + # stream during a reconnect, instead of swallowing the cancellation and bringing up + # a new (zombie) stream while close() hangs forever. + finally_close_started = asyncio.Event() + held = {"done": False} + + async def wait_error(): + raise issues.Unavailable("trigger reconnect") + + async def slow_close(flush=False): + if not held["done"]: + held["done"] = True + finally_close_started.set() + await asyncio.Event().wait() # parked until reader.close() cancels the loop + + stream1 = mock.Mock(ReaderStream) + stream1._id = 1 + stream1.wait_error = mock.AsyncMock(side_effect=wait_error) + stream1.close = mock.AsyncMock(side_effect=slow_close) + + async def wait_forever(): + await asyncio.Future() + + stream2 = mock.Mock(ReaderStream) + stream2._id = 2 + stream2.wait_error = mock.AsyncMock(side_effect=wait_forever) + stream2.close = mock.AsyncMock() + + create_calls = 0 + + async def stream_create(reader_reconnector_id, driver, settings): + nonlocal create_calls + create_calls += 1 + return stream1 if create_calls == 1 else stream2 + + with mock.patch.object(ReaderStream, "create", stream_create): + reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) + await asyncio.wait_for(finally_close_started.wait(), timeout=2) + await asyncio.wait_for(reconnector.close(flush=False), timeout=5) + + # The loop stopped on close instead of reconnecting into a second (zombie) stream. + assert create_calls == 1 + async def test_wait_error_returns_on_cancelled_error_from_receive(self, default_reader_settings): receive_call = 0