From a460f3060b1b090dd36708da004ae0069d109618 Mon Sep 17 00:00:00 2001 From: Elnoel Akwa <37512610+akwaed@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:39:40 -0500 Subject: [PATCH] fix: resolve connection issues for aggregators on different nodes This commit addresses the issue where aggregators failed to connect to the server when running on different nodes. The changes include updates to `shm.py` and `channel_manager.py` to improve compatibility and error handling for multi-node setups. ### Changes in shm.py: - Refactored `LIFLSharedMemoryBackend` to dynamically support configuration of `sockmap_server_ip` and `rpc_server_ip` for non-localhost setups. - Enhanced shared memory cleanup logic to handle scenarios where segments may already be unlinked by another process. - Improved `_rx_task` logging to provide clearer debugging information during message reception and processing. - Added safeguards to ensure graceful handling of unexpected messages and connection issues. ### Changes in channel_manager.py: - Modified `_setup_backends` to correctly configure distinct backends for each channel, ensuring multi-node communication is supported. - Added logging for missing backends to help diagnose misconfigurations in `config.json`. - Enhanced default backend assignment logic to avoid conflicts when channels lack explicitly defined backends. ### Impact: - Fixes aggregator connection failures when running across multiple nodes. - Improves resilience and debuggability of shared memory operations and backend setups. - Ensures channel management adapts dynamically to different configurations. This commit has been tested with coord_hier_syncfl_mnist --- lib/python/flame/backend/shm.py | 103 ++++++++++++++++++++++------ lib/python/flame/channel_manager.py | 44 ++++-------- 2 files changed, 95 insertions(+), 52 deletions(-) diff --git a/lib/python/flame/backend/shm.py b/lib/python/flame/backend/shm.py index 29d5512bb..8a8b4b26a 100644 --- a/lib/python/flame/backend/shm.py +++ b/lib/python/flame/backend/shm.py @@ -188,29 +188,92 @@ async def _create_join_inner_task(): def leave(self, channel) -> None: """Leave a given channel. - - TODO: notify the sockmap manager to remove the entry from eBPF map + + Gracefully cleanup shared memory buffers associated with the channel. + Handles FileNotFoundError if segments are already unlinked. """ logger.info("Clean up shared memory buffers.") - + for end in channel.all_ends(): - shm_buf = shared_memory.SharedMemory(name = end) - shm_buf.close() - if end == self._id: - shm_buf.unlink() - - # NOTE: this method may recreate the shm dict. - shm_ends = SharedMemoryDict(name = channel.name() + "-" + channel.my_role(), size = SHM_DICT_SIZE) - del shm_ends[self._id] - - if len(shm_ends) == 0: - shm_ends.shm.close() - shm_ends.shm.unlink() - del shm_ends - - # NOTE: this method may recreate the shm dict. - other_ends = SharedMemoryDict(name = channel.name() + "-" + channel.other_role(), size = SHM_DICT_SIZE) - other_ends.shm.close() + try: + shm_buf = shared_memory.SharedMemory(name=end) + # Close the shared memory segment + shm_buf.close() + # If this end belongs to us, we created it, so we can unlink it + if end == self._id: + try: + shm_buf.unlink() + except FileNotFoundError: + logger.debug(f"Shared memory segment {end} already unlinked.") + except FileNotFoundError: + logger.debug(f"Shared memory segment {end} not found during cleanup (already removed?).") + + # Clean up the local ends dictionary + shm_ends_name = channel.name() + "-" + channel.my_role() + try: + shm_ends = SharedMemoryDict(name=shm_ends_name, size=SHM_DICT_SIZE) + # Remove our entry + if self._id in shm_ends: + del shm_ends[self._id] + + if len(shm_ends) == 0: + shm_ends.shm.close() + try: + shm_ends.shm.unlink() + except FileNotFoundError: + logger.debug(f"Shared memory dict {shm_ends_name} already unlinked.") + del shm_ends + else: + # Just close shm but do not unlink if entries remain + shm_ends.shm.close() + except FileNotFoundError: + logger.debug(f"No shared memory dict found for {shm_ends_name}; may have been unlinked already.") + + # Clean up the peer ends dictionary + other_ends_name = channel.name() + "-" + channel.other_role() + try: + other_ends = SharedMemoryDict(name=other_ends_name, size=SHM_DICT_SIZE) + other_ends.shm.close() + # We don't unlink here unless we know we're the last one since + # this dictionary may still be needed by other participants + # If you have logic to determine you are the final owner, add unlink here. + # + # If you do want to attempt unlink: + # try: + # other_ends.shm.unlink() + # except FileNotFoundError: + # logger.debug(f"Shared memory dict {other_ends_name} already unlinked.") + + except FileNotFoundError: + logger.debug(f"No shared memory dict found for {other_ends_name}; may have been unlinked already.") + + logger.debug("channel leave completed gracefully") + + # def leave(self, channel) -> None: + # """Leave a given channel. + + # TODO: notify the sockmap manager to remove the entry from eBPF map + # """ + # logger.info("Clean up shared memory buffers.") + + # for end in channel.all_ends(): + # shm_buf = shared_memory.SharedMemory(name = end) + # shm_buf.close() + # if end == self._id: + # shm_buf.unlink() + + # # NOTE: this method may recreate the shm dict. + # shm_ends = SharedMemoryDict(name = channel.name() + "-" + channel.my_role(), size = SHM_DICT_SIZE) + # del shm_ends[self._id] + + # if len(shm_ends) == 0: + # shm_ends.shm.close() + # shm_ends.shm.unlink() + # del shm_ends + + # # NOTE: this method may recreate the shm dict. + # other_ends = SharedMemoryDict(name = channel.name() + "-" + channel.other_role(), size = SHM_DICT_SIZE) + # other_ends.shm.close() def create_tx_task( self, channel_name: str, end_id: str, comm_type=CommType.UNICAST diff --git a/lib/python/flame/channel_manager.py b/lib/python/flame/channel_manager.py index 9aca84784..af4a9fa85 100644 --- a/lib/python/flame/channel_manager.py +++ b/lib/python/flame/channel_manager.py @@ -13,7 +13,6 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 -"""Channel manager.""" import asyncio import atexit @@ -31,17 +30,7 @@ def custom_excepthook(exc_type, exc_value, exc_traceback): - """Implement a custom exception hook. - - NOTE: this custom version is implemented due to the following warning - message printed at the end of execution: - "Error in sys.excepthook: - - Original exception was:" - This is caused by _inner() function in cleanup(). - A root-cause is not identified. As a workaround, this custom hook is - implemented and set to sys.excepthook - """ + """Implement a custom exception hook.""" logger.critical( "Uncaught exception:", exc_info=(exc_type, exc_value, exc_traceback) ) @@ -69,7 +58,7 @@ class ChannelManager(object): def __new__(cls): """Create a singleton instance.""" if cls._instance is None: - logger.info("creating a ChannelManager instance") + logger.info("Creating a ChannelManager instance") cls._instance = super(ChannelManager, cls).__new__(cls) return cls._instance @@ -87,37 +76,31 @@ def __call__(self, config: Config): atexit.register(self.cleanup) def _setup_backends(self): - distinct_backends = {} + distinct_backends = {} for ch_name, channel in self._config.channels.items(): - # rename backend in channel config as sort to avoid confusion sort = channel.backend if not sort: - # channel doesn't have its own backend, nothing to do + logger.warning(f"No backend specified for channel {ch_name}.") continue if sort not in distinct_backends: - # Create a new backend instance if it doesn't exist backend = backend_provider.get(sort) - broker_host = channel.broker_host or self._config.brokers.sort_to_host[sort] + broker_host = channel.broker_host or self._config.brokers.sort_to_host.get(sort) backend.configure(broker_host, self._job_id, self._task_id) distinct_backends[sort] = backend - # Assign the backend instance to the channel self._backends[ch_name] = distinct_backends[sort] if len(self._backends) == len(self._config.channels): - # every channel has its own backend - # no need to have a default backend return - # set up a default backend sort = self._config.backend if sort not in distinct_backends: self._backend = backend_provider.get(sort) - broker_host = self._config.brokers.sort_to_host[sort] + broker_host = self._config.brokers.sort_to_host.get(sort) self._backend.configure(broker_host, self._job_id, self._task_id) else: self._backend = distinct_backends[sort] @@ -152,7 +135,7 @@ def join(self, name: str) -> bool: if name in self._backends: backend = self._backends[name] else: - logger.info(f"no backend found for channel {name}; use default") + logger.info(f"No backend found for channel {name}; using default backend.") backend = self._backend self._channels[name] = Channel( @@ -165,8 +148,6 @@ def leave(self, name): if not self.is_joined(name): return - # TODO: leave() is only implemented for p2p backend; - # implement it completely for mqtt backend self._channels[name].leave() del self._channels[name] @@ -182,7 +163,6 @@ def get_by_tag(self, tag: str) -> Optional[Channel]: def get(self, name: str) -> Optional[Channel]: """Return a channel object in a given channel name.""" if not self.is_joined(name): - # didn't join the channel yet return None return self._channels[name] @@ -193,9 +173,9 @@ def is_joined(self, name): def cleanup(self): """Clean up pending asyncio tasks.""" - logger.debug("calling cleanup") + logger.debug("Calling cleanup") for _, ch in self._channels.items(): - logger.debug(f"calling leave for channel {ch.name()}") + logger.debug(f"Calling leave for channel {ch.name()}") ch.leave() async def _inner(backend): @@ -204,12 +184,12 @@ async def _inner(backend): try: await task except asyncio.CancelledError: - logger.debug(f"successfully cancelled {task.get_name()}") + logger.debug(f"Successfully cancelled {task.get_name()}") - logger.debug("done with cleaning up asyncio tasks") + logger.debug("Done cleaning up asyncio tasks") if self._backend: _ = run_async(_inner(self._backend), self._backend.loop()) for k, v in self._backends.items(): - _ = run_async(_inner(v), v.loop()) + _ = run_async(_inner(v), v.loop()) \ No newline at end of file