From 32bf67d1af6da8c8278ea35038e628b5b587b2ff Mon Sep 17 00:00:00 2001 From: Vasily Gerasimov Date: Wed, 3 Jun 2026 18:16:43 +0200 Subject: [PATCH 1/3] Fixes --- tests/test_disable_discovery.py | 44 +++++++++++++++++++++++++++++++++ ydb/aio/connection.py | 8 ++++-- ydb/aio/pool.py | 36 +++++++++++++++++++++++---- ydb/connection.py | 8 ++++-- ydb/pool.py | 2 ++ 5 files changed, 89 insertions(+), 9 deletions(-) diff --git a/tests/test_disable_discovery.py b/tests/test_disable_discovery.py index 17e49c72e..2f3efd26a 100644 --- a/tests/test_disable_discovery.py +++ b/tests/test_disable_discovery.py @@ -142,6 +142,15 @@ def test_sync_driver_discovery_disabled_mock( driver.stop() +def test_sync_driver_discovery_disabled_stop_cleans_connection(driver_config_disabled_discovery, mock_connection): + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + ready_connection = mock_connection.return_value + + driver.stop() + + ready_connection.close.assert_called_once() + + def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection): """Test that when disable_discovery=False, the discovery thread is started (mock).""" with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: @@ -238,6 +247,41 @@ async def test_aio_driver_discovery_disabled_mock( teardown_async_mocks(mocks) +@pytest.mark.asyncio +async def test_aio_driver_discovery_disabled_retries_initial_connection(driver_config_disabled_discovery): + ready_attempts = [] + closed_connections = [] + + class FakeConnection: + def __init__(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = None + self._cleanup_callbacks = [] + + def add_cleanup_callback(self, callback): + self._cleanup_callbacks.append(callback) + + async def connection_ready(self, ready_timeout=10): + ready_attempts.append(self) + if len(ready_attempts) == 1: + raise ydb.issues.ConnectionLost("transient failure") + + async def close(self, grace=30): + closed_connections.append(self) + for callback in self._cleanup_callbacks: + callback(self) + + with unittest.mock.patch("ydb.aio.pool.Connection", FakeConnection): + driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery) + try: + await driver.wait(timeout=5, fail_fast=True) + + assert len(ready_attempts) == 2 + assert closed_connections == [ready_attempts[0]] + finally: + await driver.stop() + + @pytest.mark.asyncio async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection): """Test that when disable_discovery=False, the discovery is created (mock).""" diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index e5e57e3bf..dd1342d3f 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -258,8 +258,12 @@ async def destroy(self, grace: float = 0) -> None: :param grace: :return: None """ - if hasattr(self, "_channel") and hasattr(self._channel, "close"): - await self._channel.close(grace) + channel = getattr(self, "_channel", None) + if channel is not None and hasattr(channel, "close"): + await channel.close(grace) + + self._stub_instances.clear() + self._channel = None def add_cleanup_callback(self, callback: Callable[["Connection"], None]) -> None: self._cleanup_callbacks.append(callback) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 5eb51b5c9..4d952086c 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -248,17 +248,36 @@ def __init__(self, driver_config: "DriverConfig") -> None: self._store = ConnectionsCache(driver_config.use_all_nodes) self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config) self._stopped = False + self._stopping = False self._discovery: Optional[Discovery] = None self._discovery_task: "asyncio.Task[None]" if driver_config.disable_discovery: # If discovery is disabled, just add the initial endpoint to the store async def init_connection() -> None: - ready_connection = Connection(self._driver_config.endpoint, self._driver_config) - await ready_connection.connection_ready( - ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10) - ) - self._store.add(ready_connection) + ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) + while not self._stopping: + ready_connection = Connection(self._driver_config.endpoint, self._driver_config) + try: + await ready_connection.connection_ready(ready_timeout=ready_timeout) + except asyncio.CancelledError: + try: + await ready_connection.close() + except Exception: + logger.debug("Failed to close cancelled initial connection", exc_info=True) + raise + except Exception: + logger.debug("Initial connection attempt failed", exc_info=True) + try: + await ready_connection.close() + except Exception: + logger.debug("Failed to close unsuccessful initial connection", exc_info=True) + if not self._stopping: + await asyncio.sleep(1) + continue + + self._store.add(ready_connection) + return # Create and schedule the task to initialize the connection self._discovery_task = asyncio.get_event_loop().create_task(init_connection()) @@ -268,6 +287,7 @@ async def init_connection() -> None: self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run()) async def stop(self, timeout: int = 10) -> None: # type: ignore[override] # async override of sync method + self._stopping = True if self._discovery: self._discovery.stop() await self._grpc_init.close() @@ -275,6 +295,12 @@ async def stop(self, timeout: int = 10) -> None: # type: ignore[override] # as await asyncio.wait_for(self._discovery_task, timeout=timeout) except asyncio.TimeoutError: self._discovery_task.cancel() + try: + await self._discovery_task + except asyncio.CancelledError: + pass + if self._discovery is None: + await self._store.cleanup() self._stopped = True def _on_disconnected(self, connection: Connection) -> Callable[[], Any]: diff --git a/ydb/connection.py b/ydb/connection.py index d64438ef5..7e5ad3675 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -602,8 +602,12 @@ def close(self): self.destroy() def destroy(self): - if hasattr(self, "_channel") and hasattr(self._channel, "close"): - self._channel.close() + channel = getattr(self, "_channel", None) + if channel is not None and hasattr(channel, "close"): + channel.close() + + self._stub_instances.clear() + self._channel = None def ready_future(self): """ diff --git a/ydb/pool.py b/ydb/pool.py index 31bfe8ba2..4e88a0826 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -436,6 +436,8 @@ def stop(self, timeout: int = 10) -> None: self._grpc_init.close() if self._discovery_thread: self._discovery_thread.join(timeout) + else: + self._store.cleanup() def async_wait(self, fail_fast: bool = False) -> "futures.Future[None]": """ From 7dd31356dda5a18eef0899343c80819423110e51 Mon Sep 17 00:00:00 2001 From: Vasily Gerasimov Date: Thu, 4 Jun 2026 17:18:04 +0200 Subject: [PATCH 2/3] Implement in syncronous driver --- tests/test_disable_discovery.py | 27 +++++++++++++++++++++++++++ ydb/pool.py | 19 +++++++++++++------ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/test_disable_discovery.py b/tests/test_disable_discovery.py index 2f3efd26a..c7e2daae0 100644 --- a/tests/test_disable_discovery.py +++ b/tests/test_disable_discovery.py @@ -151,6 +151,33 @@ def test_sync_driver_discovery_disabled_stop_cleans_connection(driver_config_dis ready_connection.close.assert_called_once() +def test_sync_driver_discovery_disabled_retries_initial_connection( + driver_config_disabled_discovery, mock_discovery_resolver +): + ready_connection = unittest.mock.MagicMock() + ready_connection.endpoint = "localhost:2136" + ready_connection.node_id = "mock_node_id" + + with unittest.mock.patch( + "ydb.connection.Connection.ready_factory", + side_effect=[None, ready_connection], + ) as mock_factory: + with unittest.mock.patch("ydb.pool.time.sleep") as mock_sleep: + with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + try: + driver.wait(timeout=1, fail_fast=True) + + assert mock_factory.call_count == 2 + mock_sleep.assert_called_once_with(1) + mock_discovery_class.assert_not_called() + assert not mock_discovery_resolver.called + finally: + driver.stop() + + ready_connection.close.assert_called_once() + + def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection): """Test that when disable_discovery=False, the discovery thread is started (mock).""" with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: diff --git a/ydb/pool.py b/ydb/pool.py index 4e88a0826..f57dc6397 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -4,6 +4,7 @@ import abc import threading import logging +import time from concurrent import futures import collections import random @@ -404,12 +405,18 @@ def __init__(self, driver_config: "DriverConfig") -> None: if driver_config.disable_discovery: # If discovery is disabled, just add the initial endpoint to the store - ready_connection = connection_impl.Connection.ready_factory( - self._driver_config.endpoint, - self._driver_config, - ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10), - ) - self._store.add(ready_connection) + ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) + while True: + ready_connection = connection_impl.Connection.ready_factory( + self._driver_config.endpoint, + self._driver_config, + ready_timeout=ready_timeout, + ) + if self._store.add(ready_connection): + break + + logger.debug("Initial connection attempt failed") + time.sleep(1) self._discovery_thread = None else: # Start discovery thread as usual From 2f704bd4f39bb1b114588c0b359ba9a6b50a5187 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 4 Jun 2026 18:38:41 +0300 Subject: [PATCH 3/3] Retry initial connection in background thread for sync driver --- tests/test_disable_discovery.py | 23 ++++++++-------- ydb/pool.py | 48 +++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/tests/test_disable_discovery.py b/tests/test_disable_discovery.py index c7e2daae0..4c30b3a61 100644 --- a/tests/test_disable_discovery.py +++ b/tests/test_disable_discovery.py @@ -146,6 +146,7 @@ def test_sync_driver_discovery_disabled_stop_cleans_connection(driver_config_dis driver = ydb.Driver(driver_config=driver_config_disabled_discovery) ready_connection = mock_connection.return_value + driver.wait(timeout=5) driver.stop() ready_connection.close.assert_called_once() @@ -162,18 +163,16 @@ def test_sync_driver_discovery_disabled_retries_initial_connection( "ydb.connection.Connection.ready_factory", side_effect=[None, ready_connection], ) as mock_factory: - with unittest.mock.patch("ydb.pool.time.sleep") as mock_sleep: - with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: - driver = ydb.Driver(driver_config=driver_config_disabled_discovery) - try: - driver.wait(timeout=1, fail_fast=True) - - assert mock_factory.call_count == 2 - mock_sleep.assert_called_once_with(1) - mock_discovery_class.assert_not_called() - assert not mock_discovery_resolver.called - finally: - driver.stop() + with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + try: + driver.wait(timeout=5, fail_fast=True) + + assert mock_factory.call_count == 2 + mock_discovery_class.assert_not_called() + assert not mock_discovery_resolver.called + finally: + driver.stop() ready_connection.close.assert_called_once() diff --git a/ydb/pool.py b/ydb/pool.py index f57dc6397..f1a4bdf94 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -4,7 +4,6 @@ import abc import threading import logging -import time from concurrent import futures import collections import random @@ -402,29 +401,41 @@ def __init__(self, driver_config: "DriverConfig") -> None: self._store = ConnectionsCache(driver_config.use_all_nodes, driver_config.tracer) self.tracer = driver_config.tracer self._grpc_init = connection_impl.Connection(self._driver_config.endpoint, self._driver_config) + self._stopped = False + self._stop_guard = threading.Lock() + self._stop_event = threading.Event() + self._init_thread: Optional[threading.Thread] = None if driver_config.disable_discovery: - # If discovery is disabled, just add the initial endpoint to the store - ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) - while True: - ready_connection = connection_impl.Connection.ready_factory( - self._driver_config.endpoint, - self._driver_config, - ready_timeout=ready_timeout, - ) - if self._store.add(ready_connection): - break - - logger.debug("Initial connection attempt failed") - time.sleep(1) + # If discovery is disabled, establish the initial connection in a + # background thread, retrying until it succeeds or the pool is stopped. + # Doing this off the constructor keeps wait(timeout) as the blocking + # point and lets stop() interrupt the retry loop. self._discovery_thread = None + self._init_thread = threading.Thread( + name="ydb_driver_initial_connection", + target=self._init_connection, + daemon=True, + ) + self._init_thread.start() else: # Start discovery thread as usual self._discovery_thread = Discovery(self._store, self._driver_config) self._discovery_thread.start() - self._stopped = False - self._stop_guard = threading.Lock() + def _init_connection(self) -> None: + ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) + while not self._stopped: + ready_connection = connection_impl.Connection.ready_factory( + self._driver_config.endpoint, + self._driver_config, + ready_timeout=ready_timeout, + ) + if self._store.add(ready_connection): + return + + logger.debug("Initial connection attempt failed") + self._stop_event.wait(1) def stop(self, timeout: int = 10) -> None: """ @@ -438,12 +449,15 @@ def stop(self, timeout: int = 10) -> None: return self._stopped = True + self._stop_event.set() if self._discovery_thread: self._discovery_thread.stop() self._grpc_init.close() if self._discovery_thread: self._discovery_thread.join(timeout) - else: + if self._init_thread: + self._init_thread.join(timeout) + if self._discovery_thread is None: self._store.cleanup() def async_wait(self, fail_fast: bool = False) -> "futures.Future[None]":