diff --git a/tests/test_disable_discovery.py b/tests/test_disable_discovery.py index 17e49c72e..4c30b3a61 100644 --- a/tests/test_disable_discovery.py +++ b/tests/test_disable_discovery.py @@ -142,6 +142,41 @@ def test_sync_driver_discovery_disabled_mock( driver.stop() +def test_sync_driver_discovery_disabled_stop_cleans_connection(driver_config_disabled_discovery, mock_connection): + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + ready_connection = mock_connection.return_value + + driver.wait(timeout=5) + driver.stop() + + ready_connection.close.assert_called_once() + + +def test_sync_driver_discovery_disabled_retries_initial_connection( + driver_config_disabled_discovery, mock_discovery_resolver +): + ready_connection = unittest.mock.MagicMock() + ready_connection.endpoint = "localhost:2136" + ready_connection.node_id = "mock_node_id" + + with unittest.mock.patch( + "ydb.connection.Connection.ready_factory", + side_effect=[None, ready_connection], + ) as mock_factory: + with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + try: + driver.wait(timeout=5, fail_fast=True) + + assert mock_factory.call_count == 2 + mock_discovery_class.assert_not_called() + assert not mock_discovery_resolver.called + finally: + driver.stop() + + ready_connection.close.assert_called_once() + + def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection): """Test that when disable_discovery=False, the discovery thread is started (mock).""" with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: @@ -238,6 +273,41 @@ async def test_aio_driver_discovery_disabled_mock( teardown_async_mocks(mocks) +@pytest.mark.asyncio +async def test_aio_driver_discovery_disabled_retries_initial_connection(driver_config_disabled_discovery): + ready_attempts = [] + closed_connections = [] + + class FakeConnection: + def __init__(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = None + self._cleanup_callbacks = [] + + def add_cleanup_callback(self, callback): + self._cleanup_callbacks.append(callback) + + async def connection_ready(self, ready_timeout=10): + ready_attempts.append(self) + if len(ready_attempts) == 1: + raise ydb.issues.ConnectionLost("transient failure") + + async def close(self, grace=30): + closed_connections.append(self) + for callback in self._cleanup_callbacks: + callback(self) + + with unittest.mock.patch("ydb.aio.pool.Connection", FakeConnection): + driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery) + try: + await driver.wait(timeout=5, fail_fast=True) + + assert len(ready_attempts) == 2 + assert closed_connections == [ready_attempts[0]] + finally: + await driver.stop() + + @pytest.mark.asyncio async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection): """Test that when disable_discovery=False, the discovery is created (mock).""" diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index e5e57e3bf..dd1342d3f 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -258,8 +258,12 @@ async def destroy(self, grace: float = 0) -> None: :param grace: :return: None """ - if hasattr(self, "_channel") and hasattr(self._channel, "close"): - await self._channel.close(grace) + channel = getattr(self, "_channel", None) + if channel is not None and hasattr(channel, "close"): + await channel.close(grace) + + self._stub_instances.clear() + self._channel = None def add_cleanup_callback(self, callback: Callable[["Connection"], None]) -> None: self._cleanup_callbacks.append(callback) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 5eb51b5c9..4d952086c 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -248,17 +248,36 @@ def __init__(self, driver_config: "DriverConfig") -> None: self._store = ConnectionsCache(driver_config.use_all_nodes) self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config) self._stopped = False + self._stopping = False self._discovery: Optional[Discovery] = None self._discovery_task: "asyncio.Task[None]" if driver_config.disable_discovery: # If discovery is disabled, just add the initial endpoint to the store async def init_connection() -> None: - ready_connection = Connection(self._driver_config.endpoint, self._driver_config) - await ready_connection.connection_ready( - ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10) - ) - self._store.add(ready_connection) + ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) + while not self._stopping: + ready_connection = Connection(self._driver_config.endpoint, self._driver_config) + try: + await ready_connection.connection_ready(ready_timeout=ready_timeout) + except asyncio.CancelledError: + try: + await ready_connection.close() + except Exception: + logger.debug("Failed to close cancelled initial connection", exc_info=True) + raise + except Exception: + logger.debug("Initial connection attempt failed", exc_info=True) + try: + await ready_connection.close() + except Exception: + logger.debug("Failed to close unsuccessful initial connection", exc_info=True) + if not self._stopping: + await asyncio.sleep(1) + continue + + self._store.add(ready_connection) + return # Create and schedule the task to initialize the connection self._discovery_task = asyncio.get_event_loop().create_task(init_connection()) @@ -268,6 +287,7 @@ async def init_connection() -> None: self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run()) async def stop(self, timeout: int = 10) -> None: # type: ignore[override] # async override of sync method + self._stopping = True if self._discovery: self._discovery.stop() await self._grpc_init.close() @@ -275,6 +295,12 @@ async def stop(self, timeout: int = 10) -> None: # type: ignore[override] # as await asyncio.wait_for(self._discovery_task, timeout=timeout) except asyncio.TimeoutError: self._discovery_task.cancel() + try: + await self._discovery_task + except asyncio.CancelledError: + pass + if self._discovery is None: + await self._store.cleanup() self._stopped = True def _on_disconnected(self, connection: Connection) -> Callable[[], Any]: diff --git a/ydb/connection.py b/ydb/connection.py index d64438ef5..7e5ad3675 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -602,8 +602,12 @@ def close(self): self.destroy() def destroy(self): - if hasattr(self, "_channel") and hasattr(self._channel, "close"): - self._channel.close() + channel = getattr(self, "_channel", None) + if channel is not None and hasattr(channel, "close"): + channel.close() + + self._stub_instances.clear() + self._channel = None def ready_future(self): """ diff --git a/ydb/pool.py b/ydb/pool.py index 31bfe8ba2..f1a4bdf94 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -401,23 +401,41 @@ def __init__(self, driver_config: "DriverConfig") -> None: self._store = ConnectionsCache(driver_config.use_all_nodes, driver_config.tracer) self.tracer = driver_config.tracer self._grpc_init = connection_impl.Connection(self._driver_config.endpoint, self._driver_config) + self._stopped = False + self._stop_guard = threading.Lock() + self._stop_event = threading.Event() + self._init_thread: Optional[threading.Thread] = None if driver_config.disable_discovery: - # If discovery is disabled, just add the initial endpoint to the store - ready_connection = connection_impl.Connection.ready_factory( - self._driver_config.endpoint, - self._driver_config, - ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10), - ) - self._store.add(ready_connection) + # If discovery is disabled, establish the initial connection in a + # background thread, retrying until it succeeds or the pool is stopped. + # Doing this off the constructor keeps wait(timeout) as the blocking + # point and lets stop() interrupt the retry loop. self._discovery_thread = None + self._init_thread = threading.Thread( + name="ydb_driver_initial_connection", + target=self._init_connection, + daemon=True, + ) + self._init_thread.start() else: # Start discovery thread as usual self._discovery_thread = Discovery(self._store, self._driver_config) self._discovery_thread.start() - self._stopped = False - self._stop_guard = threading.Lock() + def _init_connection(self) -> None: + ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10) + while not self._stopped: + ready_connection = connection_impl.Connection.ready_factory( + self._driver_config.endpoint, + self._driver_config, + ready_timeout=ready_timeout, + ) + if self._store.add(ready_connection): + return + + logger.debug("Initial connection attempt failed") + self._stop_event.wait(1) def stop(self, timeout: int = 10) -> None: """ @@ -431,11 +449,16 @@ def stop(self, timeout: int = 10) -> None: return self._stopped = True + self._stop_event.set() if self._discovery_thread: self._discovery_thread.stop() self._grpc_init.close() if self._discovery_thread: self._discovery_thread.join(timeout) + if self._init_thread: + self._init_thread.join(timeout) + if self._discovery_thread is None: + self._store.cleanup() def async_wait(self, fail_fast: bool = False) -> "futures.Future[None]": """