Skip to content

Commit fa9f435

Browse files
committed
Working on Actions and Testing
1 parent 0a93ee0 commit fa9f435

1 file changed

Lines changed: 19 additions & 6 deletions

File tree

chaski/node.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ def __init__(
550550
self.lock_disconnect = asyncio.Lock()
551551
self.lock_propagate = asyncio.Lock()
552552

553+
self._background_tasks = set()
554+
553555
# Initialize the node's connection and event tracking structures
554556
self.edges = []
555557
self.ping_events = {}
@@ -587,14 +589,14 @@ def __init__(
587589

588590
# If the run flag is set to True, create and start the main event loop task for the node
589591
if run:
590-
asyncio.create_task(self.run())
592+
self.track_task(self.run())
591593

592594
# Request an SSL certificate for secure communication if specified
593595
if request_ssl_certificate:
594596
loop = asyncio.get_event_loop()
595597
loop.call_later(
596598
1,
597-
lambda: asyncio.create_task(
599+
lambda: self.track_task(
598600
self.request_ssl_certificate(request_ssl_certificate)
599601
),
600602
)
@@ -746,6 +748,10 @@ async def stop(self) -> None:
746748
except asyncio.TimeoutError:
747749
logger_main.warning("Timeout waiting for server to close.")
748750

751+
for task in self._background_tasks:
752+
task.cancel()
753+
await asyncio.gather(*self._background_tasks, return_exceptions=True)
754+
749755
async def _connect_to_peer(
750756
self,
751757
node: "ChaskiNode",
@@ -819,7 +825,8 @@ async def _connect_to_peer(
819825

820826
# Log new connection
821827
logger_main.debug(f"{self.name}: New connection with {edge.address}.")
822-
asyncio.create_task(self._reader_loop(edge))
828+
self.track_task(self._reader_loop(edge))
829+
823830
await self.handshake(edge, response=True)
824831
return edge
825832

@@ -1158,7 +1165,7 @@ async def _connected(
11581165
f"{self.name}: Accepted connection from {writer.get_extra_info('peername')}."
11591166
)
11601167
logger_main.debug(f"{self.name}: New connection with {edge.address}.")
1161-
asyncio.create_task(self._reader_loop(edge))
1168+
self.track_task(self._reader_loop(edge))
11621169

11631170
# # If there are no edges (connections) yet, designate this node as the root node
11641171
# if not self.edges:
@@ -1380,7 +1387,7 @@ async def _start_tcp_server(self) -> None:
13801387
# Logging the server address and starting keep-alive task
13811388
addr = self.server.sockets[0].getsockname()
13821389
logger_main.debug(f"{self.name}: Serving at address {addr}.")
1383-
self._keep_alive_task = asyncio.create_task(self._keep_alive())
1390+
self._keep_alive_task = self.track_task(self._keep_alive())
13841391

13851392
# Start serving TCP connections forever
13861393
async with self.server:
@@ -2592,7 +2599,13 @@ async def request_ssl_certificate(self, ca_address: str) -> None:
25922599

25932600
# Restart the node by stopping the current event loop and then creating a new event loop task to run the node.
25942601
await self.stop()
2595-
asyncio.create_task(self.run())
2602+
self.track_task(self.run())
25962603

25972604
if not (self.ssl_context_client and self.ssl_context_server):
25982605
raise Exception("Failed to create SSL contexts")
2606+
2607+
def track_task(self, coro):
2608+
task = asyncio.create_task(coro)
2609+
self._background_tasks.add(task)
2610+
task.add_done_callback(self._background_tasks.discard)
2611+
return task

0 commit comments

Comments
 (0)