diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 580eb336b2..f92bb53785 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -239,7 +239,8 @@ class MockSession(MagicMock): def __init__(self, *args, **kwargs): super(MockSession, self).__init__(*args, **kwargs) self.cluster = MagicMock() - self.cluster.executor = ThreadPoolExecutor(max_workers=2, initializer=self.executor_init) + self.connection_created = Event() + self.cluster.executor = ThreadPoolExecutor(max_workers=2) self.cluster.signal_connection_failure = lambda *args, **kwargs: False self.cluster.connection_factory = self.mock_connection_factory self.connection_counter = 0 @@ -259,23 +260,30 @@ def mock_connection_factory(self, *args, **kwargs): partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port="", shard_aware_port_ssl="")) self.connection_counter += 1 + self.connection_created.set() return connection - def executor_init(self, *args): - time.sleep(0.5) - LOGGER.info("Future start: %s", args) - - for attempt_num in range(20): - LOGGER.info("Testing fast shutdown %d / 20 times", attempt_num + 1) + for attempt_num in range(3): + LOGGER.info("Testing fast shutdown %d / 3 times", attempt_num + 1) host = MagicMock() host.endpoint = "1.2.3.4" - session = self.make_session() + session = MockSession() pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) LOGGER.info("Initialized pool %s", pool) + + # Wait for initial connection to be created (with timeout) + if not session.connection_created.wait(timeout=2.0): + pytest.fail("Initial connection failed to be created within 2 seconds") + LOGGER.info("Connections: %s", pool._connections) - time.sleep(0.5) + + # Shutdown the pool pool.shutdown() - time.sleep(3) - session.cluster.executor.shutdown() + + # Verify pool is shut down + assert pool.is_shutdown, "Pool should be marked as shutdown" + + # Cleanup executor with proper wait + session.cluster.executor.shutdown(wait=True)