Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/test_disable_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,41 @@ def test_sync_driver_discovery_disabled_mock(
driver.stop()


def test_sync_driver_discovery_disabled_stop_cleans_connection(driver_config_disabled_discovery, mock_connection):
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
ready_connection = mock_connection.return_value

driver.wait(timeout=5)
driver.stop()

ready_connection.close.assert_called_once()


def test_sync_driver_discovery_disabled_retries_initial_connection(
driver_config_disabled_discovery, mock_discovery_resolver
):
ready_connection = unittest.mock.MagicMock()
ready_connection.endpoint = "localhost:2136"
ready_connection.node_id = "mock_node_id"

with unittest.mock.patch(
"ydb.connection.Connection.ready_factory",
side_effect=[None, ready_connection],
) as mock_factory:
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
try:
driver.wait(timeout=5, fail_fast=True)

assert mock_factory.call_count == 2
mock_discovery_class.assert_not_called()
assert not mock_discovery_resolver.called
finally:
driver.stop()

ready_connection.close.assert_called_once()


def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection):
"""Test that when disable_discovery=False, the discovery thread is started (mock)."""
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
Expand Down Expand Up @@ -238,6 +273,41 @@ async def test_aio_driver_discovery_disabled_mock(
teardown_async_mocks(mocks)


@pytest.mark.asyncio
async def test_aio_driver_discovery_disabled_retries_initial_connection(driver_config_disabled_discovery):
ready_attempts = []
closed_connections = []

class FakeConnection:
def __init__(self, endpoint, driver_config, endpoint_options=None):
self.endpoint = endpoint
self.node_id = None
self._cleanup_callbacks = []

def add_cleanup_callback(self, callback):
self._cleanup_callbacks.append(callback)

async def connection_ready(self, ready_timeout=10):
ready_attempts.append(self)
if len(ready_attempts) == 1:
raise ydb.issues.ConnectionLost("transient failure")

async def close(self, grace=30):
closed_connections.append(self)
for callback in self._cleanup_callbacks:
callback(self)

with unittest.mock.patch("ydb.aio.pool.Connection", FakeConnection):
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
try:
await driver.wait(timeout=5, fail_fast=True)

assert len(ready_attempts) == 2
assert closed_connections == [ready_attempts[0]]
finally:
await driver.stop()

Comment thread
vgvoleg marked this conversation as resolved.

@pytest.mark.asyncio
async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection):
"""Test that when disable_discovery=False, the discovery is created (mock)."""
Expand Down
8 changes: 6 additions & 2 deletions ydb/aio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,12 @@ async def destroy(self, grace: float = 0) -> None:
:param grace:
:return: None
"""
if hasattr(self, "_channel") and hasattr(self._channel, "close"):
await self._channel.close(grace)
channel = getattr(self, "_channel", None)
if channel is not None and hasattr(channel, "close"):
await channel.close(grace)

self._stub_instances.clear()
self._channel = None

def add_cleanup_callback(self, callback: Callable[["Connection"], None]) -> None:
self._cleanup_callbacks.append(callback)
Expand Down
36 changes: 31 additions & 5 deletions ydb/aio/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,36 @@ def __init__(self, driver_config: "DriverConfig") -> None:
self._store = ConnectionsCache(driver_config.use_all_nodes)
self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config)
self._stopped = False
self._stopping = False
self._discovery: Optional[Discovery] = None
self._discovery_task: "asyncio.Task[None]"

if driver_config.disable_discovery:
# If discovery is disabled, just add the initial endpoint to the store
async def init_connection() -> None:
ready_connection = Connection(self._driver_config.endpoint, self._driver_config)
await ready_connection.connection_ready(
ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10)
)
self._store.add(ready_connection)
ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10)
while not self._stopping:
ready_connection = Connection(self._driver_config.endpoint, self._driver_config)
try:
await ready_connection.connection_ready(ready_timeout=ready_timeout)
except asyncio.CancelledError:
try:
await ready_connection.close()
except Exception:
logger.debug("Failed to close cancelled initial connection", exc_info=True)
raise
except Exception:
logger.debug("Initial connection attempt failed", exc_info=True)
try:
await ready_connection.close()
except Exception:
logger.debug("Failed to close unsuccessful initial connection", exc_info=True)
if not self._stopping:
await asyncio.sleep(1)
continue

self._store.add(ready_connection)
return

# Create and schedule the task to initialize the connection
self._discovery_task = asyncio.get_event_loop().create_task(init_connection())
Expand All @@ -268,13 +287,20 @@ async def init_connection() -> None:
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())

async def stop(self, timeout: int = 10) -> None: # type: ignore[override] # async override of sync method
self._stopping = True
if self._discovery:
self._discovery.stop()
await self._grpc_init.close()
try:
await asyncio.wait_for(self._discovery_task, timeout=timeout)
except asyncio.TimeoutError:
self._discovery_task.cancel()
try:
await self._discovery_task
except asyncio.CancelledError:
pass
if self._discovery is None:
await self._store.cleanup()
self._stopped = True

def _on_disconnected(self, connection: Connection) -> Callable[[], Any]:
Expand Down
8 changes: 6 additions & 2 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,12 @@ def close(self):
self.destroy()

def destroy(self):
if hasattr(self, "_channel") and hasattr(self._channel, "close"):
self._channel.close()
channel = getattr(self, "_channel", None)
if channel is not None and hasattr(channel, "close"):
channel.close()

self._stub_instances.clear()
self._channel = None

def ready_future(self):
"""
Expand Down
41 changes: 32 additions & 9 deletions ydb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,23 +401,41 @@ def __init__(self, driver_config: "DriverConfig") -> None:
self._store = ConnectionsCache(driver_config.use_all_nodes, driver_config.tracer)
self.tracer = driver_config.tracer
self._grpc_init = connection_impl.Connection(self._driver_config.endpoint, self._driver_config)
self._stopped = False
self._stop_guard = threading.Lock()
self._stop_event = threading.Event()
self._init_thread: Optional[threading.Thread] = None

if driver_config.disable_discovery:
# If discovery is disabled, just add the initial endpoint to the store
ready_connection = connection_impl.Connection.ready_factory(
self._driver_config.endpoint,
self._driver_config,
ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10),
)
self._store.add(ready_connection)
# If discovery is disabled, establish the initial connection in a
# background thread, retrying until it succeeds or the pool is stopped.
# Doing this off the constructor keeps wait(timeout) as the blocking
# point and lets stop() interrupt the retry loop.
self._discovery_thread = None
self._init_thread = threading.Thread(
name="ydb_driver_initial_connection",
target=self._init_connection,
daemon=True,
)
self._init_thread.start()
else:
# Start discovery thread as usual
self._discovery_thread = Discovery(self._store, self._driver_config)
self._discovery_thread.start()

self._stopped = False
self._stop_guard = threading.Lock()
def _init_connection(self) -> None:
ready_timeout = getattr(self._driver_config, "discovery_request_timeout", 10)
while not self._stopped:
ready_connection = connection_impl.Connection.ready_factory(
self._driver_config.endpoint,
self._driver_config,
ready_timeout=ready_timeout,
)
if self._store.add(ready_connection):
return

logger.debug("Initial connection attempt failed")
self._stop_event.wait(1)

def stop(self, timeout: int = 10) -> None:
"""
Expand All @@ -431,11 +449,16 @@ def stop(self, timeout: int = 10) -> None:
return

self._stopped = True
self._stop_event.set()
if self._discovery_thread:
self._discovery_thread.stop()
self._grpc_init.close()
if self._discovery_thread:
self._discovery_thread.join(timeout)
if self._init_thread:
self._init_thread.join(timeout)
if self._discovery_thread is None:
self._store.cleanup()
Comment thread
UgnineSirdis marked this conversation as resolved.

def async_wait(self, fail_fast: bool = False) -> "futures.Future[None]":
"""
Expand Down
Loading