Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 27 additions & 83 deletions hazelcast/internal/asyncio_reactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import ssl
import time
from errno import errorcode
from asyncio import AbstractEventLoop, transports
from asyncio import AbstractEventLoop, transports, CancelledError

from hazelcast.config import Config, SSLProtocol
from hazelcast.errors import TargetDisconnectedError
from hazelcast.internal.asyncio_connection import Connection
from hazelcast.core import Address

Expand Down Expand Up @@ -69,10 +70,9 @@ def __init__(
self._preconn_buffers: list = []
self._create_task: asyncio.Task | None = None
self._close_task: asyncio.Task | None = None
self._connect_timer_task: asyncio.Task | None = None
self._connect_task: asyncio.Task | None = None
self._connected = False
self._receive_buffer_size = _BUFFER_SIZE
self._sock = None

@classmethod
def create_and_connect(
Expand All @@ -97,86 +97,36 @@ def _create_protocol(self):
return HazelcastProtocol(self)

async def _create_connection(self, config, address):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
sock.settimeout(0)
self._set_socket_options(sock, config)
server_hostname = None
ssl_context = None
if config.ssl_enabled:
server_hostname = address.host
ssl_context = self._create_ssl_context(config)

try:
self.connect(sock, (address.host, address.port))
except socket.error as e:
self._inner_close()
raise e

self._sock = sock

res = await self._loop.create_connection(
self._create_protocol,
ssl=ssl_context,
server_hostname=server_hostname,
sock=sock,
)
self._connected = True

try:
sock.getpeername()
except OSError as err:
if err.errno not in (errno.ENOTCONN, errno.EINVAL):
raise
self._connected = False

sock, self._proto = res
sock = sock.get_extra_info("socket")
sockname = sock.getsockname()
host, port = sockname[0], sockname[1]
self.local_address = Address(host, port)
self._connect_timer_task = None
if not self._connected:
self._connect_timer_task = self._loop.create_task(
self._connect_retry_cb(0.01, self._sock, (address.host, address.port))
async def inner():
if not self.live:
return
res = await self._loop.create_connection(
self._create_protocol,
host=address.host,
port=address.port,
ssl=ssl_context,
server_hostname=server_hostname,
)

async def _connect_retry_cb(self, timeout, sock, address):
await asyncio.sleep(timeout)
if self._connected and self._close_task:
self._close_task.cancel()
return
self._connected = True
sock, self._proto = res
sock = sock.get_extra_info("socket")
sockname = sock.getsockname()
host, port = sockname[0], sockname[1]
self.local_address = Address(host, port)
self._connect_task = None
self.handle_connect()

self._connect_task = asyncio.create_task(inner())
try:
self.connect(sock, address)
except Exception:
# close task will handle closing the connection
return
if not self._connected:
self._connect_timer_task = self._loop.create_task(
self._connect_retry_cb(timeout, sock, address)
)
elif self._close_task:
self._close_task.cancel()

def connect(self, sock, address):
self._connected = False
err = sock.connect_ex(address)
if (
err in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK)
or err == errno.EINVAL
and os.name == "nt"
):
return
if err in (0, errno.EISCONN):
self.handle_connect_event(sock)
else:
raise OSError(err, errorcode[err])

def handle_connect_event(self, sock):
err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise OSError(err, _strerror(err))
self.handle_connect()
await self._connect_task
except CancelledError:
raise TargetDisconnectedError("connect_task")

def handle_connect(self):
self._connected = True
Expand All @@ -194,8 +144,8 @@ def handle_connect(self):
async def _close_timer_cb(self, timeout):
await asyncio.sleep(timeout)
if not self._connected:
if self._connect_timer_task:
self._connect_timer_task.cancel()
if self._connect_task:
self._connect_task.cancel()
await self.close_connection(None, IOError("Connection timed out"))

def _write(self, buf):
Expand All @@ -210,12 +160,6 @@ def _inner_close(self):
if self._proto:
self._proto.close()
self._connected = False
if self._sock:
try:
self._sock.close()
except OSError as why:
if why.errno not in (errno.ENOTCONN, errno.EBADF):
raise

def _update_read_time(self, time):
self.last_read_time = time
Expand Down
Loading