diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index 38c2060cf2..f5c7a1ad33 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -7,9 +7,10 @@ import ssl import time from errno import errorcode -from asyncio import AbstractEventLoop, transports +from asyncio import AbstractEventLoop, transports, CancelledError from hazelcast.config import Config, SSLProtocol +from hazelcast.errors import TargetDisconnectedError from hazelcast.internal.asyncio_connection import Connection from hazelcast.core import Address @@ -69,10 +70,9 @@ def __init__( self._preconn_buffers: list = [] self._create_task: asyncio.Task | None = None self._close_task: asyncio.Task | None = None - self._connect_timer_task: asyncio.Task | None = None + self._connect_task: asyncio.Task | None = None self._connected = False self._receive_buffer_size = _BUFFER_SIZE - self._sock = None @classmethod def create_and_connect( @@ -97,86 +97,36 @@ def _create_protocol(self): return HazelcastProtocol(self) async def _create_connection(self, config, address): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setblocking(False) - sock.settimeout(0) - self._set_socket_options(sock, config) server_hostname = None ssl_context = None if config.ssl_enabled: server_hostname = address.host ssl_context = self._create_ssl_context(config) - try: - self.connect(sock, (address.host, address.port)) - except socket.error as e: - self._inner_close() - raise e - - self._sock = sock - - res = await self._loop.create_connection( - self._create_protocol, - ssl=ssl_context, - server_hostname=server_hostname, - sock=sock, - ) - self._connected = True - - try: - sock.getpeername() - except OSError as err: - if err.errno not in (errno.ENOTCONN, errno.EINVAL): - raise - self._connected = False - - sock, self._proto = res - sock = sock.get_extra_info("socket") - sockname = sock.getsockname() - host, port = sockname[0], sockname[1] - self.local_address = Address(host, port) - self._connect_timer_task = None - if not self._connected: - self._connect_timer_task = self._loop.create_task( - self._connect_retry_cb(0.01, self._sock, (address.host, address.port)) + async def inner(): + if not self.live: + return + res = await self._loop.create_connection( + self._create_protocol, + host=address.host, + port=address.port, + ssl=ssl_context, + server_hostname=server_hostname, ) - - async def _connect_retry_cb(self, timeout, sock, address): - await asyncio.sleep(timeout) - if self._connected and self._close_task: - self._close_task.cancel() - return + self._connected = True + sock, self._proto = res + sock = sock.get_extra_info("socket") + sockname = sock.getsockname() + host, port = sockname[0], sockname[1] + self.local_address = Address(host, port) + self._connect_task = None + self.handle_connect() + + self._connect_task = asyncio.create_task(inner()) try: - self.connect(sock, address) - except Exception: - # close task will handle closing the connection - return - if not self._connected: - self._connect_timer_task = self._loop.create_task( - self._connect_retry_cb(timeout, sock, address) - ) - elif self._close_task: - self._close_task.cancel() - - def connect(self, sock, address): - self._connected = False - err = sock.connect_ex(address) - if ( - err in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK) - or err == errno.EINVAL - and os.name == "nt" - ): - return - if err in (0, errno.EISCONN): - self.handle_connect_event(sock) - else: - raise OSError(err, errorcode[err]) - - def handle_connect_event(self, sock): - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - raise OSError(err, _strerror(err)) - self.handle_connect() + await self._connect_task + except CancelledError: + raise TargetDisconnectedError("connect_task") def handle_connect(self): self._connected = True @@ -194,8 +144,8 @@ def handle_connect(self): async def _close_timer_cb(self, timeout): await asyncio.sleep(timeout) if not self._connected: - if self._connect_timer_task: - self._connect_timer_task.cancel() + if self._connect_task: + self._connect_task.cancel() await self.close_connection(None, IOError("Connection timed out")) def _write(self, buf): @@ -210,12 +160,6 @@ def _inner_close(self): if self._proto: self._proto.close() self._connected = False - if self._sock: - try: - self._sock.close() - except OSError as why: - if why.errno not in (errno.ENOTCONN, errno.EBADF): - raise def _update_read_time(self, time): self.last_read_time = time