@@ -550,6 +550,8 @@ def __init__(
550550 self .lock_disconnect = asyncio .Lock ()
551551 self .lock_propagate = asyncio .Lock ()
552552
553+ self ._background_tasks = set ()
554+
553555 # Initialize the node's connection and event tracking structures
554556 self .edges = []
555557 self .ping_events = {}
@@ -587,14 +589,14 @@ def __init__(
587589
588590 # If the run flag is set to True, create and start the main event loop task for the node
589591 if run :
590- asyncio . create_task (self .run ())
592+ self . track_task (self .run ())
591593
592594 # Request an SSL certificate for secure communication if specified
593595 if request_ssl_certificate :
594596 loop = asyncio .get_event_loop ()
595597 loop .call_later (
596598 1 ,
597- lambda : asyncio . create_task (
599+ lambda : self . track_task (
598600 self .request_ssl_certificate (request_ssl_certificate )
599601 ),
600602 )
@@ -746,6 +748,10 @@ async def stop(self) -> None:
746748 except asyncio .TimeoutError :
747749 logger_main .warning ("Timeout waiting for server to close." )
748750
751+ for task in self ._background_tasks :
752+ task .cancel ()
753+ await asyncio .gather (* self ._background_tasks , return_exceptions = True )
754+
749755 async def _connect_to_peer (
750756 self ,
751757 node : "ChaskiNode" ,
@@ -819,7 +825,8 @@ async def _connect_to_peer(
819825
820826 # Log new connection
821827 logger_main .debug (f"{ self .name } : New connection with { edge .address } ." )
822- asyncio .create_task (self ._reader_loop (edge ))
828+ self .track_task (self ._reader_loop (edge ))
829+
823830 await self .handshake (edge , response = True )
824831 return edge
825832
@@ -1158,7 +1165,7 @@ async def _connected(
11581165 f"{ self .name } : Accepted connection from { writer .get_extra_info ('peername' )} ."
11591166 )
11601167 logger_main .debug (f"{ self .name } : New connection with { edge .address } ." )
1161- asyncio . create_task (self ._reader_loop (edge ))
1168+ self . track_task (self ._reader_loop (edge ))
11621169
11631170 # # If there are no edges (connections) yet, designate this node as the root node
11641171 # if not self.edges:
@@ -1380,7 +1387,7 @@ async def _start_tcp_server(self) -> None:
13801387 # Logging the server address and starting keep-alive task
13811388 addr = self .server .sockets [0 ].getsockname ()
13821389 logger_main .debug (f"{ self .name } : Serving at address { addr } ." )
1383- self ._keep_alive_task = asyncio . create_task (self ._keep_alive ())
1390+ self ._keep_alive_task = self . track_task (self ._keep_alive ())
13841391
13851392 # Start serving TCP connections forever
13861393 async with self .server :
@@ -2592,7 +2599,13 @@ async def request_ssl_certificate(self, ca_address: str) -> None:
25922599
25932600 # Restart the node by stopping the current event loop and then creating a new event loop task to run the node.
25942601 await self .stop ()
2595- asyncio . create_task (self .run ())
2602+ self . track_task (self .run ())
25962603
25972604 if not (self .ssl_context_client and self .ssl_context_server ):
25982605 raise Exception ("Failed to create SSL contexts" )
2606+
2607+ def track_task (self , coro ):
2608+ task = asyncio .create_task (coro )
2609+ self ._background_tasks .add (task )
2610+ task .add_done_callback (self ._background_tasks .discard )
2611+ return task
0 commit comments